diff --git a/.gitignore b/.gitignore index dbc84edb1..092f4501c 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ /vendor/bin /docker/build /logs +/jetstream # Architecture specific extensions/prefixes *.[568vq] diff --git a/CHANGES.md b/CHANGES.md index 95e24de11..07e09480a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,32 @@ # Changelog +## Dendrite 0.6.2 (2022-02-04) + +### Fixes + +* Resolves an issue where the key change consumer in the keyserver could consume extreme amounts of CPU + +## Dendrite 0.6.1 (2022-02-04) + +### Features + +* Roomserver inputs now take place with full transactional isolation in PostgreSQL deployments +* Pull consumers are now used instead of push consumers when retrieving messages from NATS to better guarantee ordering and to reduce redelivery of duplicate messages +* Further logging tweaks, particularly when joining rooms +* Improved calculation of servers in the room, when checking for missing auth/prev events or state +* Dendrite will now skip dead servers more quickly when federating by reducing the TCP dial timeout +* The key change consumers have now been converted to use native NATS code rather than a wrapper +* Go 1.16 is now the minimum supported version for Dendrite + +### Fixes + +* Local clients should now be notified correctly of invites +* The roomserver input API now has more time to process events, particularly when fetching missing events or state, which should fix a number of errors from expired contexts +* Fixed a panic that could happen due to a closed channel in the roomserver input API +* Logging in with uppercase usernames from old installations is now supported again (contributed by [hoernschen](https://github.com/hoernschen)) +* Federated room joins now have more time to complete and should not fail due to expired contexts +* Events that were sent to the roomserver along with a complete state snapshot are now persisted with the correct state, even if they were rejected or soft-failed + ## Dendrite 0.6.0 (2022-01-28) ### Features diff --git a/appservice/appservice.go b/appservice/appservice.go index 924a609ea..7e7c67f53 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -58,7 +58,7 @@ func NewInternalAPI( }, }, } - js, _, _ := jetstream.Prepare(&base.Cfg.Global.JetStream) + js := jetstream.Prepare(&base.Cfg.Global.JetStream) // Create a connection to the appservice postgres DB appserviceDB, err := storage.NewDatabase(&base.Cfg.AppServiceAPI.Database) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 8aea5c347..7b59e3704 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -34,7 +34,7 @@ import ( type OutputRoomEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string asDB storage.Database rsAPI api.RoomserverInternalAPI @@ -66,37 +66,37 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called when the appservice component receives a new event from // the room server output log. -func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - var output api.OutputEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } - - if output.Type != api.OutputTypeNewRoomEvent { - return true - } - - events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} - events = append(events, output.NewRoomEvent.AddStateEvents...) - - // Send event to any relevant application services - if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { - log.WithError(err).Errorf("roomserver output log: filter error") - return true - } - +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + var output api.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") return true - }) + } + + if output.Type != api.OutputTypeNewRoomEvent { + return true + } + + events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} + events = append(events, output.NewRoomEvent.AddStateEvents...) + + // Send event to any relevant application services + if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { + log.WithError(err).Errorf("roomserver output log: filter error") + return true + } + + return true } // filterRoomserverEvents takes in events and decides whether any of them need diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 1c9c0ac4e..211b8d653 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -281,7 +281,7 @@ func (m *DendriteMonolith) Start() { cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.PrivateKey = sk cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("file:%s/%s", m.StorageDirectory, prefix)) + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix)) cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix)) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 1aae418d1..3d9ba8aa0 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -86,7 +86,7 @@ func (m *DendriteMonolith) Start() { cfg.Global.ServerName = gomatrixserverlib.ServerName(ygg.DerivedServerName()) cfg.Global.PrivateKey = ygg.PrivateKey() cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("file:%s/", m.StorageDirectory)) + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory)) cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 55b381ba5..401695abf 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -12,10 +12,14 @@ COPY . . RUN go build ./cmd/dendrite-monolith-server RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config -RUN ./generate-config --ci > dendrite.yaml -RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key +RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost EXPOSE 8008 8448 -CMD sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml && ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml +# At runtime, generate TLS cert based on the CA now mounted at /ca +# At runtime, replace the SERVER_NAME with what we are told +CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /ca/ca.crt --tls-authority-key /ca/ca.key && \ + ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ + cp /ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ + ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 19f7e3773..18cf94979 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -16,6 +16,7 @@ package auth import ( "context" + "database/sql" "net/http" "strings" @@ -59,8 +60,9 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) return login, func(context.Context, *util.JSONResponse) {}, nil } -func (t *LoginTypePassword) Login(ctx context.Context, r *PasswordRequest) (*Login, *util.JSONResponse) { - username := strings.ToLower(r.Username()) +func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { + r := req.(*PasswordRequest) + username := strings.ToLower(r.Username()) if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, @@ -74,8 +76,15 @@ func (t *LoginTypePassword) Login(ctx context.Context, r *PasswordRequest) (*Log JSON: jsonerror.InvalidUsername(err.Error()), } } - _, err = t.GetAccountByPassword(ctx, localpart, r.Password) + // Squash username to all lowercase letters + _, err = t.GetAccountByPassword(ctx, strings.ToLower(localpart), r.Password) if err != nil { + if err == sql.ErrNoRows { + _, err = t.GetAccountByPassword(ctx, localpart, r.Password) + if err == nil { + return &r.Login, nil + } + } // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // but that would leak the existence of the user. return nil, &util.JSONResponse{ diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 7c772125a..d678ada96 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -49,7 +49,7 @@ func AddPublicRoutes( extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, ) { - js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) syncProducer := &producers.SyncAPIProducer{ JetStream: js, diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index a79470d83..60729672e 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -83,7 +83,7 @@ func main() { if *defaultsForCI { cfg.AppServiceAPI.DisableTLSValidation = true cfg.ClientAPI.RateLimiting.Enabled = false - cfg.FederationAPI.DisableTLSValidation = true + cfg.FederationAPI.DisableTLSValidation = false // don't hit matrix.org when running tests!!! cfg.FederationAPI.KeyPerspectives = config.KeyPerspectives{} cfg.MSCs.MSCs = []string{"msc2836", "msc2946", "msc2444", "msc2753"} diff --git a/cmd/generate-keys/main.go b/cmd/generate-keys/main.go index 743109f13..bddf219dc 100644 --- a/cmd/generate-keys/main.go +++ b/cmd/generate-keys/main.go @@ -32,9 +32,12 @@ Arguments: ` var ( - tlsCertFile = flag.String("tls-cert", "", "An X509 certificate file to generate for use for TLS") - tlsKeyFile = flag.String("tls-key", "", "An RSA private key file to generate for use for TLS") - privateKeyFile = flag.String("private-key", "", "An Ed25519 private key to generate for use for object signing") + tlsCertFile = flag.String("tls-cert", "", "An X509 certificate file to generate for use for TLS") + tlsKeyFile = flag.String("tls-key", "", "An RSA private key file to generate for use for TLS") + privateKeyFile = flag.String("private-key", "", "An Ed25519 private key to generate for use for object signing") + authorityCertFile = flag.String("tls-authority-cert", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.") + authorityKeyFile = flag.String("tls-authority-key", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.") + serverName = flag.String("server", "", "Optional: Create TLS certificate/keys with this domain name set. Useful for integration testing.") ) func main() { @@ -54,8 +57,15 @@ func main() { if *tlsCertFile == "" || *tlsKeyFile == "" { log.Fatal("Zero or both of --tls-key and --tls-cert must be supplied") } - if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil { - panic(err) + if *authorityCertFile == "" && *authorityKeyFile == "" { + if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil { + panic(err) + } + } else { + // generate the TLS cert/key based on the authority given. + if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile); err != nil { + panic(err) + } } fmt.Printf("Created TLS cert file: %s\n", *tlsCertFile) fmt.Printf("Created TLS key file: %s\n", *tlsKeyFile) diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index db03001ba..febcf2864 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -42,7 +42,7 @@ func NewInternalAPI( ) api.EDUServerInputAPI { cfg := &base.Cfg.EDUServer - js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) return &input.EDUServerInputAPI{ Cache: eduCache, diff --git a/federationapi/consumers/eduserver.go b/federationapi/consumers/eduserver.go index c3e5b4d49..22fedbeb4 100644 --- a/federationapi/consumers/eduserver.go +++ b/federationapi/consumers/eduserver.go @@ -34,7 +34,7 @@ import ( type OutputEDUConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string db storage.Database queues *queue.OutgoingQueues ServerName gomatrixserverlib.ServerName @@ -66,13 +66,22 @@ func NewOutputEDUConsumer( // Start consuming from EDU servers func (t *OutputEDUConsumer) Start() error { - if _, err := t.jetstream.Subscribe(t.typingTopic, t.onTypingEvent, t.durable); err != nil { + if err := jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.typingTopic, t.durable, t.onTypingEvent, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { return err } - if _, err := t.jetstream.Subscribe(t.sendToDeviceTopic, t.onSendToDeviceEvent, t.durable); err != nil { + if err := jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.sendToDeviceTopic, t.durable, t.onSendToDeviceEvent, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { return err } - if _, err := t.jetstream.Subscribe(t.receiptTopic, t.onReceiptEvent, t.durable); err != nil { + if err := jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.receiptTopic, t.durable, t.onReceiptEvent, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { return err } return nil @@ -80,175 +89,169 @@ func (t *OutputEDUConsumer) Start() error { // onSendToDeviceEvent is called in response to a message received on the // send-to-device events topic from the EDU server. -func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) { +func (t *OutputEDUConsumer) onSendToDeviceEvent(ctx context.Context, msg *nats.Msg) bool { // Extract the send-to-device event from msg. - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var ote api.OutputSendToDeviceEvent - if err := json.Unmarshal(msg.Data, &ote); err != nil { - log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)") - return true - } - - // only send send-to-device events which originated from us - _, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender) - if err != nil { - log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender") - return true - } - if originServerName != t.ServerName { - log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere") - return true - } - - _, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID) - if err != nil { - log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination") - return true - } - - // Pack the EDU and marshal it - edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MDirectToDevice, - Origin: string(t.ServerName), - } - tdm := gomatrixserverlib.ToDeviceMessage{ - Sender: ote.Sender, - Type: ote.Type, - MessageID: util.RandomString(32), - Messages: map[string]map[string]json.RawMessage{ - ote.UserID: { - ote.DeviceID: ote.Content, - }, - }, - } - if edu.Content, err = json.Marshal(tdm); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - return true - } - - log.Infof("Sending send-to-device message into %q destination queue", destServerName) - if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { - log.WithError(err).Error("failed to send EDU") - return false - } - + var ote api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Data, &ote); err != nil { + log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)") return true - }) + } + + // only send send-to-device events which originated from us + _, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender) + if err != nil { + log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender") + return true + } + if originServerName != t.ServerName { + log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere") + return true + } + + _, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID) + if err != nil { + log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination") + return true + } + + // Pack the EDU and marshal it + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MDirectToDevice, + Origin: string(t.ServerName), + } + tdm := gomatrixserverlib.ToDeviceMessage{ + Sender: ote.Sender, + Type: ote.Type, + MessageID: util.RandomString(32), + Messages: map[string]map[string]json.RawMessage{ + ote.UserID: { + ote.DeviceID: ote.Content, + }, + }, + } + if edu.Content, err = json.Marshal(tdm); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } + + log.Infof("Sending send-to-device message into %q destination queue", destServerName) + if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } + + return true } // onTypingEvent is called in response to a message received on the typing // events topic from the EDU server. -func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Extract the typing event from msg. - var ote api.OutputTypingEvent - if err := json.Unmarshal(msg.Data, &ote); err != nil { - // Skip this msg but continue processing messages. - log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)") - _ = msg.Ack() - return true - } - - // only send typing events which originated from us - _, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID) - if err != nil { - log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender") - _ = msg.Ack() - return true - } - if typingServerName != t.ServerName { - return true - } - - joined, err := t.db.GetJoinedHosts(t.ctx, ote.Event.RoomID) - if err != nil { - log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room") - return false - } - - names := make([]gomatrixserverlib.ServerName, len(joined)) - for i := range joined { - names[i] = joined[i].ServerName - } - - edu := &gomatrixserverlib.EDU{Type: ote.Event.Type} - if edu.Content, err = json.Marshal(map[string]interface{}{ - "room_id": ote.Event.RoomID, - "user_id": ote.Event.UserID, - "typing": ote.Event.Typing, - }); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - return true - } - - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { - log.WithError(err).Error("failed to send EDU") - return false - } - +func (t *OutputEDUConsumer) onTypingEvent(ctx context.Context, msg *nats.Msg) bool { + // Extract the typing event from msg. + var ote api.OutputTypingEvent + if err := json.Unmarshal(msg.Data, &ote); err != nil { + // Skip this msg but continue processing messages. + log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)") + _ = msg.Ack() return true - }) + } + + // only send typing events which originated from us + _, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID) + if err != nil { + log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender") + _ = msg.Ack() + return true + } + if typingServerName != t.ServerName { + return true + } + + joined, err := t.db.GetJoinedHosts(ctx, ote.Event.RoomID) + if err != nil { + log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room") + return false + } + + names := make([]gomatrixserverlib.ServerName, len(joined)) + for i := range joined { + names[i] = joined[i].ServerName + } + + edu := &gomatrixserverlib.EDU{Type: ote.Event.Type} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "room_id": ote.Event.RoomID, + "user_id": ote.Event.UserID, + "typing": ote.Event.Typing, + }); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } + + if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } + + return true } // onReceiptEvent is called in response to a message received on the receipt // events topic from the EDU server. -func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Extract the typing event from msg. - var receipt api.OutputReceiptEvent - if err := json.Unmarshal(msg.Data, &receipt); err != nil { - // Skip this msg but continue processing messages. - log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") - return true - } - - // only send receipt events which originated from us - _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) - if err != nil { - log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") - return true - } - if receiptServerName != t.ServerName { - return true - } - - joined, err := t.db.GetJoinedHosts(t.ctx, receipt.RoomID) - if err != nil { - log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room") - return false - } - - names := make([]gomatrixserverlib.ServerName, len(joined)) - for i := range joined { - names[i] = joined[i].ServerName - } - - content := map[string]api.FederationReceiptMRead{} - content[receipt.RoomID] = api.FederationReceiptMRead{ - User: map[string]api.FederationReceiptData{ - receipt.UserID: { - Data: api.ReceiptTS{ - TS: receipt.Timestamp, - }, - EventIDs: []string{receipt.EventID}, - }, - }, - } - - edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MReceipt, - Origin: string(t.ServerName), - } - if edu.Content, err = json.Marshal(content); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - return true - } - - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { - log.WithError(err).Error("failed to send EDU") - return false - } - +func (t *OutputEDUConsumer) onReceiptEvent(ctx context.Context, msg *nats.Msg) bool { + // Extract the typing event from msg. + var receipt api.OutputReceiptEvent + if err := json.Unmarshal(msg.Data, &receipt); err != nil { + // Skip this msg but continue processing messages. + log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") return true - }) + } + + // only send receipt events which originated from us + _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) + if err != nil { + log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") + return true + } + if receiptServerName != t.ServerName { + return true + } + + joined, err := t.db.GetJoinedHosts(ctx, receipt.RoomID) + if err != nil { + log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room") + return false + } + + names := make([]gomatrixserverlib.ServerName, len(joined)) + for i := range joined { + names[i] = joined[i].ServerName + } + + content := map[string]api.FederationReceiptMRead{} + content[receipt.RoomID] = api.FederationReceiptMRead{ + User: map[string]api.FederationReceiptData{ + receipt.UserID: { + Data: api.ReceiptTS{ + TS: receipt.Timestamp, + }, + EventIDs: []string{receipt.EventID}, + }, + }, + } + + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MReceipt, + Origin: string(t.ServerName), + } + if edu.Content, err = json.Marshal(content); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } + + if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } + + return true } diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 6a737d0ad..1ec9f4c18 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -17,80 +17,73 @@ package consumers import ( "context" "encoding/json" - "fmt" - "github.com/Shopify/sarama" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // KeyChangeConsumer consumes events that originate in key server. type KeyChangeConsumer struct { ctx context.Context - consumer *internal.ContinualConsumer + jetstream nats.JetStreamContext + durable string db storage.Database queues *queue.OutgoingQueues serverName gomatrixserverlib.ServerName rsAPI roomserverAPI.RoomserverInternalAPI + topic string } // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. func NewKeyChangeConsumer( process *process.ProcessContext, cfg *config.KeyServer, - kafkaConsumer sarama.Consumer, + js nats.JetStreamContext, queues *queue.OutgoingQueues, store storage.Database, rsAPI roomserverAPI.RoomserverInternalAPI, ) *KeyChangeConsumer { - c := &KeyChangeConsumer{ - ctx: process.Context(), - consumer: &internal.ContinualConsumer{ - Process: process, - ComponentName: "federationapi/keychange", - Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)), - Consumer: kafkaConsumer, - PartitionStore: store, - }, + return &KeyChangeConsumer{ + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.TopicFor("FederationAPIKeyChangeConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), queues: queues, db: store, serverName: cfg.Matrix.ServerName, rsAPI: rsAPI, } - c.consumer.ProcessMessage = c.onMessage - - return c } // Start consuming from key servers func (t *KeyChangeConsumer) Start() error { - if err := t.consumer.Start(); err != nil { - return fmt.Errorf("t.consumer.Start: %w", err) - } - return nil + return jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { +func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { var m api.DeviceMessage - if err := json.Unmarshal(msg.Value, &m); err != nil { + if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") - return nil + return true } if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil { // This probably shouldn't happen but stops us from panicking if we come // across an update that doesn't satisfy either types. - return nil + return true } switch m.Type { case api.TypeCrossSigningUpdate: @@ -102,9 +95,9 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { } } -func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { +func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { if m.DeviceKeys == nil { - return nil + return true } logger := logrus.WithField("user_id", m.UserID) @@ -112,10 +105,10 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { _, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID) if err != nil { logger.WithError(err).Error("Failed to extract domain from key change event") - return nil + return true } if originServerName != t.serverName { - return nil + return true } var queryRes roomserverAPI.QueryRoomsForUserResponse @@ -125,13 +118,13 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { }, &queryRes) if err != nil { logger.WithError(err).Error("failed to calculate joined rooms for user") - return nil + return true } // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") - return nil + return true } // Pack the EDU and marshal it @@ -149,24 +142,26 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { Keys: m.KeyJSON, } if edu.Content, err = json.Marshal(event); err != nil { - return err + logger.WithError(err).Error("failed to marshal EDU JSON") + return true } - logrus.Infof("Sending device list update message to %q", destinations) - return t.queues.SendEDU(edu, t.serverName, destinations) + logger.Infof("Sending device list update message to %q", destinations) + err = t.queues.SendEDU(edu, t.serverName, destinations) + return err == nil } -func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { +func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { output := m.CrossSigningKeyUpdate _, host, err := gomatrixserverlib.SplitID('@', output.UserID) if err != nil { logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure") - return nil + return true } if host != gomatrixserverlib.ServerName(t.serverName) { // Ignore any messages that didn't originate locally, otherwise we'll // end up parroting information we received from other servers. - return nil + return true } logger := logrus.WithField("user_id", output.UserID) @@ -177,13 +172,13 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { }, &queryRes) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user") - return nil + return true } // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") - return nil + return true } // Pack the EDU and marshal it @@ -193,11 +188,12 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { } if edu.Content, err = json.Marshal(output); err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to marshal output, dropping") - return nil + return true } logger.Infof("Sending cross-signing update message to %q", destinations) - return t.queues.SendEDU(edu, t.serverName, destinations) + err = t.queues.SendEDU(edu, t.serverName, destinations) + return err == nil } func prevID(streamID int) []int { diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 25ea78274..e9862000a 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -37,7 +37,7 @@ type OutputRoomEventConsumer struct { cfg *config.FederationAPI rsAPI api.RoomserverInternalAPI jetstream nats.JetStreamContext - durable nats.SubOpt + durable string db storage.Database queues *queue.OutgoingQueues topic string @@ -66,74 +66,75 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe( - s.topic, s.onMessage, s.durable, - nats.DeliverAll(), - nats.ManualAck(), + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), ) - return err } // onMessage is called when the federation server receives a new event from the room server output log. // It is unsafe to call this with messages for the same room in multiple gorountines // because updates it will likely fail with a types.EventIDMismatchError when it // realises that it cannot update the room state using the deltas. -func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - var output api.OutputEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + var output api.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") + return true + } - switch output.Type { - case api.OutputTypeNewRoomEvent: - ev := output.NewRoomEvent.Event + switch output.Type { + case api.OutputTypeNewRoomEvent: + ev := output.NewRoomEvent.Event - if output.NewRoomEvent.RewritesState { - if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil { - log.WithError(err).Errorf("roomserver output log: purge room state failure") - return false - } - } - - if err := s.processMessage(*output.NewRoomEvent); err != nil { - switch err.(type) { - case *queue.ErrorFederationDisabled: - log.WithField("error", output.Type).Info( - err.Error(), - ) - default: - // panic rather than continue with an inconsistent database - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "event": string(ev.JSON()), - "add": output.NewRoomEvent.AddsStateEventIDs, - "del": output.NewRoomEvent.RemovesStateEventIDs, - log.ErrorKey: err, - }).Panicf("roomserver output log: write room event failure") - } - } - - case api.OutputTypeNewInboundPeek: - if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { - log.WithFields(log.Fields{ - "event": output.NewInboundPeek, - log.ErrorKey: err, - }).Panicf("roomserver output log: remote peek event failure") + if output.NewRoomEvent.RewritesState { + if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil { + log.WithError(err).Errorf("roomserver output log: purge room state failure") return false } - - default: - log.WithField("type", output.Type).Debug( - "roomserver output log: ignoring unknown output type", - ) } - return true - }) + if err := s.processMessage(*output.NewRoomEvent); err != nil { + switch err.(type) { + case *queue.ErrorFederationDisabled: + log.WithField("error", output.Type).Info( + err.Error(), + ) + default: + // panic rather than continue with an inconsistent database + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "event": string(ev.JSON()), + "add": output.NewRoomEvent.AddsStateEventIDs, + "del": output.NewRoomEvent.RemovesStateEventIDs, + log.ErrorKey: err, + }).Panicf("roomserver output log: write room event failure") + } + } + + case api.OutputTypeNewInviteEvent: + log.WithField("type", output.Type).Debug( + "received new invite, send device keys", + ) + + case api.OutputTypeNewInboundPeek: + if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { + log.WithFields(log.Fields{ + "event": output.NewInboundPeek, + log.ErrorKey: err, + }).Panicf("roomserver output log: remote peek event failure") + return false + } + + default: + log.WithField("type", output.Type).Debug( + "roomserver output log: ignoring unknown output type", + ) + } + + return true } // processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any) diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 63387b9d8..a982d8009 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -92,7 +92,7 @@ func NewInternalAPI( FailuresUntilBlacklist: cfg.FederationMaxRetries, } - js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, @@ -120,7 +120,7 @@ func NewInternalAPI( logrus.WithError(err).Panic("failed to start typing server consumer") } keyConsumer := consumers.NewKeyChangeConsumer( - base.ProcessContext, &base.Cfg.KeyServer, consumer, queues, federationDB, rsAPI, + base.ProcessContext, &base.Cfg.KeyServer, js, queues, federationDB, rsAPI, ) if err := keyConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start key server consumer") diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 4dd53c11b..7850f206c 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -196,29 +196,22 @@ func (r *FederationInternalAPI) performJoinUsingServer( return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err) } - // No longer reuse the request context from this point forward. - // We don't want the client timing out to interrupt the join. - var cancel context.CancelFunc - ctx, cancel = context.WithCancel(context.Background()) - // Try to perform a send_join using the newly built event. respSendJoin, err := r.federation.SendJoin( - ctx, + context.Background(), serverName, event, respMakeJoin.RoomVersion, ) if err != nil { r.statistics.ForServer(serverName).Failure() - cancel() return fmt.Errorf("r.federation.SendJoin: %w", err) } r.statistics.ForServer(serverName).Success() // Sanity-check the join response to ensure that it has a create // event, that the room version is known, etc. - if err := sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil { - cancel() + if err = sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil { return fmt.Errorf("sanityCheckAuthChain: %w", err) } @@ -227,41 +220,35 @@ func (r *FederationInternalAPI) performJoinUsingServer( // to complete, but if the client does give up waiting, we'll // still continue to process the join anyway so that we don't // waste the effort. - go func() { - defer cancel() + // TODO: Can we expand Check here to return a list of missing auth + // events rather than failing one at a time? + var respState *gomatrixserverlib.RespState + respState, err = respSendJoin.Check( + context.Background(), + r.keyRing, + event, + federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), + ) + if err != nil { + return fmt.Errorf("respSendJoin.Check: %w", err) + } - // TODO: Can we expand Check here to return a list of missing auth - // events rather than failing one at a time? - respState, err := respSendJoin.Check(ctx, r.keyRing, event, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) - if err != nil { - logrus.WithFields(logrus.Fields{ - "room_id": roomID, - "user_id": userID, - }).WithError(err).Error("Failed to process room join response") - return - } + // If we successfully performed a send_join above then the other + // server now thinks we're a part of the room. Send the newly + // returned state to the roomserver to update our local view. + if err = roomserverAPI.SendEventWithState( + context.Background(), + r.rsAPI, + roomserverAPI.KindNew, + respState, + event.Headered(respMakeJoin.RoomVersion), + serverName, + nil, + false, + ); err != nil { + return fmt.Errorf("roomserverAPI.SendEventWithState: %w", err) + } - // If we successfully performed a send_join above then the other - // server now thinks we're a part of the room. Send the newly - // returned state to the roomserver to update our local view. - if err = roomserverAPI.SendEventWithState( - ctx, r.rsAPI, - roomserverAPI.KindNew, - respState, - event.Headered(respMakeJoin.RoomVersion), - serverName, - nil, - false, - ); err != nil { - logrus.WithFields(logrus.Fields{ - "room_id": roomID, - "user_id": userID, - }).WithError(err).Error("Failed to send room join response to roomserver") - return - } - }() - - <-ctx.Done() return nil } diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 21a919f6a..3fa8d1f7a 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -19,12 +19,10 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - internal.PartitionStorer gomatrixserverlib.KeyDatabase UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) diff --git a/go.mod b/go.mod index 78fe4b0ef..fc18ce07e 100644 --- a/go.mod +++ b/go.mod @@ -11,18 +11,19 @@ require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect github.com/MFAshby/stdemuxerhook v1.0.0 github.com/Masterminds/semver/v3 v3.1.1 - github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 - github.com/Shopify/sarama v1.31.0 + github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/codeclysm/extract v2.2.0+incompatible github.com/containerd/containerd v1.5.9 // indirect github.com/docker/docker v20.10.12+incompatible github.com/docker/go-connections v0.4.0 + github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.3 // indirect github.com/hashicorp/golang-lru v0.5.4 + github.com/json-iterator/go v1.1.12 // indirect github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect github.com/klauspost/compress v1.14.2 // indirect github.com/lib/pq v1.10.4 @@ -40,7 +41,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275 github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 @@ -54,7 +55,9 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/pressly/goose v2.7.0+incompatible - github.com/prometheus/client_golang v1.12.1 + github.com/prometheus/client_golang v1.11.0 + github.com/prometheus/common v0.32.1 // indirect + github.com/prometheus/procfs v0.7.3 // indirect github.com/sirupsen/logrus v1.8.1 github.com/tidwall/gjson v1.13.0 github.com/tidwall/sjson v1.2.4 @@ -66,9 +69,11 @@ require ( golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/mobile v0.0.0-20220112015953-858099ff7816 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd + golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect nhooyr.io/websocket v1.8.7 ) diff --git a/go.sum b/go.sum index d332071ec..3f8a99f48 100644 --- a/go.sum +++ b/go.sum @@ -100,17 +100,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko github.com/RoaringBitmap/roaring v0.4.7/go.mod h1:8khRDP4HmeXns4xIj9oGrKSz7XTQiJx2zgh7AcNke4w= github.com/RyanCarrier/dijkstra v1.0.0/go.mod h1:5agGUBNEtUAGIANmbw09fuO3a2htPEkc1jNH01qxCWA= github.com/RyanCarrier/dijkstra-1 v0.0.0-20170512020943-0e5801a26345/go.mod h1:OK4EvWJ441LQqGzed5NGB6vKBAE34n3z7iayPcEwr30= -github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 h1:i3fOph9Hjleo6LbuqN9ODFxnwt7mOtYMpCGeC8qJN50= -github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32/go.mod h1:ne+jkLlzafIzaE4Q0Ze81T27dNgXe1wxovVEoAtSHTc= github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ= -github.com/Shopify/sarama v1.29.0/go.mod h1:2QpgD79wpdAESqNQMxNc0KYMkycd4slxGdV3TWSVqrU= -github.com/Shopify/sarama v1.31.0 h1:gObk7jCPutDxf+E6GA5G21noAZsi1SvP9ftCQYqpzus= -github.com/Shopify/sarama v1.31.0/go.mod h1:BeW3gXRc/CxgAsrSly2RE9nIXUfC9ezb7QHBPVhvzjI= -github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc= -github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= -github.com/Shopify/toxiproxy/v2 v2.3.0 h1:62YkpiP4bzdhKMH+6uC5E95y608k3zDwdzuBMsnn3uQ= -github.com/Shopify/toxiproxy/v2 v2.3.0/go.mod h1:KvQTtB6RjCJY4zqNJn7C7JDFgsG5uoHYDirfUfpIm0c= github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= @@ -353,12 +344,6 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3 github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v0.0.0-20180421182945-02af3965c54e/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/eapache/go-resiliency v1.2.0 h1:v7g92e/KSN71Rq7vSThKaWIq68fL4YHvWyiUKorFR1Q= -github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= -github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= -github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= @@ -379,8 +364,6 @@ github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 h1:u/UEqS66A5ckRmS4yNp github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6/go.mod h1:1i71OnUq3iUe1ma7Lr6yG6/rjvM3emb6yoL7xLFzcVQ= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= -github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o= @@ -498,8 +481,6 @@ github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= -github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4= @@ -552,8 +533,6 @@ github.com/gorilla/handlers v0.0.0-20150720190736-60c7bfde3e33/go.mod h1:Qkdc/uu github.com/gorilla/mux v1.7.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= @@ -580,8 +559,6 @@ github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHh github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= -github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= -github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -667,18 +644,6 @@ github.com/jbenet/goprocess v0.0.0-20160826012719-b497e2f366b8/go.mod h1:Ly/wlsj github.com/jbenet/goprocess v0.1.3/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4= github.com/jbenet/goprocess v0.1.4 h1:DRGOFReOMqqDNXwW70QkacFW0YN9QnwLV0Vqk+3oU0o= github.com/jbenet/goprocess v0.1.4/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4= -github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= -github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= -github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= -github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= -github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8= -github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= -github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= -github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= -github.com/jcmturner/gokrb5/v8 v8.4.2 h1:6ZIM6b/JJN0X8UM43ZOM6Z4SJzla+a/u7scXFJzodkA= -github.com/jcmturner/gokrb5/v8 v8.4.2/go.mod h1:sb+Xq/fTY5yktf/VxLsE3wlfPqQjp0aWNYyvBVK62bc= -github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= -github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= @@ -747,10 +712,7 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/klauspost/compress v1.12.2/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw= github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= @@ -1021,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29 h1:1t/J3AldUbgRxltlcmMbUefexxzolG5DvV2CkriZ4LM= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275 h1:f6Hh7D3EOTl1uUr76FiyHNA1h4pKBhcVUtyHbxn0hKA= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1250,9 +1212,6 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9 github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= -github.com/pierrec/lz4 v2.6.0+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM= -github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -1272,9 +1231,8 @@ github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDf github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.1.0/go.mod h1:I1FGZT9+L76gKKOs5djB6ezCbFQP1xR9D75/vuwEF3g= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.12.1 h1:ZiaPsmm9uiBeaSMRznKsCDNtPCS0T3JVDGF+06gjBzk= -github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_model v0.0.0-20171117100541-99fa1f4be8e5/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -1306,8 +1264,6 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= -github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= @@ -1433,7 +1389,6 @@ github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5/go.mod h1:70zkFmudgCuE/ github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= -github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBnvPM1Su9w= @@ -1464,11 +1419,6 @@ github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPyS github.com/willf/bitset v1.1.11-0.20200630133818-d5bec3311243/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI= github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= -github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= -github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= -github.com/xdg/scram v1.0.3/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= -github.com/xdg/stringprep v1.0.3/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v0.0.0-20180618132009-1d523034197f/go.mod h1:5yf86TLmAcydyeJq5YvxkGPE2fm/u4myDekKRoLuqhs= @@ -1551,7 +1501,6 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= @@ -1656,7 +1605,6 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210427231257-85d9c07bbe3a/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -1665,7 +1613,6 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210927181540-4e4d966f7476/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211008194852-3b03d305991f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -1799,7 +1746,6 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7-0.20210503195748-5c7c50ebbd4f/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= @@ -2035,7 +1981,6 @@ gopkg.in/yaml.v2 v2.0.0-20170712054546-1be3d31502d6/go.mod h1:JAlM8MvJe8wmxCU4Bl gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/consumers.go b/internal/consumers.go deleted file mode 100644 index 3a4e0b7f8..000000000 --- a/internal/consumers.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "fmt" - - "github.com/Shopify/sarama" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/process" - "github.com/sirupsen/logrus" -) - -// A PartitionStorer has the storage APIs needed by the consumer. -type PartitionStorer interface { - // PartitionOffsets returns the offsets the consumer has reached for each partition. - PartitionOffsets(ctx context.Context, topic string) ([]sqlutil.PartitionOffset, error) - // SetPartitionOffset records where the consumer has reached for a partition. - SetPartitionOffset(ctx context.Context, topic string, partition int32, offset int64) error -} - -// A ContinualConsumer continually consumes logs even across restarts. It requires a PartitionStorer to -// remember the offset it reached. -type ContinualConsumer struct { - // The parent context for the listener, stop consuming when this context is done - Process *process.ProcessContext - // The component name - ComponentName string - // The kafkaesque topic to consume events from. - // This is the name used in kafka to identify the stream to consume events from. - Topic string - // A kafkaesque stream consumer providing the APIs for talking to the event source. - // The interface is taken from a client library for Apache Kafka. - // But any equivalent event streaming protocol could be made to implement the same interface. - Consumer sarama.Consumer - // A thing which can load and save partition offsets for a topic. - PartitionStore PartitionStorer - // ProcessMessage is a function which will be called for each message in the log. Return an error to - // stop processing messages. See ErrShutdown for specific control signals. - ProcessMessage func(msg *sarama.ConsumerMessage) error - // ShutdownCallback is called when ProcessMessage returns ErrShutdown, after the partition has been saved. - // It is optional. - ShutdownCallback func() -} - -// ErrShutdown can be returned from ContinualConsumer.ProcessMessage to stop the ContinualConsumer. -var ErrShutdown = fmt.Errorf("shutdown") - -// Start starts the consumer consuming. -// Starts up a goroutine for each partition in the kafka stream. -// Returns nil once all the goroutines are started. -// Returns an error if it can't start consuming for any of the partitions. -func (c *ContinualConsumer) Start() error { - _, err := c.StartOffsets() - return err -} - -// StartOffsets is the same as Start but returns the loaded offsets as well. -func (c *ContinualConsumer) StartOffsets() ([]sqlutil.PartitionOffset, error) { - offsets := map[int32]int64{} - - partitions, err := c.Consumer.Partitions(c.Topic) - if err != nil { - return nil, err - } - for _, partition := range partitions { - // Default all the offsets to the beginning of the stream. - offsets[partition] = sarama.OffsetOldest - } - - storedOffsets, err := c.PartitionStore.PartitionOffsets(context.TODO(), c.Topic) - if err != nil { - return nil, err - } - for _, offset := range storedOffsets { - // We've already processed events from this partition so advance the offset to where we got to. - // ConsumePartition will start streaming from the message with the given offset (inclusive), - // so increment 1 to avoid getting the same message a second time. - offsets[offset.Partition] = 1 + offset.Offset - } - - var partitionConsumers []sarama.PartitionConsumer - for partition, offset := range offsets { - pc, err := c.Consumer.ConsumePartition(c.Topic, partition, offset) - if err != nil { - for _, p := range partitionConsumers { - p.Close() // nolint: errcheck - } - return nil, err - } - partitionConsumers = append(partitionConsumers, pc) - } - for _, pc := range partitionConsumers { - go c.consumePartition(pc) - if c.Process != nil { - c.Process.ComponentStarted() - go func(pc sarama.PartitionConsumer) { - <-c.Process.WaitForShutdown() - _ = pc.Close() - c.Process.ComponentFinished() - logrus.Infof("Stopped consumer for %q topic %q", c.ComponentName, c.Topic) - }(pc) - } - } - - return storedOffsets, nil -} - -// consumePartition consumes the room events for a single partition of the kafkaesque stream. -func (c *ContinualConsumer) consumePartition(pc sarama.PartitionConsumer) { - defer pc.Close() // nolint: errcheck - for message := range pc.Messages() { - msgErr := c.ProcessMessage(message) - // Advance our position in the stream so that we will start at the right position after a restart. - if err := c.PartitionStore.SetPartitionOffset(context.TODO(), c.Topic, message.Partition, message.Offset); err != nil { - panic(fmt.Errorf("the ContinualConsumer in %q failed to SetPartitionOffset: %w", c.ComponentName, err)) - } - // Shutdown if we were told to do so. - if msgErr == ErrShutdown { - if c.ShutdownCallback != nil { - c.ShutdownCallback() - } - return - } - } -} diff --git a/internal/test/config.go b/internal/test/config.go index bb2f8a4c6..4fb6a946c 100644 --- a/internal/test/config.go +++ b/internal/test/config.go @@ -20,6 +20,7 @@ import ( "crypto/x509" "encoding/base64" "encoding/pem" + "errors" "fmt" "io/ioutil" "math/big" @@ -158,11 +159,10 @@ func NewMatrixKey(matrixKeyPath string) (err error) { const certificateDuration = time.Hour * 24 * 365 * 10 -// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file. -func NewTLSKey(tlsKeyPath, tlsCertPath string) error { +func generateTLSTemplate(dnsNames []string) (*rsa.PrivateKey, *x509.Certificate, error) { priv, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { - return err + return nil, nil, err } notBefore := time.Now() @@ -170,7 +170,7 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { - return err + return nil, nil, err } template := x509.Certificate{ @@ -180,20 +180,21 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error { KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, + DNSNames: dnsNames, } - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - return err - } + return priv, &template, nil +} + +func writeCertificate(tlsCertPath string, derBytes []byte) error { certOut, err := os.Create(tlsCertPath) if err != nil { return err } defer certOut.Close() // nolint: errcheck - if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - return err - } + return pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) +} +func writePrivateKey(tlsKeyPath string, priv *rsa.PrivateKey) error { keyOut, err := os.OpenFile(tlsKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err @@ -205,3 +206,73 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error { }) return err } + +// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file. +func NewTLSKey(tlsKeyPath, tlsCertPath string) error { + priv, template, err := generateTLSTemplate(nil) + if err != nil { + return err + } + + // Self-signed certificate: template == parent + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + return err + } + + if err = writeCertificate(tlsCertPath, derBytes); err != nil { + return err + } + return writePrivateKey(tlsKeyPath, priv) +} + +func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string) error { + priv, template, err := generateTLSTemplate([]string{serverName}) + if err != nil { + return err + } + + // load the authority key + dat, err := ioutil.ReadFile(authorityKeyPath) + if err != nil { + return err + } + block, _ := pem.Decode([]byte(dat)) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return errors.New("authority .key is not a valid pem encoded rsa private key") + } + authorityPriv, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return err + } + + // load the authority certificate + dat, err = ioutil.ReadFile(authorityCertPath) + if err != nil { + return err + } + block, _ = pem.Decode([]byte(dat)) + if block == nil || block.Type != "CERTIFICATE" { + return errors.New("authority .crt is not a valid pem encoded x509 cert") + } + var caCerts []*x509.Certificate + caCerts, err = x509.ParseCertificates(block.Bytes) + if err != nil { + return err + } + if len(caCerts) != 1 { + return errors.New("authority .crt contains none or more than one cert") + } + authorityCert := caCerts[0] + + // Sign the new certificate using the authority's key/cert + derBytes, err := x509.CreateCertificate(rand.Reader, template, authorityCert, &priv.PublicKey, authorityPriv) + if err != nil { + return err + } + + if err = writeCertificate(tlsCertPath, derBytes); err != nil { + return err + } + return writePrivateKey(tlsKeyPath, priv) +} diff --git a/internal/version.go b/internal/version.go index f09daabd9..de0b7c8c3 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 6 - VersionPatch = 0 + VersionPatch = 2 VersionTag = "" // example: "rc1" ) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 0eea2f0fa..3933961c1 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -228,7 +228,7 @@ type QueryKeyChangesRequest struct { // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning Offset int64 // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. - // Use sarama.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). + // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). ToOffset int64 } diff --git a/keyserver/consumers/cross_signing.go b/keyserver/consumers/cross_signing.go index 4b2bd4a9a..aae69e960 100644 --- a/keyserver/consumers/cross_signing.go +++ b/keyserver/consumers/cross_signing.go @@ -18,29 +18,30 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" - - "github.com/Shopify/sarama" ) type OutputCrossSigningKeyUpdateConsumer struct { - eduServerConsumer *internal.ContinualConsumer - keyDB storage.Database - keyAPI api.KeyInternalAPI - serverName string + ctx context.Context + keyDB storage.Database + keyAPI api.KeyInternalAPI + serverName string + jetstream nats.JetStreamContext + durable string + topic string } func NewOutputCrossSigningKeyUpdateConsumer( process *process.ProcessContext, cfg *config.Dendrite, - kafkaConsumer sarama.Consumer, + js nats.JetStreamContext, keyDB storage.Database, keyAPI api.KeyInternalAPI, ) *OutputCrossSigningKeyUpdateConsumer { @@ -48,60 +49,58 @@ func NewOutputCrossSigningKeyUpdateConsumer( // topic. We will only produce events where the UserID matches our server name, // and we will only consume events where the UserID does NOT match our server // name (because the update came from a remote server). - consumer := internal.ContinualConsumer{ - Process: process, - ComponentName: "keyserver/keyserver", - Topic: cfg.Global.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), - Consumer: kafkaConsumer, - PartitionStore: keyDB, - } s := &OutputCrossSigningKeyUpdateConsumer{ - eduServerConsumer: &consumer, - keyDB: keyDB, - keyAPI: keyAPI, - serverName: string(cfg.Global.ServerName), + ctx: process.Context(), + keyDB: keyDB, + jetstream: js, + durable: cfg.Global.JetStream.Durable("KeyServerCrossSigningConsumer"), + topic: cfg.Global.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), + keyAPI: keyAPI, + serverName: string(cfg.Global.ServerName), } - consumer.ProcessMessage = s.onMessage return s } func (s *OutputCrossSigningKeyUpdateConsumer) Start() error { - return s.eduServerConsumer.Start() + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(msg *sarama.ConsumerMessage) error { +func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { var m api.DeviceMessage - if err := json.Unmarshal(msg.Value, &m); err != nil { + if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") - return nil + return true } if m.OutputCrossSigningKeyUpdate == nil { // This probably shouldn't happen but stops us from panicking if we come // across an update that doesn't satisfy either types. - return nil + return true } switch m.Type { case api.TypeCrossSigningUpdate: return t.onCrossSigningMessage(m) default: - return nil + return true } } -func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) error { +func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { output := m.CrossSigningKeyUpdate _, host, err := gomatrixserverlib.SplitID('@', output.UserID) if err != nil { logrus.WithError(err).Errorf("eduserver output log: user ID parse failure") - return nil + return true } if host == gomatrixserverlib.ServerName(s.serverName) { // Ignore any messages that contain information about our own users, as // they already originated from this server. - return nil + return true } uploadReq := &api.PerformUploadDeviceKeysRequest{ UserID: output.UserID, @@ -114,5 +113,11 @@ func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.Device } uploadRes := &api.PerformUploadDeviceKeysResponse{} s.keyAPI.PerformUploadDeviceKeys(context.TODO(), uploadReq, uploadRes) - return uploadRes.Error + if uploadRes.Error != nil { + // If the error is due to a missing or invalid parameter then we'd might + // as well just acknowledge the message, because otherwise otherwise we'll + // just keep getting delivered a faulty message over and over again. + return uploadRes.Error.IsMissingParam || uploadRes.Error.IsInvalidParam + } + return true } diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 8cc50ea0d..61ccc0303 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -40,7 +40,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { func NewInternalAPI( base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, ) api.KeyInternalAPI { - js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) db, err := storage.NewDatabase(&cfg.Database) if err != nil { @@ -66,7 +66,7 @@ func NewInternalAPI( }() keyconsumer := consumers.NewOutputCrossSigningKeyUpdateConsumer( - base.ProcessContext, base.Cfg, consumer, db, ap, + base.ProcessContext, base.Cfg, js, db, ap, ) if err := keyconsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start keyserver EDU server consumer") diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 87feae47d..0110860ea 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -18,15 +18,12 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - internal.PartitionStorer - // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) @@ -71,7 +68,7 @@ type Database interface { StoreKeyChange(ctx context.Context, userID string) (int64, error) // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). - // A to offset of sarama.OffsetNewest means no upper limit. + // A to offset of types.OffsetNewest means no upper limit. // Returns the offset of the latest key change. KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go index 20d227c24..f93a94bd3 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/keyserver/storage/postgres/key_changes_table.go @@ -17,9 +17,7 @@ package postgres import ( "context" "database/sql" - "math" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -78,9 +76,6 @@ func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID strin func (s *keyChangesStatements) SelectKeyChanges( ctx context.Context, fromOffset, toOffset int64, ) (userIDs []string, latestOffset int64, err error) { - if toOffset == sarama.OffsetNewest { - toOffset = math.MaxInt64 - } latestOffset = fromOffset rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) if err != nil { diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index d43c15ca9..e035e8c9c 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -17,9 +17,7 @@ package sqlite3 import ( "context" "database/sql" - "math" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -76,9 +74,6 @@ func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID strin func (s *keyChangesStatements) SelectKeyChanges( ctx context.Context, fromOffset, toOffset int64, ) (userIDs []string, latestOffset int64, err error) { - if toOffset == sarama.OffsetNewest { - toOffset = math.MaxInt64 - } latestOffset = fromOffset rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) if err != nil { diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 2f8cf809b..c4c99d8c4 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -9,8 +9,8 @@ import ( "reflect" "testing" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/dendrite/setup/config" ) @@ -50,7 +50,7 @@ func TestKeyChanges(t *testing.T) { MustNotError(t, err) deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, sarama.OffsetNewest) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } @@ -74,7 +74,7 @@ func TestKeyChangesNoDupes(t *testing.T) { } deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, 0, sarama.OffsetNewest) + userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 0d94c94cc..e44757e1a 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -46,7 +46,7 @@ type DeviceKeys interface { type KeyChanges interface { InsertKeyChange(ctx context.Context, userID string) (int64, error) // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. - // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset. + // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) Prepare() error diff --git a/keyserver/types/storage.go b/keyserver/types/storage.go index 3480ec65f..7fb90454e 100644 --- a/keyserver/types/storage.go +++ b/keyserver/types/storage.go @@ -14,7 +14,18 @@ package types -import "github.com/matrix-org/gomatrixserverlib" +import ( + "math" + + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + // OffsetNewest tells e.g. the database to get the most current data + OffsetNewest int64 = math.MaxInt64 + // OffsetOldest tells e.g. the database to get the oldest data + OffsetOldest int64 = 0 +) // KeyTypePurposeToInt maps a purpose to an integer, which is used in the // database to reduce the amount of space taken up by this column. diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 4b0704b9f..45a9ef497 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -42,6 +42,19 @@ const ( KindOld ) +func (k Kind) String() string { + switch k { + case KindOutlier: + return "KindOutlier" + case KindNew: + return "KindNew" + case KindOld: + return "KindOld" + default: + return "(unknown)" + } +} + // DoNotSendToOtherServers tells us not to send the event to other matrix // servers. const DoNotSendToOtherServers = "" diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 5b87e623d..fd963ad83 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -41,7 +41,7 @@ type RoomserverInternalAPI struct { fsAPI fsAPI.FederationInternalAPI asAPI asAPI.AppServiceQueryAPI JetStream nats.JetStreamContext - Durable nats.SubOpt + Durable string InputRoomEventTopic string // JetStream topic for new input room events OutputRoomEventTopic string // JetStream topic for new output room events PerspectiveServerNames []gomatrixserverlib.ServerName @@ -87,7 +87,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA InputRoomEventTopic: r.InputRoomEventTopic, OutputRoomEventTopic: r.OutputRoomEventTopic, JetStream: r.JetStream, - Durable: r.Durable, + Durable: nats.Durable(r.Durable), ServerName: r.Cfg.Matrix.ServerName, FSAPI: fsAPI, KeyRing: keyRing, diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index ddda8081c..9af0bf591 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -20,17 +20,22 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) +type checkForAuthAndSoftFailStorage interface { + state.StateResolutionStorage + StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) +} + // CheckForSoftFail returns true if the event should be soft-failed // and false otherwise. The return error value should be checked before // the soft-fail bool. func CheckForSoftFail( ctx context.Context, - db storage.Database, + db checkForAuthAndSoftFailStorage, event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { @@ -92,7 +97,7 @@ func CheckForSoftFail( // Returns the numeric IDs for the auth events. func CheckAuthEvents( ctx context.Context, - db storage.Database, + db checkForAuthAndSoftFailStorage, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -193,7 +198,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * // loadAuthEvents loads the events needed for authentication from the supplied room state. func loadAuthEvents( ctx context.Context, - db storage.Database, + db state.StateResolutionStorage, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index a38d56d7e..5bdec0a24 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,6 +19,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "sync" "time" @@ -38,6 +39,19 @@ import ( "github.com/tidwall/gjson" ) +type retryAction int +type commitAction int + +const ( + doNotRetry retryAction = iota + retryLater +) + +const ( + commitTransaction commitAction = iota + rollbackTransaction +) + var keyContentFields = map[string]string{ "m.room.join_rules": "join_rule", "m.room.history_visibility": "history_visibility", @@ -101,7 +115,8 @@ func (r *Inputer) Start() error { _ = msg.InProgress() // resets the acknowledgement wait timer defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil { + action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent) + if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } @@ -111,7 +126,12 @@ func (r *Inputer) Start() error { "type": inputRoomEvent.Event.Type(), }).Warn("Roomserver failed to process async event") } - _ = msg.Ack() + switch action { + case retryLater: + _ = msg.Nak() + case doNotRetry: + _ = msg.Ack() + } }) }, // NATS wants to acknowledge automatically by default when the message is @@ -131,6 +151,37 @@ func (r *Inputer) Start() error { return err } +// processRoomEventUsingUpdater opens up a room updater and tries to +// process the event. It returns whether or not we should positively +// or negatively acknowledge the event (i.e. for NATS) and an error +// if it occurred. +func (r *Inputer) processRoomEventUsingUpdater( + ctx context.Context, + roomID string, + inputRoomEvent *api.InputRoomEvent, +) (retryAction, error) { + roomInfo, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err) + } + updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err) + } + action, err := r.processRoomEvent(ctx, updater, inputRoomEvent) + switch action { + case commitTransaction: + if cerr := updater.Commit(); cerr != nil { + return retryLater, fmt.Errorf("updater.Commit: %w", cerr) + } + case rollbackTransaction: + if rerr := updater.Rollback(); rerr != nil { + return retryLater, fmt.Errorf("updater.Rollback: %w", rerr) + } + } + return doNotRetry, err +} + // InputRoomEvents implements api.RoomserverInternalAPI func (r *Inputer) InputRoomEvents( ctx context.Context, @@ -161,7 +212,6 @@ func (r *Inputer) InputRoomEvents( } } else { responses := make(chan error, len(request.InputRoomEvents)) - defer close(responses) for _, e := range request.InputRoomEvents { inputRoomEvent := e roomID := inputRoomEvent.Event.RoomID() @@ -178,7 +228,7 @@ func (r *Inputer) InputRoomEvents( worker.Act(nil, func() { defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - err := r.processRoomEvent(ctx, &inputRoomEvent) + _, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent) if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 2dc096674..0ca5c31a9 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -29,6 +29,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -67,14 +68,15 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, + updater *shared.RoomUpdater, input *api.InputRoomEvent, -) (err error) { +) (commitAction, error) { select { case <-ctx.Done(): // Before we do anything, make sure the context hasn't expired for this pending task. // If it has then we'll give up straight away — it's probably a synchronous input // request and the caller has already given up, but the inbox task was still queued. - return context.DeadlineExceeded + return rollbackTransaction, context.DeadlineExceeded default: } @@ -93,13 +95,21 @@ func (r *Inputer) processRoomEvent( logger := util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": event.EventID(), "room_id": event.RoomID(), + "kind": input.Kind, + "origin": input.Origin, "type": event.Type(), }) + if input.HasState { + logger = logger.WithFields(logrus.Fields{ + "has_state": input.HasState, + "state_ids": len(input.StateEventIDs), + }) + } // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. if input.Kind == api.KindOutlier { - evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) + evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()}) if err2 == nil && len(evs) == 1 { // check hash matches if we're on early room versions where the event ID was a random string idFormat, err2 := headered.RoomVersion.EventIDFormat() @@ -108,11 +118,11 @@ func (r *Inputer) processRoomEvent( case gomatrixserverlib.EventIDFormatV1: if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) { logger.Debugf("Already processed event; ignoring") - return nil + return rollbackTransaction, nil } default: logger.Debugf("Already processed event; ignoring") - return nil + return rollbackTransaction, nil } } } @@ -126,21 +136,41 @@ func (r *Inputer) processRoomEvent( AuthEventIDs: event.AuthEventIDs(), PrevEventIDs: event.PrevEventIDs(), } - if err = r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { - return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) + if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { + return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) } } - if len(missingRes.MissingAuthEventIDs) > 0 || len(missingRes.MissingPrevEventIDs) > 0 { + missingAuth := len(missingRes.MissingAuthEventIDs) > 0 + missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0 + + if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ RoomID: event.RoomID(), ExcludeSelf: true, } - if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { - return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { + return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + } + // Sort all of the servers into a map so that we can randomise + // their order. Then make sure that the input origin and the + // event origin are first on the list. + servers := map[gomatrixserverlib.ServerName]struct{}{} + for _, server := range serverRes.ServerNames { + servers[server] = struct{}{} + } + serverRes.ServerNames = serverRes.ServerNames[:0] + if input.Origin != "" { + serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) + delete(servers, input.Origin) + } + if origin := event.Origin(); origin != input.Origin { + serverRes.ServerNames = append(serverRes.ServerNames, origin) + delete(servers, origin) + } + for server := range servers { + serverRes.ServerNames = append(serverRes.ServerNames, server) + delete(servers, server) } - } - if input.Origin != "" { - serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) } // First of all, check that the auth events of the event are known. @@ -148,8 +178,8 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.checkForMissingAuthEvents: %w", err) + if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err) } // Check if the event is allowed by its auth events. If it isn't then @@ -157,7 +187,7 @@ func (r *Inputer) processRoomEvent( var rejectionErr error if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { isRejected = true - logger.WithError(rejectionErr).Warnf("Event %s rejected", event.EventID()) + logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) } // Accumulate the auth event NIDs. @@ -165,7 +195,7 @@ func (r *Inputer) processRoomEvent( authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) for _, authEventID := range authEventIDs { if _, ok := knownEvents[authEventID]; !ok { - return fmt.Errorf("missing auth event %s", authEventID) + return rollbackTransaction, fmt.Errorf("missing auth event %s", authEventID) } authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) } @@ -174,9 +204,10 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindNew { // Check that the event passes authentication checks based on the // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) + var err error + softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs) if err != nil { - logger.WithError(err).Info("Error authing soft-failed event") + logger.WithError(err).Warn("Error authing soft-failed event") } } @@ -190,7 +221,6 @@ func (r *Inputer) processRoomEvent( // typical federated room join) then we won't bother trying to fetch prev events // because we may not be allowed to see them and we have no choice but to trust // the state event IDs provided to us in the join instead. - missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0 if missingPrev && input.Kind == api.KindNew { // Don't do this for KindOld events, otherwise old events that we fetch // to satisfy missing prev events/state will end up recursively calling @@ -200,18 +230,15 @@ func (r *Inputer) processRoomEvent( origin: input.Origin, inputer: r, queryer: r.Queryer, - db: r.DB, + db: updater, federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), - servers: map[gomatrixserverlib.ServerName]struct{}{}, + servers: serverRes.ServerNames, hadEvents: map[string]bool{}, haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, } - for _, serverName := range serverRes.ServerNames { - missingState.servers[serverName] = struct{}{} - } - if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + if err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { isRejected = true rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) } else { @@ -224,16 +251,16 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) + _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected) if err != nil { - return fmt.Errorf("r.DB.StoreEvent: %w", err) + return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err) } // if storing this event results in it being redacted then do so. if !isRejected && redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, event) if rerr != nil { - return fmt.Errorf("eventutil.RedactEvent: %w", rerr) + return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr) } event = r } @@ -244,36 +271,40 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindOutlier { logger.Debug("Stored outlier") hooks.Run(hooks.KindNewEventPersisted, headered) - return nil + return commitTransaction, nil } - roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) + roomInfo, err := updater.RoomInfo(ctx, event.RoomID()) if err != nil { - return fmt.Errorf("r.DB.RoomInfo: %w", err) + return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err) } if roomInfo == nil { - return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) + return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) } if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) + err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected) if err != nil { - return fmt.Errorf("r.calculateAndSetState: %w", err) + return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err) } } // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. if isRejected || softfail { - logger.WithError(rejectionErr).WithField("soft_fail", softfail).Debug("Stored rejected event") - return rejectionErr + logger.WithError(rejectionErr).WithFields(logrus.Fields{ + "soft_fail": softfail, + "missing_prev": missingPrev, + }).Warn("Stored rejected event") + return commitTransaction, rejectionErr } switch input.Kind { case api.KindNew: if err = r.updateLatestEvents( ctx, // context + updater, // room updater roomInfo, // room info for the room being updated stateAtEvent, // state at event (below) event, // event @@ -281,7 +312,7 @@ func (r *Inputer) processRoomEvent( input.TransactionID, // transaction ID input.HasState, // rewrites state? ); err != nil { - return fmt.Errorf("r.updateLatestEvents: %w", err) + return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err) } case api.KindOld: err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{ @@ -293,7 +324,7 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return fmt.Errorf("r.WriteOutputEvents (old): %w", err) + return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err) } } @@ -312,14 +343,14 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) + return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) } } // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) - return nil + return commitTransaction, nil } // fetchAuthEvents will check to see if any of the @@ -331,16 +362,13 @@ func (r *Inputer) processRoomEvent( // they are now in the database. func (r *Inputer) fetchAuthEvents( ctx context.Context, + updater *shared.RoomUpdater, logger *logrus.Entry, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, known map[string]*types.Event, servers []gomatrixserverlib.ServerName, ) error { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, MaximumMissingProcessingTime) - defer cancel() - unknown := map[string]struct{}{} authEventIDs := event.AuthEventIDs() if len(authEventIDs) == 0 { @@ -348,7 +376,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) + authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -396,12 +424,11 @@ func (r *Inputer) fetchAuthEvents( continue } - // Check the signatures of the event. - // TODO: It really makes sense for the federation API to be doing this, - // because then it can attempt another server if one serves up an event - // with an invalid signature. For now this will do. + // Check the signatures of the event. If this fails then we'll simply + // skip it, because gomatrixserverlib.Allowed() will notice a problem + // if a critical event is missing anyway. if err := authEvent.VerifyEventSignatures(ctx, r.FSAPI.KeyRing()); err != nil { - return fmt.Errorf("event.VerifyEventSignatures: %w", err) + continue } // In order to store the new auth event, we need to know its auth chain @@ -428,9 +455,9 @@ func (r *Inputer) fetchAuthEvents( } // Finally, store the event in the database. - eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) if err != nil { - return fmt.Errorf("r.DB.StoreEvent: %w", err) + return fmt.Errorf("updater.StoreEvent: %w", err) } // Now we know about this event, it was stored and the signatures were OK. @@ -445,6 +472,7 @@ func (r *Inputer) fetchAuthEvents( func (r *Inputer) calculateAndSetState( ctx context.Context, + updater *shared.RoomUpdater, input *api.InputRoomEvent, roomInfo *types.RoomInfo, stateAtEvent *types.StateAtEvent, @@ -452,14 +480,14 @@ func (r *Inputer) calculateAndSetState( isRejected bool, ) error { var err error - roomState := state.NewStateResolution(r.DB, roomInfo) + roomState := state.NewStateResolution(updater, roomInfo) - if input.HasState && !isRejected { + if input.HasState { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID // Request join memberships only for local users only. - if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { + if joinEventNIDs, err = updater.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { // If we have no local users that are joined to the room then any state about // the room that we have is quite possibly out of date. Therefore in that case // we should overwrite it rather than merge it. @@ -469,13 +497,13 @@ func (r *Inputer) calculateAndSetState( // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { - return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) + if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err) } entries = types.DeduplicateStateEntries(entries) - if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { - return fmt.Errorf("r.DB.AddState: %w", err) + if stateAtEvent.BeforeStateSnapshotNID, err = updater.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { + return fmt.Errorf("updater.AddState: %w", err) } } else { stateAtEvent.Overwrite = false @@ -486,7 +514,7 @@ func (r *Inputer) calculateAndSetState( } } - err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + err = updater.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) if err != nil { return fmt.Errorf("r.DB.SetState: %w", err) } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 6137941e1..5173d3ab2 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -20,7 +20,6 @@ import ( "context" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -48,6 +47,7 @@ import ( // Can only be called once at a time func (r *Inputer) updateLatestEvents( ctx context.Context, + updater *shared.RoomUpdater, roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, event *gomatrixserverlib.Event, @@ -55,13 +55,6 @@ func (r *Inputer) updateLatestEvents( transactionID *api.TransactionID, rewritesState bool, ) (err error) { - updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) - if err != nil { - return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) - } - succeeded := false - defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - u := latestEventsUpdater{ ctx: ctx, api: r, @@ -78,7 +71,6 @@ func (r *Inputer) updateLatestEvents( return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } - succeeded = true return } @@ -89,7 +81,7 @@ func (r *Inputer) updateLatestEvents( type latestEventsUpdater struct { ctx context.Context api *Inputer - updater *shared.LatestEventsUpdater + updater *shared.RoomUpdater roomInfo *types.RoomInfo stateAtEvent types.StateAtEvent event *gomatrixserverlib.Event @@ -199,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB, u.roomInfo) + roomState := state.NewStateResolution(u.updater, u.roomInfo) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the @@ -413,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro if len(extraEventIDs) == 0 { return nil, nil } - extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs) + extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs) if err != nil { return nil, err } @@ -436,7 +428,7 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) stateEventNIDs = append(stateEventNIDs, entry.EventNID) } stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] - return u.api.DB.EventIDs(u.ctx, stateEventNIDs) + return u.updater.EventIDs(u.ctx, stateEventNIDs) } type eventNIDSorter []types.EventNID diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 2511097d0..ff3ed7e5d 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -31,7 +31,7 @@ import ( // consumers about the invites added or retired by the change in current state. func (r *Inputer) updateMemberships( ctx context.Context, - updater *shared.LatestEventsUpdater, + updater *shared.RoomUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { changes := membershipChanges(removed, added) @@ -79,7 +79,7 @@ func (r *Inputer) updateMemberships( } func (r *Inputer) updateMembership( - updater *shared.LatestEventsUpdater, + updater *shared.RoomUpdater, targetUserNID types.EventStateKeyNID, remove, add *gomatrixserverlib.Event, updates []api.OutputEvent, diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index aa2b94f8a..4d3306660 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/query" - "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -19,13 +19,13 @@ import ( type missingStateReq struct { origin gomatrixserverlib.ServerName - db storage.Database + db *shared.RoomUpdater inputer *Inputer queryer *query.Queryer keys gomatrixserverlib.JSONVerifier federation fedapi.FederationInternalAPI roomsMu *internal.MutexByRoom - servers map[gomatrixserverlib.ServerName]struct{} + servers []gomatrixserverlib.ServerName hadEvents map[string]bool hadEventsMutex sync.Mutex haveEvents map[string]*gomatrixserverlib.HeaderedEvent @@ -37,10 +37,6 @@ type missingStateReq struct { func (t *missingStateReq) processEventWithMissingState( ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, ) error { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, MaximumMissingProcessingTime) - defer cancel() - // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: @@ -78,7 +74,7 @@ func (t *missingStateReq) processEventWithMissingState( // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ Kind: api.KindNew, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -187,7 +183,7 @@ func (t *missingStateReq) processEventWithMissingState( } // TODO: we could do this concurrently? for _, ire := range outlierRoomEvents { - if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { + if _, err = t.inputer.processRoomEvent(ctx, t.db, &ire); err != nil { return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err) } } @@ -200,7 +196,7 @@ func (t *missingStateReq) processEventWithMissingState( stateIDs = append(stateIDs, event.EventID()) } - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ Kind: api.KindOld, Event: backwardsExtremity.Headered(roomVersion), Origin: t.origin, @@ -217,7 +213,7 @@ func (t *missingStateReq) processEventWithMissingState( // they will automatically fast-forward based on the room state at the // extremity in the last step. for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -417,7 +413,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve } var missingResp *gomatrixserverlib.RespMissingEvents - for server := range t.servers { + for _, server := range t.servers { var m gomatrixserverlib.RespMissingEvents if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, @@ -666,7 +662,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib for i := range stateIDs.StateEventIDs { ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] if !ok { - logrus.Warnf("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) + logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) continue } respState.StateEvents = append(respState.StateEvents, ev.Unwrap()) @@ -674,7 +670,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib for i := range stateIDs.AuthEventIDs { ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] if !ok { - logrus.Warnf("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) + logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) continue } respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap()) @@ -700,7 +696,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs } var event *gomatrixserverlib.Event found := false - for serverName := range t.servers { + for _, serverName := range t.servers { reqctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) @@ -718,7 +714,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs } event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion) if err != nil { - util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Transaction: Failed to parse event JSON of event") + util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Failed to parse event JSON of event returned from /event") continue } found = true @@ -729,7 +725,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } if err := event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + util.GetLogger(ctx).WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } return t.cacheAndReturn(event.Headered(roomVersion)), nil diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index e23ed47be..6559cd081 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -54,18 +55,23 @@ func (r *Inviter) PerformInvite( return nil, fmt.Errorf("failed to load RoomInfo: %w", err) } - log.WithFields(log.Fields{ - "event_id": event.EventID(), - "room_id": roomID, - "room_version": req.RoomVersion, - "target_user_id": targetUserID, - "room_info_exists": info != nil, - }).Debug("processing invite event") - _, domain, _ := gomatrixserverlib.SplitID('@', targetUserID) isTargetLocal := domain == r.Cfg.Matrix.ServerName isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName + logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ + "inviter": event.Sender(), + "invitee": *event.StateKey(), + "room_id": roomID, + "event_id": event.EventID(), + }) + logger.WithFields(log.Fields{ + "room_version": req.RoomVersion, + "room_info_exists": info != nil, + "target_local": isTargetLocal, + "origin_local": isOriginLocal, + }).Debug("processing invite event") + inviteState := req.InviteRoomState if len(inviteState) == 0 && info != nil { var is []gomatrixserverlib.InviteV2StrippedState @@ -122,75 +128,17 @@ func (r *Inviter) PerformInvite( Code: api.PerformErrorNotAllowed, Msg: "User is already joined to room", } + logger.Debugf("user already joined") return nil, nil } - if isOriginLocal { - // The invite originated locally. Therefore we have a responsibility to - // try and see if the user is allowed to make this invite. We can't do - // this for invites coming in over federation - we have to take those on - // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) - if err != nil { - log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( - "processInviteEvent.checkAuthEvents failed for event", - ) - res.Error = &api.PerformError{ - Msg: err.Error(), - Code: api.PerformErrorNotAllowed, - } - } - - // If the invite originated from us and the target isn't local then we - // should try and send the invite over federation first. It might be - // that the remote user doesn't exist, in which case we can give up - // processing here. - if req.SendAsServer != api.DoNotSendToOtherServers && !isTargetLocal { - fsReq := &federationAPI.PerformInviteRequest{ - RoomVersion: req.RoomVersion, - Event: event, - InviteRoomState: inviteState, - } - fsRes := &federationAPI.PerformInviteResponse{} - if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { - res.Error = &api.PerformError{ - Msg: err.Error(), - Code: api.PerformErrorNotAllowed, - } - log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") - return nil, nil - } - event = fsRes.Event - } - - // Send the invite event to the roomserver input stream. This will - // notify existing users in the room about the invite, update the - // membership table and ensure that the event is ready and available - // to use as an auth event when accepting the invite. - inputReq := &api.InputRoomEventsRequest{ - InputRoomEvents: []api.InputRoomEvent{ - { - Kind: api.KindNew, - Event: event, - Origin: event.Origin(), - SendAsServer: req.SendAsServer, - }, - }, - } - inputRes := &api.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) - if err = inputRes.Err(); err != nil { - res.Error = &api.PerformError{ - Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), - Code: api.PerformErrorNotAllowed, - } - log.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") - return nil, nil - } - } else { + if !isOriginLocal { // The invite originated over federation. Process the membership // update, which will notify the sync API etc about the incoming - // invite. + // invite. We do NOT send an InputRoomEvent for the invite as it + // will never pass auth checks due to lacking room state, but we + // still need to tell the client about the invite so we can accept + // it, hence we return an output event to send to the sync api. updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion) if err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) @@ -205,10 +153,77 @@ func (r *Inviter) PerformInvite( if err = updater.Commit(); err != nil { return nil, fmt.Errorf("updater.Commit: %w", err) } - + logger.Debugf("updated membership to invite and sending invite OutputEvent") return outputUpdates, nil } + // The invite originated locally. Therefore we have a responsibility to + // try and see if the user is allowed to make this invite. We can't do + // this for invites coming in over federation - we have to take those on + // trust. + _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) + if err != nil { + logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( + "processInviteEvent.checkAuthEvents failed for event", + ) + res.Error = &api.PerformError{ + Msg: err.Error(), + Code: api.PerformErrorNotAllowed, + } + return nil, nil + } + + // If the invite originated from us and the target isn't local then we + // should try and send the invite over federation first. It might be + // that the remote user doesn't exist, in which case we can give up + // processing here. + if req.SendAsServer != api.DoNotSendToOtherServers && !isTargetLocal { + fsReq := &federationAPI.PerformInviteRequest{ + RoomVersion: req.RoomVersion, + Event: event, + InviteRoomState: inviteState, + } + fsRes := &federationAPI.PerformInviteResponse{} + if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { + res.Error = &api.PerformError{ + Msg: err.Error(), + Code: api.PerformErrorNotAllowed, + } + logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") + return nil, nil + } + event = fsRes.Event + logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID()) + } + + // Send the invite event to the roomserver input stream. This will + // notify existing users in the room about the invite, update the + // membership table and ensure that the event is ready and available + // to use as an auth event when accepting the invite. + // It will NOT notify the invitee of this invite. + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{ + { + Kind: api.KindNew, + Event: event, + Origin: event.Origin(), + SendAsServer: req.SendAsServer, + }, + }, + } + inputRes := &api.InputRoomEventsResponse{} + r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) + if err = inputRes.Err(); err != nil { + res.Error = &api.PerformError{ + Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), + Code: api.PerformErrorNotAllowed, + } + logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") + return nil, nil + } + + // Don't notify the sync api of this event in the same way as a federated invite so the invitee + // gets the invite, as the roomserver will do this when it processes the m.room.member invite. return nil, nil } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 2b0bccda6..9d2a66d4c 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -51,13 +51,15 @@ func (r *Joiner) PerformJoin( req *rsAPI.PerformJoinRequest, res *rsAPI.PerformJoinResponse, ) { - roomID, joinedVia, err := r.performJoin(ctx, req) + logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ + "room_id": req.RoomIDOrAlias, + "user_id": req.UserID, + "servers": req.ServerNames, + }) + logger.Info("User requested to room join") + roomID, joinedVia, err := r.performJoin(context.Background(), req) if err != nil { - logrus.WithContext(ctx).WithFields(logrus.Fields{ - "room_id": req.RoomIDOrAlias, - "user_id": req.UserID, - "servers": req.ServerNames, - }).WithError(err).Error("Failed to join room") + logger.WithError(err).Error("Failed to join room") sentry.CaptureException(err) perr, ok := err.(*rsAPI.PerformError) if ok { @@ -67,7 +69,9 @@ func (r *Joiner) PerformJoin( Msg: err.Error(), } } + return } + logger.Info("User joined room successfully") res.RoomID = roomID res.JoinedVia = joinedVia } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index b19916491..12784e5f5 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -51,13 +51,17 @@ func (r *Leaver) PerformLeave( if domain != r.Cfg.Matrix.ServerName { return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) } + logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ + "room_id": req.RoomID, + "user_id": req.UserID, + }) + logger.Info("User requested to leave join") if strings.HasPrefix(req.RoomID, "!") { - output, err := r.performLeaveRoomByID(ctx, req, res) + output, err := r.performLeaveRoomByID(context.Background(), req, res) if err != nil { - logrus.WithContext(ctx).WithFields(logrus.Fields{ - "room_id": req.RoomID, - "user_id": req.UserID, - }).WithError(err).Error("Failed to leave room") + logger.WithError(err).Error("Failed to leave room") + } else { + logger.Info("User left room successfully") } return output, err } diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 669957be1..e1b84b80c 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -50,7 +50,7 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to room server db") } - js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) return internal.NewRoomserverAPI( cfg, roomserverDB, js, diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 15d592b46..e5f69521e 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -22,7 +22,6 @@ import ( "sort" "time" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -30,13 +29,25 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +type StateResolutionStorage interface { + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) + Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) +} + type StateResolution struct { - db storage.Database + db StateResolutionStorage roomInfo *types.RoomInfo events map[types.EventNID]*gomatrixserverlib.Event } -func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution { +func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo, diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 15764366b..a9851e05b 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -86,11 +86,10 @@ type Database interface { // Lookup the event IDs for a batch of event numeric IDs. // Returns an error if the retrieval went wrong. EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) - // Look up the latest events in a room in preparation for an update. - // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. - // Returns the latest events in the room and the last eventID sent to the log along with an updater. + // Opens and returns a room updater, which locks the room and opens a transaction. + // The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error. // If this returns an error then no further action is required. - GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error) + GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) // Look up event references for the latest events in the room and the current state snapshot. // Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns an error if there was a problem talking to the database. diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 32e457821..433e445d8 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON( } func (s *eventJSONStatements) BulkSelectEventJSON( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]tables.EventJSONPair, error) { - rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 3a7cf03e3..762b3a1fc 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { - rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt) + rows, err := stmt.QueryContext( ctx, pq.StringArray(eventStateKeys), ) if err != nil { @@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) for i := range eventStateKeyNIDs { nIDs[i] = int64(eventStateKeyNIDs[i]) } - rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt) + rows, err := stmt.QueryContext(ctx, nIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index e558072a5..1d5de5822 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID( } func (s *eventTypeStatements) BulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { - rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 778cd8d73..6c3847752 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateEntry, error) { - rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID( // bulkSelectStateEventByNID lookups a list of state events by event NID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByNID( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { tuples := stateKeyTupleSorter(stateKeyTuples) sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) if err != nil { return nil, err } @@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID( // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. func (s *eventStatements) BulkSelectStateAtEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateAtEvent, error) { - rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { - rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) +func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } @@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { - rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } func (s *eventStatements) SelectRoomNIDsForEventNIDs( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) (map[types.EventNID]types.RoomNID, error) { - rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index 344302c8f..176c16e48 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { } func (s *inviteStatements) InsertInviteEvent( - ctx context.Context, - txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, + inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { @@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent( } func (s *inviteStatements) UpdateInviteRetired( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) @@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired( // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs func (s *inviteStatements) SelectInviteActiveForUserInRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, []string, error) { - rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt) + rows, err := stmt.QueryContext( ctx, targetUserNID, roomNID, ) if err != nil { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index b0d906c80..48c2c35cd 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { } func (s *membershipStatements) InsertMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) @@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership( } func (s *membershipStatements) SelectMembershipForUpdate( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership tables.MembershipState, err error) { err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( ctx, roomNID, targetUserNID, @@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate( } func (s *membershipStatements) SelectMembershipFromRoomAndTarget( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) + err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID, &forgotten) return } func (s *membershipStatements) SelectMembershipsFromRoom( - ctx context.Context, roomNID types.RoomNID, localOnly bool, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt if localOnly { @@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } else { stmt = s.selectMembershipsFromRoomStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err := stmt.QueryContext(ctx, roomNID) if err != nil { return @@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var rows *sql.Rows @@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } else { stmt = s.selectMembershipsFromRoomAndMembershipStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err = stmt.QueryContext(ctx, roomNID, membership) if err != nil { return @@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } func (s *membershipStatements) UpdateMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( @@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership( } func (s *membershipStatements) SelectRoomsWithMembership( - ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, + ctx context.Context, txn *sql.Tx, + userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) + rows, err := stmt.QueryContext(ctx, membershipState, userID) if err != nil { return nil, err } @@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { +func (s *membershipStatements) SelectJoinedUsersSetForRooms( + ctx context.Context, txn *sql.Tx, + roomNIDs []types.RoomNID, +) (map[types.EventStateKeyNID]int, error) { roomIDarray := make([]int64, len(roomNIDs)) for i := range roomNIDs { roomIDarray[i] = int64(roomNIDs[i]) } - rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) if err != nil { return nil, err } @@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, return result, rows.Err() } -func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) +func (s *membershipStatements) SelectKnownUsers( + ctx context.Context, txn *sql.Tx, + userID types.EventStateKeyNID, searchString string, limit int, +) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt) + rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) if err != nil { return nil, err } @@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } func (s *membershipStatements) UpdateForgetMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - forget bool, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( ctx, roomNID, targetUserNID, forget, @@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership( return err } -func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { +func (s *membershipStatements) SelectLocalServerInRoom( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, +) (bool, error) { var nid types.RoomNID - err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil @@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room return found, nil } -func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (s *membershipStatements) SelectServerInRoom( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, serverName gomatrixserverlib.ServerName, +) (bool, error) { var nid types.RoomNID - err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 8deb68441..15985fcd6 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished( } func (s *publishedStatements) SelectPublishedFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (published bool, err error) { - err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&published) if err == sql.ErrNoRows { return false, nil } @@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, published bool, + ctx context.Context, txn *sql.Tx, published bool, ) ([]string, error) { - rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err := stmt.QueryContext(ctx, published) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index 031825fee..d13df8e7f 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias( } func (s *roomAliasesStatements) SelectRoomIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } @@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias( } func (s *roomAliasesStatements) SelectAliasesFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) ([]string, error) { - rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt) + rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return nil, err } @@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( } func (s *roomAliasesStatements) SelectCreatorIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (creatorID string, err error) { - err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index f51eba4d4..b2685084d 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { }.Prepare(db) } -func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { - rows, err := s.selectRoomIDsStmt.QueryContext(ctx) +func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } @@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID( return types.RoomNID(roomNID), err } -func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { +func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDs pq.Int64Array - err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( + stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan( &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs, ) if err == sql.ErrNoRows { @@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs( ) ([]types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var stateSnapshotNID int64 - stmt := s.selectLatestEventNIDsStmt + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) if err != nil { return nil, 0, err @@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs( } func (s *roomStatements) SelectRoomVersionsForRoomNIDs( - ctx context.Context, roomNIDs []types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { - rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) + stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt) + rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) if err != nil { return nil, err } @@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( return result, nil } -func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) { var array pq.Int64Array for _, nid := range roomNIDs { array = append(array, int64(nid)) } - rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) + stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx, array) if err != nil { return nil, err } @@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types return roomIDs, nil } -func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) { var array pq.StringArray for _, roomID := range roomIDs { array = append(array, roomID) } - rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array) + stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt) + rows, err := stmt.QueryContext(ctx, array) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 27d85e83b..6f8f9e1b5 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, - txn *sql.Tx, + ctx context.Context, txn *sql.Tx, entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] @@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData( for _, e := range entries { nids = append(nids, e.EventNID) } - err = s.insertStateDataStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) + err = stmt.QueryRowContext( ctx, nids.Hash(), eventNIDsAsArray(nids), ).Scan(&id) return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs types.StateBlockNIDs, + ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs, ) ([][]types.EventNID, error) { - rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt) + rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 4fc0fa48a..ce9f24636 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState( } func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]int64, len(stateNIDs)) for i := range stateNIDs { nids[i] = int64(stateNIDs[i]) } - rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids)) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go deleted file mode 100644 index 36865081a..000000000 --- a/roomserver/storage/shared/latest_events_updater.go +++ /dev/null @@ -1,133 +0,0 @@ -package shared - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type LatestEventsUpdater struct { - transaction - d *Database - roomInfo types.RoomInfo - latestEvents []types.StateAtEventAndReference - lastEventIDSent string - currentStateSnapshotNID types.StateSnapshotNID -} - -func rollback(txn *sql.Tx) { - if txn == nil { - return - } - txn.Rollback() // nolint: errcheck -} - -func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) { - eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) - if err != nil { - rollback(txn) - return nil, err - } - stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) - if err != nil { - rollback(txn) - return nil, err - } - var lastEventIDSent string - if lastEventNIDSent != 0 { - lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) - if err != nil { - rollback(txn) - return nil, err - } - } - return &LatestEventsUpdater{ - transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, - }, nil -} - -// RoomVersion implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - return u.roomInfo.RoomVersion -} - -// LatestEvents implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference { - return u.latestEvents -} - -// LastEventIDSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) LastEventIDSent() string { - return u.lastEventIDSent -} - -// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { - return u.currentStateSnapshotNID -} - -// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer -func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) - } - } - return nil -} - -// IsReferenced implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) - if err == nil { - return true, nil - } - if err == sql.ErrNoRows { - return false, nil - } - return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) -} - -// SetLatestEvents implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) SetLatestEvents( - roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, - currentStateSnapshotNID types.StateSnapshotNID, -) error { - eventNIDs := make([]types.EventNID, len(latest)) - for i := range latest { - eventNIDs[i] = latest[i].EventNID - } - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { - return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) - } - if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok { - if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok { - roomInfo.StateSnapshotNID = currentStateSnapshotNID - roomInfo.IsStub = false - u.d.Cache.StoreRoomInfo(roomID, roomInfo) - } - } - return nil - }) -} - -// HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) -} - -// MarkEventAsSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) - }) -} - -func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) -} diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go new file mode 100644 index 000000000..bb9f5dc62 --- /dev/null +++ b/roomserver/storage/shared/room_updater.go @@ -0,0 +1,262 @@ +package shared + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type RoomUpdater struct { + transaction + d *Database + roomInfo *types.RoomInfo + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID +} + +func rollback(txn *sql.Tx) { + if txn == nil { + return + } + txn.Rollback() // nolint: errcheck +} + +func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *types.RoomInfo) (*RoomUpdater, error) { + // If the roomInfo is nil then that means that the room doesn't exist + // yet, so we can't do `SelectLatestEventsNIDsForUpdate` because that + // would involve locking a row on the table that doesn't exist. Instead + // we will just run with a normal database transaction. It'll either + // succeed, processing a create event which creates the room, or it won't. + if roomInfo == nil { + return &RoomUpdater{ + transaction{ctx, txn}, d, nil, nil, "", 0, + }, nil + } + + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) + if err != nil { + rollback(txn) + return nil, err + } + stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + rollback(txn) + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + rollback(txn) + return nil, err + } + } + return &RoomUpdater{ + transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil +} + +// Implements sqlutil.Transaction +func (u *RoomUpdater) Commit() error { + if u.txn == nil { // SQLite mode probably + return nil + } + return u.txn.Commit() +} + +// Implements sqlutil.Transaction +func (u *RoomUpdater) Rollback() error { + if u.txn == nil { // SQLite mode probably + return nil + } + return u.txn.Rollback() +} + +// RoomVersion implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { + return u.roomInfo.RoomVersion +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer +func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + for _, ref := range previousEventReferences { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) + } + } + return nil + }) +} + +func (u *RoomUpdater) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + return u.d.events(ctx, u.txn, eventNIDs) +} + +func (u *RoomUpdater) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID) +} + +func (u *RoomUpdater) StoreEvent( + ctx context.Context, event *gomatrixserverlib.Event, + authEventNIDs []types.EventNID, isRejected bool, +) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { + return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected) +} + +func (u *RoomUpdater) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return u.d.stateBlockNIDs(ctx, u.txn, stateNIDs) +} + +func (u *RoomUpdater) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return u.d.stateEntries(ctx, u.txn, stateBlockNIDs) +} + +func (u *RoomUpdater) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return u.d.stateEntriesForTuples(ctx, u.txn, stateBlockNIDs, stateKeyTuples) +} + +func (u *RoomUpdater) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state) +} + +func (u *RoomUpdater) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID) + }) +} + +func (u *RoomUpdater) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return u.d.eventTypeNIDs(ctx, u.txn, eventTypes) +} + +func (u *RoomUpdater) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return u.d.eventStateKeyNIDs(ctx, u.txn, eventStateKeys) +} + +func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { + return u.d.roomInfo(ctx, u.txn, roomID) +} + +func (u *RoomUpdater) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) +} + +func (u *RoomUpdater) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly) +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { + err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { + return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) + } + if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok { + if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok { + roomInfo.StateSnapshotNID = currentStateSnapshotNID + roomInfo.IsStub = false + u.d.Cache.StoreRoomInfo(roomID, roomInfo) + } + } + return nil + }) +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { + return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) + }) +} + +func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index d4c5ebb5b..127cd1f52 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -26,23 +26,23 @@ import ( const redactionsArePermanent = true type Database struct { - DB *sql.DB - Cache caching.RoomServerCaches - Writer sqlutil.Writer - EventsTable tables.Events - EventJSONTable tables.EventJSON - EventTypesTable tables.EventTypes - EventStateKeysTable tables.EventStateKeys - RoomsTable tables.Rooms - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock - RoomAliasesTable tables.RoomAliases - PrevEventsTable tables.PreviousEvents - InvitesTable tables.Invites - MembershipTable tables.Membership - PublishedTable tables.Published - RedactionsTable tables.Redactions - GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) + DB *sql.DB + Cache caching.RoomServerCaches + Writer sqlutil.Writer + EventsTable tables.Events + EventJSONTable tables.EventJSON + EventTypesTable tables.EventTypes + EventStateKeysTable tables.EventStateKeys + RoomsTable tables.Rooms + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + PrevEventsTable tables.PreviousEvents + InvitesTable tables.Invites + MembershipTable tables.Membership + PublishedTable tables.Published + RedactionsTable tables.Redactions + GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } func (d *Database) SupportsConcurrentRoomInputs() bool { @@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) EventTypeNIDs( ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.eventTypeNIDs(ctx, nil, eventTypes) +} + +func (d *Database) eventTypeNIDs( + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) remaining := []string{} @@ -62,7 +68,7 @@ func (d *Database) EventTypeNIDs( } } if len(remaining) > 0 { - nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining) + nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining) if err != nil { return nil, err } @@ -77,11 +83,17 @@ func (d *Database) EventTypeNIDs( func (d *Database) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { - return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs) + return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs) } func (d *Database) EventStateKeyNIDs( ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.eventStateKeyNIDs(ctx, nil, eventStateKeys) +} + +func (d *Database) eventStateKeyNIDs( + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) remaining := []string{} @@ -93,7 +105,7 @@ func (d *Database) EventStateKeyNIDs( } } if len(remaining) > 0 { - nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining) + nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining) if err != nil { return nil, err } @@ -108,23 +120,31 @@ func (d *Database) EventStateKeyNIDs( func (d *Database) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { - return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) + return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs) } func (d *Database) StateEntriesForTuples( ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.stateEntriesForTuples(ctx, nil, stateBlockNIDs, stateKeyTuples) +} + +func (d *Database) stateEntriesForTuples( + ctx context.Context, txn *sql.Tx, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( - ctx, stateBlockNIDs, + ctx, txn, stateBlockNIDs, ) if err != nil { return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) } lists := []types.StateEntryList{} for i, entry := range entries { - entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples) + entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, stateKeyTuples) if err != nil { return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) } @@ -137,10 +157,14 @@ func (d *Database) StateEntriesForTuples( } func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { + return d.roomInfo(ctx, nil, roomID) +} + +func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok { return &roomInfo, nil } - roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID) + roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID) if err == nil && roomInfo != nil { d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID) d.Cache.StoreRoomInfo(roomID, *roomInfo) @@ -153,13 +177,22 @@ func (d *Database) AddState( roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + return d.addState(ctx, nil, roomNID, stateBlockNIDs, state) +} + +func (d *Database) addState( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { if len(stateBlockNIDs) > 0 && len(state) > 0 { // Check to see if the event already appears in any of the existing state // blocks. If it does then we should not add it again, as this will just // result in excess state blocks and snapshots. // TODO: Investigate why this is happening - probably input_events.go! - blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs) if berr != nil { return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr) } @@ -180,7 +213,7 @@ func (d *Database) AddState( } } } - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { if len(state) > 0 { // If there's any state left to add then let's add new blocks. var stateBlockNID types.StateBlockNID @@ -205,7 +238,13 @@ func (d *Database) AddState( func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) + return d.eventNIDs(ctx, nil, eventIDs) +} + +func (d *Database) eventNIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (map[string]types.EventNID, error) { + return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs) } func (d *Database) SetState( @@ -219,24 +258,34 @@ func (d *Database) SetState( func (d *Database) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { - return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) + return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs) } func (d *Database) SnapshotNIDFromEventID( ctx context.Context, eventID string, ) (types.StateSnapshotNID, error) { - _, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID) + return d.snapshotNIDFromEventID(ctx, nil, eventID) +} + +func (d *Database) snapshotNIDFromEventID( + ctx context.Context, txn *sql.Tx, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID) return stateNID, err } func (d *Database) EventIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]string, error) { - return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) } func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) + return d.eventsFromIDs(ctx, nil, eventIDs) +} + +func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.eventNIDs(ctx, txn, eventIDs) if err != nil { return nil, err } @@ -246,7 +295,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type nids = append(nids, nid) } - return d.Events(ctx, nids) + return d.events(ctx, txn, nids) } func (d *Database) LatestEventIDs( @@ -271,21 +320,33 @@ func (d *Database) LatestEventIDs( func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { - return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) + return d.stateBlockNIDs(ctx, nil, stateNIDs) +} + +func (d *Database) stateBlockNIDs( + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, txn, stateNIDs) } func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.stateEntries(ctx, nil, stateBlockNIDs) +} + +func (d *Database) stateEntries( + ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( - ctx, stateBlockNIDs, + ctx, txn, stateBlockNIDs, ) if err != nil { return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) } lists := make([]types.StateEntryList, 0, len(entries)) for i, entry := range entries { - eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil) + eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, nil) if err != nil { return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) } @@ -304,17 +365,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string } func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) + return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias) } func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) + return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID) } func (d *Database) GetCreatorIDForAlias( ctx context.Context, alias string, ) (string, error) { - return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) + return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias) } func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { @@ -335,7 +396,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req senderMembershipEventNID, senderMembership, isRoomforgotten, err := d.MembershipTable.SelectMembershipFromRoomAndTarget( - ctx, roomNID, requestSenderUserNID, + ctx, nil, roomNID, requestSenderUserNID, ) if err == sql.ErrNoRows { // The user has never been a member of that room @@ -349,14 +410,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req func (d *Database) GetMembershipEventNIDsForRoom( ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly) +} + +func (d *Database) getMembershipEventNIDsForRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) ([]types.EventNID, error) { if joinOnly { return d.MembershipTable.SelectMembershipsFromRoomAndMembership( - ctx, roomNID, tables.MembershipStateJoin, localOnly, + ctx, txn, roomNID, tables.MembershipStateJoin, localOnly, ) } - return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly) + return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly) } func (d *Database) GetInvitesForUser( @@ -364,22 +431,28 @@ func (d *Database) GetInvitesForUser( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) { - return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) + return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) } func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + return d.events(ctx, nil, eventNIDs) +} + +func (d *Database) events( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { return nil, err } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } var roomNIDs map[types.EventNID]types.RoomNID - roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs) + roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs) if err != nil { return nil, err } @@ -398,7 +471,7 @@ func (d *Database) Events( } fetchNIDList = append(fetchNIDList, n) } - dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList) + dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList) if err != nil { return nil, err } @@ -440,19 +513,19 @@ func (d *Database) MembershipUpdater( return updater, err } -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomInfo types.RoomInfo, -) (*LatestEventsUpdater, error) { - if d.GetLatestEventsForUpdateFn != nil { - return d.GetLatestEventsForUpdateFn(ctx, roomInfo) +func (d *Database) GetRoomUpdater( + ctx context.Context, roomInfo *types.RoomInfo, +) (*RoomUpdater, error) { + if d.GetRoomUpdaterFn != nil { + return d.GetRoomUpdaterFn(ctx, roomInfo) } txn, err := d.DB.Begin() if err != nil { return nil, err } - var updater *LatestEventsUpdater + var updater *RoomUpdater _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { - updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) + updater, err = NewRoomUpdater(ctx, d, txn, roomInfo) return err }) return updater, err @@ -461,6 +534,13 @@ func (d *Database) GetLatestEventsForUpdate( func (d *Database) StoreEvent( ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID, isRejected bool, +) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { + return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected) +} + +func (d *Database) storeEvent( + ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event, + authEventNIDs []types.EventNID, isRejected bool, ) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { var ( roomNID types.RoomNID @@ -472,8 +552,11 @@ func (d *Database) StoreEvent( redactedEventID string err error ) - - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var txn *sql.Tx + if updater != nil { + txn = updater.txn + } + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { // TODO: Here we should aim to have two different code paths for new rooms // vs existing ones. @@ -546,42 +629,32 @@ func (d *Database) StoreEvent( // events updater because it somewhat works as a mutex, ensuring // that there's a row-level lock on the latest room events (well, // on Postgres at least). - var roomInfo *types.RoomInfo - var updater *LatestEventsUpdater if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { - roomInfo, err = d.RoomInfo(ctx, event.RoomID()) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) - } - if roomInfo == nil && len(prevEvents) > 0 { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) - } // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This // function only does SELECTs though so the created txn (at this point) is just a read txn like // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater // to do writes however then this will need to go inside `Writer.Do`. - updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err) - } - // Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents - // and EndTransaction in a writer then it's possible for a new write txn to be made between the two - // function calls which will then fail with 'database is locked'. This new write txn would HAVE to be - // something like SetRoomAlias/RemoveRoomAlias as normal input events are already done sequentially due to - // SupportsConcurrentRoomInputs() == false on sqlite, though this does not apply to setting room aliases - // as they don't go via InputRoomEvents - err = d.Writer.Do(d.DB, updater.txn, func(txn *sql.Tx) error { - if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { - return fmt.Errorf("updater.StorePreviousEvents: %w", err) + succeeded := false + if updater == nil { + var roomInfo *types.RoomInfo + roomInfo, err = d.RoomInfo(ctx, event.RoomID()) + if err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) } - succeeded := true - err = sqlutil.EndTransaction(updater, &succeeded) - return err - }) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", err + if roomInfo == nil && len(prevEvents) > 0 { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) + } + updater, err = d.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) + } + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) } + if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) + } + succeeded = true } return eventNID, roomNID, types.StateAtEvent{ @@ -603,7 +676,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) } func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { - return d.PublishedTable.SelectAllPublishedRooms(ctx, true) + return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) } func (d *Database) assignRoomNID( @@ -875,14 +948,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s eventNIDs = append(eventNIDs, e.EventNID) } } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } // return the event requested for _, e := range entries { if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { - data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID}) + data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID}) if err != nil { return nil, err } @@ -922,11 +995,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership } return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) } - roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) + roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState) if err != nil { return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) } - roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs) if err != nil { return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) } @@ -945,7 +1018,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } // we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which // isn't a failure. - eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) + eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err) } @@ -965,7 +1038,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } - eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) + eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err) } @@ -999,11 +1072,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } } } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } - events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err) } @@ -1027,11 +1100,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { - roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs) + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) if err != nil { return nil, err } - userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs) + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) if err != nil { return nil, err } @@ -1041,7 +1114,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) stateKeyNIDs[i] = nid i++ } - nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs) + nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs) if err != nil { return nil, err } @@ -1057,12 +1130,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { - return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID) + return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID) } // GetServerInRoom returns true if we think a server is in a given room or false otherwise. func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { - return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName) + return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName) } // GetKnownUsers searches all users that userID knows about. @@ -1071,17 +1144,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin if err != nil { return nil, err } - return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) + return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { - return d.RoomsTable.SelectRoomIDs(ctx) + return d.RoomsTable.SelectRoomIDs(ctx, nil) } // ForgetRoom sets a users room to forgotten func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error { - roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID}) + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID}) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 53b219294..f470ea326 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -76,15 +76,20 @@ func (s *eventJSONStatements) InsertEventJSON( } func (s *eventJSONStatements) BulkSelectEventJSON( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]tables.EventJSONPair, error) { iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, selectOrig, iEventNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, selectOrig, iEventNIDs...) + } if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 62fbce2d0..bf12d5b83 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -112,15 +112,20 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { iEventStateKeys := make([]interface{}, len(eventStateKeys)) for k, v := range eventStateKeys { iEventStateKeys[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, selectOrig, iEventStateKeys...) + } else { + rows, err = s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + } if err != nil { return nil, err } @@ -138,15 +143,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) for k, v := range eventStateKeyNIDs { iEventStateKeyNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + stmt := sqlutil.TxStmt(txn, selectPrep) + rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 22df3fb22..f2c9c42fe 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID( } func (s *eventTypeStatements) BulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { /////////////// iEventTypes := make([]interface{}, len(eventTypes)) @@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( if err != nil { return nil, err } + stmt := sqlutil.TxStmt(txn, selectPrep) /////////////// - rows, err := selectPrep.QueryContext(ctx, iEventTypes...) + rows, err := stmt.QueryContext(ctx, iEventTypes...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 7483e2815..e1e6a597c 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateEntry, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) @@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByNID( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { tuples := stateKeyTupleSorter(stateKeyTuples) @@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID( if err != nil { return nil, fmt.Errorf("s.db.Prepare: %w", err) } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, params...) if err != nil { return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) @@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID( // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. func (s *eventStatements) BulkSelectStateAtEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateAtEvent, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { @@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( if err != nil { return nil, err } + selectPrep = sqlutil.TxStmt(txn, selectPrep) ////////////// rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) @@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { /////////////// iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { @@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) @@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { @@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { @@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } func (s *eventStatements) SelectRoomNIDsForEventNIDs( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) (map[types.EventNID]types.RoomNID, error) { sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) sqlPrep, err := s.db.Prepare(sqlStr) if err != nil { return nil, err } + sqlPrep = sqlutil.TxStmt(txn, sqlPrep) iEventNIDs := make([]interface{}, len(eventNIDs)) for i, v := range eventNIDs { iEventNIDs[i] = v diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index c1d7347ae..d54d313a9 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { } func (s *inviteStatements) InsertInviteEvent( - ctx context.Context, - txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, + inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { @@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent( } func (s *inviteStatements) UpdateInviteRetired( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { // gather all the event IDs we will retire stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) @@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired( // selectInviteActiveForUserInRoom returns a list of sender state key NIDs func (s *inviteStatements) SelectInviteActiveForUserInRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, []string, error) { - rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt) + rows, err := stmt.QueryContext( ctx, targetUserNID, roomNID, ) if err != nil { diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 2e58431d3..181b4b4c9 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate( } func (s *membershipStatements) SelectMembershipFromRoomAndTarget( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) + err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID, &forgotten) return } func (s *membershipStatements) SelectMembershipsFromRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var selectStmt *sql.Stmt @@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } else { selectStmt = s.selectMembershipsFromRoomStmt } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { return nil, err @@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt @@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } else { stmt = s.selectMembershipsFromRoomAndMembershipStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { return @@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership( } func (s *membershipStatements) SelectRoomsWithMembership( - ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, + ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) + rows, err := stmt.QueryContext(ctx, membershipState, userID) if err != nil { return nil, err } @@ -276,13 +280,19 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v } query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) - rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) + } if err != nil { return nil, err } @@ -299,8 +309,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, return result, rows.Err() } -func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt) + rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) if err != nil { return nil, err } @@ -317,8 +328,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } func (s *membershipStatements) UpdateForgetMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( @@ -327,9 +338,10 @@ func (s *membershipStatements) UpdateForgetMembership( return err } -func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { +func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) { var nid types.RoomNID - err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil @@ -340,9 +352,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room return found, nil } -func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { var nid types.RoomNID - err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index b07c0ac42..9e416ace3 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished( } func (s *publishedStatements) SelectPublishedFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (published bool, err error) { - err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&published) if err == sql.ErrNoRows { return false, nil } @@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, published bool, + ctx context.Context, txn *sql.Tx, published bool, ) ([]string, error) { - rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err := stmt.QueryContext(ctx, published) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 323945b88..7c7bead95 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias( } func (s *roomAliasesStatements) SelectRoomIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } @@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias( } func (s *roomAliasesStatements) SelectAliasesFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (aliases []string, err error) { aliases = []string{} - rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt) + rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return } @@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( } func (s *roomAliasesStatements) SelectCreatorIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (creatorID string, err error) { - err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index c441daec0..5413475e2 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { }.Prepare(db) } -func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { - rows, err := s.selectRoomIDsStmt.QueryContext(ctx) +func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } @@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { return roomIDs, nil } -func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { +func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDsJSON string - err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( + stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan( &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON, ) if err != nil { @@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs( } func (s *roomStatements) SelectRoomVersionsForRoomNIDs( - ctx context.Context, roomNIDs []types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) sqlPrep, err := s.db.Prepare(sqlStr) if err != nil { return nil, err } + sqlPrep = sqlutil.TxStmt(txn, sqlPrep) iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v @@ -252,13 +255,19 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( return result, nil } -func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) { iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v } sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) - rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, sqlQuery, iRoomNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + } if err != nil { return nil, err } @@ -274,13 +283,19 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types return roomIDs, nil } -func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i, v := range roomIDs { iRoomIDs[i] = v } sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) - rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, sqlQuery, iRoomIDs...) + } else { + rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + } if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 58b0b5dc2..d51fc492d 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, - txn *sql.Tx, + ctx context.Context, txn *sql.Tx, entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] @@ -94,14 +93,15 @@ func (s *stateBlockStatements) BulkInsertStateData( if err != nil { return 0, fmt.Errorf("json.Marshal: %w", err) } - err = s.insertStateDataStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) + err = stmt.QueryRowContext( ctx, nids.Hash(), js, ).Scan(&id) return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs types.StateBlockNIDs, + ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs, ) ([][]types.EventNID, error) { intfs := make([]interface{}, len(stateBlockNIDs)) for i := range stateBlockNIDs { @@ -112,6 +112,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, intfs...) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 040d99ae6..3c4bde3f5 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState( } func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]interface{}, len(stateNIDs)) for k, v := range stateNIDs { @@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil { diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 1fcc7989d..325c253b5 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { return err } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: sqlutil.NewExclusiveWriter(), - EventsTable: events, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - RoomsTable: rooms, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, - GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, + DB: db, + Cache: cache, + Writer: sqlutil.NewExclusiveWriter(), + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + RoomsTable: rooms, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: prevEvents, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + RedactionsTable: redactions, + GetRoomUpdaterFn: d.GetRoomUpdater, } return nil } @@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { return false } -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomInfo types.RoomInfo, -) (*shared.LatestEventsUpdater, error) { +func (d *Database) GetRoomUpdater( + ctx context.Context, roomInfo *types.RoomInfo, +) (*shared.RoomUpdater, error) { // TODO: Do not use transactions. We should be holding open this transaction but we cannot have // multiple write transactions on sqlite. The code will perform additional // write transactions independent of this one which will consistently cause // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) + return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo) } func (d *Database) MembershipUpdater( diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 6ad7ed2e8..fed39b944 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -18,20 +18,20 @@ type EventJSONPair struct { type EventJSON interface { // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error - BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) + BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error) } type EventTypes interface { InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) - BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error) } type EventStateKeys interface { InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) - BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) - BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) } type Events interface { @@ -42,12 +42,12 @@ type Events interface { SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError - BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) - BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) + BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error) + BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. - BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error) UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error @@ -55,12 +55,12 @@ type Events interface { BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) // BulkSelectEventID returns a map from numeric event ID to string event ID. - BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. - BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) - SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) + SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) } type Rooms interface { @@ -69,29 +69,29 @@ type Rooms interface { SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error - SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) - SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) - SelectRoomIDs(ctx context.Context) ([]string, error) - BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) - BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) + SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) + SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) + SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) + BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) + BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) } type StateSnapshot interface { InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) - BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) } type StateBlock interface { BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error) - BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) + BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) } type RoomAliases interface { InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error) - SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) - SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) - SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) + SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error) + SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error) + SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error) DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error) } @@ -106,7 +106,7 @@ type Invites interface { InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids. - SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) + SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) } type MembershipState int64 @@ -121,24 +121,24 @@ const ( type Membership interface { InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) - SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) - SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) - SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) + SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error - SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) + SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // counts of how many rooms they are joined. - SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) - SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) + SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error - SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) - SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) + SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) + SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) } type Published interface { UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error) - SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error) - SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error) + SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error) + SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error) } type RedactionInfo struct { diff --git a/setup/config/config_jetstream.go b/setup/config/config_jetstream.go index 94e2d88b3..9271cd8b4 100644 --- a/setup/config/config_jetstream.go +++ b/setup/config/config_jetstream.go @@ -2,8 +2,6 @@ package config import ( "fmt" - - "github.com/nats-io/nats.go" ) type JetStream struct { @@ -25,8 +23,8 @@ func (c *JetStream) TopicFor(name string) string { return fmt.Sprintf("%s%s", c.TopicPrefix, name) } -func (c *JetStream) Durable(name string) nats.SubOpt { - return nats.Durable(c.TopicFor(name)) +func (c *JetStream) Durable(name string) string { + return c.TopicFor(name) } func (c *JetStream) Defaults(generate bool) { diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 1891b96b3..544b5f0c3 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -1,12 +1,81 @@ package jetstream -import "github.com/nats-io/nats.go" +import ( + "context" + "fmt" -func WithJetStreamMessage(msg *nats.Msg, f func(msg *nats.Msg) bool) { - _ = msg.InProgress() - if f(msg) { - _ = msg.Ack() - } else { - _ = msg.Nak() + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" +) + +func JetStreamConsumer( + ctx context.Context, js nats.JetStreamContext, subj, durable string, + f func(ctx context.Context, msg *nats.Msg) bool, + opts ...nats.SubOpt, +) error { + defer func() { + // If there are existing consumers from before they were pull + // consumers, we need to clean up the old push consumers. However, + // in order to not affect the interest-based policies, we need to + // do this *after* creating the new pull consumers, which have + // "Pull" suffixed to their name. + if _, err := js.ConsumerInfo(subj, durable); err == nil { + if err := js.DeleteConsumer(subj, durable); err != nil { + logrus.WithContext(ctx).Warnf("Failed to clean up old consumer %q", durable) + } + } + }() + + name := durable + "Pull" + sub, err := js.PullSubscribe(subj, name, opts...) + if err != nil { + return fmt.Errorf("nats.SubscribeSync: %w", err) } + go func() { + for { + // The context behaviour here is surprising — we supply a context + // so that we can interrupt the fetch if we want, but NATS will still + // enforce its own deadline (roughly 5 seconds by default). Therefore + // it is our responsibility to check whether our context expired or + // not when a context error is returned. Footguns. Footguns everywhere. + msgs, err := sub.Fetch(1, nats.Context(ctx)) + if err != nil { + if err == context.Canceled || err == context.DeadlineExceeded { + // Work out whether it was the JetStream context that expired + // or whether it was our supplied context. + select { + case <-ctx.Done(): + // The supplied context expired, so we want to stop the + // consumer altogether. + return + default: + // The JetStream context expired, so the fetch probably + // just timed out and we should try again. + continue + } + } else { + // Something else went wrong, so we'll panic. + logrus.WithContext(ctx).WithField("subject", subj).Fatal(err) + } + } + if len(msgs) < 1 { + continue + } + msg := msgs[0] + if err = msg.InProgress(); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) + continue + } + if f(ctx, msg) { + if err = msg.Ack(); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Ack: %w", err)) + } + } else { + if err = msg.Nak(); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) + } + } + } + }() + return nil } diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 5d7937b5c..77ad2b721 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -5,20 +5,17 @@ import ( "sync" "time" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/setup/config" "github.com/sirupsen/logrus" - saramajs "github.com/S7evinK/saramajetstream" natsserver "github.com/nats-io/nats-server/v2/server" - "github.com/nats-io/nats.go" natsclient "github.com/nats-io/nats.go" ) var natsServer *natsserver.Server var natsServerMutex sync.Mutex -func Prepare(cfg *config.JetStream) (nats.JetStreamContext, sarama.Consumer, sarama.SyncProducer) { +func Prepare(cfg *config.JetStream) natsclient.JetStreamContext { // check if we need an in-process NATS Server if len(cfg.Addresses) != 0 { return setupNATS(cfg, nil) @@ -52,20 +49,20 @@ func Prepare(cfg *config.JetStream) (nats.JetStreamContext, sarama.Consumer, sar return setupNATS(cfg, nc) } -func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContext, sarama.Consumer, sarama.SyncProducer) { +func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) natsclient.JetStreamContext { if nc == nil { var err error - nc, err = nats.Connect(strings.Join(cfg.Addresses, ",")) + nc, err = natsclient.Connect(strings.Join(cfg.Addresses, ",")) if err != nil { logrus.WithError(err).Panic("Unable to connect to NATS") - return nil, nil, nil + return nil } } s, err := nc.JetStream() if err != nil { logrus.WithError(err).Panic("Unable to get JetStream context") - return nil, nil, nil + return nil } for _, stream := range streams { // streams are defined in streams.go @@ -80,7 +77,7 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContex // If we're trying to keep everything in memory (e.g. unit tests) // then overwrite the storage policy. if cfg.InMemory { - stream.Storage = nats.MemoryStorage + stream.Storage = natsclient.MemoryStorage } // Namespace the streams without modifying the original streams @@ -93,7 +90,5 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContex } } - consumer := saramajs.NewJetStreamConsumer(nc, s, "") - producer := saramajs.NewJetStreamProducer(nc, s, "") - return s, consumer, producer + return s } diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 3d340a16a..c3650085f 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -34,7 +34,7 @@ import ( type OutputClientDataConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database stream types.StreamProvider @@ -63,45 +63,45 @@ func NewOutputClientDataConsumer( // Start consuming from room servers func (s *OutputClientDataConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called when the sync server receives a new event from the client API server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - userID := msg.Header.Get(jetstream.UserID) - var output eventutil.AccountData - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("client API server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - log.WithFields(log.Fields{ - "type": output.Type, - "room_id": output.RoomID, - }).Debug("Received data from client API server") - - streamPos, err := s.db.UpsertAccountData( - s.ctx, userID, output.RoomID, output.Type, - ) - if err != nil { - sentry.CaptureException(err) - log.WithFields(log.Fields{ - "type": output.Type, - "room_id": output.RoomID, - log.ErrorKey: err, - }).Panicf("could not save account data") - } - - s.stream.Advance(streamPos) - s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) - +func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + userID := msg.Header.Get(jetstream.UserID) + var output eventutil.AccountData + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("client API server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + log.WithFields(log.Fields{ + "type": output.Type, + "room_id": output.RoomID, + }).Debug("Received data from client API server") + + streamPos, err := s.db.UpsertAccountData( + s.ctx, userID, output.RoomID, output.Type, + ) + if err != nil { + sentry.CaptureException(err) + log.WithFields(log.Fields{ + "type": output.Type, + "room_id": output.RoomID, + log.ErrorKey: err, + }).Panicf("could not save account data") + } + + s.stream.Advance(streamPos) + s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) + + return true } diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go index 57d69d6fb..392840ece 100644 --- a/syncapi/consumers/eduserver_receipts.go +++ b/syncapi/consumers/eduserver_receipts.go @@ -34,7 +34,7 @@ import ( type OutputReceiptEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database stream types.StreamProvider @@ -64,36 +64,36 @@ func NewOutputReceiptEventConsumer( // Start consuming from EDU api func (s *OutputReceiptEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var output api.OutputReceiptEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - streamPos, err := s.db.StoreReceipt( - s.ctx, - output.RoomID, - output.Type, - output.UserID, - output.EventID, - output.Timestamp, - ) - if err != nil { - sentry.CaptureException(err) - return true - } - - s.stream.Advance(streamPos) - s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) - +func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output api.OutputReceiptEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + streamPos, err := s.db.StoreReceipt( + s.ctx, + output.RoomID, + output.Type, + output.UserID, + output.EventID, + output.Timestamp, + ) + if err != nil { + sentry.CaptureException(err) + return true + } + + s.stream.Advance(streamPos) + s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) + + return true } diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go index 54e689fa1..b0beef063 100644 --- a/syncapi/consumers/eduserver_sendtodevice.go +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -36,7 +36,7 @@ import ( type OutputSendToDeviceEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database serverName gomatrixserverlib.ServerName // our server name @@ -68,52 +68,52 @@ func NewOutputSendToDeviceEventConsumer( // Start consuming from EDU api func (s *OutputSendToDeviceEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var output api.OutputSendToDeviceEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) - if err != nil { - sentry.CaptureException(err) - return true - } - if domain != s.serverName { - return true - } - - util.GetLogger(context.TODO()).WithFields(log.Fields{ - "sender": output.Sender, - "user_id": output.UserID, - "device_id": output.DeviceID, - "event_type": output.Type, - }).Info("sync API received send-to-device event from EDU server") - - streamPos, err := s.db.StoreNewSendForDeviceMessage( - s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent, - ) - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Errorf("failed to store send-to-device message") - return false - } - - s.stream.Advance(streamPos) - s.notifier.OnNewSendToDevice( - output.UserID, - []string{output.DeviceID}, - types.StreamingToken{SendToDevicePosition: streamPos}, - ) - +func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + sentry.CaptureException(err) + return true + } + if domain != s.serverName { + return true + } + + util.GetLogger(context.TODO()).WithFields(log.Fields{ + "sender": output.Sender, + "user_id": output.UserID, + "device_id": output.DeviceID, + "event_type": output.Type, + }).Info("sync API received send-to-device event from EDU server") + + streamPos, err := s.db.StoreNewSendForDeviceMessage( + s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent, + ) + if err != nil { + sentry.CaptureException(err) + log.WithError(err).Errorf("failed to store send-to-device message") + return false + } + + s.stream.Advance(streamPos) + s.notifier.OnNewSendToDevice( + output.UserID, + []string{output.DeviceID}, + types.StreamingToken{SendToDevicePosition: streamPos}, + ) + + return true } diff --git a/syncapi/consumers/eduserver_typing.go b/syncapi/consumers/eduserver_typing.go index de2f6f950..cae5df8a8 100644 --- a/syncapi/consumers/eduserver_typing.go +++ b/syncapi/consumers/eduserver_typing.go @@ -35,7 +35,7 @@ import ( type OutputTypingEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string eduCache *cache.EDUCache stream types.StreamProvider @@ -66,41 +66,41 @@ func NewOutputTypingEventConsumer( // Start consuming from EDU api func (s *OutputTypingEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var output api.OutputTypingEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - log.WithFields(log.Fields{ - "room_id": output.Event.RoomID, - "user_id": output.Event.UserID, - "typing": output.Event.Typing, - }).Debug("received data from EDU server") - - var typingPos types.StreamPosition - typingEvent := output.Event - if typingEvent.Typing { - typingPos = types.StreamPosition( - s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime), - ) - } else { - typingPos = types.StreamPosition( - s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID), - ) - } - - s.stream.Advance(typingPos) - s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) - +func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output api.OutputTypingEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + log.WithFields(log.Fields{ + "room_id": output.Event.RoomID, + "user_id": output.Event.UserID, + "typing": output.Event.Typing, + }).Debug("received data from EDU server") + + var typingPos types.StreamPosition + typingEvent := output.Event + if typingEvent.Typing { + typingPos = types.StreamPosition( + s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime), + ) + } else { + typingPos = types.StreamPosition( + s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID), + ) + } + + s.stream.Advance(typingPos) + s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) + + return true } diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 97685cc04..e806f76e6 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -18,84 +18,81 @@ import ( "context" "encoding/json" - "github.com/Shopify/sarama" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { - ctx context.Context - keyChangeConsumer *internal.ContinualConsumer - db storage.Database - notifier *notifier.Notifier - stream types.StreamProvider - serverName gomatrixserverlib.ServerName // our server name - rsAPI roomserverAPI.RoomserverInternalAPI - keyAPI api.KeyInternalAPI + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + notifier *notifier.Notifier + stream types.StreamProvider + serverName gomatrixserverlib.ServerName // our server name + rsAPI roomserverAPI.RoomserverInternalAPI + keyAPI api.KeyInternalAPI } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. // Call Start() to begin consuming from the key server. func NewOutputKeyChangeEventConsumer( process *process.ProcessContext, - serverName gomatrixserverlib.ServerName, + cfg *config.SyncAPI, topic string, - kafkaConsumer sarama.Consumer, + js nats.JetStreamContext, keyAPI api.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, ) *OutputKeyChangeEventConsumer { - - consumer := internal.ContinualConsumer{ - Process: process, - ComponentName: "syncapi/keychange", - Topic: topic, - Consumer: kafkaConsumer, - PartitionStore: store, - } - s := &OutputKeyChangeEventConsumer{ - ctx: process.Context(), - keyChangeConsumer: &consumer, - db: store, - serverName: serverName, - keyAPI: keyAPI, - rsAPI: rsAPI, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"), + topic: topic, + db: store, + serverName: cfg.Matrix.ServerName, + keyAPI: keyAPI, + rsAPI: rsAPI, + notifier: notifier, + stream: stream, } - consumer.ProcessMessage = s.onMessage - return s } // Start consuming from the key server func (s *OutputKeyChangeEventConsumer) Start() error { - return s.keyChangeConsumer.Start() + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { +func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { var m api.DeviceMessage - if err := json.Unmarshal(msg.Value, &m); err != nil { + if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") - return nil + return true } if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil { // This probably shouldn't happen but stops us from panicking if we come // across an update that doesn't satisfy either types. - return nil + return true } switch m.Type { case api.TypeCrossSigningUpdate: @@ -107,9 +104,9 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er } } -func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) error { +func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) bool { if m.DeviceKeys == nil { - return nil + return true } output := m.DeviceKeys // work out who we need to notify about the new key @@ -120,7 +117,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d if err != nil { logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") sentry.CaptureException(err) - return err + return true } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 @@ -131,10 +128,10 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d s.notifier.OnNewKeyChange(types.StreamingToken{DeviceListPosition: posUpdate}, userID, output.UserID) } - return nil + return true } -func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) error { +func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) bool { output := m.CrossSigningKeyUpdate // work out who we need to notify about the new key var queryRes roomserverAPI.QuerySharedUsersResponse @@ -144,7 +141,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage if err != nil { logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") sentry.CaptureException(err) - return err + return true } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 @@ -155,5 +152,5 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage s.notifier.OnNewKeyChange(types.StreamingToken{DeviceListPosition: posUpdate}, userID, output.UserID) } - return nil + return true } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index e9c4abe88..7fe52b728 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct { cfg *config.SyncAPI rsAPI api.RoomserverInternalAPI jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database pduStream types.StreamProvider @@ -73,65 +73,61 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe( - s.topic, s.onMessage, s.durable, - nats.DeliverAll(), - nats.ManualAck(), + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), ) - return err } // onMessage is called when the sync server receives a new event from the room server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - var err error - var output api.OutputEvent - if err = json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } - - switch output.Type { - case api.OutputTypeNewRoomEvent: - // Ignore redaction events. We will add them to the database when they are - // validated (when we receive OutputTypeRedactedEvent) - event := output.NewRoomEvent.Event - if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil { - // in the special case where the event redacts itself, just pass the message through because - // we will never see the other part of the pair - if event.Redacts() != event.EventID() { - return true - } - } - err = s.onNewRoomEvent(s.ctx, *output.NewRoomEvent) - case api.OutputTypeOldRoomEvent: - err = s.onOldRoomEvent(s.ctx, *output.OldRoomEvent) - case api.OutputTypeNewInviteEvent: - s.onNewInviteEvent(s.ctx, *output.NewInviteEvent) - case api.OutputTypeRetireInviteEvent: - s.onRetireInviteEvent(s.ctx, *output.RetireInviteEvent) - case api.OutputTypeNewPeek: - s.onNewPeek(s.ctx, *output.NewPeek) - case api.OutputTypeRetirePeek: - s.onRetirePeek(s.ctx, *output.RetirePeek) - case api.OutputTypeRedactedEvent: - err = s.onRedactEvent(s.ctx, *output.RedactedEvent) - default: - log.WithField("type", output.Type).Debug( - "roomserver output log: ignoring unknown output type", - ) - } - if err != nil { - log.WithError(err).Error("roomserver output log: failed to process event") - return false - } - +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + var err error + var output api.OutputEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") return true - }) + } + + switch output.Type { + case api.OutputTypeNewRoomEvent: + // Ignore redaction events. We will add them to the database when they are + // validated (when we receive OutputTypeRedactedEvent) + event := output.NewRoomEvent.Event + if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil { + // in the special case where the event redacts itself, just pass the message through because + // we will never see the other part of the pair + if event.Redacts() != event.EventID() { + return true + } + } + err = s.onNewRoomEvent(s.ctx, *output.NewRoomEvent) + case api.OutputTypeOldRoomEvent: + err = s.onOldRoomEvent(s.ctx, *output.OldRoomEvent) + case api.OutputTypeNewInviteEvent: + s.onNewInviteEvent(s.ctx, *output.NewInviteEvent) + case api.OutputTypeRetireInviteEvent: + s.onRetireInviteEvent(s.ctx, *output.RetireInviteEvent) + case api.OutputTypeNewPeek: + s.onNewPeek(s.ctx, *output.NewPeek) + case api.OutputTypeRetirePeek: + s.onRetirePeek(s.ctx, *output.RetirePeek) + case api.OutputTypeRedactedEvent: + err = s.onRedactEvent(s.ctx, *output.RedactedEvent) + default: + log.WithField("type", output.Type).Debug( + "roomserver output log: ignoring unknown output type", + ) + } + if err != nil { + log.WithError(err).Error("roomserver output log: failed to process event") + return false + } + + return true } func (s *OutputRoomEventConsumer) onRedactEvent( diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 41efd4a07..fa1064b70 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -18,8 +18,8 @@ import ( "context" "strings" - "github.com/Shopify/sarama" keyapi "github.com/matrix-org/dendrite/keyserver/api" + keytypes "github.com/matrix-org/dendrite/keyserver/types" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -64,8 +64,8 @@ func DeviceListCatchup( } // now also track users who we already share rooms with but who have updated their devices between the two tokens - offset := sarama.OffsetOldest - toOffset := sarama.OffsetNewest + offset := keytypes.OffsetOldest + toOffset := keytypes.OffsetNewest if to > 0 && to > from { toOffset = int64(to) } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 9cff4cad1..b464ad9cd 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -19,7 +19,6 @@ import ( eduAPI "github.com/matrix-org/dendrite/eduserver/api" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -27,8 +26,6 @@ import ( ) type Database interface { - internal.PartitionStorer - MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 39bc233ae..72462459c 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -48,7 +48,7 @@ func AddPublicRoutes( federation *gomatrixserverlib.FederationClient, cfg *config.SyncAPI, ) { - js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) syncDB, err := storage.NewSyncServerDatasource(&cfg.Database) if err != nil { @@ -65,8 +65,8 @@ func AddPublicRoutes( requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier) keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( - process, cfg.Matrix.ServerName, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), - consumer, keyAPI, rsAPI, syncDB, notifier, + process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), + js, keyAPI, rsAPI, syncDB, notifier, streams.DeviceListStreamProvider, ) if err = keyChangeConsumer.Start(); err != nil {