diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 3044b0b9d..be3c7c173 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -7,6 +7,7 @@ on: pull_request: release: types: [published] + workflow_dispatch: concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -375,6 +376,8 @@ jobs: # Build initial Dendrite image - run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . working-directory: dendrite + env: + DOCKER_BUILDKIT: 1 # Run Complement - run: | diff --git a/CHANGES.md b/CHANGES.md index 36a1a6311..0f57bffcb 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,30 @@ # Changelog +## Dendrite 0.9.5 (2022-08-25) + +### Fixes + +* The roomserver will now correctly unreject previously rejected events if necessary when reprocessing +* The handling of event soft-failure has been improved on the roomserver input by no longer applying rejection rules and still calculating state before the event if possible +* The federation `/state` and `/state_ids` endpoints should now return the correct error code when the state isn't known instead of returning a HTTP 500 +* The federation `/event` should now return outlier events correctly instead of returning a HTTP 500 +* A bug in the federation backoff allowing zero intervals has been corrected +* The `create-account` utility will no longer error if the homeserver URL ends in a trailing slash +* A regression in `/sync` introduced in 0.9.4 should be fixed + +## Dendrite 0.9.4 (2022-08-19) + +### Fixes + +* A bug in the roomserver around handling rejected outliers has been fixed +* Backfilled events will now use the correct history visibility where possible +* The device list updater backoff has been fixed, which should reduce the number of outbound HTTP requests and `Failed to query device keys for some users` log entries for dead servers +* The `/sync` endpoint will no longer incorrectly return room entries for retired invites which could cause some rooms to show up in the client "Historical" section +* The `/createRoom` endpoint will now correctly populate `is_direct` in invite membership events, which may help clients to classify direct messages correctly +* The `create-account` tool will now log an error if the shared secret is not set in the Dendrite config +* A couple of minor bugs have been fixed in the membership lazy-loading +* Queued EDUs in the federation API are now cached properly + ## Dendrite 0.9.3 (2022-08-15) ### Important diff --git a/appservice/README.md b/appservice/README.md deleted file mode 100644 index d75557448..000000000 --- a/appservice/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Application Service - -This component interfaces with external [Application -Services](https://matrix.org/docs/spec/application_service/unstable.html). -This includes any HTTP endpoints that application services call, as well as talking -to any HTTP endpoints that application services provide themselves. - -## Consumers - -This component consumes and filters events from the Roomserver Kafka stream, passing on any necessary events to subscribing application services. \ No newline at end of file diff --git a/appservice/appservice.go b/appservice/appservice.go index 8fe1b2fc4..9000adb1d 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -18,7 +18,6 @@ import ( "context" "crypto/tls" "net/http" - "sync" "time" "github.com/gorilla/mux" @@ -28,9 +27,6 @@ import ( "github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/inthttp" "github.com/matrix-org/dendrite/appservice/query" - "github.com/matrix-org/dendrite/appservice/storage" - "github.com/matrix-org/dendrite/appservice/types" - "github.com/matrix-org/dendrite/appservice/workers" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" @@ -59,57 +55,40 @@ func NewInternalAPI( Proxy: http.ProxyFromEnvironment, }, } - js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + // Create appserivce query API with an HTTP client that will be used for all + // outbound and inbound requests (inbound only for the internal API) + appserviceQueryAPI := &query.AppServiceQueryAPI{ + HTTPClient: client, + Cfg: &base.Cfg.AppServiceAPI, + } - // Create a connection to the appservice postgres DB - appserviceDB, err := storage.NewDatabase(base, &base.Cfg.AppServiceAPI.Database) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to appservice db") + if len(base.Cfg.Derived.ApplicationServices) == 0 { + return appserviceQueryAPI } // Wrap application services in a type that relates the application service and // a sync.Cond object that can be used to notify workers when there are new // events to be sent out. - workerStates := make([]types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices)) - for i, appservice := range base.Cfg.Derived.ApplicationServices { - m := sync.Mutex{} - ws := types.ApplicationServiceWorkerState{ - AppService: appservice, - Cond: sync.NewCond(&m), - } - workerStates[i] = ws - + for _, appservice := range base.Cfg.Derived.ApplicationServices { // Create bot account for this AS if it doesn't already exist - if err = generateAppServiceAccount(userAPI, appservice); err != nil { + if err := generateAppServiceAccount(userAPI, appservice); err != nil { logrus.WithFields(logrus.Fields{ "appservice": appservice.ID, }).WithError(err).Panicf("failed to generate bot account for appservice") } } - // Create appserivce query API with an HTTP client that will be used for all - // outbound and inbound requests (inbound only for the internal API) - appserviceQueryAPI := &query.AppServiceQueryAPI{ - HTTPClient: client, - Cfg: base.Cfg, - } - // Only consume if we actually have ASes to track, else we'll just chew cycles needlessly. // We can't add ASes at runtime so this is safe to do. - if len(workerStates) > 0 { - consumer := consumers.NewOutputRoomEventConsumer( - base.ProcessContext, base.Cfg, js, appserviceDB, - rsAPI, workerStates, - ) - if err := consumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start appservice roomserver consumer") - } + js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + consumer := consumers.NewOutputRoomEventConsumer( + base.ProcessContext, &base.Cfg.AppServiceAPI, + client, js, rsAPI, + ) + if err := consumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start appservice roomserver consumer") } - // Create application service transaction workers - if err := workers.SetupTransactionWorkers(client, appserviceDB, workerStates); err != nil { - logrus.WithError(err).Panicf("failed to start app service transaction workers") - } return appserviceQueryAPI } diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 37d4ef9c2..a30944e75 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -15,193 +15,214 @@ package consumers import ( + "bytes" "context" "encoding/json" + "fmt" + "math" + "net/http" + "net/url" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" - "github.com/matrix-org/dendrite/appservice/storage" - "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" ) // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - asDB storage.Database - rsAPI api.AppserviceRoomserverAPI - serverName string - workerStates []types.ApplicationServiceWorkerState + ctx context.Context + cfg *config.AppServiceAPI + client *http.Client + jetstream nats.JetStreamContext + topic string + rsAPI api.AppserviceRoomserverAPI +} + +type appserviceState struct { + *config.ApplicationService + backoff int } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call // Start() to begin consuming from room servers. func NewOutputRoomEventConsumer( process *process.ProcessContext, - cfg *config.Dendrite, + cfg *config.AppServiceAPI, + client *http.Client, js nats.JetStreamContext, - appserviceDB storage.Database, rsAPI api.AppserviceRoomserverAPI, - workerStates []types.ApplicationServiceWorkerState, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Global.JetStream.Durable("AppserviceRoomserverConsumer"), - topic: cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent), - asDB: appserviceDB, - rsAPI: rsAPI, - serverName: string(cfg.Global.ServerName), - workerStates: workerStates, + ctx: process.Context(), + cfg: cfg, + client: client, + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), + rsAPI: rsAPI, } } // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), - ) + for _, as := range s.cfg.Derived.ApplicationServices { + appsvc := as + state := &appserviceState{ + ApplicationService: &appsvc, + } + token := jetstream.Tokenise(as.ID) + if err := jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, + s.cfg.Matrix.JetStream.Durable("Appservice_"+token), + 50, // maximum number of events to send in a single transaction + func(ctx context.Context, msgs []*nats.Msg) bool { + return s.onMessage(ctx, state, msgs) + }, + nats.DeliverNew(), nats.ManualAck(), + ); err != nil { + return fmt.Errorf("failed to create %q consumer: %w", token, err) + } + } + return nil } // onMessage is called when the appservice component receives a new event from // the room server output log. -func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { - // Parse out the event JSON - var output api.OutputEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } - - log.WithFields(log.Fields{ - "type": output.Type, - }).Debug("Got a message in OutputRoomEventConsumer") - - events := []*gomatrixserverlib.HeaderedEvent{} - if output.Type == api.OutputTypeNewRoomEvent && output.NewRoomEvent != nil { - newEventID := output.NewRoomEvent.Event.EventID() - events = append(events, output.NewRoomEvent.Event) - if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { - eventsReq := &api.QueryEventsByIDRequest{ - EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), - } - eventsRes := &api.QueryEventsByIDResponse{} - for _, eventID := range output.NewRoomEvent.AddsStateEventIDs { - if eventID != newEventID { - eventsReq.EventIDs = append(eventsReq.EventIDs, eventID) - } - } - if len(eventsReq.EventIDs) > 0 { - if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil { - return false - } - events = append(events, eventsRes.Events...) - } +func (s *OutputRoomEventConsumer) onMessage( + ctx context.Context, state *appserviceState, msgs []*nats.Msg, +) bool { + log.WithField("appservice", state.ID).Tracef("Appservice worker received %d message(s) from roomserver", len(msgs)) + events := make([]*gomatrixserverlib.HeaderedEvent, 0, len(msgs)) + for _, msg := range msgs { + // Parse out the event JSON + var output api.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithField("appservice", state.ID).WithError(err).Errorf("Appservice failed to parse message, ignoring") + continue } - } else if output.Type == api.OutputTypeNewInviteEvent && output.NewInviteEvent != nil { - events = append(events, output.NewInviteEvent.Event) - } else { - log.WithFields(log.Fields{ - "type": output.Type, - }).Debug("appservice OutputRoomEventConsumer ignoring event", string(msg.Data)) + switch output.Type { + case api.OutputTypeNewRoomEvent: + if output.NewRoomEvent == nil || !s.appserviceIsInterestedInEvent(ctx, output.NewRoomEvent.Event, state.ApplicationService) { + continue + } + events = append(events, output.NewRoomEvent.Event) + if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { + newEventID := output.NewRoomEvent.Event.EventID() + eventsReq := &api.QueryEventsByIDRequest{ + EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), + } + eventsRes := &api.QueryEventsByIDResponse{} + for _, eventID := range output.NewRoomEvent.AddsStateEventIDs { + if eventID != newEventID { + eventsReq.EventIDs = append(eventsReq.EventIDs, eventID) + } + } + if len(eventsReq.EventIDs) > 0 { + if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil { + log.WithError(err).Errorf("s.rsAPI.QueryEventsByID failed") + return false + } + events = append(events, eventsRes.Events...) + } + } + + case api.OutputTypeNewInviteEvent: + if output.NewInviteEvent == nil { + continue + } + events = append(events, output.NewInviteEvent.Event) + + default: + continue + } + } + + // If there are no events selected for sending then we should + // ack the messages so that we don't get sent them again in the + // future. + if len(events) == 0 { return true } - // Send event to any relevant application services - if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { - log.WithError(err).Errorf("roomserver output log: filter error") - return true - } - - return true + // Send event to any relevant application services. If we hit + // an error here, return false, so that we negatively ack. + log.WithField("appservice", state.ID).Debugf("Appservice worker sending %d events(s) from roomserver", len(events)) + return s.sendEvents(ctx, state, events) == nil } -// filterRoomserverEvents takes in events and decides whether any of them need -// to be passed on to an external application service. It does this by checking -// each namespace of each registered application service, and if there is a -// match, adds the event to the queue for events to be sent to a particular -// application service. -func (s *OutputRoomEventConsumer) filterRoomserverEvents( - ctx context.Context, +// sendEvents passes events to the appservice by using the transactions +// endpoint. It will block for the backoff period if necessary. +func (s *OutputRoomEventConsumer) sendEvents( + ctx context.Context, state *appserviceState, events []*gomatrixserverlib.HeaderedEvent, ) error { - for _, ws := range s.workerStates { - for _, event := range events { - // Check if this event is interesting to this application service - if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) { - // Queue this event to be sent off to the application service - if err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil { - log.WithError(err).Warn("failed to insert incoming event into appservices database") - return err - } else { - // Tell our worker to send out new messages by updating remaining message - // count and waking them up with a broadcast - ws.NotifyNewEvents() - } - } - } + // Create the transaction body. + transaction, err := json.Marshal( + gomatrixserverlib.ApplicationServiceTransaction{ + Events: gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll), + }, + ) + if err != nil { + return err } + // TODO: We should probably be more intelligent and pick something not + // in the control of the event. A NATS timestamp header or something maybe. + txnID := events[0].Event.OriginServerTS() + + // Send the transaction to the appservice. + // https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid + address := fmt.Sprintf("%s/transactions/%d?access_token=%s", state.URL, txnID, url.QueryEscape(state.HSToken)) + req, err := http.NewRequestWithContext(ctx, "PUT", address, bytes.NewBuffer(transaction)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + resp, err := s.client.Do(req) + if err != nil { + return state.backoffAndPause(err) + } + + // If the response was fine then we can clear any backoffs in place and + // report that everything was OK. Otherwise, back off for a while. + switch resp.StatusCode { + case http.StatusOK: + state.backoff = 0 + default: + return state.backoffAndPause(fmt.Errorf("received HTTP status code %d from appservice", resp.StatusCode)) + } return nil } -// appserviceJoinedAtEvent returns a boolean depending on whether a given -// appservice has membership at the time a given event was created. -func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool { - // TODO: This is only checking the current room state, not the state at - // the event in question. Pretty sure this is what Synapse does too, but - // until we have a lighter way of checking the state before the event that - // doesn't involve state res, then this is probably OK. - membershipReq := &api.QueryMembershipsForRoomRequest{ - RoomID: event.RoomID(), - JoinedOnly: true, +// backoff pauses the calling goroutine for a 2^some backoff exponent seconds +func (s *appserviceState) backoffAndPause(err error) error { + if s.backoff < 6 { + s.backoff++ } - membershipRes := &api.QueryMembershipsForRoomResponse{} - - // XXX: This could potentially race if the state for the event is not known yet - // e.g. the event came over federation but we do not have the full state persisted. - if err := s.rsAPI.QueryMembershipsForRoom(ctx, membershipReq, membershipRes); err == nil { - for _, ev := range membershipRes.JoinEvents { - var membership gomatrixserverlib.MemberContent - if err = json.Unmarshal(ev.Content, &membership); err != nil || ev.StateKey == nil { - continue - } - if appservice.IsInterestedInUserID(*ev.StateKey) { - return true - } - } - } else { - log.WithFields(log.Fields{ - "room_id": event.RoomID(), - }).WithError(err).Errorf("Unable to get membership for room") - } - return false + duration := time.Second * time.Duration(math.Pow(2, float64(s.backoff))) + log.WithField("appservice", s.ID).WithError(err).Errorf("Unable to send transaction to appservice, backing off for %s", duration.String()) + time.Sleep(duration) + return err } // appserviceIsInterestedInEvent returns a boolean depending on whether a given // event falls within one of a given application service's namespaces. // // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 -func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool { - // No reason to queue events if they'll never be sent to the application - // service - if appservice.URL == "" { +func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool { + switch { + case appservice.URL == "": return false - } - - // Check Room ID and Sender of the event - if appservice.IsInterestedInUserID(event.Sender()) || - appservice.IsInterestedInRoomID(event.RoomID()) { + case appservice.IsInterestedInUserID(event.Sender()): + return true + case appservice.IsInterestedInRoomID(event.RoomID()): return true } @@ -222,10 +243,52 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont } } else { log.WithFields(log.Fields{ - "room_id": event.RoomID(), + "appservice": appservice.ID, + "room_id": event.RoomID(), }).WithError(err).Errorf("Unable to get aliases for room") } // Check if any of the members in the room match the appservice return s.appserviceJoinedAtEvent(ctx, event, appservice) } + +// appserviceJoinedAtEvent returns a boolean depending on whether a given +// appservice has membership at the time a given event was created. +func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool { + // TODO: This is only checking the current room state, not the state at + // the event in question. Pretty sure this is what Synapse does too, but + // until we have a lighter way of checking the state before the event that + // doesn't involve state res, then this is probably OK. + membershipReq := &api.QueryMembershipsForRoomRequest{ + RoomID: event.RoomID(), + JoinedOnly: true, + } + membershipRes := &api.QueryMembershipsForRoomResponse{} + + // XXX: This could potentially race if the state for the event is not known yet + // e.g. the event came over federation but we do not have the full state persisted. + if err := s.rsAPI.QueryMembershipsForRoom(ctx, membershipReq, membershipRes); err == nil { + for _, ev := range membershipRes.JoinEvents { + switch { + case ev.StateKey == nil: + continue + case ev.Type != gomatrixserverlib.MRoomMember: + continue + } + var membership gomatrixserverlib.MemberContent + err = json.Unmarshal(ev.Content, &membership) + switch { + case err != nil: + continue + case membership.Membership == gomatrixserverlib.Join: + return true + } + } + } else { + log.WithFields(log.Fields{ + "appservice": appservice.ID, + "room_id": event.RoomID(), + }).WithError(err).Errorf("Unable to get membership for room") + } + return false +} diff --git a/appservice/query/query.go b/appservice/query/query.go index dacd3caa8..53b34cb18 100644 --- a/appservice/query/query.go +++ b/appservice/query/query.go @@ -33,7 +33,7 @@ const userIDExistsPath = "/users/" // AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI type AppServiceQueryAPI struct { HTTPClient *http.Client - Cfg *config.Dendrite + Cfg *config.AppServiceAPI } // RoomAliasExists performs a request to '/room/{roomAlias}' on all known diff --git a/appservice/storage/interface.go b/appservice/storage/interface.go deleted file mode 100644 index 25d35af6c..000000000 --- a/appservice/storage/interface.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - - "github.com/matrix-org/gomatrixserverlib" -) - -type Database interface { - StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error - GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) - CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error) - UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error - RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error - GetLatestTxnID(ctx context.Context) (int, error) -} diff --git a/appservice/storage/postgres/appservice_events_table.go b/appservice/storage/postgres/appservice_events_table.go deleted file mode 100644 index a95be6b8a..000000000 --- a/appservice/storage/postgres/appservice_events_table.go +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright 2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "encoding/json" - "time" - - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" -) - -const appserviceEventsSchema = ` --- Stores events to be sent to application services -CREATE TABLE IF NOT EXISTS appservice_events ( - -- An auto-incrementing id unique to each event in the table - id BIGSERIAL NOT NULL PRIMARY KEY, - -- The ID of the application service the event will be sent to - as_id TEXT NOT NULL, - -- JSON representation of the event - headered_event_json TEXT NOT NULL, - -- The ID of the transaction that this event is a part of - txn_id BIGINT NOT NULL -); - -CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id); -` - -const selectEventsByApplicationServiceIDSQL = "" + - "SELECT id, headered_event_json, txn_id " + - "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" - -const countEventsByApplicationServiceIDSQL = "" + - "SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" - -const insertEventSQL = "" + - "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " + - "VALUES ($1, $2, $3)" - -const updateTxnIDForEventsSQL = "" + - "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3" - -const deleteEventsBeforeAndIncludingIDSQL = "" + - "DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2" - -const ( - // A transaction ID number that no transaction should ever have. Used for - // checking again the default value. - invalidTxnID = -2 -) - -type eventsStatements struct { - selectEventsByApplicationServiceIDStmt *sql.Stmt - countEventsByApplicationServiceIDStmt *sql.Stmt - insertEventStmt *sql.Stmt - updateTxnIDForEventsStmt *sql.Stmt - deleteEventsBeforeAndIncludingIDStmt *sql.Stmt -} - -func (s *eventsStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(appserviceEventsSchema) - if err != nil { - return - } - - if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil { - return - } - if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil { - return - } - if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { - return - } - if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil { - return - } - if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil { - return - } - - return -} - -// selectEventsByApplicationServiceID takes in an application service ID and -// returns a slice of events that need to be sent to that application service, -// as well as an int later used to remove these same events from the database -// once successfully sent to an application service. -func (s *eventsStatements) selectEventsByApplicationServiceID( - ctx context.Context, - applicationServiceID string, - limit int, -) ( - txnID, maxID int, - events []gomatrixserverlib.HeaderedEvent, - eventsRemaining bool, - err error, -) { - defer func() { - if err != nil { - log.WithFields(log.Fields{ - "appservice": applicationServiceID, - }).WithError(err).Fatalf("appservice unable to select new events to send") - } - }() - // Retrieve events from the database. Unsuccessfully sent events first - eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID) - if err != nil { - return - } - defer checkNamedErr(eventRows.Close, &err) - events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) - if err != nil { - return - } - - return -} - -// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil -func checkNamedErr(fn func() error, err *error) { - if e := fn(); e != nil && *err == nil { - *err = e - } -} - -func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { - // Get current time for use in calculating event age - nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - - // Iterate through each row and store event contents - // If txn_id changes dramatically, we've switched from collecting old events to - // new ones. Send back those events first. - lastTxnID := invalidTxnID - for eventsProcessed := 0; eventRows.Next(); { - var event gomatrixserverlib.HeaderedEvent - var eventJSON []byte - var id int - err = eventRows.Scan( - &id, - &eventJSON, - &txnID, - ) - if err != nil { - return nil, 0, 0, false, err - } - - // Unmarshal eventJSON - if err = json.Unmarshal(eventJSON, &event); err != nil { - return nil, 0, 0, false, err - } - - // If txnID has changed on this event from the previous event, then we've - // reached the end of a transaction's events. Return only those events. - if lastTxnID > invalidTxnID && lastTxnID != txnID { - return events, maxID, lastTxnID, true, nil - } - lastTxnID = txnID - - // Limit events that aren't part of an old transaction - if txnID == -1 { - // Return if we've hit the limit - if eventsProcessed++; eventsProcessed > limit { - return events, maxID, lastTxnID, true, nil - } - } - - if id > maxID { - maxID = id - } - - // Portion of the event that is unsigned due to rapid change - // TODO: Consider removing age as not many app services use it - if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil { - return nil, 0, 0, false, err - } - - events = append(events, event) - } - - return -} - -// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service -// IDs into the db. -func (s *eventsStatements) countEventsByApplicationServiceID( - ctx context.Context, - appServiceID string, -) (int, error) { - var count int - err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count) - if err != nil && err != sql.ErrNoRows { - return 0, err - } - - return count, nil -} - -// insertEvent inserts an event mapped to its corresponding application service -// IDs into the db. -func (s *eventsStatements) insertEvent( - ctx context.Context, - appServiceID string, - event *gomatrixserverlib.HeaderedEvent, -) (err error) { - // Convert event to JSON before inserting - eventJSON, err := json.Marshal(event) - if err != nil { - return err - } - - _, err = s.insertEventStmt.ExecContext( - ctx, - appServiceID, - eventJSON, - -1, // No transaction ID yet - ) - return -} - -// updateTxnIDForEvents sets the transactionID for a collection of events. Done -// before sending them to an AppService. Referenced before sending to make sure -// we aren't constructing multiple transactions with the same events. -func (s *eventsStatements) updateTxnIDForEvents( - ctx context.Context, - appserviceID string, - maxID, txnID int, -) (err error) { - _, err = s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID) - return -} - -// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database. -func (s *eventsStatements) deleteEventsBeforeAndIncludingID( - ctx context.Context, - appserviceID string, - eventTableID int, -) (err error) { - _, err = s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID) - return -} diff --git a/appservice/storage/postgres/storage.go b/appservice/storage/postgres/storage.go deleted file mode 100644 index a4c04b2cc..000000000 --- a/appservice/storage/postgres/storage.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - - // Import postgres database driver - _ "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" -) - -// Database stores events intended to be later sent to application services -type Database struct { - events eventsStatements - txnID txnStatements - db *sql.DB - writer sqlutil.Writer -} - -// NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*Database, error) { - var result Database - var err error - if result.db, result.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { - return nil, err - } - if err = result.prepare(); err != nil { - return nil, err - } - return &result, nil -} - -func (d *Database) prepare() error { - if err := d.events.prepare(d.db); err != nil { - return err - } - - return d.txnID.prepare(d.db) -} - -// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database -// for a transaction worker to pull and later send to an application service. -func (d *Database) StoreEvent( - ctx context.Context, - appServiceID string, - event *gomatrixserverlib.HeaderedEvent, -) error { - return d.events.insertEvent(ctx, appServiceID, event) -} - -// GetEventsWithAppServiceID returns a slice of events and their IDs intended to -// be sent to an application service given its ID. -func (d *Database) GetEventsWithAppServiceID( - ctx context.Context, - appServiceID string, - limit int, -) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) { - return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) -} - -// CountEventsWithAppServiceID returns the number of events destined for an -// application service given its ID. -func (d *Database) CountEventsWithAppServiceID( - ctx context.Context, - appServiceID string, -) (int, error) { - return d.events.countEventsByApplicationServiceID(ctx, appServiceID) -} - -// UpdateTxnIDForEvents takes in an application service ID and a -// and stores them in the DB, unless the pair already exists, in -// which case it updates them. -func (d *Database) UpdateTxnIDForEvents( - ctx context.Context, - appserviceID string, - maxID, txnID int, -) error { - return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID) -} - -// RemoveEventsBeforeAndIncludingID removes all events from the database that -// are less than or equal to a given maximum ID. IDs here are implemented as a -// serial, thus this should always delete events in chronological order. -func (d *Database) RemoveEventsBeforeAndIncludingID( - ctx context.Context, - appserviceID string, - eventTableID int, -) error { - return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID) -} - -// GetLatestTxnID returns the latest available transaction id -func (d *Database) GetLatestTxnID( - ctx context.Context, -) (int, error) { - return d.txnID.selectTxnID(ctx) -} diff --git a/appservice/storage/postgres/txn_id_counter_table.go b/appservice/storage/postgres/txn_id_counter_table.go deleted file mode 100644 index a96a0e360..000000000 --- a/appservice/storage/postgres/txn_id_counter_table.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" -) - -const txnIDSchema = ` --- Keeps a count of the current transaction ID -CREATE SEQUENCE IF NOT EXISTS txn_id_counter START 1; -` - -const selectTxnIDSQL = "SELECT nextval('txn_id_counter')" - -type txnStatements struct { - selectTxnIDStmt *sql.Stmt -} - -func (s *txnStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(txnIDSchema) - if err != nil { - return - } - - if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil { - return - } - - return -} - -// selectTxnID selects the latest ascending transaction ID -func (s *txnStatements) selectTxnID( - ctx context.Context, -) (txnID int, err error) { - err = s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID) - return -} diff --git a/appservice/storage/sqlite3/appservice_events_table.go b/appservice/storage/sqlite3/appservice_events_table.go deleted file mode 100644 index 34b4859ea..000000000 --- a/appservice/storage/sqlite3/appservice_events_table.go +++ /dev/null @@ -1,267 +0,0 @@ -// Copyright 2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "encoding/json" - "time" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" -) - -const appserviceEventsSchema = ` --- Stores events to be sent to application services -CREATE TABLE IF NOT EXISTS appservice_events ( - -- An auto-incrementing id unique to each event in the table - id INTEGER PRIMARY KEY AUTOINCREMENT, - -- The ID of the application service the event will be sent to - as_id TEXT NOT NULL, - -- JSON representation of the event - headered_event_json TEXT NOT NULL, - -- The ID of the transaction that this event is a part of - txn_id INTEGER NOT NULL -); - -CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id); -` - -const selectEventsByApplicationServiceIDSQL = "" + - "SELECT id, headered_event_json, txn_id " + - "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" - -const countEventsByApplicationServiceIDSQL = "" + - "SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" - -const insertEventSQL = "" + - "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " + - "VALUES ($1, $2, $3)" - -const updateTxnIDForEventsSQL = "" + - "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3" - -const deleteEventsBeforeAndIncludingIDSQL = "" + - "DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2" - -const ( - // A transaction ID number that no transaction should ever have. Used for - // checking again the default value. - invalidTxnID = -2 -) - -type eventsStatements struct { - db *sql.DB - writer sqlutil.Writer - selectEventsByApplicationServiceIDStmt *sql.Stmt - countEventsByApplicationServiceIDStmt *sql.Stmt - insertEventStmt *sql.Stmt - updateTxnIDForEventsStmt *sql.Stmt - deleteEventsBeforeAndIncludingIDStmt *sql.Stmt -} - -func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { - s.db = db - s.writer = writer - _, err = db.Exec(appserviceEventsSchema) - if err != nil { - return - } - - if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil { - return - } - if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil { - return - } - if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { - return - } - if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil { - return - } - if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil { - return - } - - return -} - -// selectEventsByApplicationServiceID takes in an application service ID and -// returns a slice of events that need to be sent to that application service, -// as well as an int later used to remove these same events from the database -// once successfully sent to an application service. -func (s *eventsStatements) selectEventsByApplicationServiceID( - ctx context.Context, - applicationServiceID string, - limit int, -) ( - txnID, maxID int, - events []gomatrixserverlib.HeaderedEvent, - eventsRemaining bool, - err error, -) { - defer func() { - if err != nil { - log.WithFields(log.Fields{ - "appservice": applicationServiceID, - }).WithError(err).Fatalf("appservice unable to select new events to send") - } - }() - // Retrieve events from the database. Unsuccessfully sent events first - eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID) - if err != nil { - return - } - defer checkNamedErr(eventRows.Close, &err) - events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) - if err != nil { - return - } - - return -} - -// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil -func checkNamedErr(fn func() error, err *error) { - if e := fn(); e != nil && *err == nil { - *err = e - } -} - -func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { - // Get current time for use in calculating event age - nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - - // Iterate through each row and store event contents - // If txn_id changes dramatically, we've switched from collecting old events to - // new ones. Send back those events first. - lastTxnID := invalidTxnID - for eventsProcessed := 0; eventRows.Next(); { - var event gomatrixserverlib.HeaderedEvent - var eventJSON []byte - var id int - err = eventRows.Scan( - &id, - &eventJSON, - &txnID, - ) - if err != nil { - return nil, 0, 0, false, err - } - - // Unmarshal eventJSON - if err = json.Unmarshal(eventJSON, &event); err != nil { - return nil, 0, 0, false, err - } - - // If txnID has changed on this event from the previous event, then we've - // reached the end of a transaction's events. Return only those events. - if lastTxnID > invalidTxnID && lastTxnID != txnID { - return events, maxID, lastTxnID, true, nil - } - lastTxnID = txnID - - // Limit events that aren't part of an old transaction - if txnID == -1 { - // Return if we've hit the limit - if eventsProcessed++; eventsProcessed > limit { - return events, maxID, lastTxnID, true, nil - } - } - - if id > maxID { - maxID = id - } - - // Portion of the event that is unsigned due to rapid change - // TODO: Consider removing age as not many app services use it - if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil { - return nil, 0, 0, false, err - } - - events = append(events, event) - } - - return -} - -// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service -// IDs into the db. -func (s *eventsStatements) countEventsByApplicationServiceID( - ctx context.Context, - appServiceID string, -) (int, error) { - var count int - err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count) - if err != nil && err != sql.ErrNoRows { - return 0, err - } - - return count, nil -} - -// insertEvent inserts an event mapped to its corresponding application service -// IDs into the db. -func (s *eventsStatements) insertEvent( - ctx context.Context, - appServiceID string, - event *gomatrixserverlib.HeaderedEvent, -) (err error) { - // Convert event to JSON before inserting - eventJSON, err := json.Marshal(event) - if err != nil { - return err - } - - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err := s.insertEventStmt.ExecContext( - ctx, - appServiceID, - eventJSON, - -1, // No transaction ID yet - ) - return err - }) -} - -// updateTxnIDForEvents sets the transactionID for a collection of events. Done -// before sending them to an AppService. Referenced before sending to make sure -// we aren't constructing multiple transactions with the same events. -func (s *eventsStatements) updateTxnIDForEvents( - ctx context.Context, - appserviceID string, - maxID, txnID int, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID) - return err - }) -} - -// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database. -func (s *eventsStatements) deleteEventsBeforeAndIncludingID( - ctx context.Context, - appserviceID string, - eventTableID int, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err := s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID) - return err - }) -} diff --git a/appservice/storage/sqlite3/storage.go b/appservice/storage/sqlite3/storage.go deleted file mode 100644 index ad62b3628..000000000 --- a/appservice/storage/sqlite3/storage.go +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - - // Import SQLite database driver - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" -) - -// Database stores events intended to be later sent to application services -type Database struct { - events eventsStatements - txnID txnStatements - db *sql.DB - writer sqlutil.Writer -} - -// NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*Database, error) { - var result Database - var err error - if result.db, result.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { - return nil, err - } - if err = result.prepare(); err != nil { - return nil, err - } - return &result, nil -} - -func (d *Database) prepare() error { - if err := d.events.prepare(d.db, d.writer); err != nil { - return err - } - - return d.txnID.prepare(d.db, d.writer) -} - -// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database -// for a transaction worker to pull and later send to an application service. -func (d *Database) StoreEvent( - ctx context.Context, - appServiceID string, - event *gomatrixserverlib.HeaderedEvent, -) error { - return d.events.insertEvent(ctx, appServiceID, event) -} - -// GetEventsWithAppServiceID returns a slice of events and their IDs intended to -// be sent to an application service given its ID. -func (d *Database) GetEventsWithAppServiceID( - ctx context.Context, - appServiceID string, - limit int, -) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) { - return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) -} - -// CountEventsWithAppServiceID returns the number of events destined for an -// application service given its ID. -func (d *Database) CountEventsWithAppServiceID( - ctx context.Context, - appServiceID string, -) (int, error) { - return d.events.countEventsByApplicationServiceID(ctx, appServiceID) -} - -// UpdateTxnIDForEvents takes in an application service ID and a -// and stores them in the DB, unless the pair already exists, in -// which case it updates them. -func (d *Database) UpdateTxnIDForEvents( - ctx context.Context, - appserviceID string, - maxID, txnID int, -) error { - return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID) -} - -// RemoveEventsBeforeAndIncludingID removes all events from the database that -// are less than or equal to a given maximum ID. IDs here are implemented as a -// serial, thus this should always delete events in chronological order. -func (d *Database) RemoveEventsBeforeAndIncludingID( - ctx context.Context, - appserviceID string, - eventTableID int, -) error { - return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID) -} - -// GetLatestTxnID returns the latest available transaction id -func (d *Database) GetLatestTxnID( - ctx context.Context, -) (int, error) { - return d.txnID.selectTxnID(ctx) -} diff --git a/appservice/storage/sqlite3/txn_id_counter_table.go b/appservice/storage/sqlite3/txn_id_counter_table.go deleted file mode 100644 index f2e902f98..000000000 --- a/appservice/storage/sqlite3/txn_id_counter_table.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - - "github.com/matrix-org/dendrite/internal/sqlutil" -) - -const txnIDSchema = ` --- Keeps a count of the current transaction ID -CREATE TABLE IF NOT EXISTS appservice_counters ( - name TEXT PRIMARY KEY NOT NULL, - last_id INTEGER DEFAULT 1 -); -INSERT OR IGNORE INTO appservice_counters (name, last_id) VALUES('txn_id', 1); -` - -const selectTxnIDSQL = ` - SELECT last_id FROM appservice_counters WHERE name='txn_id' -` - -const updateTxnIDSQL = ` - UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id' -` - -type txnStatements struct { - db *sql.DB - writer sqlutil.Writer - selectTxnIDStmt *sql.Stmt - updateTxnIDStmt *sql.Stmt -} - -func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { - s.db = db - s.writer = writer - _, err = db.Exec(txnIDSchema) - if err != nil { - return - } - - if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil { - return - } - - if s.updateTxnIDStmt, err = db.Prepare(updateTxnIDSQL); err != nil { - return - } - - return -} - -// selectTxnID selects the latest ascending transaction ID -func (s *txnStatements) selectTxnID( - ctx context.Context, -) (txnID int, err error) { - err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID) - if err != nil { - return err - } - - _, err = s.updateTxnIDStmt.ExecContext(ctx) - return err - }) - return -} diff --git a/appservice/storage/storage.go b/appservice/storage/storage.go deleted file mode 100644 index 89d5e0cc2..000000000 --- a/appservice/storage/storage.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/appservice/storage/postgres" - "github.com/matrix-org/dendrite/appservice/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets DB connection parameters -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties) - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/appservice/storage/storage_wasm.go b/appservice/storage/storage_wasm.go deleted file mode 100644 index 230254598..000000000 --- a/appservice/storage/storage_wasm.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/appservice/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/appservice/types/types.go b/appservice/types/types.go deleted file mode 100644 index 098face62..000000000 --- a/appservice/types/types.go +++ /dev/null @@ -1,64 +0,0 @@ -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package types - -import ( - "sync" - - "github.com/matrix-org/dendrite/setup/config" -) - -const ( - // AppServiceDeviceID is the AS dummy device ID - AppServiceDeviceID = "AS_Device" -) - -// ApplicationServiceWorkerState is a type that couples an application service, -// a lockable condition as well as some other state variables, allowing the -// roomserver to notify appservice workers when there are events ready to send -// externally to application services. -type ApplicationServiceWorkerState struct { - AppService config.ApplicationService - Cond *sync.Cond - // Events ready to be sent - EventsReady bool - // Backoff exponent (2^x secs). Max 6, aka 64s. - Backoff int -} - -// NotifyNewEvents wakes up all waiting goroutines, notifying that events remain -// in the event queue for this application service worker. -func (a *ApplicationServiceWorkerState) NotifyNewEvents() { - a.Cond.L.Lock() - a.EventsReady = true - a.Cond.Broadcast() - a.Cond.L.Unlock() -} - -// FinishEventProcessing marks all events of this worker as being sent to the -// application service. -func (a *ApplicationServiceWorkerState) FinishEventProcessing() { - a.Cond.L.Lock() - a.EventsReady = false - a.Cond.L.Unlock() -} - -// WaitForNewEvents causes the calling goroutine to wait on the worker state's -// condition for a broadcast or similar wakeup, if there are no events ready. -func (a *ApplicationServiceWorkerState) WaitForNewEvents() { - a.Cond.L.Lock() - if !a.EventsReady { - a.Cond.Wait() - } - a.Cond.L.Unlock() -} diff --git a/appservice/workers/transaction_scheduler.go b/appservice/workers/transaction_scheduler.go deleted file mode 100644 index 4dab00bd7..000000000 --- a/appservice/workers/transaction_scheduler.go +++ /dev/null @@ -1,236 +0,0 @@ -// Copyright 2018 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package workers - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "math" - "net/http" - "net/url" - "time" - - "github.com/matrix-org/dendrite/appservice/storage" - "github.com/matrix-org/dendrite/appservice/types" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" -) - -var ( - // Maximum size of events sent in each transaction. - transactionBatchSize = 50 -) - -// SetupTransactionWorkers spawns a separate goroutine for each application -// service. Each of these "workers" handle taking all events intended for their -// app service, batch them up into a single transaction (up to a max transaction -// size), then send that off to the AS's /transactions/{txnID} endpoint. It also -// handles exponentially backing off in case the AS isn't currently available. -func SetupTransactionWorkers( - client *http.Client, - appserviceDB storage.Database, - workerStates []types.ApplicationServiceWorkerState, -) error { - // Create a worker that handles transmitting events to a single homeserver - for _, workerState := range workerStates { - // Don't create a worker if this AS doesn't want to receive events - if workerState.AppService.URL != "" { - go worker(client, appserviceDB, workerState) - } - } - return nil -} - -// worker is a goroutine that sends any queued events to the application service -// it is given. -func worker(client *http.Client, db storage.Database, ws types.ApplicationServiceWorkerState) { - log.WithFields(log.Fields{ - "appservice": ws.AppService.ID, - }).Info("Starting application service") - ctx := context.Background() - - // Initial check for any leftover events to send from last time - eventCount, err := db.CountEventsWithAppServiceID(ctx, ws.AppService.ID) - if err != nil { - log.WithFields(log.Fields{ - "appservice": ws.AppService.ID, - }).WithError(err).Fatal("appservice worker unable to read queued events from DB") - return - } - if eventCount > 0 { - ws.NotifyNewEvents() - } - - // Loop forever and keep waiting for more events to send - for { - // Wait for more events if we've sent all the events in the database - ws.WaitForNewEvents() - - // Batch events up into a transaction - transactionJSON, txnID, maxEventID, eventsRemaining, err := createTransaction(ctx, db, ws.AppService.ID) - if err != nil { - log.WithFields(log.Fields{ - "appservice": ws.AppService.ID, - }).WithError(err).Fatal("appservice worker unable to create transaction") - - return - } - - // Send the events off to the application service - // Backoff if the application service does not respond - err = send(client, ws.AppService, txnID, transactionJSON) - if err != nil { - log.WithFields(log.Fields{ - "appservice": ws.AppService.ID, - }).WithError(err).Error("unable to send event") - // Backoff - backoff(&ws, err) - continue - } - - // We sent successfully, hooray! - ws.Backoff = 0 - - // Transactions have a maximum event size, so there may still be some events - // left over to send. Keep sending until none are left - if !eventsRemaining { - ws.FinishEventProcessing() - } - - // Remove sent events from the DB - err = db.RemoveEventsBeforeAndIncludingID(ctx, ws.AppService.ID, maxEventID) - if err != nil { - log.WithFields(log.Fields{ - "appservice": ws.AppService.ID, - }).WithError(err).Fatal("unable to remove appservice events from the database") - return - } - } -} - -// backoff pauses the calling goroutine for a 2^some backoff exponent seconds -func backoff(ws *types.ApplicationServiceWorkerState, err error) { - // Calculate how long to backoff for - backoffDuration := time.Duration(math.Pow(2, float64(ws.Backoff))) - backoffSeconds := time.Second * backoffDuration - - log.WithFields(log.Fields{ - "appservice": ws.AppService.ID, - }).WithError(err).Warnf("unable to send transactions successfully, backing off for %ds", - backoffDuration) - - ws.Backoff++ - if ws.Backoff > 6 { - ws.Backoff = 6 - } - - // Backoff - time.Sleep(backoffSeconds) -} - -// createTransaction takes in a slice of AS events, stores them in an AS -// transaction, and JSON-encodes the results. -func createTransaction( - ctx context.Context, - db storage.Database, - appserviceID string, -) ( - transactionJSON []byte, - txnID, maxID int, - eventsRemaining bool, - err error, -) { - // Retrieve the latest events from the DB (will return old events if they weren't successfully sent) - txnID, maxID, events, eventsRemaining, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize) - if err != nil { - log.WithFields(log.Fields{ - "appservice": appserviceID, - }).WithError(err).Fatalf("appservice worker unable to read queued events from DB") - - return - } - - // Check if these events do not already have a transaction ID - if txnID == -1 { - // If not, grab next available ID from the DB - txnID, err = db.GetLatestTxnID(ctx) - if err != nil { - return nil, 0, 0, false, err - } - - // Mark new events with current transactionID - if err = db.UpdateTxnIDForEvents(ctx, appserviceID, maxID, txnID); err != nil { - return nil, 0, 0, false, err - } - } - - var ev []*gomatrixserverlib.HeaderedEvent - for i := range events { - ev = append(ev, &events[i]) - } - - // Create a transaction and store the events inside - transaction := gomatrixserverlib.ApplicationServiceTransaction{ - Events: gomatrixserverlib.HeaderedToClientEvents(ev, gomatrixserverlib.FormatAll), - } - - transactionJSON, err = json.Marshal(transaction) - if err != nil { - return - } - - return -} - -// send sends events to an application service. Returns an error if an OK was not -// received back from the application service or the request timed out. -func send( - client *http.Client, - appservice config.ApplicationService, - txnID int, - transaction []byte, -) (err error) { - // PUT a transaction to our AS - // https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid - address := fmt.Sprintf("%s/transactions/%d?access_token=%s", appservice.URL, txnID, url.QueryEscape(appservice.HSToken)) - req, err := http.NewRequest("PUT", address, bytes.NewBuffer(transaction)) - if err != nil { - return err - } - req.Header.Set("Content-Type", "application/json") - resp, err := client.Do(req) - if err != nil { - return err - } - defer checkNamedErr(resp.Body.Close, &err) - - // Check the AS received the events correctly - if resp.StatusCode != http.StatusOK { - // TODO: Handle non-200 error codes from application services - return fmt.Errorf("non-OK status code %d returned from AS", resp.StatusCode) - } - - return nil -} - -// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil -func checkNamedErr(fn func() error, err *error) { - if e := fn(); e != nil && *err == nil { - *err = e - } -} diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index f3895ae23..60e817212 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -255,7 +255,6 @@ func (m *DendriteMonolith) Start() { cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-federationsender.db", m.StorageDirectory, prefix)) - cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-appservice.db", m.StorageDirectory, prefix)) cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 99b180c81..4432a3a45 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -94,7 +94,6 @@ func (m *DendriteMonolith) Start() { cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-keyserver.db", m.StorageDirectory)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-federationsender.db", m.StorageDirectory)) - cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-appservice.db", m.StorageDirectory)) cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.ClientAPI.RegistrationDisabled = false diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 56877051b..14b28498b 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -1,3 +1,5 @@ +#syntax=docker/dockerfile:1.2 + FROM golang:1.18-stretch as build RUN apt-get update && apt-get install -y sqlite3 WORKDIR /build @@ -8,14 +10,12 @@ RUN mkdir /dendrite # Utilise Docker caching when downloading dependencies, this stops us needlessly # downloading dependencies every time. -COPY go.mod . -COPY go.sum . -RUN go mod download - -COPY . . -RUN go build -o /dendrite ./cmd/dendrite-monolith-server -RUN go build -o /dendrite ./cmd/generate-keys -RUN go build -o /dendrite ./cmd/generate-config +RUN --mount=target=. \ + --mount=type=cache,target=/go/pkg/mod \ + --mount=type=cache,target=/root/.cache/go-build \ + go build -o /dendrite ./cmd/generate-config && \ + go build -o /dendrite ./cmd/generate-keys && \ + go build -o /dendrite ./cmd/dendrite-monolith-server WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem @@ -26,7 +26,7 @@ EXPOSE 8008 8448 # At runtime, generate TLS cert based on the CA now mounted at /ca # At runtime, replace the SERVER_NAME with what we are told -CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ +CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} + exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile index 704359a28..3a019fc20 100644 --- a/build/scripts/ComplementLocal.Dockerfile +++ b/build/scripts/ComplementLocal.Dockerfile @@ -1,3 +1,5 @@ +#syntax=docker/dockerfile:1.2 + # A local development Complement dockerfile, to be used with host mounts # /cache -> Contains the entire dendrite code at Dockerfile build time. Builds binaries but only keeps the generate-* ones. Pre-compilation saves time. # /dendrite -> Host-mounted sources @@ -9,11 +11,10 @@ FROM golang:1.18-stretch RUN apt-get update && apt-get install -y sqlite3 -WORKDIR /runtime - ENV SERVER_NAME=localhost EXPOSE 8008 8448 +WORKDIR /runtime # This script compiles Dendrite for us. RUN echo '\ #!/bin/bash -eux \n\ @@ -29,25 +30,23 @@ RUN echo '\ RUN echo '\ #!/bin/bash -eu \n\ ./generate-keys --private-key matrix_key.pem \n\ - ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ + ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ - ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ + exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ ' > run.sh && chmod +x run.sh WORKDIR /cache -# Pre-download deps; we don't need to do this if the GOPATH is mounted. -COPY go.mod . -COPY go.sum . -RUN go mod download - # Build the monolith in /cache - we won't actually use this but will rely on build artifacts to speed # up the real compilation. Build the generate-* binaries in the true /runtime locations. # If the generate-* source is changed, this dockerfile needs re-running. -COPY . . -RUN go build ./cmd/dendrite-monolith-server && go build -o /runtime ./cmd/generate-keys && go build -o /runtime ./cmd/generate-config +RUN --mount=target=. \ + --mount=type=cache,target=/go/pkg/mod \ + --mount=type=cache,target=/root/.cache/go-build \ + go build -o /runtime ./cmd/generate-config && \ + go build -o /runtime ./cmd/generate-keys WORKDIR /runtime -CMD /runtime/compile.sh && /runtime/run.sh +CMD /runtime/compile.sh && exec /runtime/run.sh diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index a8b4fbb1d..699540120 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -1,3 +1,5 @@ +#syntax=docker/dockerfile:1.2 + FROM golang:1.18-stretch as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build @@ -26,14 +28,12 @@ RUN mkdir /dendrite # Utilise Docker caching when downloading dependencies, this stops us needlessly # downloading dependencies every time. -COPY go.mod . -COPY go.sum . -RUN go mod download - -COPY . . -RUN go build -o /dendrite ./cmd/dendrite-monolith-server -RUN go build -o /dendrite ./cmd/generate-keys -RUN go build -o /dendrite ./cmd/generate-config +RUN --mount=target=. \ + --mount=type=cache,target=/go/pkg/mod \ + --mount=type=cache,target=/root/.cache/go-build \ + go build -o /dendrite ./cmd/generate-config && \ + go build -o /dendrite ./cmd/generate-keys && \ + go build -o /dendrite ./cmd/dendrite-monolith-server WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem @@ -45,10 +45,10 @@ EXPOSE 8008 8448 # At runtime, generate TLS cert based on the CA now mounted at /ca # At runtime, replace the SERVER_NAME with what we are told -CMD /build/run_postgres.sh && ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ +CMD /build/run_postgres.sh && ./generate-keys --keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ # Replace the connection string with a single postgres DB, using user/db = 'postgres' and no password, bump max_conns sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml && \ sed -i 's/max_open_conns:.*$/max_open_conns: 100/g' dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} \ No newline at end of file + exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} \ No newline at end of file diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index af0329a48..0bda1e488 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -276,19 +276,19 @@ type recaptchaResponse struct { } // validateUsername returns an error response if the username is invalid -func validateUsername(username string) *util.JSONResponse { +func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if len(username) > maxUsernameLength { + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("'username' >%d characters", maxUsernameLength)), + JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), } - } else if !validUsernameRegex.MatchString(username) { + } else if !validUsernameRegex.MatchString(localpart) { return &util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), } - } else if username[0] == '_' { // Regex checks its not a zero length string + } else if localpart[0] == '_' { // Regex checks its not a zero length string return &util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"), @@ -298,13 +298,13 @@ func validateUsername(username string) *util.JSONResponse { } // validateApplicationServiceUsername returns an error response if the username is invalid for an application service -func validateApplicationServiceUsername(username string) *util.JSONResponse { - if len(username) > maxUsernameLength { +func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("'username' >%d characters", maxUsernameLength)), + JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), } - } else if !validUsernameRegex.MatchString(username) { + } else if !validUsernameRegex.MatchString(localpart) { return &util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), @@ -523,7 +523,7 @@ func validateApplicationService( } // Check username application service is trying to register is valid - if err := validateApplicationServiceUsername(username); err != nil { + if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { return "", err } @@ -604,7 +604,7 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr := validateApplicationServiceUsername(r.Username); resErr != nil { + if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { return *resErr } case accessTokenErr == nil: @@ -617,7 +617,7 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr := validateUsername(r.Username); resErr != nil { + if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { return *resErr } } @@ -1018,7 +1018,7 @@ func RegisterAvailable( // Squash username to all lowercase letters username = strings.ToLower(username) - if err := validateUsername(username); err != nil { + if err := validateUsername(username, cfg.Matrix.ServerName); err != nil { return *err } @@ -1059,7 +1059,7 @@ func RegisterAvailable( } } -func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSecretRegistration, req *http.Request) util.JSONResponse { +func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.ClientUserAPI, sr *SharedSecretRegistration, req *http.Request) util.JSONResponse { ssrr, err := NewSharedSecretRegistrationRequest(req.Body) if err != nil { return util.JSONResponse{ @@ -1080,7 +1080,7 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec // downcase capitals ssrr.User = strings.ToLower(ssrr.User) - if resErr := validateUsername(ssrr.User); resErr != nil { + if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { return *resErr } if resErr := validatePassword(ssrr.Password); resErr != nil { diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 6904a2b34..d7a48d228 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -133,7 +133,7 @@ func Setup( } } if req.Method == http.MethodPost { - return handleSharedSecretRegistration(userAPI, sr, req) + return handleSharedSecretRegistration(cfg, userAPI, sr, req) } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index bd053f2f7..a9357f6db 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -66,10 +66,11 @@ var ( resetPassword = flag.Bool("reset-password", false, "Deprecated") serverURL = flag.String("url", "https://localhost:8448", "The URL to connect to.") validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) + timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") ) var cl = http.Client{ - Timeout: time.Second * 10, + Timeout: time.Second * 30, Transport: http.DefaultTransport, } @@ -108,6 +109,8 @@ func main() { logrus.Fatalln(err) } + cl.Timeout = *timeout + accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin) if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) @@ -124,8 +127,8 @@ type sharedSecretRegistrationRequest struct { Admin bool `json:"admin"` } -func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, admin bool) (accesToken string, err error) { - registerURL := fmt.Sprintf("%s/_synapse/admin/v1/register", serverURL) +func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, admin bool) (accessToken string, err error) { + registerURL := fmt.Sprintf("%s/_synapse/admin/v1/register", strings.Trim(serverURL, "/")) nonceReq, err := http.NewRequest(http.MethodGet, registerURL, nil) if err != nil { return "", fmt.Errorf("unable to create http request: %w", err) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 75f29fe27..52c9dc8eb 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -24,6 +24,7 @@ import ( "net" "net/http" "os" + "strings" "time" "github.com/gorilla/mux" @@ -42,6 +43,7 @@ import ( "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" @@ -70,31 +72,83 @@ func main() { var pk ed25519.PublicKey var sk ed25519.PrivateKey - keyfile := *instanceName + ".key" - if _, err := os.Stat(keyfile); os.IsNotExist(err) { - if pk, sk, err = ed25519.GenerateKey(nil); err != nil { - panic(err) + // iterate through the cli args and check if the config flag was set + configFlagSet := false + for _, arg := range os.Args { + if arg == "--config" || arg == "-config" { + configFlagSet = true + break } - if err = os.WriteFile(keyfile, sk, 0644); err != nil { - panic(err) - } - } else if err == nil { - if sk, err = os.ReadFile(keyfile); err != nil { - panic(err) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - pk = sk.Public().(ed25519.PublicKey) } + cfg := &config.Dendrite{} + + // use custom config if config flag is set + if configFlagSet { + cfg = setup.ParseFlags(true) + sk = cfg.Global.PrivateKey + } else { + keyfile := *instanceName + ".pem" + if _, err := os.Stat(keyfile); os.IsNotExist(err) { + oldkeyfile := *instanceName + ".key" + if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) { + if err = test.NewMatrixKey(keyfile); err != nil { + panic("failed to generate a new PEM key: " + err.Error()) + } + if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { + panic("failed to load PEM key: " + err.Error()) + } + } else { + if sk, err = os.ReadFile(oldkeyfile); err != nil { + panic("failed to read the old private key: " + err.Error()) + } + if len(sk) != ed25519.PrivateKeySize { + panic("the private key is not long enough") + } + if err := test.SaveMatrixKey(keyfile, sk); err != nil { + panic("failed to convert the private key to PEM format: " + err.Error()) + } + } + } else { + var err error + if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { + panic("failed to load PEM key: " + err.Error()) + } + } + cfg.Defaults(true) + cfg.Global.PrivateKey = sk + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) + cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) + cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) + cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) + cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName)) + cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName)) + cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} + cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName)) + cfg.ClientAPI.RegistrationDisabled = false + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true + if err := cfg.Derive(); err != nil { + panic(err) + } + } + + pk = sk.Public().(ed25519.PublicKey) + cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) + + base := base.NewBaseDendrite(cfg, "Monolith") + defer base.Close() // nolint: errcheck + pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) pManager := pineconeConnections.NewConnectionManager(pRouter, nil) pMulticast.Start() if instancePeer != nil && *instancePeer != "" { - pManager.AddPeer(*instancePeer) + for _, peer := range strings.Split(*instancePeer, ",") { + pManager.AddPeer(strings.Trim(peer, " \t\r\n")) + } } go func() { @@ -125,29 +179,6 @@ func main() { } }() - cfg := &config.Dendrite{} - cfg.Defaults(true) - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - cfg.Global.PrivateKey = sk - cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) - cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) - cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName)) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName)) - cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) - cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} - cfg.ClientAPI.RegistrationDisabled = false - cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true - if err := cfg.Derive(); err != nil { - panic(err) - } - - base := base.NewBaseDendrite(cfg, "Monolith") - defer base.Close() // nolint: errcheck - federation := conn.CreateFederationClient(base, pQUIC) serverKeyAPI := &signing.YggdrasilKeys{} diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 619720d6c..086baa264 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -86,7 +86,6 @@ func main() { cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName)) - cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) cfg.MSCs.MSCs = []string{"msc2836"} cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName)) cfg.ClientAPI.RegistrationDisabled = false diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 1c585d916..6ae3ff9c9 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -24,7 +24,6 @@ func main() { cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName) } if *dbURI != "" { - cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(*dbURI) cfg.FederationAPI.Database.ConnectionString = config.DataSource(*dbURI) cfg.KeyServer.Database.ConnectionString = config.DataSource(*dbURI) cfg.MSCs.Database.ConnectionString = config.DataSource(*dbURI) diff --git a/cmd/generate-keys/main.go b/cmd/generate-keys/main.go index 8acd28be0..d4c8cf78a 100644 --- a/cmd/generate-keys/main.go +++ b/cmd/generate-keys/main.go @@ -38,6 +38,7 @@ var ( authorityCertFile = flag.String("tls-authority-cert", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.") authorityKeyFile = flag.String("tls-authority-key", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.") serverName = flag.String("server", "", "Optional: Create TLS certificate/keys with this domain name set. Useful for integration testing.") + keySize = flag.Int("keysize", 4096, "Optional: Create TLS RSA private key with the given key size") ) func main() { @@ -58,12 +59,12 @@ func main() { log.Fatal("Zero or both of --tls-key and --tls-cert must be supplied") } if *authorityCertFile == "" && *authorityKeyFile == "" { - if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil { + if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile, *keySize); err != nil { panic(err) } } else { // generate the TLS cert/key based on the authority given. - if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile); err != nil { + if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile, *keySize); err != nil { panic(err) } } diff --git a/dendrite-sample.polylith.yaml b/dendrite-sample.polylith.yaml index 856b4ab22..3caf91434 100644 --- a/dendrite-sample.polylith.yaml +++ b/dendrite-sample.polylith.yaml @@ -132,13 +132,6 @@ app_service_api: listen: http://[::]:7777 # The listen address for incoming API requests connect: http://app_service_api:7777 # The connect address for other components to use - # Database configuration for this component. - database: - connection_string: postgresql://username:password@hostname/dendrite_appservice?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 - # Disable the validation of TLS certificates of appservices. This is # not recommended in production since it may allow appservice traffic # to be sent to an insecure endpoint. diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 6d3cf0e46..f3314bc98 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -67,14 +67,15 @@ func NewKeyChangeConsumer( // Start consuming from key servers func (t *KeyChangeConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, - nats.DeliverAll(), nats.ManualAck(), + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *KeyChangeConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var m api.DeviceMessage if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index a65d2aa04..e76103cd3 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -69,14 +69,15 @@ func (t *OutputPresenceConsumer) Start() error { return nil } return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } // onMessage is called in response to a message received on the presence // events topic from the client api. -func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // only send presence events which originated from us userID := msg.Header.Get(jetstream.UserID) _, serverName, err := gomatrixserverlib.SplitID('@', userID) diff --git a/federationapi/consumers/receipts.go b/federationapi/consumers/receipts.go index 2c9d79bcb..366cb264e 100644 --- a/federationapi/consumers/receipts.go +++ b/federationapi/consumers/receipts.go @@ -65,14 +65,15 @@ func NewOutputReceiptConsumer( // Start consuming from the clientapi func (t *OutputReceiptConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } // onMessage is called in response to a message received on the receipt // events topic from the client api. -func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called receipt := syncTypes.OutputReceiptEvent{ UserID: msg.Header.Get(jetstream.UserID), RoomID: msg.Header.Get(jetstream.RoomID), diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 2622ecb3f..349b50b05 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -68,8 +68,8 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } @@ -77,7 +77,8 @@ func (s *OutputRoomEventConsumer) Start() error { // It is unsafe to call this with messages for the same room in multiple gorountines // because updates it will likely fail with a types.EventIDMismatchError when it // realises that it cannot update the room state using the deltas. -func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON var output api.OutputEvent if err := json.Unmarshal(msg.Data, &output); err != nil { diff --git a/federationapi/consumers/sendtodevice.go b/federationapi/consumers/sendtodevice.go index f99a895e0..e44bad723 100644 --- a/federationapi/consumers/sendtodevice.go +++ b/federationapi/consumers/sendtodevice.go @@ -63,14 +63,15 @@ func NewOutputSendToDeviceConsumer( // Start consuming from the client api func (t *OutputSendToDeviceConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, - nats.DeliverAll(), nats.ManualAck(), + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called in response to a message received on the // send-to-device events topic from the client api. -func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // only send send-to-device events which originated from us sender := msg.Header.Get("sender") _, originServerName, err := gomatrixserverlib.SplitID('@', sender) diff --git a/federationapi/consumers/typing.go b/federationapi/consumers/typing.go index 428e1a867..9c7379136 100644 --- a/federationapi/consumers/typing.go +++ b/federationapi/consumers/typing.go @@ -62,14 +62,15 @@ func NewOutputTypingConsumer( // Start consuming from the clientapi func (t *OutputTypingConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } // onMessage is called in response to a message received on the typing // events topic from the client api. -func (t *OutputTypingConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Extract the typing event from msg. roomID := msg.Header.Get(jetstream.RoomID) userID := msg.Header.Get(jetstream.UserID) diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index b48eaf78e..1a1219873 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -329,6 +329,12 @@ func SendJoin( JSON: jsonerror.NotFound("Room does not exist"), } } + if !stateAndAuthChainResponse.StateKnown { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("State not known"), + } + } // Check if the user is already in the room. If they're already in then // there isn't much point in sending another join event into the room. diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 6fdce20ce..5377eb88f 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -135,6 +135,12 @@ func getState( return nil, nil, &resErr } + if !response.StateKnown { + return nil, nil, &util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("State not known"), + } + } if response.IsRejected { return nil, nil, &util.JSONResponse{ Code: http.StatusNotFound, diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 8bac99cbc..b8e16a259 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -5,10 +5,11 @@ import ( "sync" "time" - "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "go.uber.org/atomic" + + "github.com/matrix-org/dendrite/federationapi/storage" ) // Statistics contains information about all of the remote federated @@ -126,13 +127,13 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { go func() { until, ok := s.backoffUntil.Load().(time.Time) - if ok { + if ok && !until.IsZero() { select { case <-time.After(time.Until(until)): case <-s.interrupt: } + s.backoffStarted.Store(false) } - s.backoffStarted.Store(false) }() } diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index b62e5d9c5..ce9632ed3 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -110,6 +110,7 @@ func (d *Database) GetPendingEDUs( return fmt.Errorf("json.Unmarshal: %w", err) } edus[&Receipt{nid}] = &event + d.Cache.StoreFederationQueuedEDU(nid, &event) } return nil @@ -177,20 +178,18 @@ func (d *Database) GetPendingEDUServerNames( return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil) } -// DeleteExpiredEDUs deletes expired EDUs +// DeleteExpiredEDUs deletes expired EDUs and evicts them from the cache. func (d *Database) DeleteExpiredEDUs(ctx context.Context) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var jsonNIDs []int64 + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) (err error) { expiredBefore := gomatrixserverlib.AsTimestamp(time.Now()) - jsonNIDs, err := d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore) + jsonNIDs, err = d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore) if err != nil { return err } if len(jsonNIDs) == 0 { return nil } - for i := range jsonNIDs { - d.Cache.EvictFederationQueuedEDU(jsonNIDs[i]) - } if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil { return err @@ -198,4 +197,14 @@ func (d *Database) DeleteExpiredEDUs(ctx context.Context) error { return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore) }) + + if err != nil { + return err + } + + for i := range jsonNIDs { + d.Cache.EvictFederationQueuedEDU(jsonNIDs[i]) + } + + return nil } diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 7eba2cbee..3b0268e55 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -31,7 +31,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat func TestExpireEDUs(t *testing.T) { var expireEDUTypes = map[string]time.Duration{ - gomatrixserverlib.MReceipt: time.Millisecond, + gomatrixserverlib.MReceipt: 0, } ctx := context.Background() diff --git a/go.mod b/go.mod index fe4604a03..8bf8f454d 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c + github.com/matrix-org/gomatrixserverlib v0.0.0-20220830164018-c71e518537a2 github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.13 diff --git a/go.sum b/go.sum index 7ac9fc6e2..8b8baabcf 100644 --- a/go.sum +++ b/go.sum @@ -343,8 +343,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c h1:GhKmb8s9iXA9qsFD1SbiRo6Ee7cnbfcgJQ/iy43wczM= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220830164018-c71e518537a2 h1:esbNn9hg//tAStA6TogatAJAursw23A+yfVRQsdiv70= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220830164018-c71e518537a2/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY= github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/internal/version.go b/internal/version.go index 561e6c06f..108b8ab0f 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 9 - VersionPatch = 3 + VersionPatch = 5 VersionTag = "" // example: "rc1" ) diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go index f4f246280..d15f94267 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/keyserver/consumers/devicelistupdate.go @@ -55,14 +55,15 @@ func NewDeviceListUpdateConsumer( // Start consuming from key servers func (t *DeviceListUpdateConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, - nats.DeliverAll(), nats.ManualAck(), + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var m gomatrixserverlib.DeviceListUpdateEvent if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("Failed to read from device list update input topic") diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 80efbec51..304b67b23 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -335,8 +335,9 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { retriesMu := &sync.Mutex{} // restarter goroutine which will inject failed servers into ch when it is time go func() { + var serversToRetry []gomatrixserverlib.ServerName for { - var serversToRetry []gomatrixserverlib.ServerName + serversToRetry = serversToRetry[:0] // reuse memory time.Sleep(time.Second) retriesMu.Lock() now := time.Now() @@ -355,11 +356,17 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { } }() for serverName := range ch { + retriesMu.Lock() + _, exists := retries[serverName] + retriesMu.Unlock() + if exists { + // Don't retry a server that we're already waiting for. + continue + } waitTime, shouldRetry := u.processServer(serverName) if shouldRetry { retriesMu.Lock() - _, exists := retries[serverName] - if !exists { + if _, exists = retries[serverName]; !exists { retries[serverName] = time.Now().Add(waitTime) } retriesMu.Unlock() diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index cd506f981..75cdaedb4 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -60,7 +60,7 @@ func NewInternalAPI( updater := internal.NewDeviceListUpdater(db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable ap.Updater = updater go func() { - if err = updater.Start(); err != nil { + if err := updater.Start(); err != nil { logrus.WithError(err).Panicf("failed to start device list updater") } }() @@ -68,7 +68,7 @@ func NewInternalAPI( dlConsumer := consumers.NewDeviceListUpdateConsumer( base.ProcessContext, cfg, js, updater, ) - if err = dlConsumer.Start(); err != nil { + if err := dlConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start device list consumer") } diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index d9ea9dd1c..20931f807 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -5,9 +5,10 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" ) type PerformErrorCode int @@ -161,7 +162,8 @@ func (r *PerformBackfillRequest) PrevEventIDs() []string { // PerformBackfillResponse is a response to PerformBackfill. type PerformBackfillResponse struct { // Missing events, arbritrary order. - Events []*gomatrixserverlib.HeaderedEvent `json:"events"` + Events []*gomatrixserverlib.HeaderedEvent `json:"events"` + HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"` } type PerformPublishRequest struct { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index c8e6f9dc6..32d63bb51 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -227,6 +227,7 @@ type QueryStateAndAuthChainResponse struct { // Do all the previous events exist on this roomserver? // If some of previous events do not exist this will be false and StateEvents will be empty. PrevEventsExist bool `json:"prev_events_exist"` + StateKnown bool `json:"state_known"` // The state and auth chain events that were requested. // The lists will be in an arbitrary order. StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index bc2f28176..8b031982c 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -19,6 +19,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) // SendEvents to the roomserver The events are written with KindNew. @@ -69,6 +70,13 @@ func SendEventWithState( stateEventIDs[i] = stateEvents[i].EventID() } + logrus.WithContext(ctx).WithFields(logrus.Fields{ + "room_id": event.RoomID(), + "event_id": event.EventID(), + "outliers": len(ires), + "state_ids": len(stateEventIDs), + }).Infof("Submitting %q event to roomserver with state snapshot", event.Type()) + ires = append(ires, InputRoomEvent{ Kind: kind, Event: event, diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 648c50cf6..935a045df 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -39,7 +39,7 @@ func CheckForSoftFail( var authStateEntries []types.StateEntry var err error if rewritesState { - authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs) + authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs, true) if err != nil { return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err) } @@ -97,7 +97,7 @@ func CheckAuthEvents( authEventIDs []string, ) ([]types.EventNID, error) { // Grab the numeric IDs for the supplied auth state events from the database. - authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs) + authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs, true) if err != nil { return nil, fmt.Errorf("db.StateEntriesForEventIDs: %w", err) } diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 6091f8ec2..cbd1561f7 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -254,8 +254,15 @@ func CheckServerAllowedToSeeEvent( return false, err } default: - // Something else went wrong - return false, err + switch err.(type) { + case types.MissingStateError: + // If there's no state then we assume it's open visibility, as Synapse does: + // https://github.com/matrix-org/synapse/blob/aec87a0f9369a3015b2a53469f88d1de274e8b71/synapse/visibility.py#L654-L655 + return true, nil + default: + // Something else went wrong + return false, err + } } return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 8d24f3c59..429cc4bd2 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -36,6 +36,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" @@ -247,14 +248,24 @@ func (w *worker) _next() { // it was a synchronous request. var errString string if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil { - if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { - sentry.CaptureException(err) + switch err.(type) { + case types.RejectedError: + // Don't send events that were rejected to Sentry + logrus.WithError(err).WithFields(logrus.Fields{ + "room_id": w.roomID, + "event_id": inputRoomEvent.Event.EventID(), + "type": inputRoomEvent.Event.Type(), + }).Warn("Roomserver rejected event") + default: + if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { + sentry.CaptureException(err) + } + logrus.WithError(err).WithFields(logrus.Fields{ + "room_id": w.roomID, + "event_id": inputRoomEvent.Event.EventID(), + "type": inputRoomEvent.Event.Type(), + }).Warn("Roomserver failed to process event") } - logrus.WithError(err).WithFields(logrus.Fields{ - "room_id": w.roomID, - "event_id": inputRoomEvent.Event.EventID(), - "type": inputRoomEvent.Event.Type(), - }).Warn("Roomserver failed to process async event") _ = msg.Term() errString = err.Error() } else { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 53ccd5973..29af649ad 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -301,7 +301,7 @@ func (r *Inputer) processRoomEvent( // bother doing this if the event was already rejected as it just ends up // burning CPU time. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. - if rejectionErr == nil && !isRejected && !softfail { + if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected { var err error historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) if err != nil { @@ -313,7 +313,7 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected || softfail) + _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -353,12 +353,18 @@ func (r *Inputer) processRoomEvent( } } - // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. - if isRejected || softfail { - logger.WithError(rejectionErr).WithFields(logrus.Fields{ - "soft_fail": softfail, - "missing_prev": missingPrev, - }).Warn("Stored rejected event") + // We stop here if the event is rejected: We've stored it but won't update + // forward extremities or notify downstream components about it. + switch { + case isRejected: + logger.WithError(rejectionErr).Warn("Stored rejected event") + if rejectionErr != nil { + return types.RejectedError(rejectionErr.Error()) + } + return nil + + case softfail: + logger.WithError(rejectionErr).Warn("Stored soft-failed event") if rejectionErr != nil { return types.RejectedError(rejectionErr.Error()) } @@ -661,7 +667,7 @@ func (r *Inputer) calculateAndSetState( // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs, true); err != nil { return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err) } entries = types.DeduplicateStateEntries(entries) diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 298ba04f6..51c66415a 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -18,7 +18,6 @@ import ( "context" "fmt" - "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -140,11 +139,11 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform continue } var entries []types.StateEntry - if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil { + if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil { // attempt to fetch the missing events r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) // try again - entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs) + entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event") return err @@ -164,6 +163,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point. res.Events = events + res.HistoryVisibility = requester.historyVisiblity return nil } @@ -248,6 +248,7 @@ type backfillRequester struct { servers []gomatrixserverlib.ServerName eventIDToBeforeStateIDs map[string][]string eventIDMap map[string]*gomatrixserverlib.Event + historyVisiblity gomatrixserverlib.HistoryVisibility } func newBackfillRequester( @@ -266,6 +267,7 @@ func newBackfillRequester( eventIDMap: make(map[string]*gomatrixserverlib.Event), bwExtrems: bwExtrems, preferServer: preferServer, + historyVisiblity: gomatrixserverlib.HistoryVisibilityShared, } } @@ -317,7 +319,6 @@ FederationHit: b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res return res, nil } - sentry.CaptureException(lastErr) // temporary to see if we might need to raise the server limit return nil, lastErr } @@ -395,7 +396,6 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr } return result, nil } - sentry.CaptureException(lastErr) // temporary to see if we might need to raise the server limit return nil, lastErr } @@ -447,7 +447,8 @@ FindSuccessor: } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) + b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") return nil @@ -528,7 +529,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, // pull all events and then filter by that table. func joinEventsFromHistoryVisibility( ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry, - thisServer gomatrixserverlib.ServerName) ([]types.Event, error) { + thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { var eventNIDs []types.EventNID for _, entry := range stateEntries { @@ -542,7 +543,9 @@ func joinEventsFromHistoryVisibility( // Get all of the events in this state stateEvents, err := db.Events(ctx, eventNIDs) if err != nil { - return nil, err + // even though the default should be shared, restricting the visibility to joined + // feels more secure here. + return nil, gomatrixserverlib.HistoryVisibilityJoined, err } events := make([]*gomatrixserverlib.Event, len(stateEvents)) for i := range stateEvents { @@ -551,20 +554,22 @@ func joinEventsFromHistoryVisibility( // Can we see events in the room? canSeeEvents := auth.IsServerAllowed(thisServer, true, events) + visibility := gomatrixserverlib.HistoryVisibility(auth.HistoryVisibilityForRoom(events)) if !canSeeEvents { - logrus.Infof("ServersAtEvent history not visible to us: %s", auth.HistoryVisibilityForRoom(events)) - return nil, nil + logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) + return nil, visibility, nil } // get joined members info, err := db.RoomInfo(ctx, roomID) if err != nil { - return nil, err + return nil, visibility, nil } joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) if err != nil { - return nil, err + return nil, visibility, err } - return db.Events(ctx, joinEventNIDs) + evs, err := db.Events(ctx, joinEventNIDs) + return evs, visibility, err } func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index f5d8c2d49..6dce2bc3e 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -503,10 +503,11 @@ func (r *Queryer) QueryStateAndAuthChain( } var stateEvents []*gomatrixserverlib.Event - stateEvents, rejected, err := r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs) + stateEvents, rejected, stateMissing, err := r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs) if err != nil { return err } + response.StateKnown = !stateMissing response.IsRejected = rejected response.PrevEventsExist = true @@ -542,15 +543,18 @@ func (r *Queryer) QueryStateAndAuthChain( return err } -func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, bool, error) { +// first bool: is rejected, second bool: state missing +func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, bool, bool, error) { roomState := state.NewStateResolution(r.DB, roomInfo) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { case types.MissingEventError: - return nil, false, nil + return nil, false, true, nil + case types.MissingStateError: + return nil, false, true, nil default: - return nil, false, err + return nil, false, false, err } } // Currently only used on /state and /state_ids @@ -567,12 +571,11 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI ctx, prevStates, ) if err != nil { - return nil, rejected, err + return nil, rejected, false, err } events, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) - - return events, rejected, err + return events, rejected, false, err } type eventsFromIDs func(context.Context, []string) ([]types.Event, error) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 5c068873a..43e8da7bb 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -79,7 +79,7 @@ type Database interface { // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database // Returns a types.MissingEventError if the event IDs aren't in the database. - StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error) // Look up the string event state keys for a list of numeric event state keys // Returns an error if there was a problem talking to the database. EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c7748d2be..1e7ca7669 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -74,7 +74,7 @@ const insertEventSQL = "" + "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" + - " SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = FALSE" + + " SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = TRUE" + " RETURNING event_nid, state_snapshot_nid" const selectEventSQL = "" + @@ -88,6 +88,14 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id = ANY($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +// Bulk lookup of events by string ID that aren't excluded. +// Sort by the numeric IDs for event type and state key. +// This means we can use binary search to lookup entries by type and state key. +const bulkSelectStateEventByIDExcludingRejectedSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_id = ANY($1) AND is_rejected = FALSE" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + // Bulk look up of events by event NID, optionally filtering by the event type // or event state key NIDs if provided. (The CARDINALITY check will return true // if the provided arrays are empty, ergo no filtering). @@ -140,23 +148,24 @@ const selectEventRejectedSQL = "" + "SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2" type eventStatements struct { - insertEventStmt *sql.Stmt - selectEventStmt *sql.Stmt - bulkSelectStateEventByIDStmt *sql.Stmt - bulkSelectStateEventByNIDStmt *sql.Stmt - bulkSelectStateAtEventByIDStmt *sql.Stmt - updateEventStateStmt *sql.Stmt - selectEventSentToOutputStmt *sql.Stmt - updateEventSentToOutputStmt *sql.Stmt - selectEventIDStmt *sql.Stmt - bulkSelectStateAtEventAndReferenceStmt *sql.Stmt - bulkSelectEventReferenceStmt *sql.Stmt - bulkSelectEventIDStmt *sql.Stmt - bulkSelectEventNIDStmt *sql.Stmt - bulkSelectUnsentEventNIDStmt *sql.Stmt - selectMaxEventDepthStmt *sql.Stmt - selectRoomNIDsForEventNIDsStmt *sql.Stmt - selectEventRejectedStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventStmt *sql.Stmt + bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt + bulkSelectStateEventByNIDStmt *sql.Stmt + bulkSelectStateAtEventByIDStmt *sql.Stmt + updateEventStateStmt *sql.Stmt + selectEventSentToOutputStmt *sql.Stmt + updateEventSentToOutputStmt *sql.Stmt + selectEventIDStmt *sql.Stmt + bulkSelectStateAtEventAndReferenceStmt *sql.Stmt + bulkSelectEventReferenceStmt *sql.Stmt + bulkSelectEventIDStmt *sql.Stmt + bulkSelectEventNIDStmt *sql.Stmt + bulkSelectUnsentEventNIDStmt *sql.Stmt + selectMaxEventDepthStmt *sql.Stmt + selectRoomNIDsForEventNIDsStmt *sql.Stmt + selectEventRejectedStmt *sql.Stmt } func CreateEventsTable(db *sql.DB) error { @@ -171,6 +180,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL}, {&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.updateEventStateStmt, updateEventStateSQL}, @@ -221,11 +231,18 @@ func (s *eventStatements) SelectEvent( } // bulkSelectStateEventByID lookups a list of state events by event ID. -// If any of the requested events are missing from the database it returns a types.MissingEventError +// If not excluding rejected events, and any of the requested events are missing from +// the database it returns a types.MissingEventError. If excluding rejected events, +// the events will be silently omitted without error. func (s *eventStatements) BulkSelectStateEventByID( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool, ) ([]types.StateEntry, error) { - stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt) + var stmt *sql.Stmt + if excludeRejected { + stmt = sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDExcludingRejectedStmt) + } else { + stmt = sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt) + } rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err @@ -235,10 +252,10 @@ func (s *eventStatements) BulkSelectStateEventByID( // because of the unique constraint on event IDs. // So we can allocate an array of the correct size now. // We might get fewer results than IDs so we adjust the length of the slice before returning it. - results := make([]types.StateEntry, len(eventIDs)) + results := make([]types.StateEntry, 0, len(eventIDs)) i := 0 for ; rows.Next(); i++ { - result := &results[i] + var result types.StateEntry if err = rows.Scan( &result.EventTypeNID, &result.EventStateKeyNID, @@ -246,11 +263,12 @@ func (s *eventStatements) BulkSelectStateEventByID( ); err != nil { return nil, err } + results = append(results, result) } if err = rows.Err(); err != nil { return nil, err } - if i != len(eventIDs) { + if !excludeRejected && i != len(eventIDs) { // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. // We don't know which ones were missing because we don't return the string IDs in the query. // However it should be possible debug this by replaying queries or entries from the input kafka logs. @@ -328,7 +346,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( // Genuine create events are the only case where it's OK to have no previous state. isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1 if result.BeforeStateSnapshotNID == 0 && !isCreate { - return nil, types.MissingEventError( + return nil, types.MissingStateError( fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), ) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 4f92adf1f..f35592a76 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -113,9 +113,9 @@ func (d *Database) eventStateKeyNIDs( } func (d *Database) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, + ctx context.Context, eventIDs []string, excludeRejected bool, ) ([]types.StateEntry, error) { - return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs) + return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected) } func (d *Database) StateEntriesForTuples( diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 174e3a9a7..950d03b03 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -50,7 +50,7 @@ const insertEventSQL = ` INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT DO UPDATE - SET is_rejected = $8 WHERE is_rejected = 0 + SET is_rejected = $8 WHERE is_rejected = 1 RETURNING event_nid, state_snapshot_nid; ` @@ -65,6 +65,14 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id IN ($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +// Bulk lookup of events by string ID that aren't rejected. +// Sort by the numeric IDs for event type and state key. +// This means we can use binary search to lookup entries by type and state key. +const bulkSelectStateEventByIDExcludingRejectedSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_id IN ($1) AND is_rejected = 0" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + const bulkSelectStateEventByNIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + " WHERE event_nid IN ($1)" @@ -113,19 +121,20 @@ const selectEventRejectedSQL = "" + "SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2" type eventStatements struct { - db *sql.DB - insertEventStmt *sql.Stmt - selectEventStmt *sql.Stmt - bulkSelectStateEventByIDStmt *sql.Stmt - bulkSelectStateAtEventByIDStmt *sql.Stmt - updateEventStateStmt *sql.Stmt - selectEventSentToOutputStmt *sql.Stmt - updateEventSentToOutputStmt *sql.Stmt - selectEventIDStmt *sql.Stmt - bulkSelectStateAtEventAndReferenceStmt *sql.Stmt - bulkSelectEventReferenceStmt *sql.Stmt - bulkSelectEventIDStmt *sql.Stmt - selectEventRejectedStmt *sql.Stmt + db *sql.DB + insertEventStmt *sql.Stmt + selectEventStmt *sql.Stmt + bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt + bulkSelectStateAtEventByIDStmt *sql.Stmt + updateEventStateStmt *sql.Stmt + selectEventSentToOutputStmt *sql.Stmt + updateEventSentToOutputStmt *sql.Stmt + selectEventIDStmt *sql.Stmt + bulkSelectStateAtEventAndReferenceStmt *sql.Stmt + bulkSelectEventReferenceStmt *sql.Stmt + bulkSelectEventIDStmt *sql.Stmt + selectEventRejectedStmt *sql.Stmt //bulkSelectEventNIDStmt *sql.Stmt //bulkSelectUnsentEventNIDStmt *sql.Stmt //selectRoomNIDsForEventNIDsStmt *sql.Stmt @@ -145,6 +154,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, @@ -194,16 +204,24 @@ func (s *eventStatements) SelectEvent( } // bulkSelectStateEventByID lookups a list of state events by event ID. -// If any of the requested events are missing from the database it returns a types.MissingEventError +// If not excluding rejected events, and any of the requested events are missing from +// the database it returns a types.MissingEventError. If excluding rejected events, +// the events will be silently omitted without error. func (s *eventStatements) BulkSelectStateEventByID( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool, ) ([]types.StateEntry, error) { /////////////// + var sql string + if excludeRejected { + sql = bulkSelectStateEventByIDExcludingRejectedSQL + } else { + sql = bulkSelectStateEventByIDSQL + } iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + selectOrig := strings.Replace(sql, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err @@ -221,10 +239,10 @@ func (s *eventStatements) BulkSelectStateEventByID( // because of the unique constraint on event IDs. // So we can allocate an array of the correct size now. // We might get fewer results than IDs so we adjust the length of the slice before returning it. - results := make([]types.StateEntry, len(eventIDs)) + results := make([]types.StateEntry, 0, len(eventIDs)) i := 0 for ; rows.Next(); i++ { - result := &results[i] + var result types.StateEntry if err = rows.Scan( &result.EventTypeNID, &result.EventStateKeyNID, @@ -232,8 +250,9 @@ func (s *eventStatements) BulkSelectStateEventByID( ); err != nil { return nil, err } + results = append(results, result) } - if i != len(eventIDs) { + if !excludeRejected && i != len(eventIDs) { // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. // We don't know which ones were missing because we don't return the string IDs in the query. // However it should be possible debug this by replaying queries or entries from the input kafka logs. @@ -343,7 +362,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( // Genuine create events are the only case where it's OK to have no previous state. isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1 if result.BeforeStateSnapshotNID == 0 && !isCreate { - return nil, types.MissingEventError( + return nil, types.MissingStateError( fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), ) } diff --git a/roomserver/storage/tables/events_table_test.go b/roomserver/storage/tables/events_table_test.go index 74502a31f..107af4784 100644 --- a/roomserver/storage/tables/events_table_test.go +++ b/roomserver/storage/tables/events_table_test.go @@ -102,7 +102,7 @@ func Test_EventsTable(t *testing.T) { }) } - stateEvents, err := tab.BulkSelectStateEventByID(ctx, nil, eventIDs) + stateEvents, err := tab.BulkSelectStateEventByID(ctx, nil, eventIDs, false) assert.NoError(t, err) assert.Equal(t, len(stateEvents), len(eventIDs)) nids := make([]types.EventNID, 0, len(stateEvents)) diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index ed67c43d8..68d30f994 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -46,7 +46,7 @@ type Events interface { SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError - BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error) + BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error) BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. diff --git a/setup/base/base.go b/setup/base/base.go index b21eeba47..87f415764 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -25,21 +25,23 @@ import ( _ "net/http/pprof" "os" "os/signal" + "sync" "syscall" "time" "github.com/getsentry/sentry-go" sentryhttp "github.com/getsentry/sentry-go/http" - "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/internal/pushgateway" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus/promhttp" "go.uber.org/atomic" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -47,6 +49,8 @@ import ( "github.com/gorilla/mux" "github.com/kardianos/minwinsvc" + "github.com/sirupsen/logrus" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" asinthttp "github.com/matrix-org/dendrite/appservice/inthttp" federationAPI "github.com/matrix-org/dendrite/federationapi/api" @@ -58,7 +62,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/sirupsen/logrus" ) // BaseDendrite is a base for creating new instances of dendrite. It parses @@ -87,6 +90,7 @@ type BaseDendrite struct { Database *sql.DB DatabaseWriter sqlutil.Writer EnableMetrics bool + startupLock sync.Mutex } const NoListener = "" @@ -394,6 +398,9 @@ func (b *BaseDendrite) SetupAndServeHTTP( internalHTTPAddr, externalHTTPAddr config.HTTPAddress, certFile, keyFile *string, ) { + // Manually unlocked right before actually serving requests, + // as we don't return from this method (defer doesn't work). + b.startupLock.Lock() internalAddr, _ := internalHTTPAddr.Address() externalAddr, _ := externalHTTPAddr.Address() @@ -472,6 +479,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(b.PublicMediaAPIMux) externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(b.PublicWellKnownAPIMux) + b.startupLock.Unlock() if internalAddr != NoListener && internalAddr != externalAddr { go func() { var internalShutdown atomic.Bool // RegisterOnShutdown can be called more than once diff --git a/setup/config/config.go b/setup/config/config.go index 924b51f22..cc9c04470 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -224,12 +224,7 @@ func loadConfig( } privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath) - privateKeyData, err := readFile(privateKeyPath) - if err != nil { - return nil, err - } - - if c.Global.KeyID, c.Global.PrivateKey, err = readKeyPEM(privateKeyPath, privateKeyData, true); err != nil { + if c.Global.KeyID, c.Global.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { return nil, err } @@ -265,6 +260,14 @@ func loadConfig( return &c, nil } +func LoadMatrixKey(privateKeyPath string, readFile func(string) ([]byte, error)) (gomatrixserverlib.KeyID, ed25519.PrivateKey, error) { + privateKeyData, err := readFile(privateKeyPath) + if err != nil { + return "", nil, err + } + return readKeyPEM(privateKeyPath, privateKeyData, true) +} + // Derive generates data that is derived from various values provided in // the config file. func (config *Dendrite) Derive() error { diff --git a/setup/config/config_appservice.go b/setup/config/config_appservice.go index b8f99a612..9c3771272 100644 --- a/setup/config/config_appservice.go +++ b/setup/config/config_appservice.go @@ -31,8 +31,6 @@ type AppServiceAPI struct { InternalAPI InternalAPIOptions `yaml:"internal_api"` - Database DatabaseOptions `yaml:"database"` - // DisableTLSValidation disables the validation of X.509 TLS certs // on appservice endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` @@ -43,16 +41,9 @@ type AppServiceAPI struct { func (c *AppServiceAPI) Defaults(generate bool) { c.InternalAPI.Listen = "http://localhost:7777" c.InternalAPI.Connect = "http://localhost:7777" - c.Database.Defaults(5) - if generate { - c.Database.ConnectionString = "file:appservice.db" - } } func (c *AppServiceAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { - if c.Matrix.DatabaseOptions.ConnectionString == "" { - checkNotEmpty(configErrs, "app_service_api.database.connection_string", string(c.Database.ConnectionString)) - } if isMonolith { // polylith required configs below return } diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 1c07583e9..1ec860b04 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -9,9 +9,16 @@ import ( "github.com/sirupsen/logrus" ) +// JetStreamConsumer starts a durable consumer on the given subject with the +// given durable name. The function will be called when one or more messages +// is available, up to the maximum batch size specified. If the batch is set to +// 1 then messages will be delivered one at a time. If the function is called, +// the messages array is guaranteed to be at least 1 in size. Any provided NATS +// options will be passed through to the pull subscriber creation. The consumer +// will continue to run until the context expires, at which point it will stop. func JetStreamConsumer( - ctx context.Context, js nats.JetStreamContext, subj, durable string, - f func(ctx context.Context, msg *nats.Msg) bool, + ctx context.Context, js nats.JetStreamContext, subj, durable string, batch int, + f func(ctx context.Context, msgs []*nats.Msg) bool, opts ...nats.SubOpt, ) error { defer func() { @@ -50,7 +57,7 @@ func JetStreamConsumer( // enforce its own deadline (roughly 5 seconds by default). Therefore // it is our responsibility to check whether our context expired or // not when a context error is returned. Footguns. Footguns everywhere. - msgs, err := sub.Fetch(1, nats.Context(ctx)) + msgs, err := sub.Fetch(batch, nats.Context(ctx)) if err != nil { if err == context.Canceled || err == context.DeadlineExceeded { // Work out whether it was the JetStream context that expired @@ -74,21 +81,26 @@ func JetStreamConsumer( if len(msgs) < 1 { continue } - msg := msgs[0] - if err = msg.InProgress(nats.Context(ctx)); err != nil { - logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) - sentry.CaptureException(err) - continue - } - if f(ctx, msg) { - if err = msg.AckSync(nats.Context(ctx)); err != nil { - logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err)) + for _, msg := range msgs { + if err = msg.InProgress(nats.Context(ctx)); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) sentry.CaptureException(err) + continue + } + } + if f(ctx, msgs) { + for _, msg := range msgs { + if err = msg.AckSync(nats.Context(ctx)); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err)) + sentry.CaptureException(err) + } } } else { - if err = msg.Nak(nats.Context(ctx)); err != nil { - logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) - sentry.CaptureException(err) + for _, msg := range msgs { + if err = msg.Nak(nats.Context(ctx)); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) + sentry.CaptureException(err) + } } } } diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 051d55a35..3660e91e3 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -183,6 +183,7 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc OutputReceiptEvent: {"SyncAPIEDUServerReceiptConsumer", "FederationAPIEDUServerConsumer"}, OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"}, OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"}, + OutputRoomEvent: {"AppserviceRoomserverConsumer"}, } { streamName := cfg.Matrix.JetStream.Prefixed(stream) for _, consumer := range consumers { diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 02633b567..f0588cab8 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -75,15 +75,16 @@ func NewOutputClientDataConsumer( // Start consuming from room servers func (s *OutputClientDataConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called when the sync server receives a new event from the client API server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON userID := msg.Header.Get(jetstream.UserID) var output eventutil.AccountData diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index c8d88ddac..c42e71971 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -75,12 +75,13 @@ func NewOutputKeyChangeEventConsumer( // Start consuming from the key server func (s *OutputKeyChangeEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var m api.DeviceMessage if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index db7d67fa6..61bdc13de 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -128,12 +128,13 @@ func (s *PresenceConsumer) Start() error { return nil } return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.presenceTopic, s.durable, s.onMessage, + s.ctx, s.jetstream, s.presenceTopic, s.durable, 1, s.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } -func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *PresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called userID := msg.Header.Get(jetstream.UserID) presence := msg.Header.Get("presence") timestamp := msg.Header.Get("last_active_ts") diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index 83156cf93..a18244c44 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -74,12 +74,13 @@ func NewOutputReceiptEventConsumer( // Start consuming receipts events. func (s *OutputReceiptEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called output := types.OutputReceiptEvent{ UserID: msg.Header.Get(jetstream.UserID), RoomID: msg.Header.Get(jetstream.RoomID), diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index f77b1673b..6979eb484 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -79,15 +79,16 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called when the sync server receives a new event from the room server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON var err error var output api.OutputEvent diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 0b9153fcd..89b01d7e5 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -68,12 +68,13 @@ func NewOutputSendToDeviceEventConsumer( // Start consuming send-to-device events. func (s *OutputSendToDeviceEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called userID := msg.Header.Get(jetstream.UserID) _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { diff --git a/syncapi/consumers/typing.go b/syncapi/consumers/typing.go index 48e484ec5..88db80f8c 100644 --- a/syncapi/consumers/typing.go +++ b/syncapi/consumers/typing.go @@ -64,12 +64,13 @@ func NewOutputTypingEventConsumer( // Start consuming typing events. func (s *OutputTypingEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called roomID := msg.Header.Get(jetstream.RoomID) userID := msg.Header.Get(jetstream.UserID) typing, err := strconv.ParseBool(msg.Header.Get("typing")) diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go index 010fa7c8e..227823522 100644 --- a/syncapi/consumers/userapi.go +++ b/syncapi/consumers/userapi.go @@ -67,8 +67,8 @@ func NewOutputNotificationDataConsumer( // Start starts consumption. func (s *OutputNotificationDataConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } @@ -76,7 +76,8 @@ func (s *OutputNotificationDataConsumer) Start() error { // the push server. It is not safe for this function to be called from // multiple goroutines, or else the sync stream position may race and // be incorrectly calculated. -func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called userID := string(msg.Header.Get(jetstream.UserID)) // Parse out the event JSON diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 23824e366..3d6b2a7f3 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -18,14 +18,16 @@ import ( "context" "strings" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + keyapi "github.com/matrix-org/dendrite/keyserver/api" keytypes "github.com/matrix-org/dendrite/keyserver/types" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" ) // DeviceOTKCounts adds one-time key counts to the /sync response @@ -125,7 +127,7 @@ func DeviceListCatchup( "from": offset, "to": toOffset, "response_offset": queryRes.Offset, - }).Debugf("QueryKeyChanges request result: %+v", res.DeviceLists) + }).Tracef("QueryKeyChanges request result: %+v", res.DeviceLists) return types.StreamPosition(queryRes.Offset), hasNew, nil } @@ -277,6 +279,10 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin // it's enough to know that we have our member event here, don't need to check membership content // as it's implied by being in the respective section of the sync response. if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { + // ignore e.g. join -> join changes + if gjson.GetBytes(ev.Unsigned, "prev_content.membership").Str == gjson.GetBytes(ev.Content, "membership").Str { + continue + } return true } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 9db3d8e17..03614302c 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -350,8 +350,10 @@ func (r *messagesReq) retrieveEvents() ( startTime := time.Now() filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages") logrus.WithFields(logrus.Fields{ - "duration": time.Since(startTime), - "room_id": r.roomID, + "duration": time.Since(startTime), + "room_id": r.roomID, + "events_before": len(events), + "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err } @@ -513,6 +515,9 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] // Store the events in the database, while marking them as unfit to show // up in responses to sync requests. + if res.HistoryVisibility == "" { + res.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared + } for i := range res.Events { _, err = r.db.WriteEvent( context.Background(), @@ -521,7 +526,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] []string{}, []string{}, nil, true, - gomatrixserverlib.HistoryVisibilityShared, + res.HistoryVisibility, ) if err != nil { return nil, err @@ -534,6 +539,9 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] // last `limit` events events = events[len(events)-limit:] } + for _, ev := range events { + ev.Visibility = res.HistoryVisibility + } return events, nil } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 43a75da95..0c8ba4e3d 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -19,10 +19,11 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) type Database interface { diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go index fd0c1c56b..6ab1f0f48 100644 --- a/syncapi/storage/postgres/send_to_device_table.go +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -41,6 +41,8 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( -- The event content JSON. content TEXT NOT NULL ); + +CREATE INDEX IF NOT EXISTS syncapi_send_to_device_user_id_device_id_idx ON syncapi_send_to_device(user_id, device_id); ` const insertSendToDeviceMessageSQL = ` diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index a46e55256..b06d2c6a9 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -20,15 +20,18 @@ import ( "encoding/json" "fmt" + "github.com/tidwall/gjson" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" ) // Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite @@ -683,7 +686,7 @@ func (d *Database) GetStateDeltas( ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter, -) ([]types.StateDelta, []string, error) { +) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response // - For each room which has membership list changes: @@ -718,8 +721,6 @@ func (d *Database) GetStateDeltas( } } - var deltas []types.StateDelta - // get all the state events ever (i.e. for all available rooms) between these two positions stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) if err != nil { @@ -767,15 +768,11 @@ func (d *Database) GetStateDeltas( } // handle newly joined rooms and non-joined rooms + newlyJoinedRooms := make(map[string]bool, len(state)) for roomID, stateStreamEvents := range state { for _, ev := range stateStreamEvents { - // TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event. - // We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this, - // dupe join events will result in the entire room state coming down to the client again. This is added in - // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to - // the timeline. - if membership := getMembershipFromEvent(ev.Event, userID); membership != "" { - if membership == gomatrixserverlib.Join { + if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" { + if membership == gomatrixserverlib.Join && prevMembership != membership { // send full room state down instead of a delta var s []types.StreamEvent s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) @@ -786,6 +783,7 @@ func (d *Database) GetStateDeltas( return nil, nil, err } state[roomID] = s + newlyJoinedRooms[roomID] = true continue // we'll add this room in when we do joined rooms } @@ -806,6 +804,7 @@ func (d *Database) GetStateDeltas( Membership: gomatrixserverlib.Join, StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), RoomID: joinedRoomID, + NewlyJoined: newlyJoinedRooms[joinedRoomID], }) } @@ -892,7 +891,7 @@ func (d *Database) GetStateDeltasForFullStateSync( for roomID, stateStreamEvents := range state { for _, ev := range stateStreamEvents { - if membership := getMembershipFromEvent(ev.Event, userID); membership != "" { + if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" { if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. deltas[roomID] = types.StateDelta{ Membership: membership, @@ -1003,15 +1002,16 @@ func (d *Database) CleanSendToDeviceUpdates( // getMembershipFromEvent returns the value of content.membership iff the event is a state event // with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. -func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string { +func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) { if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { - return "" + return "", "" } membership, err := ev.Membership() if err != nil { - return "" + return "", "" } - return membership + prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str + return membership, prevMembership } // StoreReceipt stores user receipts diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index e3aa1b7a1..0da06506c 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -39,6 +39,8 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( -- The event content JSON. content TEXT NOT NULL ); + +CREATE INDEX IF NOT EXISTS syncapi_send_to_device_user_id_device_id_idx ON syncapi_send_to_device(user_id, device_id); ` const insertSendToDeviceMessageSQL = ` diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 0e9dda577..ffcf64df6 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -178,24 +178,24 @@ func (p *PDUStreamProvider) IncrementalSync( var err error var stateDeltas []types.StateDelta - var joinedRooms []string + var syncJoinedRooms []string stateFilter := req.Filter.Room.State eventFilter := req.Filter.Room.Timeline if req.WantFullState { - if stateDeltas, joinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") return } } else { - if stateDeltas, joinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") return } } - for _, roomID := range joinedRooms { + for _, roomID := range syncJoinedRooms { req.Rooms[roomID] = gomatrixserverlib.Join } @@ -209,11 +209,27 @@ func (p *PDUStreamProvider) IncrementalSync( newPos = from for _, delta := range stateDeltas { + newRange := r + // If this room was joined in this sync, try to fetch + // as much timeline events as allowed by the filter. + if delta.NewlyJoined { + // Reverse the range, so we get the most recent first. + // This will be limited by the eventFilter. + newRange = types.Range{ + From: r.To, + To: 0, + Backwards: true, + } + } var pos types.StreamPosition - if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil { + if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") return to } + // Reset the position, as it is only for the special case of newly joined rooms + if delta.NewlyJoined { + pos = newRange.From + } switch { case r.Backwards && pos < newPos: fallthrough @@ -287,7 +303,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( if stateFilter.LazyLoadMembers { delta.StateEvents, err = p.lazyLoadMembers( - ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers, + ctx, delta.RoomID, true, limited, stateFilter, device, recentEvents, delta.StateEvents, ) if err != nil && err != sql.ErrNoRows { @@ -309,12 +325,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( logrus.WithError(err).Error("unable to apply history visibility filter") } - if len(events) > 0 { - updateLatestPosition(events[len(events)-1].EventID()) - } if len(delta.StateEvents) > 0 { updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) } + if len(events) > 0 { + updateLatestPosition(events[len(events)-1].EventID()) + } switch delta.Membership { case gomatrixserverlib.Join: @@ -387,6 +403,8 @@ func applyHistoryVisibilityFilter( logrus.WithFields(logrus.Fields{ "duration": time.Since(startTime), "room_id": roomID, + "before": len(recentEvents), + "after": len(events), }).Debug("applied history visibility (sync)") return events, nil } @@ -514,7 +532,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( return nil, err } stateEvents, err = p.lazyLoadMembers(ctx, roomID, - false, limited, stateFilter.IncludeRedundantMembers, + false, limited, stateFilter, device, recentEvents, stateEvents, ) if err != nil && err != sql.ErrNoRows { @@ -533,7 +551,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( func (p *PDUStreamProvider) lazyLoadMembers( ctx context.Context, roomID string, - incremental, limited, includeRedundant bool, + incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter, device *userapi.Device, timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, ) ([]*gomatrixserverlib.HeaderedEvent, error) { @@ -563,7 +581,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( stateKey := *event.StateKey() if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental { newStateEvents = append(newStateEvents, event) - if !includeRedundant { + if !stateFilter.IncludeRedundantMembers { p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID()) } delete(timelineUsers, stateKey) @@ -578,6 +596,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( } // Query missing membership events filter := gomatrixserverlib.DefaultStateFilter() + filter.Limit = stateFilter.Limit filter.Senders = &wantUsers filter.Types = &[]string{gomatrixserverlib.MRoomMember} memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter) diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 877bcf141..15db4d30e 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -19,9 +19,11 @@ import ( "encoding/json" "sync" + "github.com/matrix-org/gomatrixserverlib" + "github.com/tidwall/gjson" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) type PresenceStreamProvider struct { @@ -175,6 +177,10 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin // it's enough to know that we have our member event here, don't need to check membership content // as it's implied by being in the respective section of the sync response. if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { + // ignore e.g. join -> join changes + if gjson.GetBytes(ev.Unsigned, "prev_content.membership").Str == gjson.GetBytes(ev.Content, "membership").Str { + continue + } return true } } diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 9d4740e93..268ed70c6 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -23,12 +23,13 @@ import ( "strconv" "time" - "github.com/matrix-org/dendrite/syncapi/storage" - "github.com/matrix-org/dendrite/syncapi/types" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" ) const defaultSyncTimeout = time.Duration(0) @@ -46,15 +47,9 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat return nil, err } } - // TODO: read from stored filters too + + // Create a default filter and apply a stored filter on top of it (if specified) filter := gomatrixserverlib.DefaultFilter() - if since.IsEmpty() { - // Send as much account data down for complete syncs as possible - // by default, otherwise clients do weird things while waiting - // for the rest of the data to trickle down. - filter.AccountData.Limit = math.MaxInt32 - filter.Room.AccountData.Limit = math.MaxInt32 - } filterQuery := req.URL.Query().Get("filter") if filterQuery != "" { if filterQuery[0] == '{' { @@ -76,6 +71,17 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat } } + // A loaded filter might have overwritten these values, + // so set them after loading the filter. + if since.IsEmpty() { + // Send as much account data down for complete syncs as possible + // by default, otherwise clients do weird things while waiting + // for the rest of the data to trickle down. + filter.AccountData.Limit = math.MaxInt32 + filter.Room.AccountData.Limit = math.MaxInt32 + filter.Room.State.Limit = math.MaxInt32 + } + logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{ "user_id": device.UserID, "device_id": device.ID, diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index d908a9629..c2c9616e8 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -298,8 +298,8 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. return giveup() case <-userStreamListener.GetNotifyChannel(syncReq.Since): - syncReq.Log.Debugln("Responding to sync after wake-up") currentPos.ApplyUpdates(userStreamListener.GetSyncPosition()) + syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync after wake-up") } } else { syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately") diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 089bdafaf..c81256aa7 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -154,8 +154,12 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { wantJoinedRooms: []string{room.ID}, }, } - // TODO: find a better way - time.Sleep(500 * time.Millisecond) + + syncUntil(t, base, alice.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) + return gjson.Get(syncBody, path).Exists() + }) for _, tc := range testCases { w := httptest.NewRecorder() @@ -191,6 +195,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) { } func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { + t.Skip("Skipped, possibly fixed") user := test.NewUser(t) room := test.NewRoom(t, user) alice := userapi.Device{ @@ -343,6 +348,13 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { // create the users alice := test.NewUser(t) + aliceDev := userapi.Device{ + ID: "ALICEID", + UserID: alice.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "ALICE", + } + bob := test.NewUser(t) bobDev := userapi.Device{ @@ -409,7 +421,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{bobDev}}, rsAPI, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{}) for _, tc := range testCases { testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) @@ -418,12 +430,18 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility)) // send the events/messages to NATS to create the rooms - beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)}) + beforeJoinBody := fmt.Sprintf("Before invite in a %s room", tc.historyVisibility) + beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": beforeJoinBody}) eventsToSend := append(room.Events(), beforeJoinEv) if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } - time.Sleep(100 * time.Millisecond) // TODO: find a better way + syncUntil(t, base, aliceDev.AccessToken, false, + func(syncBody string) bool { + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(content.body=="%s")`, room.ID, beforeJoinBody) + return gjson.Get(syncBody, path).Exists() + }, + ) // There is only one event, we expect only to be able to see this, if the room is world_readable w := httptest.NewRecorder() @@ -449,14 +467,20 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID)) afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)}) joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID)) - msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)}) + afterJoinBody := fmt.Sprintf("After join in a %s room", tc.historyVisibility) + msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": afterJoinBody}) eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } - time.Sleep(100 * time.Millisecond) // TODO: find a better way + syncUntil(t, base, aliceDev.AccessToken, false, + func(syncBody string) bool { + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(content.body=="%s")`, room.ID, afterJoinBody) + return gjson.Get(syncBody, path).Exists() + }, + ) // Verify the messages after/before invite are visible or not w = httptest.NewRecorder() @@ -511,8 +535,8 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { AccountType: userapi.AccountTypeUser, } - base, close := testrig.CreateBaseDendrite(t, dbType) - defer close() + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) @@ -607,7 +631,14 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { t.Fatalf("unable to send to device message: %v", err) } } - time.Sleep((time.Millisecond * 15) * time.Duration(tc.sendMessagesCount)) // wait a bit, so the messages can be processed + + syncUntil(t, base, alice.AccessToken, + len(tc.want) == 0, + func(body string) bool { + return gjson.Get(body, fmt.Sprintf(`to_device.events.#(content.dummy=="message %d")`, msgCounter)).Exists() + }, + ) + // Execute a /sync request, recording the response w := httptest.NewRecorder() base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ @@ -630,6 +661,42 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { } } +func syncUntil(t *testing.T, + base *base.BaseDendrite, accessToken string, + skip bool, + checkFunc func(syncBody string) bool, +) { + if checkFunc == nil { + t.Fatalf("No checkFunc defined") + } + if skip { + return + } + // loop on /sync until we receive the last send message or timeout after 5 seconds, since we don't know if the message made it + // to the syncAPI when hitting /sync + done := make(chan bool) + defer close(done) + go func() { + for { + w := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ + "access_token": accessToken, + "timeout": "1000", + }))) + if checkFunc(w.Body.String()) { + done <- true + return + } + } + }() + + select { + case <-done: + case <-time.After(time.Second * 5): + t.Fatalf("Timed out waiting for messages") + } +} + func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg { result := make([]*nats.Msg, len(input)) for i, ev := range input { diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 39b085d9c..d75d53ca9 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -37,6 +37,7 @@ var ( type StateDelta struct { RoomID string StateEvents []*gomatrixserverlib.HeaderedEvent + NewlyJoined bool Membership string // The PDU stream position of the latest membership event for this user, if applicable. // Can be 0 if there is no membership event in this delta. diff --git a/sytest-whitelist b/sytest-whitelist index dcffeaffb..5c8896b99 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -144,7 +144,6 @@ Server correctly handles incoming m.device_list_update If remote user leaves room, changes device and rejoins we see update in sync If remote user leaves room, changes device and rejoins we see update in /keys/changes If remote user leaves room we no longer receive device updates -If a device list update goes missing, the server resyncs on the next one Server correctly resyncs when client query keys and there is no remote cache Server correctly resyncs when server leaves and rejoins a room Device list doesn't change if remote server is down @@ -633,7 +632,6 @@ Test that rejected pushers are removed. Trying to add push rule with no scope fails with 400 Trying to add push rule with invalid scope fails with 400 Forward extremities remain so even after the next events are populated as outliers -If a device list update goes missing, the server resyncs on the next one uploading self-signing key notifies over federation uploading signed devices gets propagated over federation Device list doesn't change if remote server is down diff --git a/test/http.go b/test/http.go index 37b3648f8..8cd83d0a6 100644 --- a/test/http.go +++ b/test/http.go @@ -68,7 +68,7 @@ func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL str if withTLS { certFile := filepath.Join(t.TempDir(), "dendrite.cert") keyFile := filepath.Join(t.TempDir(), "dendrite.key") - err = NewTLSKey(keyFile, certFile) + err = NewTLSKey(keyFile, certFile, 1024) if err != nil { t.Errorf("failed to make TLS key: %s", err) return diff --git a/test/keys.go b/test/keys.go index 327c6ed7b..05f7317cf 100644 --- a/test/keys.go +++ b/test/keys.go @@ -15,6 +15,7 @@ package test import ( + "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -44,6 +45,10 @@ func NewMatrixKey(matrixKeyPath string) (err error) { if err != nil { return err } + return SaveMatrixKey(matrixKeyPath, data[3:]) +} + +func SaveMatrixKey(matrixKeyPath string, data ed25519.PrivateKey) error { keyOut, err := os.OpenFile(matrixKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err @@ -62,15 +67,15 @@ func NewMatrixKey(matrixKeyPath string) (err error) { Headers: map[string]string{ "Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]), }, - Bytes: data[3:], + Bytes: data, }) return err } const certificateDuration = time.Hour * 24 * 365 * 10 -func generateTLSTemplate(dnsNames []string) (*rsa.PrivateKey, *x509.Certificate, error) { - priv, err := rsa.GenerateKey(rand.Reader, 4096) +func generateTLSTemplate(dnsNames []string, bitSize int) (*rsa.PrivateKey, *x509.Certificate, error) { + priv, err := rsa.GenerateKey(rand.Reader, bitSize) if err != nil { return nil, nil, err } @@ -118,8 +123,8 @@ func writePrivateKey(tlsKeyPath string, priv *rsa.PrivateKey) error { } // NewTLSKey generates a new RSA TLS key and certificate and writes it to a file. -func NewTLSKey(tlsKeyPath, tlsCertPath string) error { - priv, template, err := generateTLSTemplate(nil) +func NewTLSKey(tlsKeyPath, tlsCertPath string, keySize int) error { + priv, template, err := generateTLSTemplate(nil, keySize) if err != nil { return err } @@ -136,8 +141,8 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error { return writePrivateKey(tlsKeyPath, priv) } -func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string) error { - priv, template, err := generateTLSTemplate([]string{serverName}) +func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string, keySize int) error { + priv, template, err := generateTLSTemplate([]string{serverName}, keySize) if err != nil { return err } diff --git a/test/testrig/base.go b/test/testrig/base.go index d13c43129..33230921c 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -57,7 +57,6 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() { // cleanup db files. This risks getting out of sync as we add more database strings :( dbFiles := []config.DataSource{ - cfg.AppServiceAPI.Database.ConnectionString, cfg.FederationAPI.Database.ConnectionString, cfg.KeyServer.Database.ConnectionString, cfg.MSCs.Database.ConnectionString, diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go index 067f93330..54654f757 100644 --- a/userapi/consumers/syncapi_readupdate.go +++ b/userapi/consumers/syncapi_readupdate.go @@ -56,15 +56,16 @@ func NewOutputReadUpdateConsumer( func (s *OutputReadUpdateConsumer) Start() error { if err := jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ); err != nil { return err } return nil } -func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var read types.ReadUpdate if err := json.Unmarshal(msg.Data, &read); err != nil { log.WithError(err).Error("userapi clientapi consumer: message parse failure") diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go index 7807c7637..3ac6f58d0 100644 --- a/userapi/consumers/syncapi_streamevent.go +++ b/userapi/consumers/syncapi_streamevent.go @@ -7,6 +7,10 @@ import ( "strings" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushrules" @@ -20,9 +24,6 @@ import ( "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/util" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" ) type OutputStreamEventConsumer struct { @@ -64,15 +65,16 @@ func NewOutputStreamEventConsumer( func (s *OutputStreamEventConsumer) Start() error { if err := jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ); err != nil { return err } return nil } -func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var output types.StreamedEvent output.Event = &gomatrixserverlib.HeaderedEvent{} if err := json.Unmarshal(msg.Data, &output); err != nil { @@ -529,7 +531,9 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat case "event_id_only": req = pushgateway.NotifyRequest{ Notification: pushgateway.Notification{ - Counts: &pushgateway.Counts{}, + Counts: &pushgateway.Counts{ + Unread: userNumUnreadNotifs, + }, Devices: devices, EventID: event.EventID(), RoomID: event.RoomID(), diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 6ba469327..2f28ee906 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -28,7 +28,6 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" - "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushrules" @@ -454,7 +453,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // Create a dummy device for AS user dev := api.Device{ // Use AS dummy device ID - ID: types.AppServiceDeviceID, + ID: "AS_Device", // AS dummy device has AS's token. AccessToken: token, AppserviceID: appService.ID,