mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-29 09:43:10 -06:00
Merge branch 'master' into logintoken
This commit is contained in:
commit
0ab352b148
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -23,6 +23,7 @@
|
|||
/vendor/bin
|
||||
/docker/build
|
||||
/logs
|
||||
/jetstream
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
|
|
|
|||
27
CHANGES.md
27
CHANGES.md
|
|
@ -1,5 +1,32 @@
|
|||
# Changelog
|
||||
|
||||
## Dendrite 0.6.2 (2022-02-04)
|
||||
|
||||
### Fixes
|
||||
|
||||
* Resolves an issue where the key change consumer in the keyserver could consume extreme amounts of CPU
|
||||
|
||||
## Dendrite 0.6.1 (2022-02-04)
|
||||
|
||||
### Features
|
||||
|
||||
* Roomserver inputs now take place with full transactional isolation in PostgreSQL deployments
|
||||
* Pull consumers are now used instead of push consumers when retrieving messages from NATS to better guarantee ordering and to reduce redelivery of duplicate messages
|
||||
* Further logging tweaks, particularly when joining rooms
|
||||
* Improved calculation of servers in the room, when checking for missing auth/prev events or state
|
||||
* Dendrite will now skip dead servers more quickly when federating by reducing the TCP dial timeout
|
||||
* The key change consumers have now been converted to use native NATS code rather than a wrapper
|
||||
* Go 1.16 is now the minimum supported version for Dendrite
|
||||
|
||||
### Fixes
|
||||
|
||||
* Local clients should now be notified correctly of invites
|
||||
* The roomserver input API now has more time to process events, particularly when fetching missing events or state, which should fix a number of errors from expired contexts
|
||||
* Fixed a panic that could happen due to a closed channel in the roomserver input API
|
||||
* Logging in with uppercase usernames from old installations is now supported again (contributed by [hoernschen](https://github.com/hoernschen))
|
||||
* Federated room joins now have more time to complete and should not fail due to expired contexts
|
||||
* Events that were sent to the roomserver along with a complete state snapshot are now persisted with the correct state, even if they were rejected or soft-failed
|
||||
|
||||
## Dendrite 0.6.0 (2022-01-28)
|
||||
|
||||
### Features
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ func NewInternalAPI(
|
|||
},
|
||||
},
|
||||
}
|
||||
js, _, _ := jetstream.Prepare(&base.Cfg.Global.JetStream)
|
||||
js := jetstream.Prepare(&base.Cfg.Global.JetStream)
|
||||
|
||||
// Create a connection to the appservice postgres DB
|
||||
appserviceDB, err := storage.NewDatabase(&base.Cfg.AppServiceAPI.Database)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ import (
|
|||
type OutputRoomEventConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
asDB storage.Database
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
|
|
@ -66,37 +66,37 @@ func NewOutputRoomEventConsumer(
|
|||
|
||||
// Start consuming from room servers
|
||||
func (s *OutputRoomEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable)
|
||||
return err
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
// onMessage is called when the appservice component receives a new event from
|
||||
// the room server output log.
|
||||
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
var output api.OutputEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("roomserver output log: message parse failure")
|
||||
return true
|
||||
}
|
||||
|
||||
if output.Type != api.OutputTypeNewRoomEvent {
|
||||
return true
|
||||
}
|
||||
|
||||
events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event}
|
||||
events = append(events, output.NewRoomEvent.AddStateEvents...)
|
||||
|
||||
// Send event to any relevant application services
|
||||
if err := s.filterRoomserverEvents(context.TODO(), events); err != nil {
|
||||
log.WithError(err).Errorf("roomserver output log: filter error")
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
var output api.OutputEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("roomserver output log: message parse failure")
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if output.Type != api.OutputTypeNewRoomEvent {
|
||||
return true
|
||||
}
|
||||
|
||||
events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event}
|
||||
events = append(events, output.NewRoomEvent.AddStateEvents...)
|
||||
|
||||
// Send event to any relevant application services
|
||||
if err := s.filterRoomserverEvents(context.TODO(), events); err != nil {
|
||||
log.WithError(err).Errorf("roomserver output log: filter error")
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// filterRoomserverEvents takes in events and decides whether any of them need
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ func (m *DendriteMonolith) Start() {
|
|||
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
|
||||
cfg.Global.PrivateKey = sk
|
||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("file:%s/%s", m.StorageDirectory, prefix))
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix))
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ func (m *DendriteMonolith) Start() {
|
|||
cfg.Global.ServerName = gomatrixserverlib.ServerName(ygg.DerivedServerName())
|
||||
cfg.Global.PrivateKey = ygg.PrivateKey()
|
||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("file:%s/", m.StorageDirectory))
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory))
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory))
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
|
||||
|
|
|
|||
|
|
@ -12,10 +12,14 @@ COPY . .
|
|||
RUN go build ./cmd/dendrite-monolith-server
|
||||
RUN go build ./cmd/generate-keys
|
||||
RUN go build ./cmd/generate-config
|
||||
RUN ./generate-config --ci > dendrite.yaml
|
||||
RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key
|
||||
RUN ./generate-keys --private-key matrix_key.pem
|
||||
|
||||
ENV SERVER_NAME=localhost
|
||||
EXPOSE 8008 8448
|
||||
|
||||
CMD sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml && ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml
|
||||
# At runtime, generate TLS cert based on the CA now mounted at /ca
|
||||
# At runtime, replace the SERVER_NAME with what we are told
|
||||
CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /ca/ca.crt --tls-authority-key /ca/ca.key && \
|
||||
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
|
||||
cp /ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
||||
./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ package auth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
|
@ -59,8 +60,9 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
|
|||
return login, func(context.Context, *util.JSONResponse) {}, nil
|
||||
}
|
||||
|
||||
func (t *LoginTypePassword) Login(ctx context.Context, r *PasswordRequest) (*Login, *util.JSONResponse) {
|
||||
username := strings.ToLower(r.Username())
|
||||
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
|
||||
r := req.(*PasswordRequest)
|
||||
username := strings.ToLower(r.Username())
|
||||
if username == "" {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
|
|
@ -74,8 +76,15 @@ func (t *LoginTypePassword) Login(ctx context.Context, r *PasswordRequest) (*Log
|
|||
JSON: jsonerror.InvalidUsername(err.Error()),
|
||||
}
|
||||
}
|
||||
_, err = t.GetAccountByPassword(ctx, localpart, r.Password)
|
||||
// Squash username to all lowercase letters
|
||||
_, err = t.GetAccountByPassword(ctx, strings.ToLower(localpart), r.Password)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
_, err = t.GetAccountByPassword(ctx, localpart, r.Password)
|
||||
if err == nil {
|
||||
return &r.Login, nil
|
||||
}
|
||||
}
|
||||
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
||||
// but that would leak the existence of the user.
|
||||
return nil, &util.JSONResponse{
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ func AddPublicRoutes(
|
|||
extRoomsProvider api.ExtraPublicRoomsProvider,
|
||||
mscCfg *config.MSCs,
|
||||
) {
|
||||
js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
js := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
|
||||
syncProducer := &producers.SyncAPIProducer{
|
||||
JetStream: js,
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ func main() {
|
|||
if *defaultsForCI {
|
||||
cfg.AppServiceAPI.DisableTLSValidation = true
|
||||
cfg.ClientAPI.RateLimiting.Enabled = false
|
||||
cfg.FederationAPI.DisableTLSValidation = true
|
||||
cfg.FederationAPI.DisableTLSValidation = false
|
||||
// don't hit matrix.org when running tests!!!
|
||||
cfg.FederationAPI.KeyPerspectives = config.KeyPerspectives{}
|
||||
cfg.MSCs.MSCs = []string{"msc2836", "msc2946", "msc2444", "msc2753"}
|
||||
|
|
|
|||
|
|
@ -32,9 +32,12 @@ Arguments:
|
|||
`
|
||||
|
||||
var (
|
||||
tlsCertFile = flag.String("tls-cert", "", "An X509 certificate file to generate for use for TLS")
|
||||
tlsKeyFile = flag.String("tls-key", "", "An RSA private key file to generate for use for TLS")
|
||||
privateKeyFile = flag.String("private-key", "", "An Ed25519 private key to generate for use for object signing")
|
||||
tlsCertFile = flag.String("tls-cert", "", "An X509 certificate file to generate for use for TLS")
|
||||
tlsKeyFile = flag.String("tls-key", "", "An RSA private key file to generate for use for TLS")
|
||||
privateKeyFile = flag.String("private-key", "", "An Ed25519 private key to generate for use for object signing")
|
||||
authorityCertFile = flag.String("tls-authority-cert", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.")
|
||||
authorityKeyFile = flag.String("tls-authority-key", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.")
|
||||
serverName = flag.String("server", "", "Optional: Create TLS certificate/keys with this domain name set. Useful for integration testing.")
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
|
@ -54,8 +57,15 @@ func main() {
|
|||
if *tlsCertFile == "" || *tlsKeyFile == "" {
|
||||
log.Fatal("Zero or both of --tls-key and --tls-cert must be supplied")
|
||||
}
|
||||
if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil {
|
||||
panic(err)
|
||||
if *authorityCertFile == "" && *authorityKeyFile == "" {
|
||||
if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
// generate the TLS cert/key based on the authority given.
|
||||
if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
fmt.Printf("Created TLS cert file: %s\n", *tlsCertFile)
|
||||
fmt.Printf("Created TLS key file: %s\n", *tlsKeyFile)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func NewInternalAPI(
|
|||
) api.EDUServerInputAPI {
|
||||
cfg := &base.Cfg.EDUServer
|
||||
|
||||
js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
js := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
|
||||
return &input.EDUServerInputAPI{
|
||||
Cache: eduCache,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ import (
|
|||
type OutputEDUConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
|
|
@ -66,13 +66,22 @@ func NewOutputEDUConsumer(
|
|||
|
||||
// Start consuming from EDU servers
|
||||
func (t *OutputEDUConsumer) Start() error {
|
||||
if _, err := t.jetstream.Subscribe(t.typingTopic, t.onTypingEvent, t.durable); err != nil {
|
||||
if err := jetstream.JetStreamConsumer(
|
||||
t.ctx, t.jetstream, t.typingTopic, t.durable, t.onTypingEvent,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := t.jetstream.Subscribe(t.sendToDeviceTopic, t.onSendToDeviceEvent, t.durable); err != nil {
|
||||
if err := jetstream.JetStreamConsumer(
|
||||
t.ctx, t.jetstream, t.sendToDeviceTopic, t.durable, t.onSendToDeviceEvent,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := t.jetstream.Subscribe(t.receiptTopic, t.onReceiptEvent, t.durable); err != nil {
|
||||
if err := jetstream.JetStreamConsumer(
|
||||
t.ctx, t.jetstream, t.receiptTopic, t.durable, t.onReceiptEvent,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
|
@ -80,175 +89,169 @@ func (t *OutputEDUConsumer) Start() error {
|
|||
|
||||
// onSendToDeviceEvent is called in response to a message received on the
|
||||
// send-to-device events topic from the EDU server.
|
||||
func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) {
|
||||
func (t *OutputEDUConsumer) onSendToDeviceEvent(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Extract the send-to-device event from msg.
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
var ote api.OutputSendToDeviceEvent
|
||||
if err := json.Unmarshal(msg.Data, &ote); err != nil {
|
||||
log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)")
|
||||
return true
|
||||
}
|
||||
|
||||
// only send send-to-device events which originated from us
|
||||
_, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender")
|
||||
return true
|
||||
}
|
||||
if originServerName != t.ServerName {
|
||||
log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere")
|
||||
return true
|
||||
}
|
||||
|
||||
_, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination")
|
||||
return true
|
||||
}
|
||||
|
||||
// Pack the EDU and marshal it
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: gomatrixserverlib.MDirectToDevice,
|
||||
Origin: string(t.ServerName),
|
||||
}
|
||||
tdm := gomatrixserverlib.ToDeviceMessage{
|
||||
Sender: ote.Sender,
|
||||
Type: ote.Type,
|
||||
MessageID: util.RandomString(32),
|
||||
Messages: map[string]map[string]json.RawMessage{
|
||||
ote.UserID: {
|
||||
ote.DeviceID: ote.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
if edu.Content, err = json.Marshal(tdm); err != nil {
|
||||
log.WithError(err).Error("failed to marshal EDU JSON")
|
||||
return true
|
||||
}
|
||||
|
||||
log.Infof("Sending send-to-device message into %q destination queue", destServerName)
|
||||
if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil {
|
||||
log.WithError(err).Error("failed to send EDU")
|
||||
return false
|
||||
}
|
||||
|
||||
var ote api.OutputSendToDeviceEvent
|
||||
if err := json.Unmarshal(msg.Data, &ote); err != nil {
|
||||
log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)")
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// only send send-to-device events which originated from us
|
||||
_, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender")
|
||||
return true
|
||||
}
|
||||
if originServerName != t.ServerName {
|
||||
log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere")
|
||||
return true
|
||||
}
|
||||
|
||||
_, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination")
|
||||
return true
|
||||
}
|
||||
|
||||
// Pack the EDU and marshal it
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: gomatrixserverlib.MDirectToDevice,
|
||||
Origin: string(t.ServerName),
|
||||
}
|
||||
tdm := gomatrixserverlib.ToDeviceMessage{
|
||||
Sender: ote.Sender,
|
||||
Type: ote.Type,
|
||||
MessageID: util.RandomString(32),
|
||||
Messages: map[string]map[string]json.RawMessage{
|
||||
ote.UserID: {
|
||||
ote.DeviceID: ote.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
if edu.Content, err = json.Marshal(tdm); err != nil {
|
||||
log.WithError(err).Error("failed to marshal EDU JSON")
|
||||
return true
|
||||
}
|
||||
|
||||
log.Infof("Sending send-to-device message into %q destination queue", destServerName)
|
||||
if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil {
|
||||
log.WithError(err).Error("failed to send EDU")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// onTypingEvent is called in response to a message received on the typing
|
||||
// events topic from the EDU server.
|
||||
func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
// Extract the typing event from msg.
|
||||
var ote api.OutputTypingEvent
|
||||
if err := json.Unmarshal(msg.Data, &ote); err != nil {
|
||||
// Skip this msg but continue processing messages.
|
||||
log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)")
|
||||
_ = msg.Ack()
|
||||
return true
|
||||
}
|
||||
|
||||
// only send typing events which originated from us
|
||||
_, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender")
|
||||
_ = msg.Ack()
|
||||
return true
|
||||
}
|
||||
if typingServerName != t.ServerName {
|
||||
return true
|
||||
}
|
||||
|
||||
joined, err := t.db.GetJoinedHosts(t.ctx, ote.Event.RoomID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room")
|
||||
return false
|
||||
}
|
||||
|
||||
names := make([]gomatrixserverlib.ServerName, len(joined))
|
||||
for i := range joined {
|
||||
names[i] = joined[i].ServerName
|
||||
}
|
||||
|
||||
edu := &gomatrixserverlib.EDU{Type: ote.Event.Type}
|
||||
if edu.Content, err = json.Marshal(map[string]interface{}{
|
||||
"room_id": ote.Event.RoomID,
|
||||
"user_id": ote.Event.UserID,
|
||||
"typing": ote.Event.Typing,
|
||||
}); err != nil {
|
||||
log.WithError(err).Error("failed to marshal EDU JSON")
|
||||
return true
|
||||
}
|
||||
|
||||
if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil {
|
||||
log.WithError(err).Error("failed to send EDU")
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *OutputEDUConsumer) onTypingEvent(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Extract the typing event from msg.
|
||||
var ote api.OutputTypingEvent
|
||||
if err := json.Unmarshal(msg.Data, &ote); err != nil {
|
||||
// Skip this msg but continue processing messages.
|
||||
log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)")
|
||||
_ = msg.Ack()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// only send typing events which originated from us
|
||||
_, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender")
|
||||
_ = msg.Ack()
|
||||
return true
|
||||
}
|
||||
if typingServerName != t.ServerName {
|
||||
return true
|
||||
}
|
||||
|
||||
joined, err := t.db.GetJoinedHosts(ctx, ote.Event.RoomID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room")
|
||||
return false
|
||||
}
|
||||
|
||||
names := make([]gomatrixserverlib.ServerName, len(joined))
|
||||
for i := range joined {
|
||||
names[i] = joined[i].ServerName
|
||||
}
|
||||
|
||||
edu := &gomatrixserverlib.EDU{Type: ote.Event.Type}
|
||||
if edu.Content, err = json.Marshal(map[string]interface{}{
|
||||
"room_id": ote.Event.RoomID,
|
||||
"user_id": ote.Event.UserID,
|
||||
"typing": ote.Event.Typing,
|
||||
}); err != nil {
|
||||
log.WithError(err).Error("failed to marshal EDU JSON")
|
||||
return true
|
||||
}
|
||||
|
||||
if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil {
|
||||
log.WithError(err).Error("failed to send EDU")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// onReceiptEvent is called in response to a message received on the receipt
|
||||
// events topic from the EDU server.
|
||||
func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
// Extract the typing event from msg.
|
||||
var receipt api.OutputReceiptEvent
|
||||
if err := json.Unmarshal(msg.Data, &receipt); err != nil {
|
||||
// Skip this msg but continue processing messages.
|
||||
log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)")
|
||||
return true
|
||||
}
|
||||
|
||||
// only send receipt events which originated from us
|
||||
_, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender")
|
||||
return true
|
||||
}
|
||||
if receiptServerName != t.ServerName {
|
||||
return true
|
||||
}
|
||||
|
||||
joined, err := t.db.GetJoinedHosts(t.ctx, receipt.RoomID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room")
|
||||
return false
|
||||
}
|
||||
|
||||
names := make([]gomatrixserverlib.ServerName, len(joined))
|
||||
for i := range joined {
|
||||
names[i] = joined[i].ServerName
|
||||
}
|
||||
|
||||
content := map[string]api.FederationReceiptMRead{}
|
||||
content[receipt.RoomID] = api.FederationReceiptMRead{
|
||||
User: map[string]api.FederationReceiptData{
|
||||
receipt.UserID: {
|
||||
Data: api.ReceiptTS{
|
||||
TS: receipt.Timestamp,
|
||||
},
|
||||
EventIDs: []string{receipt.EventID},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: gomatrixserverlib.MReceipt,
|
||||
Origin: string(t.ServerName),
|
||||
}
|
||||
if edu.Content, err = json.Marshal(content); err != nil {
|
||||
log.WithError(err).Error("failed to marshal EDU JSON")
|
||||
return true
|
||||
}
|
||||
|
||||
if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil {
|
||||
log.WithError(err).Error("failed to send EDU")
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *OutputEDUConsumer) onReceiptEvent(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Extract the typing event from msg.
|
||||
var receipt api.OutputReceiptEvent
|
||||
if err := json.Unmarshal(msg.Data, &receipt); err != nil {
|
||||
// Skip this msg but continue processing messages.
|
||||
log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)")
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// only send receipt events which originated from us
|
||||
_, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender")
|
||||
return true
|
||||
}
|
||||
if receiptServerName != t.ServerName {
|
||||
return true
|
||||
}
|
||||
|
||||
joined, err := t.db.GetJoinedHosts(ctx, receipt.RoomID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room")
|
||||
return false
|
||||
}
|
||||
|
||||
names := make([]gomatrixserverlib.ServerName, len(joined))
|
||||
for i := range joined {
|
||||
names[i] = joined[i].ServerName
|
||||
}
|
||||
|
||||
content := map[string]api.FederationReceiptMRead{}
|
||||
content[receipt.RoomID] = api.FederationReceiptMRead{
|
||||
User: map[string]api.FederationReceiptData{
|
||||
receipt.UserID: {
|
||||
Data: api.ReceiptTS{
|
||||
TS: receipt.Timestamp,
|
||||
},
|
||||
EventIDs: []string{receipt.EventID},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: gomatrixserverlib.MReceipt,
|
||||
Origin: string(t.ServerName),
|
||||
}
|
||||
if edu.Content, err = json.Marshal(content); err != nil {
|
||||
log.WithError(err).Error("failed to marshal EDU JSON")
|
||||
return true
|
||||
}
|
||||
|
||||
if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil {
|
||||
log.WithError(err).Error("failed to send EDU")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,80 +17,73 @@ package consumers
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/queue"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// KeyChangeConsumer consumes events that originate in key server.
|
||||
type KeyChangeConsumer struct {
|
||||
ctx context.Context
|
||||
consumer *internal.ContinualConsumer
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
serverName gomatrixserverlib.ServerName
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||
topic string
|
||||
}
|
||||
|
||||
// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers.
|
||||
func NewKeyChangeConsumer(
|
||||
process *process.ProcessContext,
|
||||
cfg *config.KeyServer,
|
||||
kafkaConsumer sarama.Consumer,
|
||||
js nats.JetStreamContext,
|
||||
queues *queue.OutgoingQueues,
|
||||
store storage.Database,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
) *KeyChangeConsumer {
|
||||
c := &KeyChangeConsumer{
|
||||
ctx: process.Context(),
|
||||
consumer: &internal.ContinualConsumer{
|
||||
Process: process,
|
||||
ComponentName: "federationapi/keychange",
|
||||
Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)),
|
||||
Consumer: kafkaConsumer,
|
||||
PartitionStore: store,
|
||||
},
|
||||
return &KeyChangeConsumer{
|
||||
ctx: process.Context(),
|
||||
jetstream: js,
|
||||
durable: cfg.Matrix.JetStream.TopicFor("FederationAPIKeyChangeConsumer"),
|
||||
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
|
||||
queues: queues,
|
||||
db: store,
|
||||
serverName: cfg.Matrix.ServerName,
|
||||
rsAPI: rsAPI,
|
||||
}
|
||||
c.consumer.ProcessMessage = c.onMessage
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Start consuming from key servers
|
||||
func (t *KeyChangeConsumer) Start() error {
|
||||
if err := t.consumer.Start(); err != nil {
|
||||
return fmt.Errorf("t.consumer.Start: %w", err)
|
||||
}
|
||||
return nil
|
||||
return jetstream.JetStreamConsumer(
|
||||
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
// onMessage is called in response to a message received on the
|
||||
// key change events topic from the key server.
|
||||
func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
||||
func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var m api.DeviceMessage
|
||||
if err := json.Unmarshal(msg.Value, &m); err != nil {
|
||||
if err := json.Unmarshal(msg.Data, &m); err != nil {
|
||||
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil {
|
||||
// This probably shouldn't happen but stops us from panicking if we come
|
||||
// across an update that doesn't satisfy either types.
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
switch m.Type {
|
||||
case api.TypeCrossSigningUpdate:
|
||||
|
|
@ -102,9 +95,9 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
|
||||
func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
|
||||
if m.DeviceKeys == nil {
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
logger := logrus.WithField("user_id", m.UserID)
|
||||
|
||||
|
|
@ -112,10 +105,10 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
|
|||
_, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Failed to extract domain from key change event")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
if originServerName != t.serverName {
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
|
||||
var queryRes roomserverAPI.QueryRoomsForUserResponse
|
||||
|
|
@ -125,13 +118,13 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
|
|||
}, &queryRes)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("failed to calculate joined rooms for user")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
// send this key change to all servers who share rooms with this user.
|
||||
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
|
||||
// Pack the EDU and marshal it
|
||||
|
|
@ -149,24 +142,26 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
|
|||
Keys: m.KeyJSON,
|
||||
}
|
||||
if edu.Content, err = json.Marshal(event); err != nil {
|
||||
return err
|
||||
logger.WithError(err).Error("failed to marshal EDU JSON")
|
||||
return true
|
||||
}
|
||||
|
||||
logrus.Infof("Sending device list update message to %q", destinations)
|
||||
return t.queues.SendEDU(edu, t.serverName, destinations)
|
||||
logger.Infof("Sending device list update message to %q", destinations)
|
||||
err = t.queues.SendEDU(edu, t.serverName, destinations)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
|
||||
func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
|
||||
output := m.CrossSigningKeyUpdate
|
||||
_, host, err := gomatrixserverlib.SplitID('@', output.UserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
if host != gomatrixserverlib.ServerName(t.serverName) {
|
||||
// Ignore any messages that didn't originate locally, otherwise we'll
|
||||
// end up parroting information we received from other servers.
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
logger := logrus.WithField("user_id", output.UserID)
|
||||
|
||||
|
|
@ -177,13 +172,13 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
|
|||
}, &queryRes)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
// send this key change to all servers who share rooms with this user.
|
||||
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
|
||||
// Pack the EDU and marshal it
|
||||
|
|
@ -193,11 +188,12 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
|
|||
}
|
||||
if edu.Content, err = json.Marshal(output); err != nil {
|
||||
logger.WithError(err).Error("fedsender key change consumer: failed to marshal output, dropping")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
|
||||
logger.Infof("Sending cross-signing update message to %q", destinations)
|
||||
return t.queues.SendEDU(edu, t.serverName, destinations)
|
||||
err = t.queues.SendEDU(edu, t.serverName, destinations)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func prevID(streamID int) []int {
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ type OutputRoomEventConsumer struct {
|
|||
cfg *config.FederationAPI
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
topic string
|
||||
|
|
@ -66,74 +66,75 @@ func NewOutputRoomEventConsumer(
|
|||
|
||||
// Start consuming from room servers
|
||||
func (s *OutputRoomEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(
|
||||
s.topic, s.onMessage, s.durable,
|
||||
nats.DeliverAll(),
|
||||
nats.ManualAck(),
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// onMessage is called when the federation server receives a new event from the room server output log.
|
||||
// It is unsafe to call this with messages for the same room in multiple gorountines
|
||||
// because updates it will likely fail with a types.EventIDMismatchError when it
|
||||
// realises that it cannot update the room state using the deltas.
|
||||
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
var output api.OutputEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("roomserver output log: message parse failure")
|
||||
return true
|
||||
}
|
||||
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
var output api.OutputEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("roomserver output log: message parse failure")
|
||||
return true
|
||||
}
|
||||
|
||||
switch output.Type {
|
||||
case api.OutputTypeNewRoomEvent:
|
||||
ev := output.NewRoomEvent.Event
|
||||
switch output.Type {
|
||||
case api.OutputTypeNewRoomEvent:
|
||||
ev := output.NewRoomEvent.Event
|
||||
|
||||
if output.NewRoomEvent.RewritesState {
|
||||
if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil {
|
||||
log.WithError(err).Errorf("roomserver output log: purge room state failure")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.processMessage(*output.NewRoomEvent); err != nil {
|
||||
switch err.(type) {
|
||||
case *queue.ErrorFederationDisabled:
|
||||
log.WithField("error", output.Type).Info(
|
||||
err.Error(),
|
||||
)
|
||||
default:
|
||||
// panic rather than continue with an inconsistent database
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": ev.EventID(),
|
||||
"event": string(ev.JSON()),
|
||||
"add": output.NewRoomEvent.AddsStateEventIDs,
|
||||
"del": output.NewRoomEvent.RemovesStateEventIDs,
|
||||
log.ErrorKey: err,
|
||||
}).Panicf("roomserver output log: write room event failure")
|
||||
}
|
||||
}
|
||||
|
||||
case api.OutputTypeNewInboundPeek:
|
||||
if err := s.processInboundPeek(*output.NewInboundPeek); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"event": output.NewInboundPeek,
|
||||
log.ErrorKey: err,
|
||||
}).Panicf("roomserver output log: remote peek event failure")
|
||||
if output.NewRoomEvent.RewritesState {
|
||||
if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil {
|
||||
log.WithError(err).Errorf("roomserver output log: purge room state failure")
|
||||
return false
|
||||
}
|
||||
|
||||
default:
|
||||
log.WithField("type", output.Type).Debug(
|
||||
"roomserver output log: ignoring unknown output type",
|
||||
)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
if err := s.processMessage(*output.NewRoomEvent); err != nil {
|
||||
switch err.(type) {
|
||||
case *queue.ErrorFederationDisabled:
|
||||
log.WithField("error", output.Type).Info(
|
||||
err.Error(),
|
||||
)
|
||||
default:
|
||||
// panic rather than continue with an inconsistent database
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": ev.EventID(),
|
||||
"event": string(ev.JSON()),
|
||||
"add": output.NewRoomEvent.AddsStateEventIDs,
|
||||
"del": output.NewRoomEvent.RemovesStateEventIDs,
|
||||
log.ErrorKey: err,
|
||||
}).Panicf("roomserver output log: write room event failure")
|
||||
}
|
||||
}
|
||||
|
||||
case api.OutputTypeNewInviteEvent:
|
||||
log.WithField("type", output.Type).Debug(
|
||||
"received new invite, send device keys",
|
||||
)
|
||||
|
||||
case api.OutputTypeNewInboundPeek:
|
||||
if err := s.processInboundPeek(*output.NewInboundPeek); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"event": output.NewInboundPeek,
|
||||
log.ErrorKey: err,
|
||||
}).Panicf("roomserver output log: remote peek event failure")
|
||||
return false
|
||||
}
|
||||
|
||||
default:
|
||||
log.WithField("type", output.Type).Debug(
|
||||
"roomserver output log: ignoring unknown output type",
|
||||
)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any)
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ func NewInternalAPI(
|
|||
FailuresUntilBlacklist: cfg.FederationMaxRetries,
|
||||
}
|
||||
|
||||
js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
js := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
|
||||
queues := queue.NewOutgoingQueues(
|
||||
federationDB, base.ProcessContext,
|
||||
|
|
@ -120,7 +120,7 @@ func NewInternalAPI(
|
|||
logrus.WithError(err).Panic("failed to start typing server consumer")
|
||||
}
|
||||
keyConsumer := consumers.NewKeyChangeConsumer(
|
||||
base.ProcessContext, &base.Cfg.KeyServer, consumer, queues, federationDB, rsAPI,
|
||||
base.ProcessContext, &base.Cfg.KeyServer, js, queues, federationDB, rsAPI,
|
||||
)
|
||||
if err := keyConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panic("failed to start key server consumer")
|
||||
|
|
|
|||
|
|
@ -196,29 +196,22 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
|||
return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err)
|
||||
}
|
||||
|
||||
// No longer reuse the request context from this point forward.
|
||||
// We don't want the client timing out to interrupt the join.
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
// Try to perform a send_join using the newly built event.
|
||||
respSendJoin, err := r.federation.SendJoin(
|
||||
ctx,
|
||||
context.Background(),
|
||||
serverName,
|
||||
event,
|
||||
respMakeJoin.RoomVersion,
|
||||
)
|
||||
if err != nil {
|
||||
r.statistics.ForServer(serverName).Failure()
|
||||
cancel()
|
||||
return fmt.Errorf("r.federation.SendJoin: %w", err)
|
||||
}
|
||||
r.statistics.ForServer(serverName).Success()
|
||||
|
||||
// Sanity-check the join response to ensure that it has a create
|
||||
// event, that the room version is known, etc.
|
||||
if err := sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil {
|
||||
cancel()
|
||||
if err = sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil {
|
||||
return fmt.Errorf("sanityCheckAuthChain: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -227,41 +220,35 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
|||
// to complete, but if the client does give up waiting, we'll
|
||||
// still continue to process the join anyway so that we don't
|
||||
// waste the effort.
|
||||
go func() {
|
||||
defer cancel()
|
||||
// TODO: Can we expand Check here to return a list of missing auth
|
||||
// events rather than failing one at a time?
|
||||
var respState *gomatrixserverlib.RespState
|
||||
respState, err = respSendJoin.Check(
|
||||
context.Background(),
|
||||
r.keyRing,
|
||||
event,
|
||||
federatedAuthProvider(ctx, r.federation, r.keyRing, serverName),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("respSendJoin.Check: %w", err)
|
||||
}
|
||||
|
||||
// TODO: Can we expand Check here to return a list of missing auth
|
||||
// events rather than failing one at a time?
|
||||
respState, err := respSendJoin.Check(ctx, r.keyRing, event, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName))
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"room_id": roomID,
|
||||
"user_id": userID,
|
||||
}).WithError(err).Error("Failed to process room join response")
|
||||
return
|
||||
}
|
||||
// If we successfully performed a send_join above then the other
|
||||
// server now thinks we're a part of the room. Send the newly
|
||||
// returned state to the roomserver to update our local view.
|
||||
if err = roomserverAPI.SendEventWithState(
|
||||
context.Background(),
|
||||
r.rsAPI,
|
||||
roomserverAPI.KindNew,
|
||||
respState,
|
||||
event.Headered(respMakeJoin.RoomVersion),
|
||||
serverName,
|
||||
nil,
|
||||
false,
|
||||
); err != nil {
|
||||
return fmt.Errorf("roomserverAPI.SendEventWithState: %w", err)
|
||||
}
|
||||
|
||||
// If we successfully performed a send_join above then the other
|
||||
// server now thinks we're a part of the room. Send the newly
|
||||
// returned state to the roomserver to update our local view.
|
||||
if err = roomserverAPI.SendEventWithState(
|
||||
ctx, r.rsAPI,
|
||||
roomserverAPI.KindNew,
|
||||
respState,
|
||||
event.Headered(respMakeJoin.RoomVersion),
|
||||
serverName,
|
||||
nil,
|
||||
false,
|
||||
); err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"room_id": roomID,
|
||||
"user_id": userID,
|
||||
}).WithError(err).Error("Failed to send room join response to roomserver")
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,12 +19,10 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
internal.PartitionStorer
|
||||
gomatrixserverlib.KeyDatabase
|
||||
|
||||
UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error)
|
||||
|
|
|
|||
13
go.mod
13
go.mod
|
|
@ -11,18 +11,19 @@ require (
|
|||
github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect
|
||||
github.com/MFAshby/stdemuxerhook v1.0.0
|
||||
github.com/Masterminds/semver/v3 v3.1.1
|
||||
github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32
|
||||
github.com/Shopify/sarama v1.31.0
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/codeclysm/extract v2.2.0+incompatible
|
||||
github.com/containerd/containerd v1.5.9 // indirect
|
||||
github.com/docker/docker v20.10.12+incompatible
|
||||
github.com/docker/go-connections v0.4.0
|
||||
github.com/frankban/quicktest v1.14.0 // indirect
|
||||
github.com/getsentry/sentry-go v0.12.0
|
||||
github.com/gologme/log v1.3.0
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/h2non/filetype v1.1.3 // indirect
|
||||
github.com/hashicorp/golang-lru v0.5.4
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect
|
||||
github.com/klauspost/compress v1.14.2 // indirect
|
||||
github.com/lib/pq v1.10.4
|
||||
|
|
@ -40,7 +41,7 @@ require (
|
|||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275
|
||||
github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/mattn/go-sqlite3 v1.14.10
|
||||
|
|
@ -54,7 +55,9 @@ require (
|
|||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pressly/goose v2.7.0+incompatible
|
||||
github.com/prometheus/client_golang v1.12.1
|
||||
github.com/prometheus/client_golang v1.11.0
|
||||
github.com/prometheus/common v0.32.1 // indirect
|
||||
github.com/prometheus/procfs v0.7.3 // indirect
|
||||
github.com/sirupsen/logrus v1.8.1
|
||||
github.com/tidwall/gjson v1.13.0
|
||||
github.com/tidwall/sjson v1.2.4
|
||||
|
|
@ -66,9 +69,11 @@ require (
|
|||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
|
||||
golang.org/x/mobile v0.0.0-20220112015953-858099ff7816
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
|
||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
|
||||
gopkg.in/h2non/bimg.v1 v1.1.5
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
nhooyr.io/websocket v1.8.7
|
||||
)
|
||||
|
||||
|
|
|
|||
61
go.sum
61
go.sum
|
|
@ -100,17 +100,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko
|
|||
github.com/RoaringBitmap/roaring v0.4.7/go.mod h1:8khRDP4HmeXns4xIj9oGrKSz7XTQiJx2zgh7AcNke4w=
|
||||
github.com/RyanCarrier/dijkstra v1.0.0/go.mod h1:5agGUBNEtUAGIANmbw09fuO3a2htPEkc1jNH01qxCWA=
|
||||
github.com/RyanCarrier/dijkstra-1 v0.0.0-20170512020943-0e5801a26345/go.mod h1:OK4EvWJ441LQqGzed5NGB6vKBAE34n3z7iayPcEwr30=
|
||||
github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 h1:i3fOph9Hjleo6LbuqN9ODFxnwt7mOtYMpCGeC8qJN50=
|
||||
github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32/go.mod h1:ne+jkLlzafIzaE4Q0Ze81T27dNgXe1wxovVEoAtSHTc=
|
||||
github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0=
|
||||
github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ=
|
||||
github.com/Shopify/sarama v1.29.0/go.mod h1:2QpgD79wpdAESqNQMxNc0KYMkycd4slxGdV3TWSVqrU=
|
||||
github.com/Shopify/sarama v1.31.0 h1:gObk7jCPutDxf+E6GA5G21noAZsi1SvP9ftCQYqpzus=
|
||||
github.com/Shopify/sarama v1.31.0/go.mod h1:BeW3gXRc/CxgAsrSly2RE9nIXUfC9ezb7QHBPVhvzjI=
|
||||
github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc=
|
||||
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
||||
github.com/Shopify/toxiproxy/v2 v2.3.0 h1:62YkpiP4bzdhKMH+6uC5E95y608k3zDwdzuBMsnn3uQ=
|
||||
github.com/Shopify/toxiproxy/v2 v2.3.0/go.mod h1:KvQTtB6RjCJY4zqNJn7C7JDFgsG5uoHYDirfUfpIm0c=
|
||||
github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA=
|
||||
github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4=
|
||||
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
|
||||
|
|
@ -353,12 +344,6 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3
|
|||
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/dustin/go-humanize v0.0.0-20180421182945-02af3965c54e/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/eapache/go-resiliency v1.2.0 h1:v7g92e/KSN71Rq7vSThKaWIq68fL4YHvWyiUKorFR1Q=
|
||||
github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
|
||||
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw=
|
||||
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
|
||||
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
|
||||
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
|
||||
github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM=
|
||||
github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
|
||||
github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs=
|
||||
|
|
@ -379,8 +364,6 @@ github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 h1:u/UEqS66A5ckRmS4yNp
|
|||
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6/go.mod h1:1i71OnUq3iUe1ma7Lr6yG6/rjvM3emb6yoL7xLFzcVQ=
|
||||
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
|
||||
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
|
||||
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
|
||||
github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY=
|
||||
github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k=
|
||||
github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o=
|
||||
|
|
@ -498,8 +481,6 @@ github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw
|
|||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U=
|
||||
github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo=
|
||||
github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4=
|
||||
|
|
@ -552,8 +533,6 @@ github.com/gorilla/handlers v0.0.0-20150720190736-60c7bfde3e33/go.mod h1:Qkdc/uu
|
|||
github.com/gorilla/mux v1.7.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
||||
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
|
|
@ -580,8 +559,6 @@ github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHh
|
|||
github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI=
|
||||
github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA=
|
||||
github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4=
|
||||
github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
|
|
@ -667,18 +644,6 @@ github.com/jbenet/goprocess v0.0.0-20160826012719-b497e2f366b8/go.mod h1:Ly/wlsj
|
|||
github.com/jbenet/goprocess v0.1.3/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4=
|
||||
github.com/jbenet/goprocess v0.1.4 h1:DRGOFReOMqqDNXwW70QkacFW0YN9QnwLV0Vqk+3oU0o=
|
||||
github.com/jbenet/goprocess v0.1.4/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4=
|
||||
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
|
||||
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
|
||||
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
|
||||
github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM=
|
||||
github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8=
|
||||
github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o=
|
||||
github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o=
|
||||
github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg=
|
||||
github.com/jcmturner/gokrb5/v8 v8.4.2 h1:6ZIM6b/JJN0X8UM43ZOM6Z4SJzla+a/u7scXFJzodkA=
|
||||
github.com/jcmturner/gokrb5/v8 v8.4.2/go.mod h1:sb+Xq/fTY5yktf/VxLsE3wlfPqQjp0aWNYyvBVK62bc=
|
||||
github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
|
||||
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||
github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU=
|
||||
github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
|
|
@ -747,10 +712,7 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0
|
|||
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.12.2/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||
github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw=
|
||||
github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
|
||||
|
|
@ -1021,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
|
|||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29 h1:1t/J3AldUbgRxltlcmMbUefexxzolG5DvV2CkriZ4LM=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275 h1:f6Hh7D3EOTl1uUr76FiyHNA1h4pKBhcVUtyHbxn0hKA=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
|
||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
|
|
@ -1250,9 +1212,6 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9
|
|||
github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc=
|
||||
github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
|
||||
github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
|
||||
github.com/pierrec/lz4 v2.6.0+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
||||
github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM=
|
||||
github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
|
@ -1272,9 +1231,8 @@ github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDf
|
|||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||
github.com/prometheus/client_golang v1.1.0/go.mod h1:I1FGZT9+L76gKKOs5djB6ezCbFQP1xR9D75/vuwEF3g=
|
||||
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
||||
github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ=
|
||||
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
||||
github.com/prometheus/client_golang v1.12.1 h1:ZiaPsmm9uiBeaSMRznKsCDNtPCS0T3JVDGF+06gjBzk=
|
||||
github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
|
||||
github.com/prometheus/client_model v0.0.0-20171117100541-99fa1f4be8e5/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
|
|
@ -1306,8 +1264,6 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
|
|||
github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU=
|
||||
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||
|
|
@ -1433,7 +1389,6 @@ github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5/go.mod h1:70zkFmudgCuE/
|
|||
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
|
||||
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
|
||||
github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
|
||||
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
|
||||
github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBnvPM1Su9w=
|
||||
|
|
@ -1464,11 +1419,6 @@ github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPyS
|
|||
github.com/willf/bitset v1.1.11-0.20200630133818-d5bec3311243/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4=
|
||||
github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI=
|
||||
github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||
github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs=
|
||||
github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM=
|
||||
github.com/xdg/scram v1.0.3/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I=
|
||||
github.com/xdg/stringprep v1.0.3/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
|
||||
github.com/xeipuuv/gojsonschema v0.0.0-20180618132009-1d523034197f/go.mod h1:5yf86TLmAcydyeJq5YvxkGPE2fm/u4myDekKRoLuqhs=
|
||||
|
|
@ -1551,7 +1501,6 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
|||
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
|
|
@ -1656,7 +1605,6 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY
|
|||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210427231257-85d9c07bbe3a/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
|
|
@ -1665,7 +1613,6 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx
|
|||
golang.org/x/net v0.0.0-20210927181540-4e4d966f7476/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211008194852-3b03d305991f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk=
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
|
|
@ -1799,7 +1746,6 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3
|
|||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7-0.20210503195748-5c7c50ebbd4f/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||
|
|
@ -2035,7 +1981,6 @@ gopkg.in/yaml.v2 v2.0.0-20170712054546-1be3d31502d6/go.mod h1:JAlM8MvJe8wmxCU4Bl
|
|||
gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
|
|
|||
|
|
@ -1,139 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// A PartitionStorer has the storage APIs needed by the consumer.
|
||||
type PartitionStorer interface {
|
||||
// PartitionOffsets returns the offsets the consumer has reached for each partition.
|
||||
PartitionOffsets(ctx context.Context, topic string) ([]sqlutil.PartitionOffset, error)
|
||||
// SetPartitionOffset records where the consumer has reached for a partition.
|
||||
SetPartitionOffset(ctx context.Context, topic string, partition int32, offset int64) error
|
||||
}
|
||||
|
||||
// A ContinualConsumer continually consumes logs even across restarts. It requires a PartitionStorer to
|
||||
// remember the offset it reached.
|
||||
type ContinualConsumer struct {
|
||||
// The parent context for the listener, stop consuming when this context is done
|
||||
Process *process.ProcessContext
|
||||
// The component name
|
||||
ComponentName string
|
||||
// The kafkaesque topic to consume events from.
|
||||
// This is the name used in kafka to identify the stream to consume events from.
|
||||
Topic string
|
||||
// A kafkaesque stream consumer providing the APIs for talking to the event source.
|
||||
// The interface is taken from a client library for Apache Kafka.
|
||||
// But any equivalent event streaming protocol could be made to implement the same interface.
|
||||
Consumer sarama.Consumer
|
||||
// A thing which can load and save partition offsets for a topic.
|
||||
PartitionStore PartitionStorer
|
||||
// ProcessMessage is a function which will be called for each message in the log. Return an error to
|
||||
// stop processing messages. See ErrShutdown for specific control signals.
|
||||
ProcessMessage func(msg *sarama.ConsumerMessage) error
|
||||
// ShutdownCallback is called when ProcessMessage returns ErrShutdown, after the partition has been saved.
|
||||
// It is optional.
|
||||
ShutdownCallback func()
|
||||
}
|
||||
|
||||
// ErrShutdown can be returned from ContinualConsumer.ProcessMessage to stop the ContinualConsumer.
|
||||
var ErrShutdown = fmt.Errorf("shutdown")
|
||||
|
||||
// Start starts the consumer consuming.
|
||||
// Starts up a goroutine for each partition in the kafka stream.
|
||||
// Returns nil once all the goroutines are started.
|
||||
// Returns an error if it can't start consuming for any of the partitions.
|
||||
func (c *ContinualConsumer) Start() error {
|
||||
_, err := c.StartOffsets()
|
||||
return err
|
||||
}
|
||||
|
||||
// StartOffsets is the same as Start but returns the loaded offsets as well.
|
||||
func (c *ContinualConsumer) StartOffsets() ([]sqlutil.PartitionOffset, error) {
|
||||
offsets := map[int32]int64{}
|
||||
|
||||
partitions, err := c.Consumer.Partitions(c.Topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, partition := range partitions {
|
||||
// Default all the offsets to the beginning of the stream.
|
||||
offsets[partition] = sarama.OffsetOldest
|
||||
}
|
||||
|
||||
storedOffsets, err := c.PartitionStore.PartitionOffsets(context.TODO(), c.Topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, offset := range storedOffsets {
|
||||
// We've already processed events from this partition so advance the offset to where we got to.
|
||||
// ConsumePartition will start streaming from the message with the given offset (inclusive),
|
||||
// so increment 1 to avoid getting the same message a second time.
|
||||
offsets[offset.Partition] = 1 + offset.Offset
|
||||
}
|
||||
|
||||
var partitionConsumers []sarama.PartitionConsumer
|
||||
for partition, offset := range offsets {
|
||||
pc, err := c.Consumer.ConsumePartition(c.Topic, partition, offset)
|
||||
if err != nil {
|
||||
for _, p := range partitionConsumers {
|
||||
p.Close() // nolint: errcheck
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
partitionConsumers = append(partitionConsumers, pc)
|
||||
}
|
||||
for _, pc := range partitionConsumers {
|
||||
go c.consumePartition(pc)
|
||||
if c.Process != nil {
|
||||
c.Process.ComponentStarted()
|
||||
go func(pc sarama.PartitionConsumer) {
|
||||
<-c.Process.WaitForShutdown()
|
||||
_ = pc.Close()
|
||||
c.Process.ComponentFinished()
|
||||
logrus.Infof("Stopped consumer for %q topic %q", c.ComponentName, c.Topic)
|
||||
}(pc)
|
||||
}
|
||||
}
|
||||
|
||||
return storedOffsets, nil
|
||||
}
|
||||
|
||||
// consumePartition consumes the room events for a single partition of the kafkaesque stream.
|
||||
func (c *ContinualConsumer) consumePartition(pc sarama.PartitionConsumer) {
|
||||
defer pc.Close() // nolint: errcheck
|
||||
for message := range pc.Messages() {
|
||||
msgErr := c.ProcessMessage(message)
|
||||
// Advance our position in the stream so that we will start at the right position after a restart.
|
||||
if err := c.PartitionStore.SetPartitionOffset(context.TODO(), c.Topic, message.Partition, message.Offset); err != nil {
|
||||
panic(fmt.Errorf("the ContinualConsumer in %q failed to SetPartitionOffset: %w", c.ComponentName, err))
|
||||
}
|
||||
// Shutdown if we were told to do so.
|
||||
if msgErr == ErrShutdown {
|
||||
if c.ShutdownCallback != nil {
|
||||
c.ShutdownCallback()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
|
|
@ -158,11 +159,10 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
|
|||
|
||||
const certificateDuration = time.Hour * 24 * 365 * 10
|
||||
|
||||
// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file.
|
||||
func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
||||
func generateTLSTemplate(dnsNames []string) (*rsa.PrivateKey, *x509.Certificate, error) {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
|
|
@ -170,7 +170,7 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
|||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
|
|
@ -180,20 +180,21 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
|||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: dnsNames,
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return priv, &template, nil
|
||||
}
|
||||
|
||||
func writeCertificate(tlsCertPath string, derBytes []byte) error {
|
||||
certOut, err := os.Create(tlsCertPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer certOut.Close() // nolint: errcheck
|
||||
if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||
return err
|
||||
}
|
||||
return pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
}
|
||||
|
||||
func writePrivateKey(tlsKeyPath string, priv *rsa.PrivateKey) error {
|
||||
keyOut, err := os.OpenFile(tlsKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -205,3 +206,73 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
|||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file.
|
||||
func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
||||
priv, template, err := generateTLSTemplate(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Self-signed certificate: template == parent
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = writeCertificate(tlsCertPath, derBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
return writePrivateKey(tlsKeyPath, priv)
|
||||
}
|
||||
|
||||
func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string) error {
|
||||
priv, template, err := generateTLSTemplate([]string{serverName})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the authority key
|
||||
dat, err := ioutil.ReadFile(authorityKeyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
block, _ := pem.Decode([]byte(dat))
|
||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||
return errors.New("authority .key is not a valid pem encoded rsa private key")
|
||||
}
|
||||
authorityPriv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the authority certificate
|
||||
dat, err = ioutil.ReadFile(authorityCertPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
block, _ = pem.Decode([]byte(dat))
|
||||
if block == nil || block.Type != "CERTIFICATE" {
|
||||
return errors.New("authority .crt is not a valid pem encoded x509 cert")
|
||||
}
|
||||
var caCerts []*x509.Certificate
|
||||
caCerts, err = x509.ParseCertificates(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(caCerts) != 1 {
|
||||
return errors.New("authority .crt contains none or more than one cert")
|
||||
}
|
||||
authorityCert := caCerts[0]
|
||||
|
||||
// Sign the new certificate using the authority's key/cert
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, authorityCert, &priv.PublicKey, authorityPriv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = writeCertificate(tlsCertPath, derBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
return writePrivateKey(tlsKeyPath, priv)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ var build string
|
|||
const (
|
||||
VersionMajor = 0
|
||||
VersionMinor = 6
|
||||
VersionPatch = 0
|
||||
VersionPatch = 2
|
||||
VersionTag = "" // example: "rc1"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ type QueryKeyChangesRequest struct {
|
|||
// The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
|
||||
Offset int64
|
||||
// The inclusive offset where to track key changes up to. Messages with this offset are included in the response.
|
||||
// Use sarama.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing).
|
||||
// Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing).
|
||||
ToOffset int64
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,29 +18,30 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
)
|
||||
|
||||
type OutputCrossSigningKeyUpdateConsumer struct {
|
||||
eduServerConsumer *internal.ContinualConsumer
|
||||
keyDB storage.Database
|
||||
keyAPI api.KeyInternalAPI
|
||||
serverName string
|
||||
ctx context.Context
|
||||
keyDB storage.Database
|
||||
keyAPI api.KeyInternalAPI
|
||||
serverName string
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
}
|
||||
|
||||
func NewOutputCrossSigningKeyUpdateConsumer(
|
||||
process *process.ProcessContext,
|
||||
cfg *config.Dendrite,
|
||||
kafkaConsumer sarama.Consumer,
|
||||
js nats.JetStreamContext,
|
||||
keyDB storage.Database,
|
||||
keyAPI api.KeyInternalAPI,
|
||||
) *OutputCrossSigningKeyUpdateConsumer {
|
||||
|
|
@ -48,60 +49,58 @@ func NewOutputCrossSigningKeyUpdateConsumer(
|
|||
// topic. We will only produce events where the UserID matches our server name,
|
||||
// and we will only consume events where the UserID does NOT match our server
|
||||
// name (because the update came from a remote server).
|
||||
consumer := internal.ContinualConsumer{
|
||||
Process: process,
|
||||
ComponentName: "keyserver/keyserver",
|
||||
Topic: cfg.Global.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
|
||||
Consumer: kafkaConsumer,
|
||||
PartitionStore: keyDB,
|
||||
}
|
||||
s := &OutputCrossSigningKeyUpdateConsumer{
|
||||
eduServerConsumer: &consumer,
|
||||
keyDB: keyDB,
|
||||
keyAPI: keyAPI,
|
||||
serverName: string(cfg.Global.ServerName),
|
||||
ctx: process.Context(),
|
||||
keyDB: keyDB,
|
||||
jetstream: js,
|
||||
durable: cfg.Global.JetStream.Durable("KeyServerCrossSigningConsumer"),
|
||||
topic: cfg.Global.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
|
||||
keyAPI: keyAPI,
|
||||
serverName: string(cfg.Global.ServerName),
|
||||
}
|
||||
consumer.ProcessMessage = s.onMessage
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *OutputCrossSigningKeyUpdateConsumer) Start() error {
|
||||
return s.eduServerConsumer.Start()
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
// onMessage is called in response to a message received on the
|
||||
// key change events topic from the key server.
|
||||
func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
||||
func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var m api.DeviceMessage
|
||||
if err := json.Unmarshal(msg.Value, &m); err != nil {
|
||||
if err := json.Unmarshal(msg.Data, &m); err != nil {
|
||||
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
if m.OutputCrossSigningKeyUpdate == nil {
|
||||
// This probably shouldn't happen but stops us from panicking if we come
|
||||
// across an update that doesn't satisfy either types.
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
switch m.Type {
|
||||
case api.TypeCrossSigningUpdate:
|
||||
return t.onCrossSigningMessage(m)
|
||||
default:
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
|
||||
func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
|
||||
output := m.CrossSigningKeyUpdate
|
||||
_, host, err := gomatrixserverlib.SplitID('@', output.UserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("eduserver output log: user ID parse failure")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
if host == gomatrixserverlib.ServerName(s.serverName) {
|
||||
// Ignore any messages that contain information about our own users, as
|
||||
// they already originated from this server.
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
uploadReq := &api.PerformUploadDeviceKeysRequest{
|
||||
UserID: output.UserID,
|
||||
|
|
@ -114,5 +113,11 @@ func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.Device
|
|||
}
|
||||
uploadRes := &api.PerformUploadDeviceKeysResponse{}
|
||||
s.keyAPI.PerformUploadDeviceKeys(context.TODO(), uploadReq, uploadRes)
|
||||
return uploadRes.Error
|
||||
if uploadRes.Error != nil {
|
||||
// If the error is due to a missing or invalid parameter then we'd might
|
||||
// as well just acknowledge the message, because otherwise otherwise we'll
|
||||
// just keep getting delivered a faulty message over and over again.
|
||||
return uploadRes.Error.IsMissingParam || uploadRes.Error.IsInvalidParam
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
|
|||
func NewInternalAPI(
|
||||
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
|
||||
) api.KeyInternalAPI {
|
||||
js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
js := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
|
||||
db, err := storage.NewDatabase(&cfg.Database)
|
||||
if err != nil {
|
||||
|
|
@ -66,7 +66,7 @@ func NewInternalAPI(
|
|||
}()
|
||||
|
||||
keyconsumer := consumers.NewOutputCrossSigningKeyUpdateConsumer(
|
||||
base.ProcessContext, base.Cfg, consumer, db, ap,
|
||||
base.ProcessContext, base.Cfg, js, db, ap,
|
||||
)
|
||||
if err := keyconsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start keyserver EDU server consumer")
|
||||
|
|
|
|||
|
|
@ -18,15 +18,12 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
internal.PartitionStorer
|
||||
|
||||
// ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
|
||||
// of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
|
||||
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
|
|
@ -71,7 +68,7 @@ type Database interface {
|
|||
StoreKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
|
||||
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
|
||||
// A to offset of sarama.OffsetNewest means no upper limit.
|
||||
// A to offset of types.OffsetNewest means no upper limit.
|
||||
// Returns the offset of the latest key change.
|
||||
KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"math"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
|
@ -78,9 +76,6 @@ func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID strin
|
|||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
if toOffset == sarama.OffsetNewest {
|
||||
toOffset = math.MaxInt64
|
||||
}
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"math"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
|
@ -76,9 +74,6 @@ func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID strin
|
|||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
if toOffset == sarama.OffsetNewest {
|
||||
toOffset = math.MaxInt64
|
||||
}
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ import (
|
|||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ func TestKeyChanges(t *testing.T) {
|
|||
MustNotError(t, err)
|
||||
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, sarama.OffsetNewest)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
|
|
@ -74,7 +74,7 @@ func TestKeyChangesNoDupes(t *testing.T) {
|
|||
}
|
||||
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, sarama.OffsetNewest)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ type DeviceKeys interface {
|
|||
type KeyChanges interface {
|
||||
InsertKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
|
||||
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset.
|
||||
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
|
||||
SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
|
||||
Prepare() error
|
||||
|
|
|
|||
|
|
@ -14,7 +14,18 @@
|
|||
|
||||
package types
|
||||
|
||||
import "github.com/matrix-org/gomatrixserverlib"
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const (
|
||||
// OffsetNewest tells e.g. the database to get the most current data
|
||||
OffsetNewest int64 = math.MaxInt64
|
||||
// OffsetOldest tells e.g. the database to get the oldest data
|
||||
OffsetOldest int64 = 0
|
||||
)
|
||||
|
||||
// KeyTypePurposeToInt maps a purpose to an integer, which is used in the
|
||||
// database to reduce the amount of space taken up by this column.
|
||||
|
|
|
|||
|
|
@ -42,6 +42,19 @@ const (
|
|||
KindOld
|
||||
)
|
||||
|
||||
func (k Kind) String() string {
|
||||
switch k {
|
||||
case KindOutlier:
|
||||
return "KindOutlier"
|
||||
case KindNew:
|
||||
return "KindNew"
|
||||
case KindOld:
|
||||
return "KindOld"
|
||||
default:
|
||||
return "(unknown)"
|
||||
}
|
||||
}
|
||||
|
||||
// DoNotSendToOtherServers tells us not to send the event to other matrix
|
||||
// servers.
|
||||
const DoNotSendToOtherServers = ""
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ type RoomserverInternalAPI struct {
|
|||
fsAPI fsAPI.FederationInternalAPI
|
||||
asAPI asAPI.AppServiceQueryAPI
|
||||
JetStream nats.JetStreamContext
|
||||
Durable nats.SubOpt
|
||||
Durable string
|
||||
InputRoomEventTopic string // JetStream topic for new input room events
|
||||
OutputRoomEventTopic string // JetStream topic for new output room events
|
||||
PerspectiveServerNames []gomatrixserverlib.ServerName
|
||||
|
|
@ -87,7 +87,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA
|
|||
InputRoomEventTopic: r.InputRoomEventTopic,
|
||||
OutputRoomEventTopic: r.OutputRoomEventTopic,
|
||||
JetStream: r.JetStream,
|
||||
Durable: r.Durable,
|
||||
Durable: nats.Durable(r.Durable),
|
||||
ServerName: r.Cfg.Matrix.ServerName,
|
||||
FSAPI: fsAPI,
|
||||
KeyRing: keyRing,
|
||||
|
|
|
|||
|
|
@ -20,17 +20,22 @@ import (
|
|||
"sort"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type checkForAuthAndSoftFailStorage interface {
|
||||
state.StateResolutionStorage
|
||||
StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
|
||||
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||
}
|
||||
|
||||
// CheckForSoftFail returns true if the event should be soft-failed
|
||||
// and false otherwise. The return error value should be checked before
|
||||
// the soft-fail bool.
|
||||
func CheckForSoftFail(
|
||||
ctx context.Context,
|
||||
db storage.Database,
|
||||
db checkForAuthAndSoftFailStorage,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
stateEventIDs []string,
|
||||
) (bool, error) {
|
||||
|
|
@ -92,7 +97,7 @@ func CheckForSoftFail(
|
|||
// Returns the numeric IDs for the auth events.
|
||||
func CheckAuthEvents(
|
||||
ctx context.Context,
|
||||
db storage.Database,
|
||||
db checkForAuthAndSoftFailStorage,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
authEventIDs []string,
|
||||
) ([]types.EventNID, error) {
|
||||
|
|
@ -193,7 +198,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
|
|||
// loadAuthEvents loads the events needed for authentication from the supplied room state.
|
||||
func loadAuthEvents(
|
||||
ctx context.Context,
|
||||
db storage.Database,
|
||||
db state.StateResolutionStorage,
|
||||
needed gomatrixserverlib.StateNeeded,
|
||||
state []types.StateEntry,
|
||||
) (result authEvents, err error) {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -38,6 +39,19 @@ import (
|
|||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type retryAction int
|
||||
type commitAction int
|
||||
|
||||
const (
|
||||
doNotRetry retryAction = iota
|
||||
retryLater
|
||||
)
|
||||
|
||||
const (
|
||||
commitTransaction commitAction = iota
|
||||
rollbackTransaction
|
||||
)
|
||||
|
||||
var keyContentFields = map[string]string{
|
||||
"m.room.join_rules": "join_rule",
|
||||
"m.room.history_visibility": "history_visibility",
|
||||
|
|
@ -101,7 +115,8 @@ func (r *Inputer) Start() error {
|
|||
_ = msg.InProgress() // resets the acknowledgement wait timer
|
||||
defer eventsInProgress.Delete(index)
|
||||
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
|
||||
if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil {
|
||||
action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
|
|
@ -111,7 +126,12 @@ func (r *Inputer) Start() error {
|
|||
"type": inputRoomEvent.Event.Type(),
|
||||
}).Warn("Roomserver failed to process async event")
|
||||
}
|
||||
_ = msg.Ack()
|
||||
switch action {
|
||||
case retryLater:
|
||||
_ = msg.Nak()
|
||||
case doNotRetry:
|
||||
_ = msg.Ack()
|
||||
}
|
||||
})
|
||||
},
|
||||
// NATS wants to acknowledge automatically by default when the message is
|
||||
|
|
@ -131,6 +151,37 @@ func (r *Inputer) Start() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// processRoomEventUsingUpdater opens up a room updater and tries to
|
||||
// process the event. It returns whether or not we should positively
|
||||
// or negatively acknowledge the event (i.e. for NATS) and an error
|
||||
// if it occurred.
|
||||
func (r *Inputer) processRoomEventUsingUpdater(
|
||||
ctx context.Context,
|
||||
roomID string,
|
||||
inputRoomEvent *api.InputRoomEvent,
|
||||
) (retryAction, error) {
|
||||
roomInfo, err := r.DB.RoomInfo(ctx, roomID)
|
||||
if err != nil {
|
||||
return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err)
|
||||
}
|
||||
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
|
||||
if err != nil {
|
||||
return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
|
||||
}
|
||||
action, err := r.processRoomEvent(ctx, updater, inputRoomEvent)
|
||||
switch action {
|
||||
case commitTransaction:
|
||||
if cerr := updater.Commit(); cerr != nil {
|
||||
return retryLater, fmt.Errorf("updater.Commit: %w", cerr)
|
||||
}
|
||||
case rollbackTransaction:
|
||||
if rerr := updater.Rollback(); rerr != nil {
|
||||
return retryLater, fmt.Errorf("updater.Rollback: %w", rerr)
|
||||
}
|
||||
}
|
||||
return doNotRetry, err
|
||||
}
|
||||
|
||||
// InputRoomEvents implements api.RoomserverInternalAPI
|
||||
func (r *Inputer) InputRoomEvents(
|
||||
ctx context.Context,
|
||||
|
|
@ -161,7 +212,6 @@ func (r *Inputer) InputRoomEvents(
|
|||
}
|
||||
} else {
|
||||
responses := make(chan error, len(request.InputRoomEvents))
|
||||
defer close(responses)
|
||||
for _, e := range request.InputRoomEvents {
|
||||
inputRoomEvent := e
|
||||
roomID := inputRoomEvent.Event.RoomID()
|
||||
|
|
@ -178,7 +228,7 @@ func (r *Inputer) InputRoomEvents(
|
|||
worker.Act(nil, func() {
|
||||
defer eventsInProgress.Delete(index)
|
||||
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
|
||||
err := r.processRoomEvent(ctx, &inputRoomEvent)
|
||||
_, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||
sentry.CaptureException(err)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -67,14 +68,15 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
|
|||
// nolint:gocyclo
|
||||
func (r *Inputer) processRoomEvent(
|
||||
ctx context.Context,
|
||||
updater *shared.RoomUpdater,
|
||||
input *api.InputRoomEvent,
|
||||
) (err error) {
|
||||
) (commitAction, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Before we do anything, make sure the context hasn't expired for this pending task.
|
||||
// If it has then we'll give up straight away — it's probably a synchronous input
|
||||
// request and the caller has already given up, but the inbox task was still queued.
|
||||
return context.DeadlineExceeded
|
||||
return rollbackTransaction, context.DeadlineExceeded
|
||||
default:
|
||||
}
|
||||
|
||||
|
|
@ -93,13 +95,21 @@ func (r *Inputer) processRoomEvent(
|
|||
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"event_id": event.EventID(),
|
||||
"room_id": event.RoomID(),
|
||||
"kind": input.Kind,
|
||||
"origin": input.Origin,
|
||||
"type": event.Type(),
|
||||
})
|
||||
if input.HasState {
|
||||
logger = logger.WithFields(logrus.Fields{
|
||||
"has_state": input.HasState,
|
||||
"state_ids": len(input.StateEventIDs),
|
||||
})
|
||||
}
|
||||
|
||||
// if we have already got this event then do not process it again, if the input kind is an outlier.
|
||||
// Outliers contain no extra information which may warrant a re-processing.
|
||||
if input.Kind == api.KindOutlier {
|
||||
evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()})
|
||||
evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()})
|
||||
if err2 == nil && len(evs) == 1 {
|
||||
// check hash matches if we're on early room versions where the event ID was a random string
|
||||
idFormat, err2 := headered.RoomVersion.EventIDFormat()
|
||||
|
|
@ -108,11 +118,11 @@ func (r *Inputer) processRoomEvent(
|
|||
case gomatrixserverlib.EventIDFormatV1:
|
||||
if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) {
|
||||
logger.Debugf("Already processed event; ignoring")
|
||||
return nil
|
||||
return rollbackTransaction, nil
|
||||
}
|
||||
default:
|
||||
logger.Debugf("Already processed event; ignoring")
|
||||
return nil
|
||||
return rollbackTransaction, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -126,21 +136,41 @@ func (r *Inputer) processRoomEvent(
|
|||
AuthEventIDs: event.AuthEventIDs(),
|
||||
PrevEventIDs: event.PrevEventIDs(),
|
||||
}
|
||||
if err = r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil {
|
||||
return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err)
|
||||
if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err)
|
||||
}
|
||||
}
|
||||
if len(missingRes.MissingAuthEventIDs) > 0 || len(missingRes.MissingPrevEventIDs) > 0 {
|
||||
missingAuth := len(missingRes.MissingAuthEventIDs) > 0
|
||||
missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0
|
||||
|
||||
if missingAuth || missingPrev {
|
||||
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
|
||||
RoomID: event.RoomID(),
|
||||
ExcludeSelf: true,
|
||||
}
|
||||
if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
|
||||
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
|
||||
if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
|
||||
}
|
||||
// Sort all of the servers into a map so that we can randomise
|
||||
// their order. Then make sure that the input origin and the
|
||||
// event origin are first on the list.
|
||||
servers := map[gomatrixserverlib.ServerName]struct{}{}
|
||||
for _, server := range serverRes.ServerNames {
|
||||
servers[server] = struct{}{}
|
||||
}
|
||||
serverRes.ServerNames = serverRes.ServerNames[:0]
|
||||
if input.Origin != "" {
|
||||
serverRes.ServerNames = append(serverRes.ServerNames, input.Origin)
|
||||
delete(servers, input.Origin)
|
||||
}
|
||||
if origin := event.Origin(); origin != input.Origin {
|
||||
serverRes.ServerNames = append(serverRes.ServerNames, origin)
|
||||
delete(servers, origin)
|
||||
}
|
||||
for server := range servers {
|
||||
serverRes.ServerNames = append(serverRes.ServerNames, server)
|
||||
delete(servers, server)
|
||||
}
|
||||
}
|
||||
if input.Origin != "" {
|
||||
serverRes.ServerNames = append(serverRes.ServerNames, input.Origin)
|
||||
}
|
||||
|
||||
// First of all, check that the auth events of the event are known.
|
||||
|
|
@ -148,8 +178,8 @@ func (r *Inputer) processRoomEvent(
|
|||
isRejected := false
|
||||
authEvents := gomatrixserverlib.NewAuthEvents(nil)
|
||||
knownEvents := map[string]*types.Event{}
|
||||
if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
|
||||
return fmt.Errorf("r.checkForMissingAuthEvents: %w", err)
|
||||
if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err)
|
||||
}
|
||||
|
||||
// Check if the event is allowed by its auth events. If it isn't then
|
||||
|
|
@ -157,7 +187,7 @@ func (r *Inputer) processRoomEvent(
|
|||
var rejectionErr error
|
||||
if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil {
|
||||
isRejected = true
|
||||
logger.WithError(rejectionErr).Warnf("Event %s rejected", event.EventID())
|
||||
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
|
||||
}
|
||||
|
||||
// Accumulate the auth event NIDs.
|
||||
|
|
@ -165,7 +195,7 @@ func (r *Inputer) processRoomEvent(
|
|||
authEventNIDs := make([]types.EventNID, 0, len(authEventIDs))
|
||||
for _, authEventID := range authEventIDs {
|
||||
if _, ok := knownEvents[authEventID]; !ok {
|
||||
return fmt.Errorf("missing auth event %s", authEventID)
|
||||
return rollbackTransaction, fmt.Errorf("missing auth event %s", authEventID)
|
||||
}
|
||||
authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID)
|
||||
}
|
||||
|
|
@ -174,9 +204,10 @@ func (r *Inputer) processRoomEvent(
|
|||
if input.Kind == api.KindNew {
|
||||
// Check that the event passes authentication checks based on the
|
||||
// current room state.
|
||||
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
|
||||
var err error
|
||||
softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs)
|
||||
if err != nil {
|
||||
logger.WithError(err).Info("Error authing soft-failed event")
|
||||
logger.WithError(err).Warn("Error authing soft-failed event")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -190,7 +221,6 @@ func (r *Inputer) processRoomEvent(
|
|||
// typical federated room join) then we won't bother trying to fetch prev events
|
||||
// because we may not be allowed to see them and we have no choice but to trust
|
||||
// the state event IDs provided to us in the join instead.
|
||||
missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0
|
||||
if missingPrev && input.Kind == api.KindNew {
|
||||
// Don't do this for KindOld events, otherwise old events that we fetch
|
||||
// to satisfy missing prev events/state will end up recursively calling
|
||||
|
|
@ -200,18 +230,15 @@ func (r *Inputer) processRoomEvent(
|
|||
origin: input.Origin,
|
||||
inputer: r,
|
||||
queryer: r.Queryer,
|
||||
db: r.DB,
|
||||
db: updater,
|
||||
federation: r.FSAPI,
|
||||
keys: r.KeyRing,
|
||||
roomsMu: internal.NewMutexByRoom(),
|
||||
servers: map[gomatrixserverlib.ServerName]struct{}{},
|
||||
servers: serverRes.ServerNames,
|
||||
hadEvents: map[string]bool{},
|
||||
haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{},
|
||||
}
|
||||
for _, serverName := range serverRes.ServerNames {
|
||||
missingState.servers[serverName] = struct{}{}
|
||||
}
|
||||
if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
|
||||
if err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
|
||||
isRejected = true
|
||||
rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err)
|
||||
} else {
|
||||
|
|
@ -224,16 +251,16 @@ func (r *Inputer) processRoomEvent(
|
|||
}
|
||||
|
||||
// Store the event.
|
||||
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected)
|
||||
_, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.DB.StoreEvent: %w", err)
|
||||
return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err)
|
||||
}
|
||||
|
||||
// if storing this event results in it being redacted then do so.
|
||||
if !isRejected && redactedEventID == event.EventID() {
|
||||
r, rerr := eventutil.RedactEvent(redactionEvent, event)
|
||||
if rerr != nil {
|
||||
return fmt.Errorf("eventutil.RedactEvent: %w", rerr)
|
||||
return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr)
|
||||
}
|
||||
event = r
|
||||
}
|
||||
|
|
@ -244,36 +271,40 @@ func (r *Inputer) processRoomEvent(
|
|||
if input.Kind == api.KindOutlier {
|
||||
logger.Debug("Stored outlier")
|
||||
hooks.Run(hooks.KindNewEventPersisted, headered)
|
||||
return nil
|
||||
return commitTransaction, nil
|
||||
}
|
||||
|
||||
roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID())
|
||||
roomInfo, err := updater.RoomInfo(ctx, event.RoomID())
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.DB.RoomInfo: %w", err)
|
||||
return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err)
|
||||
}
|
||||
if roomInfo == nil {
|
||||
return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID())
|
||||
return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID())
|
||||
}
|
||||
|
||||
if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 {
|
||||
// We haven't calculated a state for this event yet.
|
||||
// Lets calculate one.
|
||||
err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected)
|
||||
err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.calculateAndSetState: %w", err)
|
||||
return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it.
|
||||
if isRejected || softfail {
|
||||
logger.WithError(rejectionErr).WithField("soft_fail", softfail).Debug("Stored rejected event")
|
||||
return rejectionErr
|
||||
logger.WithError(rejectionErr).WithFields(logrus.Fields{
|
||||
"soft_fail": softfail,
|
||||
"missing_prev": missingPrev,
|
||||
}).Warn("Stored rejected event")
|
||||
return commitTransaction, rejectionErr
|
||||
}
|
||||
|
||||
switch input.Kind {
|
||||
case api.KindNew:
|
||||
if err = r.updateLatestEvents(
|
||||
ctx, // context
|
||||
updater, // room updater
|
||||
roomInfo, // room info for the room being updated
|
||||
stateAtEvent, // state at event (below)
|
||||
event, // event
|
||||
|
|
@ -281,7 +312,7 @@ func (r *Inputer) processRoomEvent(
|
|||
input.TransactionID, // transaction ID
|
||||
input.HasState, // rewrites state?
|
||||
); err != nil {
|
||||
return fmt.Errorf("r.updateLatestEvents: %w", err)
|
||||
return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err)
|
||||
}
|
||||
case api.KindOld:
|
||||
err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{
|
||||
|
|
@ -293,7 +324,7 @@ func (r *Inputer) processRoomEvent(
|
|||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.WriteOutputEvents (old): %w", err)
|
||||
return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -312,14 +343,14 @@ func (r *Inputer) processRoomEvent(
|
|||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
|
||||
return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Everything was OK — the latest events updater didn't error and
|
||||
// we've sent output events. Finally, generate a hook call.
|
||||
hooks.Run(hooks.KindNewEventPersisted, headered)
|
||||
return nil
|
||||
return commitTransaction, nil
|
||||
}
|
||||
|
||||
// fetchAuthEvents will check to see if any of the
|
||||
|
|
@ -331,16 +362,13 @@ func (r *Inputer) processRoomEvent(
|
|||
// they are now in the database.
|
||||
func (r *Inputer) fetchAuthEvents(
|
||||
ctx context.Context,
|
||||
updater *shared.RoomUpdater,
|
||||
logger *logrus.Entry,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
auth *gomatrixserverlib.AuthEvents,
|
||||
known map[string]*types.Event,
|
||||
servers []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, MaximumMissingProcessingTime)
|
||||
defer cancel()
|
||||
|
||||
unknown := map[string]struct{}{}
|
||||
authEventIDs := event.AuthEventIDs()
|
||||
if len(authEventIDs) == 0 {
|
||||
|
|
@ -348,7 +376,7 @@ func (r *Inputer) fetchAuthEvents(
|
|||
}
|
||||
|
||||
for _, authEventID := range authEventIDs {
|
||||
authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID})
|
||||
authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID})
|
||||
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
|
||||
unknown[authEventID] = struct{}{}
|
||||
continue
|
||||
|
|
@ -396,12 +424,11 @@ func (r *Inputer) fetchAuthEvents(
|
|||
continue
|
||||
}
|
||||
|
||||
// Check the signatures of the event.
|
||||
// TODO: It really makes sense for the federation API to be doing this,
|
||||
// because then it can attempt another server if one serves up an event
|
||||
// with an invalid signature. For now this will do.
|
||||
// Check the signatures of the event. If this fails then we'll simply
|
||||
// skip it, because gomatrixserverlib.Allowed() will notice a problem
|
||||
// if a critical event is missing anyway.
|
||||
if err := authEvent.VerifyEventSignatures(ctx, r.FSAPI.KeyRing()); err != nil {
|
||||
return fmt.Errorf("event.VerifyEventSignatures: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// In order to store the new auth event, we need to know its auth chain
|
||||
|
|
@ -428,9 +455,9 @@ func (r *Inputer) fetchAuthEvents(
|
|||
}
|
||||
|
||||
// Finally, store the event in the database.
|
||||
eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
|
||||
eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.DB.StoreEvent: %w", err)
|
||||
return fmt.Errorf("updater.StoreEvent: %w", err)
|
||||
}
|
||||
|
||||
// Now we know about this event, it was stored and the signatures were OK.
|
||||
|
|
@ -445,6 +472,7 @@ func (r *Inputer) fetchAuthEvents(
|
|||
|
||||
func (r *Inputer) calculateAndSetState(
|
||||
ctx context.Context,
|
||||
updater *shared.RoomUpdater,
|
||||
input *api.InputRoomEvent,
|
||||
roomInfo *types.RoomInfo,
|
||||
stateAtEvent *types.StateAtEvent,
|
||||
|
|
@ -452,14 +480,14 @@ func (r *Inputer) calculateAndSetState(
|
|||
isRejected bool,
|
||||
) error {
|
||||
var err error
|
||||
roomState := state.NewStateResolution(r.DB, roomInfo)
|
||||
roomState := state.NewStateResolution(updater, roomInfo)
|
||||
|
||||
if input.HasState && !isRejected {
|
||||
if input.HasState {
|
||||
// Check here if we think we're in the room already.
|
||||
stateAtEvent.Overwrite = true
|
||||
var joinEventNIDs []types.EventNID
|
||||
// Request join memberships only for local users only.
|
||||
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
|
||||
if joinEventNIDs, err = updater.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
|
||||
// If we have no local users that are joined to the room then any state about
|
||||
// the room that we have is quite possibly out of date. Therefore in that case
|
||||
// we should overwrite it rather than merge it.
|
||||
|
|
@ -469,13 +497,13 @@ func (r *Inputer) calculateAndSetState(
|
|||
// We've been told what the state at the event is so we don't need to calculate it.
|
||||
// Check that those state events are in the database and store the state.
|
||||
var entries []types.StateEntry
|
||||
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
|
||||
return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err)
|
||||
if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
|
||||
return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
|
||||
}
|
||||
entries = types.DeduplicateStateEntries(entries)
|
||||
|
||||
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
|
||||
return fmt.Errorf("r.DB.AddState: %w", err)
|
||||
if stateAtEvent.BeforeStateSnapshotNID, err = updater.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
|
||||
return fmt.Errorf("updater.AddState: %w", err)
|
||||
}
|
||||
} else {
|
||||
stateAtEvent.Overwrite = false
|
||||
|
|
@ -486,7 +514,7 @@ func (r *Inputer) calculateAndSetState(
|
|||
}
|
||||
}
|
||||
|
||||
err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
|
||||
err = updater.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.DB.SetState: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
|
|
@ -48,6 +47,7 @@ import (
|
|||
// Can only be called once at a time
|
||||
func (r *Inputer) updateLatestEvents(
|
||||
ctx context.Context,
|
||||
updater *shared.RoomUpdater,
|
||||
roomInfo *types.RoomInfo,
|
||||
stateAtEvent types.StateAtEvent,
|
||||
event *gomatrixserverlib.Event,
|
||||
|
|
@ -55,13 +55,6 @@ func (r *Inputer) updateLatestEvents(
|
|||
transactionID *api.TransactionID,
|
||||
rewritesState bool,
|
||||
) (err error) {
|
||||
updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
|
||||
}
|
||||
succeeded := false
|
||||
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
|
||||
|
||||
u := latestEventsUpdater{
|
||||
ctx: ctx,
|
||||
api: r,
|
||||
|
|
@ -78,7 +71,6 @@ func (r *Inputer) updateLatestEvents(
|
|||
return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
|
||||
}
|
||||
|
||||
succeeded = true
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -89,7 +81,7 @@ func (r *Inputer) updateLatestEvents(
|
|||
type latestEventsUpdater struct {
|
||||
ctx context.Context
|
||||
api *Inputer
|
||||
updater *shared.LatestEventsUpdater
|
||||
updater *shared.RoomUpdater
|
||||
roomInfo *types.RoomInfo
|
||||
stateAtEvent types.StateAtEvent
|
||||
event *gomatrixserverlib.Event
|
||||
|
|
@ -199,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
|||
|
||||
func (u *latestEventsUpdater) latestState() error {
|
||||
var err error
|
||||
roomState := state.NewStateResolution(u.api.DB, u.roomInfo)
|
||||
roomState := state.NewStateResolution(u.updater, u.roomInfo)
|
||||
|
||||
// Work out if the state at the extremities has actually changed
|
||||
// or not. If they haven't then we won't bother doing all of the
|
||||
|
|
@ -413,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro
|
|||
if len(extraEventIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs)
|
||||
extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -436,7 +428,7 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error)
|
|||
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
||||
}
|
||||
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
|
||||
return u.api.DB.EventIDs(u.ctx, stateEventNIDs)
|
||||
return u.updater.EventIDs(u.ctx, stateEventNIDs)
|
||||
}
|
||||
|
||||
type eventNIDSorter []types.EventNID
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import (
|
|||
// consumers about the invites added or retired by the change in current state.
|
||||
func (r *Inputer) updateMemberships(
|
||||
ctx context.Context,
|
||||
updater *shared.LatestEventsUpdater,
|
||||
updater *shared.RoomUpdater,
|
||||
removed, added []types.StateEntry,
|
||||
) ([]api.OutputEvent, error) {
|
||||
changes := membershipChanges(removed, added)
|
||||
|
|
@ -79,7 +79,7 @@ func (r *Inputer) updateMemberships(
|
|||
}
|
||||
|
||||
func (r *Inputer) updateMembership(
|
||||
updater *shared.LatestEventsUpdater,
|
||||
updater *shared.RoomUpdater,
|
||||
targetUserNID types.EventStateKeyNID,
|
||||
remove, add *gomatrixserverlib.Event,
|
||||
updates []api.OutputEvent,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/internal/query"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
|
@ -19,13 +19,13 @@ import (
|
|||
|
||||
type missingStateReq struct {
|
||||
origin gomatrixserverlib.ServerName
|
||||
db storage.Database
|
||||
db *shared.RoomUpdater
|
||||
inputer *Inputer
|
||||
queryer *query.Queryer
|
||||
keys gomatrixserverlib.JSONVerifier
|
||||
federation fedapi.FederationInternalAPI
|
||||
roomsMu *internal.MutexByRoom
|
||||
servers map[gomatrixserverlib.ServerName]struct{}
|
||||
servers []gomatrixserverlib.ServerName
|
||||
hadEvents map[string]bool
|
||||
hadEventsMutex sync.Mutex
|
||||
haveEvents map[string]*gomatrixserverlib.HeaderedEvent
|
||||
|
|
@ -37,10 +37,6 @@ type missingStateReq struct {
|
|||
func (t *missingStateReq) processEventWithMissingState(
|
||||
ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion,
|
||||
) error {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, MaximumMissingProcessingTime)
|
||||
defer cancel()
|
||||
|
||||
// We are missing the previous events for this events.
|
||||
// This means that there is a gap in our view of the history of the
|
||||
// room. There two ways that we can handle such a gap:
|
||||
|
|
@ -78,7 +74,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
|||
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
|
||||
// in the gap in the DAG
|
||||
for _, newEvent := range newEvents {
|
||||
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
|
||||
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
|
||||
Kind: api.KindNew,
|
||||
Event: newEvent.Headered(roomVersion),
|
||||
Origin: t.origin,
|
||||
|
|
@ -187,7 +183,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
|||
}
|
||||
// TODO: we could do this concurrently?
|
||||
for _, ire := range outlierRoomEvents {
|
||||
if err = t.inputer.processRoomEvent(ctx, &ire); err != nil {
|
||||
if _, err = t.inputer.processRoomEvent(ctx, t.db, &ire); err != nil {
|
||||
return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -200,7 +196,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
|||
stateIDs = append(stateIDs, event.EventID())
|
||||
}
|
||||
|
||||
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
|
||||
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
|
||||
Kind: api.KindOld,
|
||||
Event: backwardsExtremity.Headered(roomVersion),
|
||||
Origin: t.origin,
|
||||
|
|
@ -217,7 +213,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
|||
// they will automatically fast-forward based on the room state at the
|
||||
// extremity in the last step.
|
||||
for _, newEvent := range newEvents {
|
||||
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
|
||||
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
|
||||
Kind: api.KindOld,
|
||||
Event: newEvent.Headered(roomVersion),
|
||||
Origin: t.origin,
|
||||
|
|
@ -417,7 +413,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve
|
|||
}
|
||||
|
||||
var missingResp *gomatrixserverlib.RespMissingEvents
|
||||
for server := range t.servers {
|
||||
for _, server := range t.servers {
|
||||
var m gomatrixserverlib.RespMissingEvents
|
||||
if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{
|
||||
Limit: 20,
|
||||
|
|
@ -666,7 +662,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib
|
|||
for i := range stateIDs.StateEventIDs {
|
||||
ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]]
|
||||
if !ok {
|
||||
logrus.Warnf("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i])
|
||||
logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i])
|
||||
continue
|
||||
}
|
||||
respState.StateEvents = append(respState.StateEvents, ev.Unwrap())
|
||||
|
|
@ -674,7 +670,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib
|
|||
for i := range stateIDs.AuthEventIDs {
|
||||
ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]]
|
||||
if !ok {
|
||||
logrus.Warnf("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i])
|
||||
logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i])
|
||||
continue
|
||||
}
|
||||
respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap())
|
||||
|
|
@ -700,7 +696,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
|
|||
}
|
||||
var event *gomatrixserverlib.Event
|
||||
found := false
|
||||
for serverName := range t.servers {
|
||||
for _, serverName := range t.servers {
|
||||
reqctx, cancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer cancel()
|
||||
txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID)
|
||||
|
|
@ -718,7 +714,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
|
|||
}
|
||||
event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Transaction: Failed to parse event JSON of event")
|
||||
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Failed to parse event JSON of event returned from /event")
|
||||
continue
|
||||
}
|
||||
found = true
|
||||
|
|
@ -729,7 +725,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
|
|||
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
|
||||
}
|
||||
if err := event.VerifyEventSignatures(ctx, t.keys); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
||||
util.GetLogger(ctx).WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
|
||||
return nil, verifySigError{event.EventID(), err}
|
||||
}
|
||||
return t.cacheAndReturn(event.Headered(roomVersion)), nil
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -54,18 +55,23 @@ func (r *Inviter) PerformInvite(
|
|||
return nil, fmt.Errorf("failed to load RoomInfo: %w", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": event.EventID(),
|
||||
"room_id": roomID,
|
||||
"room_version": req.RoomVersion,
|
||||
"target_user_id": targetUserID,
|
||||
"room_info_exists": info != nil,
|
||||
}).Debug("processing invite event")
|
||||
|
||||
_, domain, _ := gomatrixserverlib.SplitID('@', targetUserID)
|
||||
isTargetLocal := domain == r.Cfg.Matrix.ServerName
|
||||
isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName
|
||||
|
||||
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
|
||||
"inviter": event.Sender(),
|
||||
"invitee": *event.StateKey(),
|
||||
"room_id": roomID,
|
||||
"event_id": event.EventID(),
|
||||
})
|
||||
logger.WithFields(log.Fields{
|
||||
"room_version": req.RoomVersion,
|
||||
"room_info_exists": info != nil,
|
||||
"target_local": isTargetLocal,
|
||||
"origin_local": isOriginLocal,
|
||||
}).Debug("processing invite event")
|
||||
|
||||
inviteState := req.InviteRoomState
|
||||
if len(inviteState) == 0 && info != nil {
|
||||
var is []gomatrixserverlib.InviteV2StrippedState
|
||||
|
|
@ -122,75 +128,17 @@ func (r *Inviter) PerformInvite(
|
|||
Code: api.PerformErrorNotAllowed,
|
||||
Msg: "User is already joined to room",
|
||||
}
|
||||
logger.Debugf("user already joined")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if isOriginLocal {
|
||||
// The invite originated locally. Therefore we have a responsibility to
|
||||
// try and see if the user is allowed to make this invite. We can't do
|
||||
// this for invites coming in over federation - we have to take those on
|
||||
// trust.
|
||||
_, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs())
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
||||
"processInviteEvent.checkAuthEvents failed for event",
|
||||
)
|
||||
res.Error = &api.PerformError{
|
||||
Msg: err.Error(),
|
||||
Code: api.PerformErrorNotAllowed,
|
||||
}
|
||||
}
|
||||
|
||||
// If the invite originated from us and the target isn't local then we
|
||||
// should try and send the invite over federation first. It might be
|
||||
// that the remote user doesn't exist, in which case we can give up
|
||||
// processing here.
|
||||
if req.SendAsServer != api.DoNotSendToOtherServers && !isTargetLocal {
|
||||
fsReq := &federationAPI.PerformInviteRequest{
|
||||
RoomVersion: req.RoomVersion,
|
||||
Event: event,
|
||||
InviteRoomState: inviteState,
|
||||
}
|
||||
fsRes := &federationAPI.PerformInviteResponse{}
|
||||
if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil {
|
||||
res.Error = &api.PerformError{
|
||||
Msg: err.Error(),
|
||||
Code: api.PerformErrorNotAllowed,
|
||||
}
|
||||
log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed")
|
||||
return nil, nil
|
||||
}
|
||||
event = fsRes.Event
|
||||
}
|
||||
|
||||
// Send the invite event to the roomserver input stream. This will
|
||||
// notify existing users in the room about the invite, update the
|
||||
// membership table and ensure that the event is ready and available
|
||||
// to use as an auth event when accepting the invite.
|
||||
inputReq := &api.InputRoomEventsRequest{
|
||||
InputRoomEvents: []api.InputRoomEvent{
|
||||
{
|
||||
Kind: api.KindNew,
|
||||
Event: event,
|
||||
Origin: event.Origin(),
|
||||
SendAsServer: req.SendAsServer,
|
||||
},
|
||||
},
|
||||
}
|
||||
inputRes := &api.InputRoomEventsResponse{}
|
||||
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
|
||||
if err = inputRes.Err(); err != nil {
|
||||
res.Error = &api.PerformError{
|
||||
Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()),
|
||||
Code: api.PerformErrorNotAllowed,
|
||||
}
|
||||
log.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed")
|
||||
return nil, nil
|
||||
}
|
||||
} else {
|
||||
if !isOriginLocal {
|
||||
// The invite originated over federation. Process the membership
|
||||
// update, which will notify the sync API etc about the incoming
|
||||
// invite.
|
||||
// invite. We do NOT send an InputRoomEvent for the invite as it
|
||||
// will never pass auth checks due to lacking room state, but we
|
||||
// still need to tell the client about the invite so we can accept
|
||||
// it, hence we return an output event to send to the sync api.
|
||||
updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
|
||||
|
|
@ -205,10 +153,77 @@ func (r *Inviter) PerformInvite(
|
|||
if err = updater.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("updater.Commit: %w", err)
|
||||
}
|
||||
|
||||
logger.Debugf("updated membership to invite and sending invite OutputEvent")
|
||||
return outputUpdates, nil
|
||||
}
|
||||
|
||||
// The invite originated locally. Therefore we have a responsibility to
|
||||
// try and see if the user is allowed to make this invite. We can't do
|
||||
// this for invites coming in over federation - we have to take those on
|
||||
// trust.
|
||||
_, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs())
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
||||
"processInviteEvent.checkAuthEvents failed for event",
|
||||
)
|
||||
res.Error = &api.PerformError{
|
||||
Msg: err.Error(),
|
||||
Code: api.PerformErrorNotAllowed,
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If the invite originated from us and the target isn't local then we
|
||||
// should try and send the invite over federation first. It might be
|
||||
// that the remote user doesn't exist, in which case we can give up
|
||||
// processing here.
|
||||
if req.SendAsServer != api.DoNotSendToOtherServers && !isTargetLocal {
|
||||
fsReq := &federationAPI.PerformInviteRequest{
|
||||
RoomVersion: req.RoomVersion,
|
||||
Event: event,
|
||||
InviteRoomState: inviteState,
|
||||
}
|
||||
fsRes := &federationAPI.PerformInviteResponse{}
|
||||
if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil {
|
||||
res.Error = &api.PerformError{
|
||||
Msg: err.Error(),
|
||||
Code: api.PerformErrorNotAllowed,
|
||||
}
|
||||
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed")
|
||||
return nil, nil
|
||||
}
|
||||
event = fsRes.Event
|
||||
logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID())
|
||||
}
|
||||
|
||||
// Send the invite event to the roomserver input stream. This will
|
||||
// notify existing users in the room about the invite, update the
|
||||
// membership table and ensure that the event is ready and available
|
||||
// to use as an auth event when accepting the invite.
|
||||
// It will NOT notify the invitee of this invite.
|
||||
inputReq := &api.InputRoomEventsRequest{
|
||||
InputRoomEvents: []api.InputRoomEvent{
|
||||
{
|
||||
Kind: api.KindNew,
|
||||
Event: event,
|
||||
Origin: event.Origin(),
|
||||
SendAsServer: req.SendAsServer,
|
||||
},
|
||||
},
|
||||
}
|
||||
inputRes := &api.InputRoomEventsResponse{}
|
||||
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
|
||||
if err = inputRes.Err(); err != nil {
|
||||
res.Error = &api.PerformError{
|
||||
Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()),
|
||||
Code: api.PerformErrorNotAllowed,
|
||||
}
|
||||
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Don't notify the sync api of this event in the same way as a federated invite so the invitee
|
||||
// gets the invite, as the roomserver will do this when it processes the m.room.member invite.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -51,13 +51,15 @@ func (r *Joiner) PerformJoin(
|
|||
req *rsAPI.PerformJoinRequest,
|
||||
res *rsAPI.PerformJoinResponse,
|
||||
) {
|
||||
roomID, joinedVia, err := r.performJoin(ctx, req)
|
||||
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"room_id": req.RoomIDOrAlias,
|
||||
"user_id": req.UserID,
|
||||
"servers": req.ServerNames,
|
||||
})
|
||||
logger.Info("User requested to room join")
|
||||
roomID, joinedVia, err := r.performJoin(context.Background(), req)
|
||||
if err != nil {
|
||||
logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"room_id": req.RoomIDOrAlias,
|
||||
"user_id": req.UserID,
|
||||
"servers": req.ServerNames,
|
||||
}).WithError(err).Error("Failed to join room")
|
||||
logger.WithError(err).Error("Failed to join room")
|
||||
sentry.CaptureException(err)
|
||||
perr, ok := err.(*rsAPI.PerformError)
|
||||
if ok {
|
||||
|
|
@ -67,7 +69,9 @@ func (r *Joiner) PerformJoin(
|
|||
Msg: err.Error(),
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.Info("User joined room successfully")
|
||||
res.RoomID = roomID
|
||||
res.JoinedVia = joinedVia
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,13 +51,17 @@ func (r *Leaver) PerformLeave(
|
|||
if domain != r.Cfg.Matrix.ServerName {
|
||||
return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID)
|
||||
}
|
||||
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"room_id": req.RoomID,
|
||||
"user_id": req.UserID,
|
||||
})
|
||||
logger.Info("User requested to leave join")
|
||||
if strings.HasPrefix(req.RoomID, "!") {
|
||||
output, err := r.performLeaveRoomByID(ctx, req, res)
|
||||
output, err := r.performLeaveRoomByID(context.Background(), req, res)
|
||||
if err != nil {
|
||||
logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"room_id": req.RoomID,
|
||||
"user_id": req.UserID,
|
||||
}).WithError(err).Error("Failed to leave room")
|
||||
logger.WithError(err).Error("Failed to leave room")
|
||||
} else {
|
||||
logger.Info("User left room successfully")
|
||||
}
|
||||
return output, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ func NewInternalAPI(
|
|||
logrus.WithError(err).Panicf("failed to connect to room server db")
|
||||
}
|
||||
|
||||
js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
js := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
|
||||
return internal.NewRoomserverAPI(
|
||||
cfg, roomserverDB, js,
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import (
|
|||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
|
|
@ -30,13 +29,25 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type StateResolutionStorage interface {
|
||||
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
||||
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
||||
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
||||
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
||||
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||
}
|
||||
|
||||
type StateResolution struct {
|
||||
db storage.Database
|
||||
db StateResolutionStorage
|
||||
roomInfo *types.RoomInfo
|
||||
events map[types.EventNID]*gomatrixserverlib.Event
|
||||
}
|
||||
|
||||
func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution {
|
||||
func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution {
|
||||
return StateResolution{
|
||||
db: db,
|
||||
roomInfo: roomInfo,
|
||||
|
|
|
|||
|
|
@ -86,11 +86,10 @@ type Database interface {
|
|||
// Lookup the event IDs for a batch of event numeric IDs.
|
||||
// Returns an error if the retrieval went wrong.
|
||||
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||
// Look up the latest events in a room in preparation for an update.
|
||||
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
||||
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
||||
// Opens and returns a room updater, which locks the room and opens a transaction.
|
||||
// The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
||||
// If this returns an error then no further action is required.
|
||||
GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error)
|
||||
GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
|
||||
// Look up event references for the latest events in the room and the current state snapshot.
|
||||
// Returns the latest events, the current state and the maximum depth of the latest events plus 1.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
|
|
|
|||
|
|
@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON(
|
|||
}
|
||||
|
||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) ([]tables.EventJSONPair, error) {
|
||||
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, pq.StringArray(eventStateKeys),
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
|
||||
for i := range eventStateKeyNIDs {
|
||||
nIDs[i] = int64(eventStateKeyNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt)
|
||||
rows, err := stmt.QueryContext(ctx, nIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID(
|
|||
}
|
||||
|
||||
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||
ctx context.Context, eventTypes []string,
|
||||
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent(
|
|||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) BulkSelectStateEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) ([]types.StateEntry, error) {
|
||||
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
|||
// bulkSelectStateEventByNID lookups a list of state events by event NID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) BulkSelectStateEventByNID(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntry, error) {
|
||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||
sort.Sort(tuples)
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||
rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
|||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||
func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) ([]types.StateAtEvent, error) {
|
||||
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference(
|
|||
}
|
||||
|
||||
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
|
|||
|
||||
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||
// If an event ID is not in the database then it is omitted from the map.
|
||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
|
||||
rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
|
|||
}
|
||||
|
||||
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) (map[types.EventNID]types.RoomNID, error) {
|
||||
rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
|
|||
}
|
||||
|
||||
func (s *inviteStatements) InsertInviteEvent(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
inviteEventID string, roomNID types.RoomNID,
|
||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||
inviteEventJSON []byte,
|
||||
) (bool, error) {
|
||||
|
|
@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent(
|
|||
}
|
||||
|
||||
func (s *inviteStatements) UpdateInviteRetired(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||
|
|
@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired(
|
|||
|
||||
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||
) ([]types.EventStateKeyNID, []string, error) {
|
||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, targetUserNID, roomNID,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
|
|||
}
|
||||
|
||||
func (s *membershipStatements) InsertMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
localTarget bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
||||
|
|
@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipForUpdate(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (membership tables.MembershipState, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
|
|
@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
|
||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
).Scan(&membership, &eventNID, &forgotten)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||
ctx context.Context, roomNID types.RoomNID, localOnly bool,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var stmt *sql.Stmt
|
||||
if localOnly {
|
||||
|
|
@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
|||
} else {
|
||||
stmt = s.selectMembershipsFromRoomStmt
|
||||
}
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var rows *sql.Rows
|
||||
|
|
@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
|||
} else {
|
||||
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
||||
}
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
rows, err = stmt.QueryContext(ctx, roomNID, membership)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) UpdateMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||
eventNID types.EventNID, forgotten bool,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
|
||||
|
|
@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||
return roomNIDs, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNIDs []types.RoomNID,
|
||||
) (map[types.EventStateKeyNID]int, error) {
|
||||
roomIDarray := make([]int64, len(roomNIDs))
|
||||
for i := range roomNIDs {
|
||||
roomIDarray[i] = int64(roomNIDs[i])
|
||||
}
|
||||
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
|||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
func (s *membershipStatements) SelectKnownUsers(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID types.EventStateKeyNID, searchString string, limit int,
|
||||
) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
|
||||
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
|
|||
}
|
||||
|
||||
func (s *membershipStatements) UpdateForgetMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
forget bool,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
|
||||
ctx, roomNID, targetUserNID, forget,
|
||||
|
|
@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
||||
func (s *membershipStatements) SelectLocalServerInRoom(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID,
|
||||
) (bool, error) {
|
||||
var nid types.RoomNID
|
||||
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
|
||||
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
|
|
@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
|
|||
return found, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
func (s *membershipStatements) SelectServerInRoom(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, serverName gomatrixserverlib.ServerName,
|
||||
) (bool, error) {
|
||||
var nid types.RoomNID
|
||||
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
|
||||
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
|
|
|
|||
|
|
@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished(
|
|||
}
|
||||
|
||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (published bool, err error) {
|
||||
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
|
|
@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
|
|||
}
|
||||
|
||||
func (s *publishedStatements) SelectAllPublishedRooms(
|
||||
ctx context.Context, published bool,
|
||||
ctx context.Context, txn *sql.Tx, published bool,
|
||||
) ([]string, error) {
|
||||
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
|
||||
rows, err := stmt.QueryContext(ctx, published)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
ctx context.Context, txn *sql.Tx, alias string,
|
||||
) (roomID string, err error) {
|
||||
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
|
||||
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) ([]string, error) {
|
||||
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
ctx context.Context, txn *sql.Tx, alias string,
|
||||
) (creatorID string, err error) {
|
||||
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
|
||||
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID(
|
|||
return types.RoomNID(roomNID), err
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
|
||||
var info types.RoomInfo
|
||||
var latestNIDs pq.Int64Array
|
||||
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
|
||||
err := stmt.QueryRowContext(ctx, roomID).Scan(
|
||||
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs(
|
|||
) ([]types.EventNID, types.StateSnapshotNID, error) {
|
||||
var nids pq.Int64Array
|
||||
var stateSnapshotNID int64
|
||||
stmt := s.selectLatestEventNIDsStmt
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
|
||||
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
|
|
@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
|||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||
ctx context.Context, roomNIDs []types.RoomNID,
|
||||
ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
|
||||
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
|
||||
rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
var array pq.Int64Array
|
||||
for _, nid := range roomNIDs {
|
||||
array = append(array, int64(nid))
|
||||
}
|
||||
rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array)
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, array)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
|
|||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
|
||||
var array pq.StringArray
|
||||
for _, roomID := range roomIDs {
|
||||
array = append(array, roomID)
|
||||
}
|
||||
rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array)
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, array)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
|||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkInsertStateData(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
entries types.StateEntries,
|
||||
) (id types.StateBlockNID, err error) {
|
||||
entries = entries[:util.SortAndUnique(entries)]
|
||||
|
|
@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
|||
for _, e := range entries {
|
||||
nids = append(nids, e.EventNID)
|
||||
}
|
||||
err = s.insertStateDataStmt.QueryRowContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
ctx, nids.Hash(), eventNIDsAsArray(nids),
|
||||
).Scan(&id)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||
ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
|
||||
ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
|
||||
) ([][]types.EventNID, error) {
|
||||
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
|
||||
rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState(
|
|||
}
|
||||
|
||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
nids := make([]int64, len(stateNIDs))
|
||||
for i := range stateNIDs {
|
||||
nids[i] = int64(stateNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,133 +0,0 @@
|
|||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type LatestEventsUpdater struct {
|
||||
transaction
|
||||
d *Database
|
||||
roomInfo types.RoomInfo
|
||||
latestEvents []types.StateAtEventAndReference
|
||||
lastEventIDSent string
|
||||
currentStateSnapshotNID types.StateSnapshotNID
|
||||
}
|
||||
|
||||
func rollback(txn *sql.Tx) {
|
||||
if txn == nil {
|
||||
return
|
||||
}
|
||||
txn.Rollback() // nolint: errcheck
|
||||
}
|
||||
|
||||
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) {
|
||||
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
||||
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
|
||||
if err != nil {
|
||||
rollback(txn)
|
||||
return nil, err
|
||||
}
|
||||
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
|
||||
if err != nil {
|
||||
rollback(txn)
|
||||
return nil, err
|
||||
}
|
||||
var lastEventIDSent string
|
||||
if lastEventNIDSent != 0 {
|
||||
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
|
||||
if err != nil {
|
||||
rollback(txn)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &LatestEventsUpdater{
|
||||
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RoomVersion implements types.RoomRecentEventsUpdater
|
||||
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
|
||||
return u.roomInfo.RoomVersion
|
||||
}
|
||||
|
||||
// LatestEvents implements types.RoomRecentEventsUpdater
|
||||
func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
|
||||
return u.latestEvents
|
||||
}
|
||||
|
||||
// LastEventIDSent implements types.RoomRecentEventsUpdater
|
||||
func (u *LatestEventsUpdater) LastEventIDSent() string {
|
||||
return u.lastEventIDSent
|
||||
}
|
||||
|
||||
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
|
||||
func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
||||
return u.currentStateSnapshotNID
|
||||
}
|
||||
|
||||
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
|
||||
func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||
for _, ref := range previousEventReferences {
|
||||
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||
func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
||||
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
|
||||
}
|
||||
|
||||
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
||||
func (u *LatestEventsUpdater) SetLatestEvents(
|
||||
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
|
||||
currentStateSnapshotNID types.StateSnapshotNID,
|
||||
) error {
|
||||
eventNIDs := make([]types.EventNID, len(latest))
|
||||
for i := range latest {
|
||||
eventNIDs[i] = latest[i].EventNID
|
||||
}
|
||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
|
||||
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
|
||||
}
|
||||
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
|
||||
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
|
||||
roomInfo.StateSnapshotNID = currentStateSnapshotNID
|
||||
roomInfo.IsStub = false
|
||||
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
||||
func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
||||
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
|
||||
}
|
||||
|
||||
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
||||
func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
|
||||
})
|
||||
}
|
||||
|
||||
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
||||
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
|
||||
}
|
||||
262
roomserver/storage/shared/room_updater.go
Normal file
262
roomserver/storage/shared/room_updater.go
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type RoomUpdater struct {
|
||||
transaction
|
||||
d *Database
|
||||
roomInfo *types.RoomInfo
|
||||
latestEvents []types.StateAtEventAndReference
|
||||
lastEventIDSent string
|
||||
currentStateSnapshotNID types.StateSnapshotNID
|
||||
}
|
||||
|
||||
func rollback(txn *sql.Tx) {
|
||||
if txn == nil {
|
||||
return
|
||||
}
|
||||
txn.Rollback() // nolint: errcheck
|
||||
}
|
||||
|
||||
func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *types.RoomInfo) (*RoomUpdater, error) {
|
||||
// If the roomInfo is nil then that means that the room doesn't exist
|
||||
// yet, so we can't do `SelectLatestEventsNIDsForUpdate` because that
|
||||
// would involve locking a row on the table that doesn't exist. Instead
|
||||
// we will just run with a normal database transaction. It'll either
|
||||
// succeed, processing a create event which creates the room, or it won't.
|
||||
if roomInfo == nil {
|
||||
return &RoomUpdater{
|
||||
transaction{ctx, txn}, d, nil, nil, "", 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
||||
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
|
||||
if err != nil {
|
||||
rollback(txn)
|
||||
return nil, err
|
||||
}
|
||||
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
|
||||
if err != nil {
|
||||
rollback(txn)
|
||||
return nil, err
|
||||
}
|
||||
var lastEventIDSent string
|
||||
if lastEventNIDSent != 0 {
|
||||
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
|
||||
if err != nil {
|
||||
rollback(txn)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &RoomUpdater{
|
||||
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Implements sqlutil.Transaction
|
||||
func (u *RoomUpdater) Commit() error {
|
||||
if u.txn == nil { // SQLite mode probably
|
||||
return nil
|
||||
}
|
||||
return u.txn.Commit()
|
||||
}
|
||||
|
||||
// Implements sqlutil.Transaction
|
||||
func (u *RoomUpdater) Rollback() error {
|
||||
if u.txn == nil { // SQLite mode probably
|
||||
return nil
|
||||
}
|
||||
return u.txn.Rollback()
|
||||
}
|
||||
|
||||
// RoomVersion implements types.RoomRecentEventsUpdater
|
||||
func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
|
||||
return u.roomInfo.RoomVersion
|
||||
}
|
||||
|
||||
// LatestEvents implements types.RoomRecentEventsUpdater
|
||||
func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference {
|
||||
return u.latestEvents
|
||||
}
|
||||
|
||||
// LastEventIDSent implements types.RoomRecentEventsUpdater
|
||||
func (u *RoomUpdater) LastEventIDSent() string {
|
||||
return u.lastEventIDSent
|
||||
}
|
||||
|
||||
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
|
||||
func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
||||
return u.currentStateSnapshotNID
|
||||
}
|
||||
|
||||
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
|
||||
func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||
for _, ref := range previousEventReferences {
|
||||
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) Events(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) ([]types.Event, error) {
|
||||
return u.d.events(ctx, u.txn, eventNIDs)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) SnapshotNIDFromEventID(
|
||||
ctx context.Context, eventID string,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) StoreEvent(
|
||||
ctx context.Context, event *gomatrixserverlib.Event,
|
||||
authEventNIDs []types.EventNID, isRejected bool,
|
||||
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
|
||||
return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) StateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
return u.d.stateBlockNIDs(ctx, u.txn, stateNIDs)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) StateEntries(
|
||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||
) ([]types.StateEntryList, error) {
|
||||
return u.d.stateEntries(ctx, u.txn, stateBlockNIDs)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) StateEntriesForTuples(
|
||||
ctx context.Context,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntryList, error) {
|
||||
return u.d.stateEntriesForTuples(ctx, u.txn, stateBlockNIDs, stateKeyTuples)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) AddState(
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
state []types.StateEntry,
|
||||
) (stateNID types.StateSnapshotNID, err error) {
|
||||
return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) SetState(
|
||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||
) error {
|
||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||
return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID)
|
||||
})
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) EventTypeNIDs(
|
||||
ctx context.Context, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
return u.d.eventTypeNIDs(ctx, u.txn, eventTypes)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) EventStateKeyNIDs(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
return u.d.eventStateKeyNIDs(ctx, u.txn, eventStateKeys)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||
return u.d.roomInfo(ctx, u.txn, roomID)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) EventIDs(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) (map[types.EventNID]string, error) {
|
||||
return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) StateAtEventIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateAtEvent, error) {
|
||||
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) StateEntriesForEventIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateEntry, error) {
|
||||
return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
|
||||
return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
|
||||
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
||||
) ([]types.EventNID, error) {
|
||||
return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly)
|
||||
}
|
||||
|
||||
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||
func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
||||
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
|
||||
}
|
||||
|
||||
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
||||
func (u *RoomUpdater) SetLatestEvents(
|
||||
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
|
||||
currentStateSnapshotNID types.StateSnapshotNID,
|
||||
) error {
|
||||
eventNIDs := make([]types.EventNID, len(latest))
|
||||
for i := range latest {
|
||||
eventNIDs[i] = latest[i].EventNID
|
||||
}
|
||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
|
||||
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
|
||||
}
|
||||
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
|
||||
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
|
||||
roomInfo.StateSnapshotNID = currentStateSnapshotNID
|
||||
roomInfo.IsStub = false
|
||||
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
||||
func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
||||
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
|
||||
}
|
||||
|
||||
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
||||
func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
|
||||
})
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
||||
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
|
||||
}
|
||||
|
|
@ -26,23 +26,23 @@ import (
|
|||
const redactionsArePermanent = true
|
||||
|
||||
type Database struct {
|
||||
DB *sql.DB
|
||||
Cache caching.RoomServerCaches
|
||||
Writer sqlutil.Writer
|
||||
EventsTable tables.Events
|
||||
EventJSONTable tables.EventJSON
|
||||
EventTypesTable tables.EventTypes
|
||||
EventStateKeysTable tables.EventStateKeys
|
||||
RoomsTable tables.Rooms
|
||||
StateSnapshotTable tables.StateSnapshot
|
||||
StateBlockTable tables.StateBlock
|
||||
RoomAliasesTable tables.RoomAliases
|
||||
PrevEventsTable tables.PreviousEvents
|
||||
InvitesTable tables.Invites
|
||||
MembershipTable tables.Membership
|
||||
PublishedTable tables.Published
|
||||
RedactionsTable tables.Redactions
|
||||
GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error)
|
||||
DB *sql.DB
|
||||
Cache caching.RoomServerCaches
|
||||
Writer sqlutil.Writer
|
||||
EventsTable tables.Events
|
||||
EventJSONTable tables.EventJSON
|
||||
EventTypesTable tables.EventTypes
|
||||
EventStateKeysTable tables.EventStateKeys
|
||||
RoomsTable tables.Rooms
|
||||
StateSnapshotTable tables.StateSnapshot
|
||||
StateBlockTable tables.StateBlock
|
||||
RoomAliasesTable tables.RoomAliases
|
||||
PrevEventsTable tables.PreviousEvents
|
||||
InvitesTable tables.Invites
|
||||
MembershipTable tables.Membership
|
||||
PublishedTable tables.Published
|
||||
RedactionsTable tables.Redactions
|
||||
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
|
||||
}
|
||||
|
||||
func (d *Database) SupportsConcurrentRoomInputs() bool {
|
||||
|
|
@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
|
|||
|
||||
func (d *Database) EventTypeNIDs(
|
||||
ctx context.Context, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
return d.eventTypeNIDs(ctx, nil, eventTypes)
|
||||
}
|
||||
|
||||
func (d *Database) eventTypeNIDs(
|
||||
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
result := make(map[string]types.EventTypeNID)
|
||||
remaining := []string{}
|
||||
|
|
@ -62,7 +68,7 @@ func (d *Database) EventTypeNIDs(
|
|||
}
|
||||
}
|
||||
if len(remaining) > 0 {
|
||||
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining)
|
||||
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -77,11 +83,17 @@ func (d *Database) EventTypeNIDs(
|
|||
func (d *Database) EventStateKeys(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs)
|
||||
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
|
||||
}
|
||||
|
||||
func (d *Database) EventStateKeyNIDs(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
|
||||
}
|
||||
|
||||
func (d *Database) eventStateKeyNIDs(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
result := make(map[string]types.EventStateKeyNID)
|
||||
remaining := []string{}
|
||||
|
|
@ -93,7 +105,7 @@ func (d *Database) EventStateKeyNIDs(
|
|||
}
|
||||
}
|
||||
if len(remaining) > 0 {
|
||||
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining)
|
||||
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -108,23 +120,31 @@ func (d *Database) EventStateKeyNIDs(
|
|||
func (d *Database) StateEntriesForEventIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateEntry, error) {
|
||||
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs)
|
||||
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs)
|
||||
}
|
||||
|
||||
func (d *Database) StateEntriesForTuples(
|
||||
ctx context.Context,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntryList, error) {
|
||||
return d.stateEntriesForTuples(ctx, nil, stateBlockNIDs, stateKeyTuples)
|
||||
}
|
||||
|
||||
func (d *Database) stateEntriesForTuples(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntryList, error) {
|
||||
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
|
||||
ctx, stateBlockNIDs,
|
||||
ctx, txn, stateBlockNIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
|
||||
}
|
||||
lists := []types.StateEntryList{}
|
||||
for i, entry := range entries {
|
||||
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples)
|
||||
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, stateKeyTuples)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
|
||||
}
|
||||
|
|
@ -137,10 +157,14 @@ func (d *Database) StateEntriesForTuples(
|
|||
}
|
||||
|
||||
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||
return d.roomInfo(ctx, nil, roomID)
|
||||
}
|
||||
|
||||
func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
|
||||
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
|
||||
return &roomInfo, nil
|
||||
}
|
||||
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID)
|
||||
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
|
||||
if err == nil && roomInfo != nil {
|
||||
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
|
||||
d.Cache.StoreRoomInfo(roomID, *roomInfo)
|
||||
|
|
@ -153,13 +177,22 @@ func (d *Database) AddState(
|
|||
roomNID types.RoomNID,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
state []types.StateEntry,
|
||||
) (stateNID types.StateSnapshotNID, err error) {
|
||||
return d.addState(ctx, nil, roomNID, stateBlockNIDs, state)
|
||||
}
|
||||
|
||||
func (d *Database) addState(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
state []types.StateEntry,
|
||||
) (stateNID types.StateSnapshotNID, err error) {
|
||||
if len(stateBlockNIDs) > 0 && len(state) > 0 {
|
||||
// Check to see if the event already appears in any of the existing state
|
||||
// blocks. If it does then we should not add it again, as this will just
|
||||
// result in excess state blocks and snapshots.
|
||||
// TODO: Investigate why this is happening - probably input_events.go!
|
||||
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
|
||||
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
|
||||
if berr != nil {
|
||||
return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
|
||||
}
|
||||
|
|
@ -180,7 +213,7 @@ func (d *Database) AddState(
|
|||
}
|
||||
}
|
||||
}
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||
if len(state) > 0 {
|
||||
// If there's any state left to add then let's add new blocks.
|
||||
var stateBlockNID types.StateBlockNID
|
||||
|
|
@ -205,7 +238,13 @@ func (d *Database) AddState(
|
|||
func (d *Database) EventNIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) (map[string]types.EventNID, error) {
|
||||
return d.EventsTable.BulkSelectEventNID(ctx, eventIDs)
|
||||
return d.eventNIDs(ctx, nil, eventIDs)
|
||||
}
|
||||
|
||||
func (d *Database) eventNIDs(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) (map[string]types.EventNID, error) {
|
||||
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
|
||||
}
|
||||
|
||||
func (d *Database) SetState(
|
||||
|
|
@ -219,24 +258,34 @@ func (d *Database) SetState(
|
|||
func (d *Database) StateAtEventIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateAtEvent, error) {
|
||||
return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs)
|
||||
return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
|
||||
}
|
||||
|
||||
func (d *Database) SnapshotNIDFromEventID(
|
||||
ctx context.Context, eventID string,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
_, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID)
|
||||
return d.snapshotNIDFromEventID(ctx, nil, eventID)
|
||||
}
|
||||
|
||||
func (d *Database) snapshotNIDFromEventID(
|
||||
ctx context.Context, txn *sql.Tx, eventID string,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
|
||||
return stateNID, err
|
||||
}
|
||||
|
||||
func (d *Database) EventIDs(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) (map[types.EventNID]string, error) {
|
||||
return d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
|
||||
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||
}
|
||||
|
||||
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
|
||||
nidMap, err := d.EventNIDs(ctx, eventIDs)
|
||||
return d.eventsFromIDs(ctx, nil, eventIDs)
|
||||
}
|
||||
|
||||
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
|
||||
nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -246,7 +295,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type
|
|||
nids = append(nids, nid)
|
||||
}
|
||||
|
||||
return d.Events(ctx, nids)
|
||||
return d.events(ctx, txn, nids)
|
||||
}
|
||||
|
||||
func (d *Database) LatestEventIDs(
|
||||
|
|
@ -271,21 +320,33 @@ func (d *Database) LatestEventIDs(
|
|||
func (d *Database) StateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs)
|
||||
return d.stateBlockNIDs(ctx, nil, stateNIDs)
|
||||
}
|
||||
|
||||
func (d *Database) stateBlockNIDs(
|
||||
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
|
||||
}
|
||||
|
||||
func (d *Database) StateEntries(
|
||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||
) ([]types.StateEntryList, error) {
|
||||
return d.stateEntries(ctx, nil, stateBlockNIDs)
|
||||
}
|
||||
|
||||
func (d *Database) stateEntries(
|
||||
ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
|
||||
) ([]types.StateEntryList, error) {
|
||||
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
|
||||
ctx, stateBlockNIDs,
|
||||
ctx, txn, stateBlockNIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
|
||||
}
|
||||
lists := make([]types.StateEntryList, 0, len(entries))
|
||||
for i, entry := range entries {
|
||||
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil)
|
||||
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
|
||||
}
|
||||
|
|
@ -304,17 +365,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string
|
|||
}
|
||||
|
||||
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
|
||||
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias)
|
||||
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias)
|
||||
}
|
||||
|
||||
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
|
||||
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID)
|
||||
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID)
|
||||
}
|
||||
|
||||
func (d *Database) GetCreatorIDForAlias(
|
||||
ctx context.Context, alias string,
|
||||
) (string, error) {
|
||||
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias)
|
||||
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias)
|
||||
}
|
||||
|
||||
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
|
||||
|
|
@ -335,7 +396,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
|
|||
|
||||
senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
|
||||
d.MembershipTable.SelectMembershipFromRoomAndTarget(
|
||||
ctx, roomNID, requestSenderUserNID,
|
||||
ctx, nil, roomNID, requestSenderUserNID,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
// The user has never been a member of that room
|
||||
|
|
@ -349,14 +410,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
|
|||
|
||||
func (d *Database) GetMembershipEventNIDsForRoom(
|
||||
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
||||
) ([]types.EventNID, error) {
|
||||
return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly)
|
||||
}
|
||||
|
||||
func (d *Database) getMembershipEventNIDsForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
||||
) ([]types.EventNID, error) {
|
||||
if joinOnly {
|
||||
return d.MembershipTable.SelectMembershipsFromRoomAndMembership(
|
||||
ctx, roomNID, tables.MembershipStateJoin, localOnly,
|
||||
ctx, txn, roomNID, tables.MembershipStateJoin, localOnly,
|
||||
)
|
||||
}
|
||||
|
||||
return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly)
|
||||
return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
|
||||
}
|
||||
|
||||
func (d *Database) GetInvitesForUser(
|
||||
|
|
@ -364,22 +431,28 @@ func (d *Database) GetInvitesForUser(
|
|||
roomNID types.RoomNID,
|
||||
targetUserNID types.EventStateKeyNID,
|
||||
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
|
||||
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
|
||||
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
|
||||
}
|
||||
|
||||
func (d *Database) Events(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) ([]types.Event, error) {
|
||||
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
|
||||
return d.events(ctx, nil, eventNIDs)
|
||||
}
|
||||
|
||||
func (d *Database) events(
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) ([]types.Event, error) {
|
||||
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
|
||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs)
|
||||
if err != nil {
|
||||
eventIDs = map[types.EventNID]string{}
|
||||
}
|
||||
var roomNIDs map[types.EventNID]types.RoomNID
|
||||
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs)
|
||||
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -398,7 +471,7 @@ func (d *Database) Events(
|
|||
}
|
||||
fetchNIDList = append(fetchNIDList, n)
|
||||
}
|
||||
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList)
|
||||
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -440,19 +513,19 @@ func (d *Database) MembershipUpdater(
|
|||
return updater, err
|
||||
}
|
||||
|
||||
func (d *Database) GetLatestEventsForUpdate(
|
||||
ctx context.Context, roomInfo types.RoomInfo,
|
||||
) (*LatestEventsUpdater, error) {
|
||||
if d.GetLatestEventsForUpdateFn != nil {
|
||||
return d.GetLatestEventsForUpdateFn(ctx, roomInfo)
|
||||
func (d *Database) GetRoomUpdater(
|
||||
ctx context.Context, roomInfo *types.RoomInfo,
|
||||
) (*RoomUpdater, error) {
|
||||
if d.GetRoomUpdaterFn != nil {
|
||||
return d.GetRoomUpdaterFn(ctx, roomInfo)
|
||||
}
|
||||
txn, err := d.DB.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var updater *LatestEventsUpdater
|
||||
var updater *RoomUpdater
|
||||
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo)
|
||||
updater, err = NewRoomUpdater(ctx, d, txn, roomInfo)
|
||||
return err
|
||||
})
|
||||
return updater, err
|
||||
|
|
@ -461,6 +534,13 @@ func (d *Database) GetLatestEventsForUpdate(
|
|||
func (d *Database) StoreEvent(
|
||||
ctx context.Context, event *gomatrixserverlib.Event,
|
||||
authEventNIDs []types.EventNID, isRejected bool,
|
||||
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
|
||||
return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected)
|
||||
}
|
||||
|
||||
func (d *Database) storeEvent(
|
||||
ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event,
|
||||
authEventNIDs []types.EventNID, isRejected bool,
|
||||
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
|
||||
var (
|
||||
roomNID types.RoomNID
|
||||
|
|
@ -472,8 +552,11 @@ func (d *Database) StoreEvent(
|
|||
redactedEventID string
|
||||
err error
|
||||
)
|
||||
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var txn *sql.Tx
|
||||
if updater != nil {
|
||||
txn = updater.txn
|
||||
}
|
||||
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||
// TODO: Here we should aim to have two different code paths for new rooms
|
||||
// vs existing ones.
|
||||
|
||||
|
|
@ -546,42 +629,32 @@ func (d *Database) StoreEvent(
|
|||
// events updater because it somewhat works as a mutex, ensuring
|
||||
// that there's a row-level lock on the latest room events (well,
|
||||
// on Postgres at least).
|
||||
var roomInfo *types.RoomInfo
|
||||
var updater *LatestEventsUpdater
|
||||
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
|
||||
roomInfo, err = d.RoomInfo(ctx, event.RoomID())
|
||||
if err != nil {
|
||||
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
|
||||
}
|
||||
if roomInfo == nil && len(prevEvents) > 0 {
|
||||
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
|
||||
}
|
||||
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
|
||||
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
|
||||
// function only does SELECTs though so the created txn (at this point) is just a read txn like
|
||||
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
|
||||
// to do writes however then this will need to go inside `Writer.Do`.
|
||||
updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo)
|
||||
if err != nil {
|
||||
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err)
|
||||
}
|
||||
// Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents
|
||||
// and EndTransaction in a writer then it's possible for a new write txn to be made between the two
|
||||
// function calls which will then fail with 'database is locked'. This new write txn would HAVE to be
|
||||
// something like SetRoomAlias/RemoveRoomAlias as normal input events are already done sequentially due to
|
||||
// SupportsConcurrentRoomInputs() == false on sqlite, though this does not apply to setting room aliases
|
||||
// as they don't go via InputRoomEvents
|
||||
err = d.Writer.Do(d.DB, updater.txn, func(txn *sql.Tx) error {
|
||||
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
|
||||
return fmt.Errorf("updater.StorePreviousEvents: %w", err)
|
||||
succeeded := false
|
||||
if updater == nil {
|
||||
var roomInfo *types.RoomInfo
|
||||
roomInfo, err = d.RoomInfo(ctx, event.RoomID())
|
||||
if err != nil {
|
||||
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
|
||||
}
|
||||
succeeded := true
|
||||
err = sqlutil.EndTransaction(updater, &succeeded)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, types.StateAtEvent{}, nil, "", err
|
||||
if roomInfo == nil && len(prevEvents) > 0 {
|
||||
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
|
||||
}
|
||||
updater, err = d.GetRoomUpdater(ctx, roomInfo)
|
||||
if err != nil {
|
||||
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
|
||||
}
|
||||
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
|
||||
}
|
||||
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
|
||||
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
|
||||
}
|
||||
succeeded = true
|
||||
}
|
||||
|
||||
return eventNID, roomNID, types.StateAtEvent{
|
||||
|
|
@ -603,7 +676,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool)
|
|||
}
|
||||
|
||||
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
|
||||
return d.PublishedTable.SelectAllPublishedRooms(ctx, true)
|
||||
return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
|
||||
}
|
||||
|
||||
func (d *Database) assignRoomNID(
|
||||
|
|
@ -875,14 +948,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
|
|||
eventNIDs = append(eventNIDs, e.EventNID)
|
||||
}
|
||||
}
|
||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
|
||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||
if err != nil {
|
||||
eventIDs = map[types.EventNID]string{}
|
||||
}
|
||||
// return the event requested
|
||||
for _, e := range entries {
|
||||
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
|
||||
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID})
|
||||
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -922,11 +995,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
|
|||
}
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
||||
}
|
||||
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState)
|
||||
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
|
||||
}
|
||||
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs)
|
||||
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)
|
||||
}
|
||||
|
|
@ -945,7 +1018,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||
}
|
||||
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
|
||||
// isn't a failure.
|
||||
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes)
|
||||
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
|
||||
}
|
||||
|
|
@ -965,7 +1038,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||
|
||||
}
|
||||
|
||||
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys)
|
||||
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err)
|
||||
}
|
||||
|
|
@ -999,11 +1072,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||
}
|
||||
}
|
||||
}
|
||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
|
||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||
if err != nil {
|
||||
eventIDs = map[types.EventNID]string{}
|
||||
}
|
||||
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
|
||||
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err)
|
||||
}
|
||||
|
|
@ -1027,11 +1100,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||
|
||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs)
|
||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs)
|
||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1041,7 +1114,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
|
|||
stateKeyNIDs[i] = nid
|
||||
i++
|
||||
}
|
||||
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs)
|
||||
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1057,12 +1130,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
|
|||
|
||||
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
||||
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
||||
return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID)
|
||||
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
|
||||
}
|
||||
|
||||
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
|
||||
func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName)
|
||||
return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName)
|
||||
}
|
||||
|
||||
// GetKnownUsers searches all users that userID knows about.
|
||||
|
|
@ -1071,17 +1144,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit)
|
||||
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
|
||||
}
|
||||
|
||||
// GetKnownRooms returns a list of all rooms we know about.
|
||||
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||
return d.RoomsTable.SelectRoomIDs(ctx)
|
||||
return d.RoomsTable.SelectRoomIDs(ctx, nil)
|
||||
}
|
||||
|
||||
// ForgetRoom sets a users room to forgotten
|
||||
func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error {
|
||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID})
|
||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,15 +76,20 @@ func (s *eventJSONStatements) InsertEventJSON(
|
|||
}
|
||||
|
||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) ([]tables.EventJSONPair, error) {
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,15 +112,20 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
iEventStateKeys := make([]interface{}, len(eventStateKeys))
|
||||
for k, v := range eventStateKeys {
|
||||
iEventStateKeys[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -138,15 +143,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
|
||||
for k, v := range eventStateKeyNIDs {
|
||||
iEventStateKeyNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID(
|
|||
}
|
||||
|
||||
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||
ctx context.Context, eventTypes []string,
|
||||
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
///////////////
|
||||
iEventTypes := make([]interface{}, len(eventTypes))
|
||||
|
|
@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
///////////////
|
||||
|
||||
rows, err := selectPrep.QueryContext(ctx, iEventTypes...)
|
||||
rows, err := stmt.QueryContext(ctx, iEventTypes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent(
|
|||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) BulkSelectStateEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) ([]types.StateEntry, error) {
|
||||
///////////////
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
|
|
@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
///////////////
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
|
|
@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
|||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) BulkSelectStateEventByNID(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntry, error) {
|
||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||
|
|
@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("s.db.Prepare: %w", err)
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
rows, err := selectStmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
|
||||
|
|
@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
|||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||
func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) ([]types.StateAtEvent, error) {
|
||||
///////////////
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
|
|
@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
///////////////
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
if err != nil {
|
||||
|
|
@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectPrep = sqlutil.TxStmt(txn, selectPrep)
|
||||
//////////////
|
||||
|
||||
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
|
||||
|
|
@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference(
|
|||
}
|
||||
|
||||
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
///////////////
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for k, v := range eventNIDs {
|
||||
|
|
@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
///////////////
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||
|
|
@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
|
|||
|
||||
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||
// If an event ID is not in the database then it is omitted from the map.
|
||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
|
||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
|
||||
///////////////
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for k, v := range eventIDs {
|
||||
|
|
@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
///////////////
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
if err != nil {
|
||||
|
|
@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
|
|||
}
|
||||
|
||||
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) (map[types.EventNID]types.RoomNID, error) {
|
||||
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
|
||||
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for i, v := range eventNIDs {
|
||||
iEventNIDs[i] = v
|
||||
|
|
|
|||
|
|
@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
|
|||
}
|
||||
|
||||
func (s *inviteStatements) InsertInviteEvent(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
inviteEventID string, roomNID types.RoomNID,
|
||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||
inviteEventJSON []byte,
|
||||
) (bool, error) {
|
||||
|
|
@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent(
|
|||
}
|
||||
|
||||
func (s *inviteStatements) UpdateInviteRetired(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventIDs []string, err error) {
|
||||
// gather all the event IDs we will retire
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
||||
|
|
@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired(
|
|||
|
||||
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||
) ([]types.EventStateKeyNID, []string, error) {
|
||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, targetUserNID, roomNID,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
|
||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
).Scan(&membership, &eventNID, &forgotten)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var selectStmt *sql.Stmt
|
||||
|
|
@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
|||
} else {
|
||||
selectStmt = s.selectMembershipsFromRoomStmt
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
rows, err := selectStmt.QueryContext(ctx, roomNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var stmt *sql.Stmt
|
||||
|
|
@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
|||
} else {
|
||||
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
||||
}
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, membership)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -276,13 +280,19 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||
return roomNIDs, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -299,8 +309,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
|||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
|
||||
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -317,8 +328,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
|
|||
}
|
||||
|
||||
func (s *membershipStatements) UpdateForgetMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
forget bool,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
|
||||
|
|
@ -327,9 +338,10 @@ func (s *membershipStatements) UpdateForgetMembership(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
||||
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) {
|
||||
var nid types.RoomNID
|
||||
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
|
||||
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
|
|
@ -340,9 +352,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
|
|||
return found, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
var nid types.RoomNID
|
||||
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
|
||||
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
|
|
|
|||
|
|
@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished(
|
|||
}
|
||||
|
||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (published bool, err error) {
|
||||
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
|
|
@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
|
|||
}
|
||||
|
||||
func (s *publishedStatements) SelectAllPublishedRooms(
|
||||
ctx context.Context, published bool,
|
||||
ctx context.Context, txn *sql.Tx, published bool,
|
||||
) ([]string, error) {
|
||||
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
|
||||
rows, err := stmt.QueryContext(ctx, published)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
ctx context.Context, txn *sql.Tx, alias string,
|
||||
) (roomID string, err error) {
|
||||
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
|
||||
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (aliases []string, err error) {
|
||||
aliases = []string{}
|
||||
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
ctx context.Context, txn *sql.Tx, alias string,
|
||||
) (creatorID string, err error) {
|
||||
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
|
||||
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
|||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
|
||||
var info types.RoomInfo
|
||||
var latestNIDsJSON string
|
||||
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
|
||||
err := stmt.QueryRowContext(ctx, roomID).Scan(
|
||||
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
|||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||
ctx context.Context, roomNIDs []types.RoomNID,
|
||||
ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
|
||||
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
|
||||
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
|
|
@ -252,13 +255,19 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -274,13 +283,19 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
|
|||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
|
||||
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||
for i, v := range roomIDs {
|
||||
iRoomIDs[i] = v
|
||||
}
|
||||
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
|||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkInsertStateData(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
entries types.StateEntries,
|
||||
) (id types.StateBlockNID, err error) {
|
||||
entries = entries[:util.SortAndUnique(entries)]
|
||||
|
|
@ -94,14 +93,15 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
|||
if err != nil {
|
||||
return 0, fmt.Errorf("json.Marshal: %w", err)
|
||||
}
|
||||
err = s.insertStateDataStmt.QueryRowContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
ctx, nids.Hash(), js,
|
||||
).Scan(&id)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||
ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
|
||||
ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
|
||||
) ([][]types.EventNID, error) {
|
||||
intfs := make([]interface{}, len(stateBlockNIDs))
|
||||
for i := range stateBlockNIDs {
|
||||
|
|
@ -112,6 +112,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
rows, err := selectStmt.QueryContext(ctx, intfs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState(
|
|||
}
|
||||
|
||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
nids := make([]interface{}, len(stateNIDs))
|
||||
for k, v := range stateNIDs {
|
||||
|
|
@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
|
|||
return err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: db,
|
||||
Cache: cache,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
EventsTable: events,
|
||||
EventTypesTable: eventTypes,
|
||||
EventStateKeysTable: eventStateKeys,
|
||||
EventJSONTable: eventJSON,
|
||||
RoomsTable: rooms,
|
||||
StateBlockTable: stateBlock,
|
||||
StateSnapshotTable: stateSnapshot,
|
||||
PrevEventsTable: prevEvents,
|
||||
RoomAliasesTable: roomAliases,
|
||||
InvitesTable: invites,
|
||||
MembershipTable: membership,
|
||||
PublishedTable: published,
|
||||
RedactionsTable: redactions,
|
||||
GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate,
|
||||
DB: db,
|
||||
Cache: cache,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
EventsTable: events,
|
||||
EventTypesTable: eventTypes,
|
||||
EventStateKeysTable: eventStateKeys,
|
||||
EventJSONTable: eventJSON,
|
||||
RoomsTable: rooms,
|
||||
StateBlockTable: stateBlock,
|
||||
StateSnapshotTable: stateSnapshot,
|
||||
PrevEventsTable: prevEvents,
|
||||
RoomAliasesTable: roomAliases,
|
||||
InvitesTable: invites,
|
||||
MembershipTable: membership,
|
||||
PublishedTable: published,
|
||||
RedactionsTable: redactions,
|
||||
GetRoomUpdaterFn: d.GetRoomUpdater,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (d *Database) GetLatestEventsForUpdate(
|
||||
ctx context.Context, roomInfo types.RoomInfo,
|
||||
) (*shared.LatestEventsUpdater, error) {
|
||||
func (d *Database) GetRoomUpdater(
|
||||
ctx context.Context, roomInfo *types.RoomInfo,
|
||||
) (*shared.RoomUpdater, error) {
|
||||
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
||||
// multiple write transactions on sqlite. The code will perform additional
|
||||
// write transactions independent of this one which will consistently cause
|
||||
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
||||
// same DB anyway, and we only execute updates sequentially, the only worries
|
||||
// are for rolling back when things go wrong. (atomicity)
|
||||
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo)
|
||||
return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo)
|
||||
}
|
||||
|
||||
func (d *Database) MembershipUpdater(
|
||||
|
|
|
|||
|
|
@ -18,20 +18,20 @@ type EventJSONPair struct {
|
|||
type EventJSON interface {
|
||||
// Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions).
|
||||
InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error
|
||||
BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error)
|
||||
BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error)
|
||||
}
|
||||
|
||||
type EventTypes interface {
|
||||
InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
|
||||
SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
|
||||
BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||
BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||
}
|
||||
|
||||
type EventStateKeys interface {
|
||||
InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
|
||||
SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
|
||||
BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||
BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
||||
BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||
BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
||||
}
|
||||
|
||||
type Events interface {
|
||||
|
|
@ -42,12 +42,12 @@ type Events interface {
|
|||
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
|
||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
|
||||
BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
|
||||
BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error)
|
||||
BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
|
||||
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||
BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
||||
BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error)
|
||||
UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
|
||||
SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error)
|
||||
UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error
|
||||
|
|
@ -55,12 +55,12 @@ type Events interface {
|
|||
BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error)
|
||||
BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error)
|
||||
// BulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||
BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||
BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||
// If an event ID is not in the database then it is omitted from the map.
|
||||
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
|
||||
BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
|
||||
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
|
||||
SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
|
||||
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
|
||||
}
|
||||
|
||||
type Rooms interface {
|
||||
|
|
@ -69,29 +69,29 @@ type Rooms interface {
|
|||
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
|
||||
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
|
||||
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
||||
SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
|
||||
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||
SelectRoomIDs(ctx context.Context) ([]string, error)
|
||||
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
|
||||
BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error)
|
||||
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
|
||||
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
|
||||
SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
|
||||
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
|
||||
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
|
||||
}
|
||||
|
||||
type StateSnapshot interface {
|
||||
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
|
||||
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||
BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||
}
|
||||
|
||||
type StateBlock interface {
|
||||
BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error)
|
||||
BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
|
||||
BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
|
||||
//BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
||||
}
|
||||
|
||||
type RoomAliases interface {
|
||||
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)
|
||||
SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error)
|
||||
SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error)
|
||||
SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error)
|
||||
SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error)
|
||||
SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error)
|
||||
SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error)
|
||||
DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error)
|
||||
}
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ type Invites interface {
|
|||
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
|
||||
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
|
||||
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
|
||||
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
|
||||
SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
|
||||
}
|
||||
|
||||
type MembershipState int64
|
||||
|
|
@ -121,24 +121,24 @@ const (
|
|||
type Membership interface {
|
||||
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
|
||||
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
|
||||
SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
|
||||
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||
SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
|
||||
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
|
||||
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
|
||||
// counts of how many rooms they are joined.
|
||||
SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
|
||||
SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
|
||||
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
|
||||
SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
|
||||
SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
||||
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
}
|
||||
|
||||
type Published interface {
|
||||
UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error)
|
||||
SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error)
|
||||
SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error)
|
||||
SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error)
|
||||
SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error)
|
||||
}
|
||||
|
||||
type RedactionInfo struct {
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ package config
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
type JetStream struct {
|
||||
|
|
@ -25,8 +23,8 @@ func (c *JetStream) TopicFor(name string) string {
|
|||
return fmt.Sprintf("%s%s", c.TopicPrefix, name)
|
||||
}
|
||||
|
||||
func (c *JetStream) Durable(name string) nats.SubOpt {
|
||||
return nats.Durable(c.TopicFor(name))
|
||||
func (c *JetStream) Durable(name string) string {
|
||||
return c.TopicFor(name)
|
||||
}
|
||||
|
||||
func (c *JetStream) Defaults(generate bool) {
|
||||
|
|
|
|||
|
|
@ -1,12 +1,81 @@
|
|||
package jetstream
|
||||
|
||||
import "github.com/nats-io/nats.go"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
func WithJetStreamMessage(msg *nats.Msg, f func(msg *nats.Msg) bool) {
|
||||
_ = msg.InProgress()
|
||||
if f(msg) {
|
||||
_ = msg.Ack()
|
||||
} else {
|
||||
_ = msg.Nak()
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func JetStreamConsumer(
|
||||
ctx context.Context, js nats.JetStreamContext, subj, durable string,
|
||||
f func(ctx context.Context, msg *nats.Msg) bool,
|
||||
opts ...nats.SubOpt,
|
||||
) error {
|
||||
defer func() {
|
||||
// If there are existing consumers from before they were pull
|
||||
// consumers, we need to clean up the old push consumers. However,
|
||||
// in order to not affect the interest-based policies, we need to
|
||||
// do this *after* creating the new pull consumers, which have
|
||||
// "Pull" suffixed to their name.
|
||||
if _, err := js.ConsumerInfo(subj, durable); err == nil {
|
||||
if err := js.DeleteConsumer(subj, durable); err != nil {
|
||||
logrus.WithContext(ctx).Warnf("Failed to clean up old consumer %q", durable)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
name := durable + "Pull"
|
||||
sub, err := js.PullSubscribe(subj, name, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nats.SubscribeSync: %w", err)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
// The context behaviour here is surprising — we supply a context
|
||||
// so that we can interrupt the fetch if we want, but NATS will still
|
||||
// enforce its own deadline (roughly 5 seconds by default). Therefore
|
||||
// it is our responsibility to check whether our context expired or
|
||||
// not when a context error is returned. Footguns. Footguns everywhere.
|
||||
msgs, err := sub.Fetch(1, nats.Context(ctx))
|
||||
if err != nil {
|
||||
if err == context.Canceled || err == context.DeadlineExceeded {
|
||||
// Work out whether it was the JetStream context that expired
|
||||
// or whether it was our supplied context.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// The supplied context expired, so we want to stop the
|
||||
// consumer altogether.
|
||||
return
|
||||
default:
|
||||
// The JetStream context expired, so the fetch probably
|
||||
// just timed out and we should try again.
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// Something else went wrong, so we'll panic.
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Fatal(err)
|
||||
}
|
||||
}
|
||||
if len(msgs) < 1 {
|
||||
continue
|
||||
}
|
||||
msg := msgs[0]
|
||||
if err = msg.InProgress(); err != nil {
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
|
||||
continue
|
||||
}
|
||||
if f(ctx, msg) {
|
||||
if err = msg.Ack(); err != nil {
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Ack: %w", err))
|
||||
}
|
||||
} else {
|
||||
if err = msg.Nak(); err != nil {
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,20 +5,17 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
saramajs "github.com/S7evinK/saramajetstream"
|
||||
natsserver "github.com/nats-io/nats-server/v2/server"
|
||||
"github.com/nats-io/nats.go"
|
||||
natsclient "github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
var natsServer *natsserver.Server
|
||||
var natsServerMutex sync.Mutex
|
||||
|
||||
func Prepare(cfg *config.JetStream) (nats.JetStreamContext, sarama.Consumer, sarama.SyncProducer) {
|
||||
func Prepare(cfg *config.JetStream) natsclient.JetStreamContext {
|
||||
// check if we need an in-process NATS Server
|
||||
if len(cfg.Addresses) != 0 {
|
||||
return setupNATS(cfg, nil)
|
||||
|
|
@ -52,20 +49,20 @@ func Prepare(cfg *config.JetStream) (nats.JetStreamContext, sarama.Consumer, sar
|
|||
return setupNATS(cfg, nc)
|
||||
}
|
||||
|
||||
func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContext, sarama.Consumer, sarama.SyncProducer) {
|
||||
func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) natsclient.JetStreamContext {
|
||||
if nc == nil {
|
||||
var err error
|
||||
nc, err = nats.Connect(strings.Join(cfg.Addresses, ","))
|
||||
nc, err = natsclient.Connect(strings.Join(cfg.Addresses, ","))
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panic("Unable to connect to NATS")
|
||||
return nil, nil, nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
s, err := nc.JetStream()
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panic("Unable to get JetStream context")
|
||||
return nil, nil, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, stream := range streams { // streams are defined in streams.go
|
||||
|
|
@ -80,7 +77,7 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContex
|
|||
// If we're trying to keep everything in memory (e.g. unit tests)
|
||||
// then overwrite the storage policy.
|
||||
if cfg.InMemory {
|
||||
stream.Storage = nats.MemoryStorage
|
||||
stream.Storage = natsclient.MemoryStorage
|
||||
}
|
||||
|
||||
// Namespace the streams without modifying the original streams
|
||||
|
|
@ -93,7 +90,5 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContex
|
|||
}
|
||||
}
|
||||
|
||||
consumer := saramajs.NewJetStreamConsumer(nc, s, "")
|
||||
producer := saramajs.NewJetStreamProducer(nc, s, "")
|
||||
return s, consumer, producer
|
||||
return s
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ import (
|
|||
type OutputClientDataConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
stream types.StreamProvider
|
||||
|
|
@ -63,45 +63,45 @@ func NewOutputClientDataConsumer(
|
|||
|
||||
// Start consuming from room servers
|
||||
func (s *OutputClientDataConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable)
|
||||
return err
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
// onMessage is called when the sync server receives a new event from the client API server output log.
|
||||
// It is not safe for this function to be called from multiple goroutines, or else the
|
||||
// sync stream position may race and be incorrectly calculated.
|
||||
func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
userID := msg.Header.Get(jetstream.UserID)
|
||||
var output eventutil.AccountData
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("client API server output log: message parse failure")
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"type": output.Type,
|
||||
"room_id": output.RoomID,
|
||||
}).Debug("Received data from client API server")
|
||||
|
||||
streamPos, err := s.db.UpsertAccountData(
|
||||
s.ctx, userID, output.RoomID, output.Type,
|
||||
)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
log.WithFields(log.Fields{
|
||||
"type": output.Type,
|
||||
"room_id": output.RoomID,
|
||||
log.ErrorKey: err,
|
||||
}).Panicf("could not save account data")
|
||||
}
|
||||
|
||||
s.stream.Advance(streamPos)
|
||||
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
|
||||
|
||||
func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
userID := msg.Header.Get(jetstream.UserID)
|
||||
var output eventutil.AccountData
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("client API server output log: message parse failure")
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"type": output.Type,
|
||||
"room_id": output.RoomID,
|
||||
}).Debug("Received data from client API server")
|
||||
|
||||
streamPos, err := s.db.UpsertAccountData(
|
||||
s.ctx, userID, output.RoomID, output.Type,
|
||||
)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
log.WithFields(log.Fields{
|
||||
"type": output.Type,
|
||||
"room_id": output.RoomID,
|
||||
log.ErrorKey: err,
|
||||
}).Panicf("could not save account data")
|
||||
}
|
||||
|
||||
s.stream.Advance(streamPos)
|
||||
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ import (
|
|||
type OutputReceiptEventConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
stream types.StreamProvider
|
||||
|
|
@ -64,36 +64,36 @@ func NewOutputReceiptEventConsumer(
|
|||
|
||||
// Start consuming from EDU api
|
||||
func (s *OutputReceiptEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable)
|
||||
return err
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
var output api.OutputReceiptEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("EDU server output log: message parse failure")
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
}
|
||||
|
||||
streamPos, err := s.db.StoreReceipt(
|
||||
s.ctx,
|
||||
output.RoomID,
|
||||
output.Type,
|
||||
output.UserID,
|
||||
output.EventID,
|
||||
output.Timestamp,
|
||||
)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
}
|
||||
|
||||
s.stream.Advance(streamPos)
|
||||
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
|
||||
|
||||
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var output api.OutputReceiptEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("EDU server output log: message parse failure")
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
streamPos, err := s.db.StoreReceipt(
|
||||
s.ctx,
|
||||
output.RoomID,
|
||||
output.Type,
|
||||
output.UserID,
|
||||
output.EventID,
|
||||
output.Timestamp,
|
||||
)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
}
|
||||
|
||||
s.stream.Advance(streamPos)
|
||||
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ import (
|
|||
type OutputSendToDeviceEventConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
serverName gomatrixserverlib.ServerName // our server name
|
||||
|
|
@ -68,52 +68,52 @@ func NewOutputSendToDeviceEventConsumer(
|
|||
|
||||
// Start consuming from EDU api
|
||||
func (s *OutputSendToDeviceEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable)
|
||||
return err
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
var output api.OutputSendToDeviceEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("EDU server output log: message parse failure")
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
}
|
||||
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', output.UserID)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
}
|
||||
if domain != s.serverName {
|
||||
return true
|
||||
}
|
||||
|
||||
util.GetLogger(context.TODO()).WithFields(log.Fields{
|
||||
"sender": output.Sender,
|
||||
"user_id": output.UserID,
|
||||
"device_id": output.DeviceID,
|
||||
"event_type": output.Type,
|
||||
}).Info("sync API received send-to-device event from EDU server")
|
||||
|
||||
streamPos, err := s.db.StoreNewSendForDeviceMessage(
|
||||
s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent,
|
||||
)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
log.WithError(err).Errorf("failed to store send-to-device message")
|
||||
return false
|
||||
}
|
||||
|
||||
s.stream.Advance(streamPos)
|
||||
s.notifier.OnNewSendToDevice(
|
||||
output.UserID,
|
||||
[]string{output.DeviceID},
|
||||
types.StreamingToken{SendToDevicePosition: streamPos},
|
||||
)
|
||||
|
||||
func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var output api.OutputSendToDeviceEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("EDU server output log: message parse failure")
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', output.UserID)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
}
|
||||
if domain != s.serverName {
|
||||
return true
|
||||
}
|
||||
|
||||
util.GetLogger(context.TODO()).WithFields(log.Fields{
|
||||
"sender": output.Sender,
|
||||
"user_id": output.UserID,
|
||||
"device_id": output.DeviceID,
|
||||
"event_type": output.Type,
|
||||
}).Info("sync API received send-to-device event from EDU server")
|
||||
|
||||
streamPos, err := s.db.StoreNewSendForDeviceMessage(
|
||||
s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent,
|
||||
)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
log.WithError(err).Errorf("failed to store send-to-device message")
|
||||
return false
|
||||
}
|
||||
|
||||
s.stream.Advance(streamPos)
|
||||
s.notifier.OnNewSendToDevice(
|
||||
output.UserID,
|
||||
[]string{output.DeviceID},
|
||||
types.StreamingToken{SendToDevicePosition: streamPos},
|
||||
)
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ import (
|
|||
type OutputTypingEventConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
eduCache *cache.EDUCache
|
||||
stream types.StreamProvider
|
||||
|
|
@ -66,41 +66,41 @@ func NewOutputTypingEventConsumer(
|
|||
|
||||
// Start consuming from EDU api
|
||||
func (s *OutputTypingEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable)
|
||||
return err
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
var output api.OutputTypingEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("EDU server output log: message parse failure")
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"room_id": output.Event.RoomID,
|
||||
"user_id": output.Event.UserID,
|
||||
"typing": output.Event.Typing,
|
||||
}).Debug("received data from EDU server")
|
||||
|
||||
var typingPos types.StreamPosition
|
||||
typingEvent := output.Event
|
||||
if typingEvent.Typing {
|
||||
typingPos = types.StreamPosition(
|
||||
s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime),
|
||||
)
|
||||
} else {
|
||||
typingPos = types.StreamPosition(
|
||||
s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID),
|
||||
)
|
||||
}
|
||||
|
||||
s.stream.Advance(typingPos)
|
||||
s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos})
|
||||
|
||||
func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var output api.OutputTypingEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("EDU server output log: message parse failure")
|
||||
sentry.CaptureException(err)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"room_id": output.Event.RoomID,
|
||||
"user_id": output.Event.UserID,
|
||||
"typing": output.Event.Typing,
|
||||
}).Debug("received data from EDU server")
|
||||
|
||||
var typingPos types.StreamPosition
|
||||
typingEvent := output.Event
|
||||
if typingEvent.Typing {
|
||||
typingPos = types.StreamPosition(
|
||||
s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime),
|
||||
)
|
||||
} else {
|
||||
typingPos = types.StreamPosition(
|
||||
s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID),
|
||||
)
|
||||
}
|
||||
|
||||
s.stream.Advance(typingPos)
|
||||
s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos})
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,84 +18,81 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OutputKeyChangeEventConsumer consumes events that originated in the key server.
|
||||
type OutputKeyChangeEventConsumer struct {
|
||||
ctx context.Context
|
||||
keyChangeConsumer *internal.ContinualConsumer
|
||||
db storage.Database
|
||||
notifier *notifier.Notifier
|
||||
stream types.StreamProvider
|
||||
serverName gomatrixserverlib.ServerName // our server name
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||
keyAPI api.KeyInternalAPI
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
notifier *notifier.Notifier
|
||||
stream types.StreamProvider
|
||||
serverName gomatrixserverlib.ServerName // our server name
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||
keyAPI api.KeyInternalAPI
|
||||
}
|
||||
|
||||
// NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer.
|
||||
// Call Start() to begin consuming from the key server.
|
||||
func NewOutputKeyChangeEventConsumer(
|
||||
process *process.ProcessContext,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
cfg *config.SyncAPI,
|
||||
topic string,
|
||||
kafkaConsumer sarama.Consumer,
|
||||
js nats.JetStreamContext,
|
||||
keyAPI api.KeyInternalAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
store storage.Database,
|
||||
notifier *notifier.Notifier,
|
||||
stream types.StreamProvider,
|
||||
) *OutputKeyChangeEventConsumer {
|
||||
|
||||
consumer := internal.ContinualConsumer{
|
||||
Process: process,
|
||||
ComponentName: "syncapi/keychange",
|
||||
Topic: topic,
|
||||
Consumer: kafkaConsumer,
|
||||
PartitionStore: store,
|
||||
}
|
||||
|
||||
s := &OutputKeyChangeEventConsumer{
|
||||
ctx: process.Context(),
|
||||
keyChangeConsumer: &consumer,
|
||||
db: store,
|
||||
serverName: serverName,
|
||||
keyAPI: keyAPI,
|
||||
rsAPI: rsAPI,
|
||||
notifier: notifier,
|
||||
stream: stream,
|
||||
ctx: process.Context(),
|
||||
jetstream: js,
|
||||
durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"),
|
||||
topic: topic,
|
||||
db: store,
|
||||
serverName: cfg.Matrix.ServerName,
|
||||
keyAPI: keyAPI,
|
||||
rsAPI: rsAPI,
|
||||
notifier: notifier,
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
consumer.ProcessMessage = s.onMessage
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start consuming from the key server
|
||||
func (s *OutputKeyChangeEventConsumer) Start() error {
|
||||
return s.keyChangeConsumer.Start()
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
||||
func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var m api.DeviceMessage
|
||||
if err := json.Unmarshal(msg.Value, &m); err != nil {
|
||||
if err := json.Unmarshal(msg.Data, &m); err != nil {
|
||||
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil {
|
||||
// This probably shouldn't happen but stops us from panicking if we come
|
||||
// across an update that doesn't satisfy either types.
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
switch m.Type {
|
||||
case api.TypeCrossSigningUpdate:
|
||||
|
|
@ -107,9 +104,9 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
|
|||
}
|
||||
}
|
||||
|
||||
func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) error {
|
||||
func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) bool {
|
||||
if m.DeviceKeys == nil {
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
output := m.DeviceKeys
|
||||
// work out who we need to notify about the new key
|
||||
|
|
@ -120,7 +117,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d
|
|||
if err != nil {
|
||||
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
|
||||
sentry.CaptureException(err)
|
||||
return err
|
||||
return true
|
||||
}
|
||||
// make sure we get our own key updates too!
|
||||
queryRes.UserIDsToCount[output.UserID] = 1
|
||||
|
|
@ -131,10 +128,10 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d
|
|||
s.notifier.OnNewKeyChange(types.StreamingToken{DeviceListPosition: posUpdate}, userID, output.UserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) error {
|
||||
func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) bool {
|
||||
output := m.CrossSigningKeyUpdate
|
||||
// work out who we need to notify about the new key
|
||||
var queryRes roomserverAPI.QuerySharedUsersResponse
|
||||
|
|
@ -144,7 +141,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
|
|||
if err != nil {
|
||||
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
|
||||
sentry.CaptureException(err)
|
||||
return err
|
||||
return true
|
||||
}
|
||||
// make sure we get our own key updates too!
|
||||
queryRes.UserIDsToCount[output.UserID] = 1
|
||||
|
|
@ -155,5 +152,5 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
|
|||
s.notifier.OnNewKeyChange(types.StreamingToken{DeviceListPosition: posUpdate}, userID, output.UserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct {
|
|||
cfg *config.SyncAPI
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
pduStream types.StreamProvider
|
||||
|
|
@ -73,65 +73,61 @@ func NewOutputRoomEventConsumer(
|
|||
|
||||
// Start consuming from room servers
|
||||
func (s *OutputRoomEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(
|
||||
s.topic, s.onMessage, s.durable,
|
||||
nats.DeliverAll(),
|
||||
nats.ManualAck(),
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// onMessage is called when the sync server receives a new event from the room server output log.
|
||||
// It is not safe for this function to be called from multiple goroutines, or else the
|
||||
// sync stream position may race and be incorrectly calculated.
|
||||
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
var err error
|
||||
var output api.OutputEvent
|
||||
if err = json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("roomserver output log: message parse failure")
|
||||
return true
|
||||
}
|
||||
|
||||
switch output.Type {
|
||||
case api.OutputTypeNewRoomEvent:
|
||||
// Ignore redaction events. We will add them to the database when they are
|
||||
// validated (when we receive OutputTypeRedactedEvent)
|
||||
event := output.NewRoomEvent.Event
|
||||
if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil {
|
||||
// in the special case where the event redacts itself, just pass the message through because
|
||||
// we will never see the other part of the pair
|
||||
if event.Redacts() != event.EventID() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
err = s.onNewRoomEvent(s.ctx, *output.NewRoomEvent)
|
||||
case api.OutputTypeOldRoomEvent:
|
||||
err = s.onOldRoomEvent(s.ctx, *output.OldRoomEvent)
|
||||
case api.OutputTypeNewInviteEvent:
|
||||
s.onNewInviteEvent(s.ctx, *output.NewInviteEvent)
|
||||
case api.OutputTypeRetireInviteEvent:
|
||||
s.onRetireInviteEvent(s.ctx, *output.RetireInviteEvent)
|
||||
case api.OutputTypeNewPeek:
|
||||
s.onNewPeek(s.ctx, *output.NewPeek)
|
||||
case api.OutputTypeRetirePeek:
|
||||
s.onRetirePeek(s.ctx, *output.RetirePeek)
|
||||
case api.OutputTypeRedactedEvent:
|
||||
err = s.onRedactEvent(s.ctx, *output.RedactedEvent)
|
||||
default:
|
||||
log.WithField("type", output.Type).Debug(
|
||||
"roomserver output log: ignoring unknown output type",
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
log.WithError(err).Error("roomserver output log: failed to process event")
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
var err error
|
||||
var output api.OutputEvent
|
||||
if err = json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("roomserver output log: message parse failure")
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
switch output.Type {
|
||||
case api.OutputTypeNewRoomEvent:
|
||||
// Ignore redaction events. We will add them to the database when they are
|
||||
// validated (when we receive OutputTypeRedactedEvent)
|
||||
event := output.NewRoomEvent.Event
|
||||
if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil {
|
||||
// in the special case where the event redacts itself, just pass the message through because
|
||||
// we will never see the other part of the pair
|
||||
if event.Redacts() != event.EventID() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
err = s.onNewRoomEvent(s.ctx, *output.NewRoomEvent)
|
||||
case api.OutputTypeOldRoomEvent:
|
||||
err = s.onOldRoomEvent(s.ctx, *output.OldRoomEvent)
|
||||
case api.OutputTypeNewInviteEvent:
|
||||
s.onNewInviteEvent(s.ctx, *output.NewInviteEvent)
|
||||
case api.OutputTypeRetireInviteEvent:
|
||||
s.onRetireInviteEvent(s.ctx, *output.RetireInviteEvent)
|
||||
case api.OutputTypeNewPeek:
|
||||
s.onNewPeek(s.ctx, *output.NewPeek)
|
||||
case api.OutputTypeRetirePeek:
|
||||
s.onRetirePeek(s.ctx, *output.RetirePeek)
|
||||
case api.OutputTypeRedactedEvent:
|
||||
err = s.onRedactEvent(s.ctx, *output.RedactedEvent)
|
||||
default:
|
||||
log.WithField("type", output.Type).Debug(
|
||||
"roomserver output log: ignoring unknown output type",
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
log.WithError(err).Error("roomserver output log: failed to process event")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) onRedactEvent(
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ import (
|
|||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
keytypes "github.com/matrix-org/dendrite/keyserver/types"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -64,8 +64,8 @@ func DeviceListCatchup(
|
|||
}
|
||||
|
||||
// now also track users who we already share rooms with but who have updated their devices between the two tokens
|
||||
offset := sarama.OffsetOldest
|
||||
toOffset := sarama.OffsetNewest
|
||||
offset := keytypes.OffsetOldest
|
||||
toOffset := keytypes.OffsetNewest
|
||||
if to > 0 && to > from {
|
||||
toOffset = int64(to)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ import (
|
|||
|
||||
eduAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
|
|
@ -27,8 +26,6 @@ import (
|
|||
)
|
||||
|
||||
type Database interface {
|
||||
internal.PartitionStorer
|
||||
|
||||
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
|
||||
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
|
||||
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ func AddPublicRoutes(
|
|||
federation *gomatrixserverlib.FederationClient,
|
||||
cfg *config.SyncAPI,
|
||||
) {
|
||||
js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
js := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
|
||||
syncDB, err := storage.NewSyncServerDatasource(&cfg.Database)
|
||||
if err != nil {
|
||||
|
|
@ -65,8 +65,8 @@ func AddPublicRoutes(
|
|||
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier)
|
||||
|
||||
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
|
||||
process, cfg.Matrix.ServerName, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
|
||||
consumer, keyAPI, rsAPI, syncDB, notifier,
|
||||
process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
|
||||
js, keyAPI, rsAPI, syncDB, notifier,
|
||||
streams.DeviceListStreamProvider,
|
||||
)
|
||||
if err = keyChangeConsumer.Start(); err != nil {
|
||||
|
|
|
|||
Loading…
Reference in a new issue