Merge branch 'main' into s7evink/syncresponse

This commit is contained in:
Neil Alexander 2022-10-05 13:44:10 +01:00 committed by GitHub
commit 14506388e3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 123 additions and 48 deletions

View file

@ -101,6 +101,11 @@ func (s *OutputRoomEventConsumer) onMessage(
log.WithField("appservice", state.ID).Tracef("Appservice worker received %d message(s) from roomserver", len(msgs)) log.WithField("appservice", state.ID).Tracef("Appservice worker received %d message(s) from roomserver", len(msgs))
events := make([]*gomatrixserverlib.HeaderedEvent, 0, len(msgs)) events := make([]*gomatrixserverlib.HeaderedEvent, 0, len(msgs))
for _, msg := range msgs { for _, msg := range msgs {
// Only handle events we care about
receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType))
if receivedType != api.OutputTypeNewRoomEvent && receivedType != api.OutputTypeNewInviteEvent {
continue
}
// Parse out the event JSON // Parse out the event JSON
var output api.OutputEvent var output api.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {

View file

@ -79,6 +79,13 @@ func (s *OutputRoomEventConsumer) Start() error {
// realises that it cannot update the room state using the deltas. // realises that it cannot update the room state using the deltas.
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called msg := msgs[0] // Guaranteed to exist if onMessage is called
receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType))
// Only handle events we care about
if receivedType != api.OutputTypeNewRoomEvent && receivedType != api.OutputTypeNewInboundPeek {
return true
}
// Parse out the event JSON // Parse out the event JSON
var output api.OutputEvent var output api.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {

View file

@ -21,6 +21,29 @@ import (
"strings" "strings"
"github.com/blevesearch/bleve/v2" "github.com/blevesearch/bleve/v2"
// side effect imports to allow all possible languages
_ "github.com/blevesearch/bleve/v2/analysis/lang/ar"
_ "github.com/blevesearch/bleve/v2/analysis/lang/cjk"
_ "github.com/blevesearch/bleve/v2/analysis/lang/ckb"
_ "github.com/blevesearch/bleve/v2/analysis/lang/da"
_ "github.com/blevesearch/bleve/v2/analysis/lang/de"
_ "github.com/blevesearch/bleve/v2/analysis/lang/en"
_ "github.com/blevesearch/bleve/v2/analysis/lang/es"
_ "github.com/blevesearch/bleve/v2/analysis/lang/fa"
_ "github.com/blevesearch/bleve/v2/analysis/lang/fi"
_ "github.com/blevesearch/bleve/v2/analysis/lang/fr"
_ "github.com/blevesearch/bleve/v2/analysis/lang/hi"
_ "github.com/blevesearch/bleve/v2/analysis/lang/hr"
_ "github.com/blevesearch/bleve/v2/analysis/lang/hu"
_ "github.com/blevesearch/bleve/v2/analysis/lang/it"
_ "github.com/blevesearch/bleve/v2/analysis/lang/nl"
_ "github.com/blevesearch/bleve/v2/analysis/lang/no"
_ "github.com/blevesearch/bleve/v2/analysis/lang/pt"
_ "github.com/blevesearch/bleve/v2/analysis/lang/ro"
_ "github.com/blevesearch/bleve/v2/analysis/lang/ru"
_ "github.com/blevesearch/bleve/v2/analysis/lang/sv"
_ "github.com/blevesearch/bleve/v2/analysis/lang/tr"
"github.com/blevesearch/bleve/v2/mapping" "github.com/blevesearch/bleve/v2/mapping"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"

View file

@ -2,6 +2,7 @@ package sqlutil
import ( import (
"database/sql" "database/sql"
"flag"
"fmt" "fmt"
"regexp" "regexp"
@ -9,6 +10,8 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
var skipSanityChecks = flag.Bool("skip-db-sanity", false, "Ignore sanity checks on the database connections (NOT RECOMMENDED!)")
// Open opens a database specified by its database driver name and a driver-specific data source name, // Open opens a database specified by its database driver name and a driver-specific data source name,
// usually consisting of at least a database name and connection information. Includes tracing driver // usually consisting of at least a database name and connection information. Includes tracing driver
// if DENDRITE_TRACE_SQL=1 // if DENDRITE_TRACE_SQL=1
@ -37,15 +40,39 @@ func Open(dbProperties *config.DatabaseOptions, writer Writer) (*sql.DB, error)
return nil, err return nil, err
} }
if driverName != "sqlite3" { if driverName != "sqlite3" {
logrus.WithFields(logrus.Fields{ logger := logrus.WithFields(logrus.Fields{
"MaxOpenConns": dbProperties.MaxOpenConns(), "max_open_conns": dbProperties.MaxOpenConns(),
"MaxIdleConns": dbProperties.MaxIdleConns(), "max_idle_conns": dbProperties.MaxIdleConns(),
"ConnMaxLifetime": dbProperties.ConnMaxLifetime(), "conn_max_lifetime": dbProperties.ConnMaxLifetime(),
"dataSourceName": regexp.MustCompile(`://[^@]*@`).ReplaceAllLiteralString(dsn, "://"), "data_source_name": regexp.MustCompile(`://[^@]*@`).ReplaceAllLiteralString(dsn, "://"),
}).Debug("Setting DB connection limits") })
logger.Debug("Setting DB connection limits")
db.SetMaxOpenConns(dbProperties.MaxOpenConns()) db.SetMaxOpenConns(dbProperties.MaxOpenConns())
db.SetMaxIdleConns(dbProperties.MaxIdleConns()) db.SetMaxIdleConns(dbProperties.MaxIdleConns())
db.SetConnMaxLifetime(dbProperties.ConnMaxLifetime()) db.SetConnMaxLifetime(dbProperties.ConnMaxLifetime())
if !*skipSanityChecks {
if dbProperties.MaxOpenConns() == 0 {
logrus.Warnf("WARNING: Configuring 'max_open_conns' to be unlimited is not recommended. This can result in bad performance or deadlocks.")
}
switch driverName {
case "postgres":
// Perform a quick sanity check if possible that we aren't trying to use more database
// connections than PostgreSQL is willing to give us.
var max, reserved int
if err := db.QueryRow("SELECT setting::integer FROM pg_settings WHERE name='max_connections';").Scan(&max); err != nil {
return nil, fmt.Errorf("failed to find maximum connections: %w", err)
}
if err := db.QueryRow("SELECT setting::integer FROM pg_settings WHERE name='superuser_reserved_connections';").Scan(&reserved); err != nil {
return nil, fmt.Errorf("failed to find reserved connections: %w", err)
}
if configured, allowed := dbProperties.MaxOpenConns(), max-reserved; configured > allowed {
logrus.Errorf("ERROR: The configured 'max_open_conns' is greater than the %d non-superuser connections that PostgreSQL is configured to allow. This can result in bad performance or deadlocks. Please pay close attention to your configured database connection counts. If you REALLY know what you are doing and want to override this error, pass the --skip-db-sanity option to Dendrite.", allowed)
return nil, fmt.Errorf("database sanity checks failed")
}
}
}
} }
return db, nil return db, nil
} }

View file

@ -424,7 +424,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
"succeeded": successCount, "succeeded": successCount,
"failed": len(userIDs) - successCount, "failed": len(userIDs) - successCount,
"wait_time": waitTime, "wait_time": waitTime,
}).Warn("Failed to query device keys for some users") }).Debug("Failed to query device keys for some users")
} }
return waitTime, !allUsersSucceeded return waitTime, !allUsersSucceeded
} }

View file

@ -278,6 +278,7 @@ type QuerySharedUsersRequest struct {
OtherUserIDs []string OtherUserIDs []string
ExcludeRoomIDs []string ExcludeRoomIDs []string
IncludeRoomIDs []string IncludeRoomIDs []string
LocalOnly bool
} }
type QuerySharedUsersResponse struct { type QuerySharedUsersResponse struct {

View file

@ -799,7 +799,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
} }
roomIDs = roomIDs[:j] roomIDs = roomIDs[:j]
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs) users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs, req.LocalOnly)
if err != nil { if err != nil {
return err return err
} }

View file

@ -17,12 +17,13 @@ package producers
import ( import (
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/roomserver/acls"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/roomserver/acls"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/jetstream"
) )
var keyContentFields = map[string]string{ var keyContentFields = map[string]string{
@ -40,10 +41,8 @@ type RoomEventProducer struct {
func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.OutputEvent) error { func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.OutputEvent) error {
var err error var err error
for _, update := range updates { for _, update := range updates {
msg := &nats.Msg{ msg := nats.NewMsg(r.Topic)
Subject: r.Topic, msg.Header.Set(jetstream.RoomEventType, string(update.Type))
Header: nats.Header{},
}
msg.Header.Set(jetstream.RoomID, roomID) msg.Header.Set(jetstream.RoomID, roomID)
msg.Data, err = json.Marshal(update) msg.Data, err = json.Marshal(update)
if err != nil { if err != nil {

View file

@ -157,7 +157,7 @@ type Database interface {
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms. // JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string, localOnly bool) (map[string]int, error)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
// GetServerInRoom returns true if we think a server is in a given room or false otherwise. // GetServerInRoom returns true if we think a server is in a given room or false otherwise.

View file

@ -68,14 +68,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
var selectJoinedUsersSetForRoomsAndUserSQL = "" + var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" + " WHERE (target_local OR $1 = false)" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " AND room_nid = ANY($2) AND target_nid = ANY($3)" +
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
" AND forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND" + " WHERE (target_local OR $1 = false) " +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " AND room_nid = ANY($2)" +
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
" AND forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
@ -334,6 +338,7 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID, roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID, userNIDs []types.EventStateKeyNID,
localOnly bool,
) (map[types.EventStateKeyNID]int, error) { ) (map[types.EventStateKeyNID]int, error) {
var ( var (
rows *sql.Rows rows *sql.Rows
@ -342,9 +347,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
if len(userNIDs) > 0 { if len(userNIDs) > 0 {
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt) stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) rows, err = stmt.QueryContext(ctx, localOnly, pq.Array(roomNIDs), pq.Array(userNIDs))
} else { } else {
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs)) rows, err = stmt.QueryContext(ctx, localOnly, pq.Array(roomNIDs))
} }
if err != nil { if err != nil {

View file

@ -1280,7 +1280,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
} }
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms. // JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) { func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string, localOnly bool) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1295,7 +1295,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
userNIDs = append(userNIDs, nid) userNIDs = append(userNIDs, nid)
nidToUserID[nid] = id nidToUserID[nid] = id
} }
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs) userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs, localOnly)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -44,14 +44,18 @@ const membershipSchema = `
var selectJoinedUsersSetForRoomsAndUserSQL = "" + var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" + " WHERE (target_local OR $1 = false)" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " AND room_nid IN ($2) AND target_nid IN ($3)" +
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
" AND forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND " + " WHERE (target_local OR $1 = false)" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " AND room_nid IN ($2)" +
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
" AND forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
@ -305,8 +309,9 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil return roomNIDs, nil
} }
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) { func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error) {
params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs)) params := make([]interface{}, 0, 1+len(roomNIDs)+len(userNIDs))
params = append(params, localOnly)
for _, v := range roomNIDs { for _, v := range roomNIDs {
params = append(params, v) params = append(params, v)
} }
@ -314,10 +319,10 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
params = append(params, v) params = append(params, v)
} }
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($2)", sqlutil.QueryVariadicOffset(len(roomNIDs), 1), 1)
if len(userNIDs) > 0 { if len(userNIDs) > 0 {
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(roomNIDs), 1), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) query = strings.Replace(query, "($3)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)+1), 1)
} }
var rows *sql.Rows var rows *sql.Rows
var err error var err error

View file

@ -137,7 +137,7 @@ type Membership interface {
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)

View file

@ -79,7 +79,7 @@ func TestMembershipTable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, inRoom) assert.True(t, inRoom)
userJoinedToRooms, err := tab.SelectJoinedUsersSetForRooms(ctx, nil, []types.RoomNID{1}, userNIDs) userJoinedToRooms, err := tab.SelectJoinedUsersSetForRooms(ctx, nil, []types.RoomNID{1}, userNIDs, false)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(userJoinedToRooms)) assert.Equal(t, 1, len(userJoinedToRooms))

View file

@ -9,9 +9,10 @@ import (
) )
const ( const (
UserID = "user_id" UserID = "user_id"
RoomID = "room_id" RoomID = "room_id"
EventID = "event_id" EventID = "event_id"
RoomEventType = "output_room_event_type"
) )
var ( var (

View file

@ -111,7 +111,8 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d
// work out who we need to notify about the new key // work out who we need to notify about the new key
var queryRes roomserverAPI.QuerySharedUsersResponse var queryRes roomserverAPI.QuerySharedUsersResponse
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{ err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
UserID: output.UserID, UserID: output.UserID,
LocalOnly: true,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
@ -135,7 +136,8 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
// work out who we need to notify about the new key // work out who we need to notify about the new key
var queryRes roomserverAPI.QuerySharedUsersResponse var queryRes roomserverAPI.QuerySharedUsersResponse
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{ err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
UserID: output.UserID, UserID: output.UserID,
LocalOnly: true,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")

View file

@ -4,10 +4,11 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/nats-io/nats.go"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/nats-io/nats.go"
) )
func MustPublishMsgs(t *testing.T, jsctx nats.JetStreamContext, msgs ...*nats.Msg) { func MustPublishMsgs(t *testing.T, jsctx nats.JetStreamContext, msgs ...*nats.Msg) {
@ -21,10 +22,8 @@ func MustPublishMsgs(t *testing.T, jsctx nats.JetStreamContext, msgs ...*nats.Ms
func NewOutputEventMsg(t *testing.T, base *base.BaseDendrite, roomID string, update api.OutputEvent) *nats.Msg { func NewOutputEventMsg(t *testing.T, base *base.BaseDendrite, roomID string, update api.OutputEvent) *nats.Msg {
t.Helper() t.Helper()
msg := &nats.Msg{ msg := nats.NewMsg(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent))
Subject: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent), msg.Header.Set(jetstream.RoomEventType, string(update.Type))
Header: nats.Header{},
}
msg.Header.Set(jetstream.RoomID, roomID) msg.Header.Set(jetstream.RoomID, roomID)
var err error var err error
msg.Data, err = json.Marshal(update) msg.Data, err = json.Marshal(update)

View file

@ -72,15 +72,16 @@ func (s *OutputRoomEventConsumer) Start() error {
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called msg := msgs[0] // Guaranteed to exist if onMessage is called
// Only handle events we care about
if rsapi.OutputType(msg.Header.Get(jetstream.RoomEventType)) != rsapi.OutputTypeNewRoomEvent {
return true
}
var output rsapi.OutputEvent var output rsapi.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream // If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure") log.WithError(err).Errorf("roomserver output log: message parse failure")
return true return true
} }
if output.Type != rsapi.OutputTypeNewRoomEvent {
return true
}
event := output.NewRoomEvent.Event event := output.NewRoomEvent.Event
if event == nil { if event == nil {
log.Errorf("userapi consumer: expected event") log.Errorf("userapi consumer: expected event")