Merge branch 'main' into neilalexander/purgeroom

This commit is contained in:
Neil Alexander 2022-09-01 09:26:14 +01:00 committed by GitHub
commit 4c3810a76b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
45 changed files with 391 additions and 1605 deletions

View file

@ -1,10 +0,0 @@
# Application Service
This component interfaces with external [Application
Services](https://matrix.org/docs/spec/application_service/unstable.html).
This includes any HTTP endpoints that application services call, as well as talking
to any HTTP endpoints that application services provide themselves.
## Consumers
This component consumes and filters events from the Roomserver Kafka stream, passing on any necessary events to subscribing application services.

View file

@ -18,7 +18,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -28,9 +27,6 @@ import (
"github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/consumers"
"github.com/matrix-org/dendrite/appservice/inthttp" "github.com/matrix-org/dendrite/appservice/inthttp"
"github.com/matrix-org/dendrite/appservice/query" "github.com/matrix-org/dendrite/appservice/query"
"github.com/matrix-org/dendrite/appservice/storage"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/appservice/workers"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -59,57 +55,40 @@ func NewInternalAPI(
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
}, },
} }
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) // Create appserivce query API with an HTTP client that will be used for all
// outbound and inbound requests (inbound only for the internal API)
appserviceQueryAPI := &query.AppServiceQueryAPI{
HTTPClient: client,
Cfg: &base.Cfg.AppServiceAPI,
}
// Create a connection to the appservice postgres DB if len(base.Cfg.Derived.ApplicationServices) == 0 {
appserviceDB, err := storage.NewDatabase(base, &base.Cfg.AppServiceAPI.Database) return appserviceQueryAPI
if err != nil {
logrus.WithError(err).Panicf("failed to connect to appservice db")
} }
// Wrap application services in a type that relates the application service and // Wrap application services in a type that relates the application service and
// a sync.Cond object that can be used to notify workers when there are new // a sync.Cond object that can be used to notify workers when there are new
// events to be sent out. // events to be sent out.
workerStates := make([]types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices)) for _, appservice := range base.Cfg.Derived.ApplicationServices {
for i, appservice := range base.Cfg.Derived.ApplicationServices {
m := sync.Mutex{}
ws := types.ApplicationServiceWorkerState{
AppService: appservice,
Cond: sync.NewCond(&m),
}
workerStates[i] = ws
// Create bot account for this AS if it doesn't already exist // Create bot account for this AS if it doesn't already exist
if err = generateAppServiceAccount(userAPI, appservice); err != nil { if err := generateAppServiceAccount(userAPI, appservice); err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"appservice": appservice.ID, "appservice": appservice.ID,
}).WithError(err).Panicf("failed to generate bot account for appservice") }).WithError(err).Panicf("failed to generate bot account for appservice")
} }
} }
// Create appserivce query API with an HTTP client that will be used for all
// outbound and inbound requests (inbound only for the internal API)
appserviceQueryAPI := &query.AppServiceQueryAPI{
HTTPClient: client,
Cfg: base.Cfg,
}
// Only consume if we actually have ASes to track, else we'll just chew cycles needlessly. // Only consume if we actually have ASes to track, else we'll just chew cycles needlessly.
// We can't add ASes at runtime so this is safe to do. // We can't add ASes at runtime so this is safe to do.
if len(workerStates) > 0 { js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
consumer := consumers.NewOutputRoomEventConsumer( consumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, base.Cfg, js, appserviceDB, base.ProcessContext, &base.Cfg.AppServiceAPI,
rsAPI, workerStates, client, js, rsAPI,
) )
if err := consumer.Start(); err != nil { if err := consumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start appservice roomserver consumer") logrus.WithError(err).Panicf("failed to start appservice roomserver consumer")
}
} }
// Create application service transaction workers
if err := workers.SetupTransactionWorkers(client, appserviceDB, workerStates); err != nil {
logrus.WithError(err).Panicf("failed to start app service transaction workers")
}
return appserviceQueryAPI return appserviceQueryAPI
} }

View file

@ -15,14 +15,18 @@
package consumers package consumers
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"math"
"net/http"
"net/url"
"time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/matrix-org/dendrite/appservice/storage"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
@ -33,177 +37,192 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext cfg *config.AppServiceAPI
durable string client *http.Client
topic string jetstream nats.JetStreamContext
asDB storage.Database topic string
rsAPI api.AppserviceRoomserverAPI rsAPI api.AppserviceRoomserverAPI
serverName string }
workerStates []types.ApplicationServiceWorkerState
type appserviceState struct {
*config.ApplicationService
backoff int
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call
// Start() to begin consuming from room servers. // Start() to begin consuming from room servers.
func NewOutputRoomEventConsumer( func NewOutputRoomEventConsumer(
process *process.ProcessContext, process *process.ProcessContext,
cfg *config.Dendrite, cfg *config.AppServiceAPI,
client *http.Client,
js nats.JetStreamContext, js nats.JetStreamContext,
appserviceDB storage.Database,
rsAPI api.AppserviceRoomserverAPI, rsAPI api.AppserviceRoomserverAPI,
workerStates []types.ApplicationServiceWorkerState,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, cfg: cfg,
durable: cfg.Global.JetStream.Durable("AppserviceRoomserverConsumer"), client: client,
topic: cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent), jetstream: js,
asDB: appserviceDB, topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
rsAPI: rsAPI, rsAPI: rsAPI,
serverName: string(cfg.Global.ServerName),
workerStates: workerStates,
} }
} }
// Start consuming from room servers // Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
return jetstream.JetStreamConsumer( for _, as := range s.cfg.Derived.ApplicationServices {
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, appsvc := as
nats.DeliverAll(), nats.ManualAck(), state := &appserviceState{
) ApplicationService: &appsvc,
}
token := jetstream.Tokenise(as.ID)
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic,
s.cfg.Matrix.JetStream.Durable("Appservice_"+token),
50, // maximum number of events to send in a single transaction
func(ctx context.Context, msgs []*nats.Msg) bool {
return s.onMessage(ctx, state, msgs)
},
nats.DeliverNew(), nats.ManualAck(),
); err != nil {
return fmt.Errorf("failed to create %q consumer: %w", token, err)
}
}
return nil
} }
// onMessage is called when the appservice component receives a new event from // onMessage is called when the appservice component receives a new event from
// the room server output log. // the room server output log.
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(
// Parse out the event JSON ctx context.Context, state *appserviceState, msgs []*nats.Msg,
var output api.OutputEvent ) bool {
if err := json.Unmarshal(msg.Data, &output); err != nil { log.WithField("appservice", state.ID).Tracef("Appservice worker received %d message(s) from roomserver", len(msgs))
// If the message was invalid, log it and move on to the next message in the stream events := make([]*gomatrixserverlib.HeaderedEvent, 0, len(msgs))
log.WithError(err).Errorf("roomserver output log: message parse failure") for _, msg := range msgs {
return true // Parse out the event JSON
} var output api.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithFields(log.Fields{ // If the message was invalid, log it and move on to the next message in the stream
"type": output.Type, log.WithField("appservice", state.ID).WithError(err).Errorf("Appservice failed to parse message, ignoring")
}).Debug("Got a message in OutputRoomEventConsumer") continue
events := []*gomatrixserverlib.HeaderedEvent{}
if output.Type == api.OutputTypeNewRoomEvent && output.NewRoomEvent != nil {
newEventID := output.NewRoomEvent.Event.EventID()
events = append(events, output.NewRoomEvent.Event)
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
eventsReq := &api.QueryEventsByIDRequest{
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
}
eventsRes := &api.QueryEventsByIDResponse{}
for _, eventID := range output.NewRoomEvent.AddsStateEventIDs {
if eventID != newEventID {
eventsReq.EventIDs = append(eventsReq.EventIDs, eventID)
}
}
if len(eventsReq.EventIDs) > 0 {
if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil {
log.WithError(err).Errorf("s.rsAPI.QueryEventsByID failed")
return false
}
events = append(events, eventsRes.Events...)
}
} }
} else if output.Type == api.OutputTypeNewInviteEvent && output.NewInviteEvent != nil { switch output.Type {
events = append(events, output.NewInviteEvent.Event) case api.OutputTypeNewRoomEvent:
} else { if output.NewRoomEvent == nil || !s.appserviceIsInterestedInEvent(ctx, output.NewRoomEvent.Event, state.ApplicationService) {
log.WithFields(log.Fields{ continue
"type": output.Type, }
}).Debug("appservice OutputRoomEventConsumer ignoring event", string(msg.Data)) events = append(events, output.NewRoomEvent.Event)
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
newEventID := output.NewRoomEvent.Event.EventID()
eventsReq := &api.QueryEventsByIDRequest{
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
}
eventsRes := &api.QueryEventsByIDResponse{}
for _, eventID := range output.NewRoomEvent.AddsStateEventIDs {
if eventID != newEventID {
eventsReq.EventIDs = append(eventsReq.EventIDs, eventID)
}
}
if len(eventsReq.EventIDs) > 0 {
if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil {
log.WithError(err).Errorf("s.rsAPI.QueryEventsByID failed")
return false
}
events = append(events, eventsRes.Events...)
}
}
case api.OutputTypeNewInviteEvent:
if output.NewInviteEvent == nil {
continue
}
events = append(events, output.NewInviteEvent.Event)
default:
continue
}
}
// If there are no events selected for sending then we should
// ack the messages so that we don't get sent them again in the
// future.
if len(events) == 0 {
return true return true
} }
// Send event to any relevant application services // Send event to any relevant application services. If we hit
if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { // an error here, return false, so that we negatively ack.
log.WithError(err).Errorf("roomserver output log: filter error") log.WithField("appservice", state.ID).Debugf("Appservice worker sending %d events(s) from roomserver", len(events))
return true return s.sendEvents(ctx, state, events) == nil
}
return true
} }
// filterRoomserverEvents takes in events and decides whether any of them need // sendEvents passes events to the appservice by using the transactions
// to be passed on to an external application service. It does this by checking // endpoint. It will block for the backoff period if necessary.
// each namespace of each registered application service, and if there is a func (s *OutputRoomEventConsumer) sendEvents(
// match, adds the event to the queue for events to be sent to a particular ctx context.Context, state *appserviceState,
// application service.
func (s *OutputRoomEventConsumer) filterRoomserverEvents(
ctx context.Context,
events []*gomatrixserverlib.HeaderedEvent, events []*gomatrixserverlib.HeaderedEvent,
) error { ) error {
for _, ws := range s.workerStates { // Create the transaction body.
for _, event := range events { transaction, err := json.Marshal(
// Check if this event is interesting to this application service gomatrixserverlib.ApplicationServiceTransaction{
if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) { Events: gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll),
// Queue this event to be sent off to the application service },
if err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil { )
log.WithError(err).Warn("failed to insert incoming event into appservices database") if err != nil {
return err return err
} else {
// Tell our worker to send out new messages by updating remaining message
// count and waking them up with a broadcast
ws.NotifyNewEvents()
}
}
}
} }
// TODO: We should probably be more intelligent and pick something not
// in the control of the event. A NATS timestamp header or something maybe.
txnID := events[0].Event.OriginServerTS()
// Send the transaction to the appservice.
// https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid
address := fmt.Sprintf("%s/transactions/%d?access_token=%s", state.URL, txnID, url.QueryEscape(state.HSToken))
req, err := http.NewRequestWithContext(ctx, "PUT", address, bytes.NewBuffer(transaction))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return state.backoffAndPause(err)
}
// If the response was fine then we can clear any backoffs in place and
// report that everything was OK. Otherwise, back off for a while.
switch resp.StatusCode {
case http.StatusOK:
state.backoff = 0
default:
return state.backoffAndPause(fmt.Errorf("received HTTP status code %d from appservice", resp.StatusCode))
}
return nil return nil
} }
// appserviceJoinedAtEvent returns a boolean depending on whether a given // backoff pauses the calling goroutine for a 2^some backoff exponent seconds
// appservice has membership at the time a given event was created. func (s *appserviceState) backoffAndPause(err error) error {
func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool { if s.backoff < 6 {
// TODO: This is only checking the current room state, not the state at s.backoff++
// the event in question. Pretty sure this is what Synapse does too, but
// until we have a lighter way of checking the state before the event that
// doesn't involve state res, then this is probably OK.
membershipReq := &api.QueryMembershipsForRoomRequest{
RoomID: event.RoomID(),
JoinedOnly: true,
} }
membershipRes := &api.QueryMembershipsForRoomResponse{} duration := time.Second * time.Duration(math.Pow(2, float64(s.backoff)))
log.WithField("appservice", s.ID).WithError(err).Errorf("Unable to send transaction to appservice, backing off for %s", duration.String())
// XXX: This could potentially race if the state for the event is not known yet time.Sleep(duration)
// e.g. the event came over federation but we do not have the full state persisted. return err
if err := s.rsAPI.QueryMembershipsForRoom(ctx, membershipReq, membershipRes); err == nil {
for _, ev := range membershipRes.JoinEvents {
var membership gomatrixserverlib.MemberContent
if err = json.Unmarshal(ev.Content, &membership); err != nil || ev.StateKey == nil {
continue
}
if appservice.IsInterestedInUserID(*ev.StateKey) {
return true
}
}
} else {
log.WithFields(log.Fields{
"room_id": event.RoomID(),
}).WithError(err).Errorf("Unable to get membership for room")
}
return false
} }
// appserviceIsInterestedInEvent returns a boolean depending on whether a given // appserviceIsInterestedInEvent returns a boolean depending on whether a given
// event falls within one of a given application service's namespaces. // event falls within one of a given application service's namespaces.
// //
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool {
// No reason to queue events if they'll never be sent to the application switch {
// service case appservice.URL == "":
if appservice.URL == "" {
return false return false
} case appservice.IsInterestedInUserID(event.Sender()):
return true
// Check Room ID and Sender of the event case appservice.IsInterestedInRoomID(event.RoomID()):
if appservice.IsInterestedInUserID(event.Sender()) ||
appservice.IsInterestedInRoomID(event.RoomID()) {
return true return true
} }
@ -224,10 +243,52 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont
} }
} else { } else {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"room_id": event.RoomID(), "appservice": appservice.ID,
"room_id": event.RoomID(),
}).WithError(err).Errorf("Unable to get aliases for room") }).WithError(err).Errorf("Unable to get aliases for room")
} }
// Check if any of the members in the room match the appservice // Check if any of the members in the room match the appservice
return s.appserviceJoinedAtEvent(ctx, event, appservice) return s.appserviceJoinedAtEvent(ctx, event, appservice)
} }
// appserviceJoinedAtEvent returns a boolean depending on whether a given
// appservice has membership at the time a given event was created.
func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool {
// TODO: This is only checking the current room state, not the state at
// the event in question. Pretty sure this is what Synapse does too, but
// until we have a lighter way of checking the state before the event that
// doesn't involve state res, then this is probably OK.
membershipReq := &api.QueryMembershipsForRoomRequest{
RoomID: event.RoomID(),
JoinedOnly: true,
}
membershipRes := &api.QueryMembershipsForRoomResponse{}
// XXX: This could potentially race if the state for the event is not known yet
// e.g. the event came over federation but we do not have the full state persisted.
if err := s.rsAPI.QueryMembershipsForRoom(ctx, membershipReq, membershipRes); err == nil {
for _, ev := range membershipRes.JoinEvents {
switch {
case ev.StateKey == nil:
continue
case ev.Type != gomatrixserverlib.MRoomMember:
continue
}
var membership gomatrixserverlib.MemberContent
err = json.Unmarshal(ev.Content, &membership)
switch {
case err != nil:
continue
case membership.Membership == gomatrixserverlib.Join:
return true
}
}
} else {
log.WithFields(log.Fields{
"appservice": appservice.ID,
"room_id": event.RoomID(),
}).WithError(err).Errorf("Unable to get membership for room")
}
return false
}

View file

@ -33,7 +33,7 @@ const userIDExistsPath = "/users/"
// AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI // AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI
type AppServiceQueryAPI struct { type AppServiceQueryAPI struct {
HTTPClient *http.Client HTTPClient *http.Client
Cfg *config.Dendrite Cfg *config.AppServiceAPI
} }
// RoomAliasExists performs a request to '/room/{roomAlias}' on all known // RoomAliasExists performs a request to '/room/{roomAlias}' on all known

View file

@ -1,30 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error
GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error)
CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error)
UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error
RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error
GetLatestTxnID(ctx context.Context) (int, error)
}

View file

@ -1,256 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const appserviceEventsSchema = `
-- Stores events to be sent to application services
CREATE TABLE IF NOT EXISTS appservice_events (
-- An auto-incrementing id unique to each event in the table
id BIGSERIAL NOT NULL PRIMARY KEY,
-- The ID of the application service the event will be sent to
as_id TEXT NOT NULL,
-- JSON representation of the event
headered_event_json TEXT NOT NULL,
-- The ID of the transaction that this event is a part of
txn_id BIGINT NOT NULL
);
CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
`
const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)"
const updateTxnIDForEventsSQL = "" +
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
const deleteEventsBeforeAndIncludingIDSQL = "" +
"DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2"
const (
// A transaction ID number that no transaction should ever have. Used for
// checking again the default value.
invalidTxnID = -2
)
type eventsStatements struct {
selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt
insertEventStmt *sql.Stmt
updateTxnIDForEventsStmt *sql.Stmt
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
}
func (s *eventsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(appserviceEventsSchema)
if err != nil {
return
}
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
return
}
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil {
return
}
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
return
}
if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil {
return
}
if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil {
return
}
return
}
// selectEventsByApplicationServiceID takes in an application service ID and
// returns a slice of events that need to be sent to that application service,
// as well as an int later used to remove these same events from the database
// once successfully sent to an application service.
func (s *eventsStatements) selectEventsByApplicationServiceID(
ctx context.Context,
applicationServiceID string,
limit int,
) (
txnID, maxID int,
events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool,
err error,
) {
defer func() {
if err != nil {
log.WithFields(log.Fields{
"appservice": applicationServiceID,
}).WithError(err).Fatalf("appservice unable to select new events to send")
}
}()
// Retrieve events from the database. Unsuccessfully sent events first
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
if err != nil {
return
}
defer checkNamedErr(eventRows.Close, &err)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
if err != nil {
return
}
return
}
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
}
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
// Iterate through each row and store event contents
// If txn_id changes dramatically, we've switched from collecting old events to
// new ones. Send back those events first.
lastTxnID := invalidTxnID
for eventsProcessed := 0; eventRows.Next(); {
var event gomatrixserverlib.HeaderedEvent
var eventJSON []byte
var id int
err = eventRows.Scan(
&id,
&eventJSON,
&txnID,
)
if err != nil {
return nil, 0, 0, false, err
}
// Unmarshal eventJSON
if err = json.Unmarshal(eventJSON, &event); err != nil {
return nil, 0, 0, false, err
}
// If txnID has changed on this event from the previous event, then we've
// reached the end of a transaction's events. Return only those events.
if lastTxnID > invalidTxnID && lastTxnID != txnID {
return events, maxID, lastTxnID, true, nil
}
lastTxnID = txnID
// Limit events that aren't part of an old transaction
if txnID == -1 {
// Return if we've hit the limit
if eventsProcessed++; eventsProcessed > limit {
return events, maxID, lastTxnID, true, nil
}
}
if id > maxID {
maxID = id
}
// Portion of the event that is unsigned due to rapid change
// TODO: Consider removing age as not many app services use it
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
return nil, 0, 0, false, err
}
events = append(events, event)
}
return
}
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service
// IDs into the db.
func (s *eventsStatements) countEventsByApplicationServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
var count int
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return count, nil
}
// insertEvent inserts an event mapped to its corresponding application service
// IDs into the db.
func (s *eventsStatements) insertEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) (err error) {
// Convert event to JSON before inserting
eventJSON, err := json.Marshal(event)
if err != nil {
return err
}
_, err = s.insertEventStmt.ExecContext(
ctx,
appServiceID,
eventJSON,
-1, // No transaction ID yet
)
return
}
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
// before sending them to an AppService. Referenced before sending to make sure
// we aren't constructing multiple transactions with the same events.
func (s *eventsStatements) updateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) (err error) {
_, err = s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
return
}
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) (err error) {
_, err = s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
return
}

View file

@ -1,115 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
// Import postgres database driver
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores events intended to be later sent to application services
type Database struct {
events eventsStatements
txnID txnStatements
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*Database, error) {
var result Database
var err error
if result.db, result.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
return nil, err
}
return &result, nil
}
func (d *Database) prepare() error {
if err := d.events.prepare(d.db); err != nil {
return err
}
return d.txnID.prepare(d.db)
}
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
// for a transaction worker to pull and later send to an application service.
func (d *Database) StoreEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) error {
return d.events.insertEvent(ctx, appServiceID, event)
}
// GetEventsWithAppServiceID returns a slice of events and their IDs intended to
// be sent to an application service given its ID.
func (d *Database) GetEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
limit int,
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
}
// CountEventsWithAppServiceID returns the number of events destined for an
// application service given its ID.
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
}
// UpdateTxnIDForEvents takes in an application service ID and a
// and stores them in the DB, unless the pair already exists, in
// which case it updates them.
func (d *Database) UpdateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) error {
return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID)
}
// RemoveEventsBeforeAndIncludingID removes all events from the database that
// are less than or equal to a given maximum ID. IDs here are implemented as a
// serial, thus this should always delete events in chronological order.
func (d *Database) RemoveEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) error {
return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID)
}
// GetLatestTxnID returns the latest available transaction id
func (d *Database) GetLatestTxnID(
ctx context.Context,
) (int, error) {
return d.txnID.selectTxnID(ctx)
}

View file

@ -1,53 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
)
const txnIDSchema = `
-- Keeps a count of the current transaction ID
CREATE SEQUENCE IF NOT EXISTS txn_id_counter START 1;
`
const selectTxnIDSQL = "SELECT nextval('txn_id_counter')"
type txnStatements struct {
selectTxnIDStmt *sql.Stmt
}
func (s *txnStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(txnIDSchema)
if err != nil {
return
}
if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil {
return
}
return
}
// selectTxnID selects the latest ascending transaction ID
func (s *txnStatements) selectTxnID(
ctx context.Context,
) (txnID int, err error) {
err = s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
return
}

View file

@ -1,267 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const appserviceEventsSchema = `
-- Stores events to be sent to application services
CREATE TABLE IF NOT EXISTS appservice_events (
-- An auto-incrementing id unique to each event in the table
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The ID of the application service the event will be sent to
as_id TEXT NOT NULL,
-- JSON representation of the event
headered_event_json TEXT NOT NULL,
-- The ID of the transaction that this event is a part of
txn_id INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
`
const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)"
const updateTxnIDForEventsSQL = "" +
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
const deleteEventsBeforeAndIncludingIDSQL = "" +
"DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2"
const (
// A transaction ID number that no transaction should ever have. Used for
// checking again the default value.
invalidTxnID = -2
)
type eventsStatements struct {
db *sql.DB
writer sqlutil.Writer
selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt
insertEventStmt *sql.Stmt
updateTxnIDForEventsStmt *sql.Stmt
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
}
func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = writer
_, err = db.Exec(appserviceEventsSchema)
if err != nil {
return
}
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
return
}
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil {
return
}
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
return
}
if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil {
return
}
if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil {
return
}
return
}
// selectEventsByApplicationServiceID takes in an application service ID and
// returns a slice of events that need to be sent to that application service,
// as well as an int later used to remove these same events from the database
// once successfully sent to an application service.
func (s *eventsStatements) selectEventsByApplicationServiceID(
ctx context.Context,
applicationServiceID string,
limit int,
) (
txnID, maxID int,
events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool,
err error,
) {
defer func() {
if err != nil {
log.WithFields(log.Fields{
"appservice": applicationServiceID,
}).WithError(err).Fatalf("appservice unable to select new events to send")
}
}()
// Retrieve events from the database. Unsuccessfully sent events first
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
if err != nil {
return
}
defer checkNamedErr(eventRows.Close, &err)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
if err != nil {
return
}
return
}
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
}
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
// Iterate through each row and store event contents
// If txn_id changes dramatically, we've switched from collecting old events to
// new ones. Send back those events first.
lastTxnID := invalidTxnID
for eventsProcessed := 0; eventRows.Next(); {
var event gomatrixserverlib.HeaderedEvent
var eventJSON []byte
var id int
err = eventRows.Scan(
&id,
&eventJSON,
&txnID,
)
if err != nil {
return nil, 0, 0, false, err
}
// Unmarshal eventJSON
if err = json.Unmarshal(eventJSON, &event); err != nil {
return nil, 0, 0, false, err
}
// If txnID has changed on this event from the previous event, then we've
// reached the end of a transaction's events. Return only those events.
if lastTxnID > invalidTxnID && lastTxnID != txnID {
return events, maxID, lastTxnID, true, nil
}
lastTxnID = txnID
// Limit events that aren't part of an old transaction
if txnID == -1 {
// Return if we've hit the limit
if eventsProcessed++; eventsProcessed > limit {
return events, maxID, lastTxnID, true, nil
}
}
if id > maxID {
maxID = id
}
// Portion of the event that is unsigned due to rapid change
// TODO: Consider removing age as not many app services use it
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
return nil, 0, 0, false, err
}
events = append(events, event)
}
return
}
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service
// IDs into the db.
func (s *eventsStatements) countEventsByApplicationServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
var count int
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return count, nil
}
// insertEvent inserts an event mapped to its corresponding application service
// IDs into the db.
func (s *eventsStatements) insertEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) (err error) {
// Convert event to JSON before inserting
eventJSON, err := json.Marshal(event)
if err != nil {
return err
}
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.insertEventStmt.ExecContext(
ctx,
appServiceID,
eventJSON,
-1, // No transaction ID yet
)
return err
})
}
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
// before sending them to an AppService. Referenced before sending to make sure
// we aren't constructing multiple transactions with the same events.
func (s *eventsStatements) updateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
return err
})
}
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
return err
})
}

View file

@ -1,114 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
// Import SQLite database driver
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores events intended to be later sent to application services
type Database struct {
events eventsStatements
txnID txnStatements
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*Database, error) {
var result Database
var err error
if result.db, result.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
return nil, err
}
return &result, nil
}
func (d *Database) prepare() error {
if err := d.events.prepare(d.db, d.writer); err != nil {
return err
}
return d.txnID.prepare(d.db, d.writer)
}
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
// for a transaction worker to pull and later send to an application service.
func (d *Database) StoreEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) error {
return d.events.insertEvent(ctx, appServiceID, event)
}
// GetEventsWithAppServiceID returns a slice of events and their IDs intended to
// be sent to an application service given its ID.
func (d *Database) GetEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
limit int,
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
}
// CountEventsWithAppServiceID returns the number of events destined for an
// application service given its ID.
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
}
// UpdateTxnIDForEvents takes in an application service ID and a
// and stores them in the DB, unless the pair already exists, in
// which case it updates them.
func (d *Database) UpdateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) error {
return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID)
}
// RemoveEventsBeforeAndIncludingID removes all events from the database that
// are less than or equal to a given maximum ID. IDs here are implemented as a
// serial, thus this should always delete events in chronological order.
func (d *Database) RemoveEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) error {
return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID)
}
// GetLatestTxnID returns the latest available transaction id
func (d *Database) GetLatestTxnID(
ctx context.Context,
) (int, error) {
return d.txnID.selectTxnID(ctx)
}

View file

@ -1,82 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const txnIDSchema = `
-- Keeps a count of the current transaction ID
CREATE TABLE IF NOT EXISTS appservice_counters (
name TEXT PRIMARY KEY NOT NULL,
last_id INTEGER DEFAULT 1
);
INSERT OR IGNORE INTO appservice_counters (name, last_id) VALUES('txn_id', 1);
`
const selectTxnIDSQL = `
SELECT last_id FROM appservice_counters WHERE name='txn_id'
`
const updateTxnIDSQL = `
UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id'
`
type txnStatements struct {
db *sql.DB
writer sqlutil.Writer
selectTxnIDStmt *sql.Stmt
updateTxnIDStmt *sql.Stmt
}
func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = writer
_, err = db.Exec(txnIDSchema)
if err != nil {
return
}
if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil {
return
}
if s.updateTxnIDStmt, err = db.Prepare(updateTxnIDSQL); err != nil {
return
}
return
}
// selectTxnID selects the latest ascending transaction ID
func (s *txnStatements) selectTxnID(
ctx context.Context,
) (txnID int, err error) {
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
if err != nil {
return err
}
_, err = s.updateTxnIDStmt.ExecContext(ctx)
return err
})
return
}

View file

@ -1,40 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/appservice/storage/postgres"
"github.com/matrix-org/dendrite/appservice/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets DB connection parameters
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -1,34 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/appservice/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
)
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -1,64 +0,0 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"sync"
"github.com/matrix-org/dendrite/setup/config"
)
const (
// AppServiceDeviceID is the AS dummy device ID
AppServiceDeviceID = "AS_Device"
)
// ApplicationServiceWorkerState is a type that couples an application service,
// a lockable condition as well as some other state variables, allowing the
// roomserver to notify appservice workers when there are events ready to send
// externally to application services.
type ApplicationServiceWorkerState struct {
AppService config.ApplicationService
Cond *sync.Cond
// Events ready to be sent
EventsReady bool
// Backoff exponent (2^x secs). Max 6, aka 64s.
Backoff int
}
// NotifyNewEvents wakes up all waiting goroutines, notifying that events remain
// in the event queue for this application service worker.
func (a *ApplicationServiceWorkerState) NotifyNewEvents() {
a.Cond.L.Lock()
a.EventsReady = true
a.Cond.Broadcast()
a.Cond.L.Unlock()
}
// FinishEventProcessing marks all events of this worker as being sent to the
// application service.
func (a *ApplicationServiceWorkerState) FinishEventProcessing() {
a.Cond.L.Lock()
a.EventsReady = false
a.Cond.L.Unlock()
}
// WaitForNewEvents causes the calling goroutine to wait on the worker state's
// condition for a broadcast or similar wakeup, if there are no events ready.
func (a *ApplicationServiceWorkerState) WaitForNewEvents() {
a.Cond.L.Lock()
if !a.EventsReady {
a.Cond.Wait()
}
a.Cond.L.Unlock()
}

View file

@ -1,236 +0,0 @@
// Copyright 2018 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package workers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"math"
"net/http"
"net/url"
"time"
"github.com/matrix-org/dendrite/appservice/storage"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
var (
// Maximum size of events sent in each transaction.
transactionBatchSize = 50
)
// SetupTransactionWorkers spawns a separate goroutine for each application
// service. Each of these "workers" handle taking all events intended for their
// app service, batch them up into a single transaction (up to a max transaction
// size), then send that off to the AS's /transactions/{txnID} endpoint. It also
// handles exponentially backing off in case the AS isn't currently available.
func SetupTransactionWorkers(
client *http.Client,
appserviceDB storage.Database,
workerStates []types.ApplicationServiceWorkerState,
) error {
// Create a worker that handles transmitting events to a single homeserver
for _, workerState := range workerStates {
// Don't create a worker if this AS doesn't want to receive events
if workerState.AppService.URL != "" {
go worker(client, appserviceDB, workerState)
}
}
return nil
}
// worker is a goroutine that sends any queued events to the application service
// it is given.
func worker(client *http.Client, db storage.Database, ws types.ApplicationServiceWorkerState) {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).Info("Starting application service")
ctx := context.Background()
// Initial check for any leftover events to send from last time
eventCount, err := db.CountEventsWithAppServiceID(ctx, ws.AppService.ID)
if err != nil {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Fatal("appservice worker unable to read queued events from DB")
return
}
if eventCount > 0 {
ws.NotifyNewEvents()
}
// Loop forever and keep waiting for more events to send
for {
// Wait for more events if we've sent all the events in the database
ws.WaitForNewEvents()
// Batch events up into a transaction
transactionJSON, txnID, maxEventID, eventsRemaining, err := createTransaction(ctx, db, ws.AppService.ID)
if err != nil {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Fatal("appservice worker unable to create transaction")
return
}
// Send the events off to the application service
// Backoff if the application service does not respond
err = send(client, ws.AppService, txnID, transactionJSON)
if err != nil {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Error("unable to send event")
// Backoff
backoff(&ws, err)
continue
}
// We sent successfully, hooray!
ws.Backoff = 0
// Transactions have a maximum event size, so there may still be some events
// left over to send. Keep sending until none are left
if !eventsRemaining {
ws.FinishEventProcessing()
}
// Remove sent events from the DB
err = db.RemoveEventsBeforeAndIncludingID(ctx, ws.AppService.ID, maxEventID)
if err != nil {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Fatal("unable to remove appservice events from the database")
return
}
}
}
// backoff pauses the calling goroutine for a 2^some backoff exponent seconds
func backoff(ws *types.ApplicationServiceWorkerState, err error) {
// Calculate how long to backoff for
backoffDuration := time.Duration(math.Pow(2, float64(ws.Backoff)))
backoffSeconds := time.Second * backoffDuration
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Warnf("unable to send transactions successfully, backing off for %ds",
backoffDuration)
ws.Backoff++
if ws.Backoff > 6 {
ws.Backoff = 6
}
// Backoff
time.Sleep(backoffSeconds)
}
// createTransaction takes in a slice of AS events, stores them in an AS
// transaction, and JSON-encodes the results.
func createTransaction(
ctx context.Context,
db storage.Database,
appserviceID string,
) (
transactionJSON []byte,
txnID, maxID int,
eventsRemaining bool,
err error,
) {
// Retrieve the latest events from the DB (will return old events if they weren't successfully sent)
txnID, maxID, events, eventsRemaining, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize)
if err != nil {
log.WithFields(log.Fields{
"appservice": appserviceID,
}).WithError(err).Fatalf("appservice worker unable to read queued events from DB")
return
}
// Check if these events do not already have a transaction ID
if txnID == -1 {
// If not, grab next available ID from the DB
txnID, err = db.GetLatestTxnID(ctx)
if err != nil {
return nil, 0, 0, false, err
}
// Mark new events with current transactionID
if err = db.UpdateTxnIDForEvents(ctx, appserviceID, maxID, txnID); err != nil {
return nil, 0, 0, false, err
}
}
var ev []*gomatrixserverlib.HeaderedEvent
for i := range events {
ev = append(ev, &events[i])
}
// Create a transaction and store the events inside
transaction := gomatrixserverlib.ApplicationServiceTransaction{
Events: gomatrixserverlib.HeaderedToClientEvents(ev, gomatrixserverlib.FormatAll),
}
transactionJSON, err = json.Marshal(transaction)
if err != nil {
return
}
return
}
// send sends events to an application service. Returns an error if an OK was not
// received back from the application service or the request timed out.
func send(
client *http.Client,
appservice config.ApplicationService,
txnID int,
transaction []byte,
) (err error) {
// PUT a transaction to our AS
// https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid
address := fmt.Sprintf("%s/transactions/%d?access_token=%s", appservice.URL, txnID, url.QueryEscape(appservice.HSToken))
req, err := http.NewRequest("PUT", address, bytes.NewBuffer(transaction))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer checkNamedErr(resp.Body.Close, &err)
// Check the AS received the events correctly
if resp.StatusCode != http.StatusOK {
// TODO: Handle non-200 error codes from application services
return fmt.Errorf("non-OK status code %d returned from AS", resp.StatusCode)
}
return nil
}
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
}

View file

@ -255,7 +255,6 @@ func (m *DendriteMonolith) Start() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-federationsender.db", m.StorageDirectory, prefix)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-federationsender.db", m.StorageDirectory, prefix))
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-appservice.db", m.StorageDirectory, prefix))
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}

View file

@ -94,7 +94,6 @@ func (m *DendriteMonolith) Start() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-keyserver.db", m.StorageDirectory)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-keyserver.db", m.StorageDirectory))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-federationsender.db", m.StorageDirectory)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-federationsender.db", m.StorageDirectory))
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-appservice.db", m.StorageDirectory))
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.RegistrationDisabled = false

View file

@ -24,6 +24,7 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"strings"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -42,6 +43,7 @@ import (
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -70,31 +72,83 @@ func main() {
var pk ed25519.PublicKey var pk ed25519.PublicKey
var sk ed25519.PrivateKey var sk ed25519.PrivateKey
keyfile := *instanceName + ".key" // iterate through the cli args and check if the config flag was set
if _, err := os.Stat(keyfile); os.IsNotExist(err) { configFlagSet := false
if pk, sk, err = ed25519.GenerateKey(nil); err != nil { for _, arg := range os.Args {
panic(err) if arg == "--config" || arg == "-config" {
configFlagSet = true
break
} }
if err = os.WriteFile(keyfile, sk, 0644); err != nil {
panic(err)
}
} else if err == nil {
if sk, err = os.ReadFile(keyfile); err != nil {
panic(err)
}
if len(sk) != ed25519.PrivateKeySize {
panic("the private key is not long enough")
}
pk = sk.Public().(ed25519.PublicKey)
} }
cfg := &config.Dendrite{}
// use custom config if config flag is set
if configFlagSet {
cfg = setup.ParseFlags(true)
sk = cfg.Global.PrivateKey
} else {
keyfile := *instanceName + ".pem"
if _, err := os.Stat(keyfile); os.IsNotExist(err) {
oldkeyfile := *instanceName + ".key"
if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) {
if err = test.NewMatrixKey(keyfile); err != nil {
panic("failed to generate a new PEM key: " + err.Error())
}
if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil {
panic("failed to load PEM key: " + err.Error())
}
} else {
if sk, err = os.ReadFile(oldkeyfile); err != nil {
panic("failed to read the old private key: " + err.Error())
}
if len(sk) != ed25519.PrivateKeySize {
panic("the private key is not long enough")
}
if err := test.SaveMatrixKey(keyfile, sk); err != nil {
panic("failed to convert the private key to PEM format: " + err.Error())
}
}
} else {
var err error
if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil {
panic("failed to load PEM key: " + err.Error())
}
}
cfg.Defaults(true)
cfg.Global.PrivateKey = sk
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName))
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
if err := cfg.Derive(); err != nil {
panic(err)
}
}
pk = sk.Public().(ed25519.PublicKey)
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
base := base.NewBaseDendrite(cfg, "Monolith")
defer base.Close() // nolint: errcheck
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
pManager := pineconeConnections.NewConnectionManager(pRouter, nil) pManager := pineconeConnections.NewConnectionManager(pRouter, nil)
pMulticast.Start() pMulticast.Start()
if instancePeer != nil && *instancePeer != "" { if instancePeer != nil && *instancePeer != "" {
pManager.AddPeer(*instancePeer) for _, peer := range strings.Split(*instancePeer, ",") {
pManager.AddPeer(strings.Trim(peer, " \t\r\n"))
}
} }
go func() { go func() {
@ -125,29 +179,6 @@ func main() {
} }
}() }()
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
cfg.Global.PrivateKey = sk
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
if err := cfg.Derive(); err != nil {
panic(err)
}
base := base.NewBaseDendrite(cfg, "Monolith")
defer base.Close() // nolint: errcheck
federation := conn.CreateFederationClient(base, pQUIC) federation := conn.CreateFederationClient(base, pQUIC)
serverKeyAPI := &signing.YggdrasilKeys{} serverKeyAPI := &signing.YggdrasilKeys{}

View file

@ -86,7 +86,6 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
cfg.MSCs.MSCs = []string{"msc2836"} cfg.MSCs.MSCs = []string{"msc2836"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName)) cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName))
cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.RegistrationDisabled = false

View file

@ -24,7 +24,6 @@ func main() {
cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName) cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName)
} }
if *dbURI != "" { if *dbURI != "" {
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(*dbURI)
cfg.FederationAPI.Database.ConnectionString = config.DataSource(*dbURI) cfg.FederationAPI.Database.ConnectionString = config.DataSource(*dbURI)
cfg.KeyServer.Database.ConnectionString = config.DataSource(*dbURI) cfg.KeyServer.Database.ConnectionString = config.DataSource(*dbURI)
cfg.MSCs.Database.ConnectionString = config.DataSource(*dbURI) cfg.MSCs.Database.ConnectionString = config.DataSource(*dbURI)

View file

@ -132,13 +132,6 @@ app_service_api:
listen: http://[::]:7777 # The listen address for incoming API requests listen: http://[::]:7777 # The listen address for incoming API requests
connect: http://app_service_api:7777 # The connect address for other components to use connect: http://app_service_api:7777 # The connect address for other components to use
# Database configuration for this component.
database:
connection_string: postgresql://username:password@hostname/dendrite_appservice?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# Disable the validation of TLS certificates of appservices. This is # Disable the validation of TLS certificates of appservices. This is
# not recommended in production since it may allow appservice traffic # not recommended in production since it may allow appservice traffic
# to be sent to an insecure endpoint. # to be sent to an insecure endpoint.

View file

@ -67,14 +67,15 @@ func NewKeyChangeConsumer(
// Start consuming from key servers // Start consuming from key servers
func (t *KeyChangeConsumer) Start() error { func (t *KeyChangeConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1,
nats.DeliverAll(), nats.ManualAck(), t.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
// onMessage is called in response to a message received on the // onMessage is called in response to a message received on the
// key change events topic from the key server. // key change events topic from the key server.
func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *KeyChangeConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var m api.DeviceMessage var m api.DeviceMessage
if err := json.Unmarshal(msg.Data, &m); err != nil { if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic") logrus.WithError(err).Errorf("failed to read device message from key change topic")

View file

@ -69,14 +69,15 @@ func (t *OutputPresenceConsumer) Start() error {
return nil return nil
} }
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
) )
} }
// onMessage is called in response to a message received on the presence // onMessage is called in response to a message received on the presence
// events topic from the client api. // events topic from the client api.
func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// only send presence events which originated from us // only send presence events which originated from us
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
_, serverName, err := gomatrixserverlib.SplitID('@', userID) _, serverName, err := gomatrixserverlib.SplitID('@', userID)

View file

@ -65,14 +65,15 @@ func NewOutputReceiptConsumer(
// Start consuming from the clientapi // Start consuming from the clientapi
func (t *OutputReceiptConsumer) Start() error { func (t *OutputReceiptConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
) )
} }
// onMessage is called in response to a message received on the receipt // onMessage is called in response to a message received on the receipt
// events topic from the client api. // events topic from the client api.
func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
receipt := syncTypes.OutputReceiptEvent{ receipt := syncTypes.OutputReceiptEvent{
UserID: msg.Header.Get(jetstream.UserID), UserID: msg.Header.Get(jetstream.UserID),
RoomID: msg.Header.Get(jetstream.RoomID), RoomID: msg.Header.Get(jetstream.RoomID),

View file

@ -69,8 +69,8 @@ func NewOutputRoomEventConsumer(
// Start consuming from room servers // Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
@ -78,7 +78,8 @@ func (s *OutputRoomEventConsumer) Start() error {
// It is unsafe to call this with messages for the same room in multiple gorountines // It is unsafe to call this with messages for the same room in multiple gorountines
// because updates it will likely fail with a types.EventIDMismatchError when it // because updates it will likely fail with a types.EventIDMismatchError when it
// realises that it cannot update the room state using the deltas. // realises that it cannot update the room state using the deltas.
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Parse out the event JSON // Parse out the event JSON
var output api.OutputEvent var output api.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {

View file

@ -63,14 +63,15 @@ func NewOutputSendToDeviceConsumer(
// Start consuming from the client api // Start consuming from the client api
func (t *OutputSendToDeviceConsumer) Start() error { func (t *OutputSendToDeviceConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1,
nats.DeliverAll(), nats.ManualAck(), t.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
// onMessage is called in response to a message received on the // onMessage is called in response to a message received on the
// send-to-device events topic from the client api. // send-to-device events topic from the client api.
func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// only send send-to-device events which originated from us // only send send-to-device events which originated from us
sender := msg.Header.Get("sender") sender := msg.Header.Get("sender")
_, originServerName, err := gomatrixserverlib.SplitID('@', sender) _, originServerName, err := gomatrixserverlib.SplitID('@', sender)

View file

@ -62,14 +62,15 @@ func NewOutputTypingConsumer(
// Start consuming from the clientapi // Start consuming from the clientapi
func (t *OutputTypingConsumer) Start() error { func (t *OutputTypingConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
) )
} }
// onMessage is called in response to a message received on the typing // onMessage is called in response to a message received on the typing
// events topic from the client api. // events topic from the client api.
func (t *OutputTypingConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Extract the typing event from msg. // Extract the typing event from msg.
roomID := msg.Header.Get(jetstream.RoomID) roomID := msg.Header.Get(jetstream.RoomID)
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)

View file

@ -55,14 +55,15 @@ func NewDeviceListUpdateConsumer(
// Start consuming from key servers // Start consuming from key servers
func (t *DeviceListUpdateConsumer) Start() error { func (t *DeviceListUpdateConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1,
nats.DeliverAll(), nats.ManualAck(), t.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
// onMessage is called in response to a message received on the // onMessage is called in response to a message received on the
// key change events topic from the key server. // key change events topic from the key server.
func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var m gomatrixserverlib.DeviceListUpdateEvent var m gomatrixserverlib.DeviceListUpdateEvent
if err := json.Unmarshal(msg.Data, &m); err != nil { if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("Failed to read from device list update input topic") logrus.WithError(err).Errorf("Failed to read from device list update input topic")

View file

@ -224,12 +224,7 @@ func loadConfig(
} }
privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath) privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath)
privateKeyData, err := readFile(privateKeyPath) if c.Global.KeyID, c.Global.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil {
if err != nil {
return nil, err
}
if c.Global.KeyID, c.Global.PrivateKey, err = readKeyPEM(privateKeyPath, privateKeyData, true); err != nil {
return nil, err return nil, err
} }
@ -265,6 +260,14 @@ func loadConfig(
return &c, nil return &c, nil
} }
func LoadMatrixKey(privateKeyPath string, readFile func(string) ([]byte, error)) (gomatrixserverlib.KeyID, ed25519.PrivateKey, error) {
privateKeyData, err := readFile(privateKeyPath)
if err != nil {
return "", nil, err
}
return readKeyPEM(privateKeyPath, privateKeyData, true)
}
// Derive generates data that is derived from various values provided in // Derive generates data that is derived from various values provided in
// the config file. // the config file.
func (config *Dendrite) Derive() error { func (config *Dendrite) Derive() error {

View file

@ -31,8 +31,6 @@ type AppServiceAPI struct {
InternalAPI InternalAPIOptions `yaml:"internal_api"` InternalAPI InternalAPIOptions `yaml:"internal_api"`
Database DatabaseOptions `yaml:"database"`
// DisableTLSValidation disables the validation of X.509 TLS certs // DisableTLSValidation disables the validation of X.509 TLS certs
// on appservice endpoints. This is not recommended in production! // on appservice endpoints. This is not recommended in production!
DisableTLSValidation bool `yaml:"disable_tls_validation"` DisableTLSValidation bool `yaml:"disable_tls_validation"`
@ -43,16 +41,9 @@ type AppServiceAPI struct {
func (c *AppServiceAPI) Defaults(generate bool) { func (c *AppServiceAPI) Defaults(generate bool) {
c.InternalAPI.Listen = "http://localhost:7777" c.InternalAPI.Listen = "http://localhost:7777"
c.InternalAPI.Connect = "http://localhost:7777" c.InternalAPI.Connect = "http://localhost:7777"
c.Database.Defaults(5)
if generate {
c.Database.ConnectionString = "file:appservice.db"
}
} }
func (c *AppServiceAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { func (c *AppServiceAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
if c.Matrix.DatabaseOptions.ConnectionString == "" {
checkNotEmpty(configErrs, "app_service_api.database.connection_string", string(c.Database.ConnectionString))
}
if isMonolith { // polylith required configs below if isMonolith { // polylith required configs below
return return
} }

View file

@ -9,9 +9,16 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// JetStreamConsumer starts a durable consumer on the given subject with the
// given durable name. The function will be called when one or more messages
// is available, up to the maximum batch size specified. If the batch is set to
// 1 then messages will be delivered one at a time. If the function is called,
// the messages array is guaranteed to be at least 1 in size. Any provided NATS
// options will be passed through to the pull subscriber creation. The consumer
// will continue to run until the context expires, at which point it will stop.
func JetStreamConsumer( func JetStreamConsumer(
ctx context.Context, js nats.JetStreamContext, subj, durable string, ctx context.Context, js nats.JetStreamContext, subj, durable string, batch int,
f func(ctx context.Context, msg *nats.Msg) bool, f func(ctx context.Context, msgs []*nats.Msg) bool,
opts ...nats.SubOpt, opts ...nats.SubOpt,
) error { ) error {
defer func() { defer func() {
@ -50,7 +57,7 @@ func JetStreamConsumer(
// enforce its own deadline (roughly 5 seconds by default). Therefore // enforce its own deadline (roughly 5 seconds by default). Therefore
// it is our responsibility to check whether our context expired or // it is our responsibility to check whether our context expired or
// not when a context error is returned. Footguns. Footguns everywhere. // not when a context error is returned. Footguns. Footguns everywhere.
msgs, err := sub.Fetch(1, nats.Context(ctx)) msgs, err := sub.Fetch(batch, nats.Context(ctx))
if err != nil { if err != nil {
if err == context.Canceled || err == context.DeadlineExceeded { if err == context.Canceled || err == context.DeadlineExceeded {
// Work out whether it was the JetStream context that expired // Work out whether it was the JetStream context that expired
@ -74,21 +81,26 @@ func JetStreamConsumer(
if len(msgs) < 1 { if len(msgs) < 1 {
continue continue
} }
msg := msgs[0] for _, msg := range msgs {
if err = msg.InProgress(nats.Context(ctx)); err != nil { if err = msg.InProgress(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
sentry.CaptureException(err)
continue
}
if f(ctx, msg) {
if err = msg.AckSync(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
sentry.CaptureException(err) sentry.CaptureException(err)
continue
}
}
if f(ctx, msgs) {
for _, msg := range msgs {
if err = msg.AckSync(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
sentry.CaptureException(err)
}
} }
} else { } else {
if err = msg.Nak(nats.Context(ctx)); err != nil { for _, msg := range msgs {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) if err = msg.Nak(nats.Context(ctx)); err != nil {
sentry.CaptureException(err) logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
sentry.CaptureException(err)
}
} }
} }
} }

View file

@ -183,6 +183,7 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
OutputReceiptEvent: {"SyncAPIEDUServerReceiptConsumer", "FederationAPIEDUServerConsumer"}, OutputReceiptEvent: {"SyncAPIEDUServerReceiptConsumer", "FederationAPIEDUServerConsumer"},
OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"}, OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"},
OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"}, OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"},
OutputRoomEvent: {"AppserviceRoomserverConsumer"},
} { } {
streamName := cfg.Matrix.JetStream.Prefixed(stream) streamName := cfg.Matrix.JetStream.Prefixed(stream)
for _, consumer := range consumers { for _, consumer := range consumers {

View file

@ -75,15 +75,16 @@ func NewOutputClientDataConsumer(
// Start consuming from room servers // Start consuming from room servers
func (s *OutputClientDataConsumer) Start() error { func (s *OutputClientDataConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
// onMessage is called when the sync server receives a new event from the client API server output log. // onMessage is called when the sync server receives a new event from the client API server output log.
// It is not safe for this function to be called from multiple goroutines, or else the // It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated. // sync stream position may race and be incorrectly calculated.
func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Parse out the event JSON // Parse out the event JSON
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
var output eventutil.AccountData var output eventutil.AccountData

View file

@ -75,12 +75,13 @@ func NewOutputKeyChangeEventConsumer(
// Start consuming from the key server // Start consuming from the key server
func (s *OutputKeyChangeEventConsumer) Start() error { func (s *OutputKeyChangeEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var m api.DeviceMessage var m api.DeviceMessage
if err := json.Unmarshal(msg.Data, &m); err != nil { if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic") logrus.WithError(err).Errorf("failed to read device message from key change topic")

View file

@ -128,12 +128,13 @@ func (s *PresenceConsumer) Start() error {
return nil return nil
} }
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.presenceTopic, s.durable, s.onMessage, s.ctx, s.jetstream, s.presenceTopic, s.durable, 1, s.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
) )
} }
func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *PresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
presence := msg.Header.Get("presence") presence := msg.Header.Get("presence")
timestamp := msg.Header.Get("last_active_ts") timestamp := msg.Header.Get("last_active_ts")

View file

@ -74,12 +74,13 @@ func NewOutputReceiptEventConsumer(
// Start consuming receipts events. // Start consuming receipts events.
func (s *OutputReceiptEventConsumer) Start() error { func (s *OutputReceiptEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
output := types.OutputReceiptEvent{ output := types.OutputReceiptEvent{
UserID: msg.Header.Get(jetstream.UserID), UserID: msg.Header.Get(jetstream.UserID),
RoomID: msg.Header.Get(jetstream.RoomID), RoomID: msg.Header.Get(jetstream.RoomID),

View file

@ -80,15 +80,16 @@ func NewOutputRoomEventConsumer(
// Start consuming from room servers // Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
// onMessage is called when the sync server receives a new event from the room server output log. // onMessage is called when the sync server receives a new event from the room server output log.
// It is not safe for this function to be called from multiple goroutines, or else the // It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated. // sync stream position may race and be incorrectly calculated.
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Parse out the event JSON // Parse out the event JSON
var err error var err error
var output api.OutputEvent var output api.OutputEvent

View file

@ -68,12 +68,13 @@ func NewOutputSendToDeviceEventConsumer(
// Start consuming send-to-device events. // Start consuming send-to-device events.
func (s *OutputSendToDeviceEventConsumer) Start() error { func (s *OutputSendToDeviceEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {

View file

@ -64,12 +64,13 @@ func NewOutputTypingEventConsumer(
// Start consuming typing events. // Start consuming typing events.
func (s *OutputTypingEventConsumer) Start() error { func (s *OutputTypingEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
roomID := msg.Header.Get(jetstream.RoomID) roomID := msg.Header.Get(jetstream.RoomID)
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
typing, err := strconv.ParseBool(msg.Header.Get("typing")) typing, err := strconv.ParseBool(msg.Header.Get("typing"))

View file

@ -67,8 +67,8 @@ func NewOutputNotificationDataConsumer(
// Start starts consumption. // Start starts consumption.
func (s *OutputNotificationDataConsumer) Start() error { func (s *OutputNotificationDataConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
@ -76,7 +76,8 @@ func (s *OutputNotificationDataConsumer) Start() error {
// the push server. It is not safe for this function to be called from // the push server. It is not safe for this function to be called from
// multiple goroutines, or else the sync stream position may race and // multiple goroutines, or else the sync stream position may race and
// be incorrectly calculated. // be incorrectly calculated.
func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := string(msg.Header.Get(jetstream.UserID)) userID := string(msg.Header.Get(jetstream.UserID))
// Parse out the event JSON // Parse out the event JSON

View file

@ -15,6 +15,7 @@
package test package test
import ( import (
"crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
@ -44,6 +45,10 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
if err != nil { if err != nil {
return err return err
} }
return SaveMatrixKey(matrixKeyPath, data[3:])
}
func SaveMatrixKey(matrixKeyPath string, data ed25519.PrivateKey) error {
keyOut, err := os.OpenFile(matrixKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) keyOut, err := os.OpenFile(matrixKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil { if err != nil {
return err return err
@ -62,7 +67,7 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
Headers: map[string]string{ Headers: map[string]string{
"Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]), "Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]),
}, },
Bytes: data[3:], Bytes: data,
}) })
return err return err
} }

View file

@ -57,7 +57,6 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() { return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() {
// cleanup db files. This risks getting out of sync as we add more database strings :( // cleanup db files. This risks getting out of sync as we add more database strings :(
dbFiles := []config.DataSource{ dbFiles := []config.DataSource{
cfg.AppServiceAPI.Database.ConnectionString,
cfg.FederationAPI.Database.ConnectionString, cfg.FederationAPI.Database.ConnectionString,
cfg.KeyServer.Database.ConnectionString, cfg.KeyServer.Database.ConnectionString,
cfg.MSCs.Database.ConnectionString, cfg.MSCs.Database.ConnectionString,

View file

@ -56,15 +56,16 @@ func NewOutputReadUpdateConsumer(
func (s *OutputReadUpdateConsumer) Start() error { func (s *OutputReadUpdateConsumer) Start() error {
if err := jetstream.JetStreamConsumer( if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
); err != nil { ); err != nil {
return err return err
} }
return nil return nil
} }
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var read types.ReadUpdate var read types.ReadUpdate
if err := json.Unmarshal(msg.Data, &read); err != nil { if err := json.Unmarshal(msg.Data, &read); err != nil {
log.WithError(err).Error("userapi clientapi consumer: message parse failure") log.WithError(err).Error("userapi clientapi consumer: message parse failure")

View file

@ -65,15 +65,16 @@ func NewOutputStreamEventConsumer(
func (s *OutputStreamEventConsumer) Start() error { func (s *OutputStreamEventConsumer) Start() error {
if err := jetstream.JetStreamConsumer( if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
); err != nil { ); err != nil {
return err return err
} }
return nil return nil
} }
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var output types.StreamedEvent var output types.StreamedEvent
output.Event = &gomatrixserverlib.HeaderedEvent{} output.Event = &gomatrixserverlib.HeaderedEvent{}
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {

View file

@ -28,7 +28,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
@ -454,7 +453,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// Create a dummy device for AS user // Create a dummy device for AS user
dev := api.Device{ dev := api.Device{
// Use AS dummy device ID // Use AS dummy device ID
ID: types.AppServiceDeviceID, ID: "AS_Device",
// AS dummy device has AS's token. // AS dummy device has AS's token.
AccessToken: token, AccessToken: token,
AppserviceID: appService.ID, AppserviceID: appService.ID,