mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-07 06:03:09 -06:00
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/fts
This commit is contained in:
commit
49c5ad62e9
|
|
@ -17,6 +17,7 @@ package producers
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
|
|
@ -83,7 +84,7 @@ func (p *SyncAPIProducer) SendReceipt(
|
|||
m.Header.Set(jetstream.RoomID, roomID)
|
||||
m.Header.Set(jetstream.EventID, eventID)
|
||||
m.Header.Set("type", receiptType)
|
||||
m.Header.Set("timestamp", strconv.Itoa(int(timestamp)))
|
||||
m.Header.Set("timestamp", fmt.Sprintf("%d", timestamp))
|
||||
|
||||
log.WithFields(log.Fields{}).Tracef("Producing to topic '%s'", p.TopicReceiptEvent)
|
||||
_, err := p.JetStream.PublishMsg(m, nats.Context(ctx))
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msg *nats.Msg) bo
|
|||
return true
|
||||
}
|
||||
|
||||
timestamp, err := strconv.Atoi(msg.Header.Get("timestamp"))
|
||||
timestamp, err := strconv.ParseUint(msg.Header.Get("timestamp"), 10, 64)
|
||||
if err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("EDU output log: message parse failure")
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ func (p *SyncAPIProducer) SendReceipt(
|
|||
m.Header.Set(jetstream.RoomID, roomID)
|
||||
m.Header.Set(jetstream.EventID, eventID)
|
||||
m.Header.Set("type", receiptType)
|
||||
m.Header.Set("timestamp", strconv.Itoa(int(timestamp)))
|
||||
m.Header.Set("timestamp", fmt.Sprintf("%d", timestamp))
|
||||
|
||||
log.WithFields(log.Fields{}).Tracef("Producing to topic '%s'", p.TopicReceiptEvent)
|
||||
_, err := p.JetStream.PublishMsg(m, nats.Context(ctx))
|
||||
|
|
|
|||
|
|
@ -1,36 +1,26 @@
|
|||
package storage
|
||||
package storage_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||
"github.com/matrix-org/dendrite/keyserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
func MustCreateDatabase(t *testing.T) (Database, func()) {
|
||||
tmpfile, err := ioutil.TempFile("", "keyserver_storage_test")
|
||||
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
t.Logf("Database %s", tmpfile.Name())
|
||||
db, err := NewDatabase(nil, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(fmt.Sprintf("file://%s", tmpfile.Name())),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to NewDatabase: %s", err)
|
||||
}
|
||||
return db, func() {
|
||||
os.Remove(tmpfile.Name())
|
||||
t.Fatalf("failed to create new database: %v", err)
|
||||
}
|
||||
return db, close
|
||||
}
|
||||
|
||||
func MustNotError(t *testing.T, err error) {
|
||||
|
|
@ -42,151 +32,159 @@ func MustNotError(t *testing.T, err error) {
|
|||
}
|
||||
|
||||
func TestKeyChanges(t *testing.T) {
|
||||
db, clean := MustCreateDatabase(t)
|
||||
defer clean()
|
||||
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeIDC {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := MustCreateDatabase(t, dbType)
|
||||
defer clean()
|
||||
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeIDC {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyChangesNoDupes(t *testing.T) {
|
||||
db, clean := MustCreateDatabase(t)
|
||||
defer clean()
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
if deviceChangeIDA == deviceChangeIDB {
|
||||
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
|
||||
}
|
||||
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeID {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := MustCreateDatabase(t, dbType)
|
||||
defer clean()
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
if deviceChangeIDA == deviceChangeIDB {
|
||||
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
|
||||
}
|
||||
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeID {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyChangesUpperLimit(t *testing.T) {
|
||||
db, clean := MustCreateDatabase(t)
|
||||
defer clean()
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeIDB {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := MustCreateDatabase(t, dbType)
|
||||
defer clean()
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeIDB {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
||||
// and that they are returned correctly when querying for device keys.
|
||||
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||
var err error
|
||||
db, clean := MustCreateDatabase(t)
|
||||
defer clean()
|
||||
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
||||
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||
msgs := []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := MustCreateDatabase(t, dbType)
|
||||
defer clean()
|
||||
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
||||
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||
msgs := []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 1
|
||||
},
|
||||
// StreamID: 1
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: bob,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: bob,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 1 as this is a different user
|
||||
},
|
||||
// StreamID: 1 as this is a different user
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "another_device",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "another_device",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 2 as this is a 2nd device key
|
||||
},
|
||||
// StreamID: 2 as this is a 2nd device key
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||
}
|
||||
if msgs[1].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||
}
|
||||
if msgs[2].StreamID != 2 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||
}
|
||||
|
||||
// updating a device sets the next stream ID for that user
|
||||
msgs = []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v2"}`),
|
||||
},
|
||||
// StreamID: 3
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 3 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||
}
|
||||
|
||||
// Querying for device keys returns the latest stream IDs
|
||||
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||
}
|
||||
wantStreamIDs := map[string]int64{
|
||||
"AAA": 3,
|
||||
"another_device": 2,
|
||||
}
|
||||
if len(msgs) != len(wantStreamIDs) {
|
||||
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
|
||||
}
|
||||
for _, m := range msgs {
|
||||
if m.StreamID != wantStreamIDs[m.DeviceID] {
|
||||
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
|
||||
}
|
||||
}
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||
}
|
||||
if msgs[1].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||
}
|
||||
if msgs[2].StreamID != 2 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||
}
|
||||
|
||||
// updating a device sets the next stream ID for that user
|
||||
msgs = []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v2"}`),
|
||||
},
|
||||
// StreamID: 3
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 3 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||
}
|
||||
|
||||
// Querying for device keys returns the latest stream IDs
|
||||
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||
}
|
||||
wantStreamIDs := map[string]int64{
|
||||
"AAA": 3,
|
||||
"another_device": 2,
|
||||
}
|
||||
if len(msgs) != len(wantStreamIDs) {
|
||||
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
|
||||
}
|
||||
for _, m := range msgs {
|
||||
if m.StreamID != wantStreamIDs[m.DeviceID] {
|
||||
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/internal/query"
|
||||
"github.com/matrix-org/dendrite/roomserver/producers"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
|
|
@ -39,6 +40,7 @@ type RoomserverInternalAPI struct {
|
|||
*perform.Upgrader
|
||||
*perform.Admin
|
||||
ProcessContext *process.ProcessContext
|
||||
Base *base.BaseDendrite
|
||||
DB storage.Database
|
||||
Cfg *config.RoomServer
|
||||
Cache caching.RoomServerCaches
|
||||
|
|
@ -56,33 +58,38 @@ type RoomserverInternalAPI struct {
|
|||
}
|
||||
|
||||
func NewRoomserverAPI(
|
||||
processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database,
|
||||
js nats.JetStreamContext, nc *nats.Conn, inputRoomEventTopic string,
|
||||
caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName,
|
||||
base *base.BaseDendrite, roomserverDB storage.Database,
|
||||
js nats.JetStreamContext, nc *nats.Conn,
|
||||
) *RoomserverInternalAPI {
|
||||
var perspectiveServerNames []gomatrixserverlib.ServerName
|
||||
for _, kp := range base.Cfg.FederationAPI.KeyPerspectives {
|
||||
perspectiveServerNames = append(perspectiveServerNames, kp.ServerName)
|
||||
}
|
||||
|
||||
serverACLs := acls.NewServerACLs(roomserverDB)
|
||||
producer := &producers.RoomEventProducer{
|
||||
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent)),
|
||||
Topic: string(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)),
|
||||
JetStream: js,
|
||||
ACLs: serverACLs,
|
||||
}
|
||||
a := &RoomserverInternalAPI{
|
||||
ProcessContext: processCtx,
|
||||
ProcessContext: base.ProcessContext,
|
||||
DB: roomserverDB,
|
||||
Cfg: cfg,
|
||||
Cache: caches,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
Base: base,
|
||||
Cfg: &base.Cfg.RoomServer,
|
||||
Cache: base.Caches,
|
||||
ServerName: base.Cfg.Global.ServerName,
|
||||
PerspectiveServerNames: perspectiveServerNames,
|
||||
InputRoomEventTopic: inputRoomEventTopic,
|
||||
InputRoomEventTopic: base.Cfg.Global.JetStream.Prefixed(jetstream.InputRoomEvent),
|
||||
OutputProducer: producer,
|
||||
JetStream: js,
|
||||
NATSClient: nc,
|
||||
Durable: cfg.Matrix.JetStream.Durable("RoomserverInputConsumer"),
|
||||
Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"),
|
||||
ServerACLs: serverACLs,
|
||||
Queryer: &query.Queryer{
|
||||
DB: roomserverDB,
|
||||
Cache: caches,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
Cache: base.Caches,
|
||||
ServerName: base.Cfg.Global.ServerName,
|
||||
ServerACLs: serverACLs,
|
||||
},
|
||||
// perform-er structs get initialised when we have a federation sender to use
|
||||
|
|
@ -98,8 +105,9 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
|||
r.KeyRing = keyRing
|
||||
|
||||
r.Inputer = &input.Inputer{
|
||||
Cfg: r.Cfg,
|
||||
ProcessContext: r.ProcessContext,
|
||||
Cfg: &r.Base.Cfg.RoomServer,
|
||||
Base: r.Base,
|
||||
ProcessContext: r.Base.ProcessContext,
|
||||
DB: r.DB,
|
||||
InputRoomEventTopic: r.InputRoomEventTopic,
|
||||
OutputProducer: r.OutputProducer,
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/internal/query"
|
||||
"github.com/matrix-org/dendrite/roomserver/producers"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
|
|
@ -69,6 +70,7 @@ import (
|
|||
// or C.
|
||||
type Inputer struct {
|
||||
Cfg *config.RoomServer
|
||||
Base *base.BaseDendrite
|
||||
ProcessContext *process.ProcessContext
|
||||
DB storage.Database
|
||||
NATSClient *nats.Conn
|
||||
|
|
@ -160,7 +162,9 @@ func (r *Inputer) startWorkerForRoom(roomID string) {
|
|||
// will look to see if we have a worker for that room which has its
|
||||
// own consumer. If we don't, we'll start one.
|
||||
func (r *Inputer) Start() error {
|
||||
prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
|
||||
if r.Base.EnableMetrics {
|
||||
prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
|
||||
}
|
||||
_, err := r.JetStream.Subscribe(
|
||||
"", // This is blank because we specified it in BindStream.
|
||||
func(m *nats.Msg) {
|
||||
|
|
|
|||
|
|
@ -17,13 +17,10 @@ package roomserver
|
|||
import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/inthttp"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/internal"
|
||||
"github.com/matrix-org/dendrite/roomserver/inthttp"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -40,11 +37,6 @@ func NewInternalAPI(
|
|||
) api.RoomserverInternalAPI {
|
||||
cfg := &base.Cfg.RoomServer
|
||||
|
||||
var perspectiveServerNames []gomatrixserverlib.ServerName
|
||||
for _, kp := range base.Cfg.FederationAPI.KeyPerspectives {
|
||||
perspectiveServerNames = append(perspectiveServerNames, kp.ServerName)
|
||||
}
|
||||
|
||||
roomserverDB, err := storage.Open(base, &cfg.Database, base.Caches)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to room server db")
|
||||
|
|
@ -53,8 +45,6 @@ func NewInternalAPI(
|
|||
js, nc := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||
|
||||
return internal.NewRoomserverAPI(
|
||||
base.ProcessContext, cfg, roomserverDB, js, nc,
|
||||
cfg.Matrix.JetStream.Prefixed(jetstream.InputRoomEvent),
|
||||
base.Caches, perspectiveServerNames,
|
||||
base, roomserverDB, js, nc,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
69
roomserver/roomserver_test.go
Normal file
69
roomserver/roomserver_test.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package roomserver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Database: %v", err)
|
||||
}
|
||||
return base, db, close
|
||||
}
|
||||
|
||||
func Test_SharedUsers(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
|
||||
|
||||
// Invite and join Bob
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "invite",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
base, _, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
rsAPI := roomserver.NewInternalAPI(base)
|
||||
// SetFederationAPI starts the room event input consumer
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
// Create the room
|
||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil {
|
||||
t.Fatalf("failed to send events: %v", err)
|
||||
}
|
||||
|
||||
// Query the shared users for Alice, there should only be Bob.
|
||||
// This is used by the SyncAPI keychange consumer.
|
||||
res := &api.QuerySharedUsersResponse{}
|
||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
|
||||
t.Fatalf("unable to query known users: %v", err)
|
||||
}
|
||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||
}
|
||||
// Also verify that we get the expected result when specifying OtherUserIDs.
|
||||
// This is used by the SyncAPI when getting device list changes.
|
||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
|
||||
t.Fatalf("unable to query known users: %v", err)
|
||||
}
|
||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -110,7 +110,7 @@ func (v *StateResolution) LoadStateAtEvent(
|
|||
|
||||
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err)
|
||||
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
|
||||
}
|
||||
if snapshotNID == 0 {
|
||||
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
|
||||
|
|
|
|||
|
|
@ -65,12 +65,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
|
|||
);
|
||||
`
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid = ANY($1) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
// Insert a row in to membership table so that it can be locked by the
|
||||
// SELECT FOR UPDATE
|
||||
const insertMembershipSQL = "" +
|
||||
|
|
@ -153,6 +159,7 @@ type membershipStatements struct {
|
|||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||
updateMembershipStmt *sql.Stmt
|
||||
selectRoomsWithMembershipStmt *sql.Stmt
|
||||
selectJoinedUsersSetForRoomsAndUserStmt *sql.Stmt
|
||||
selectJoinedUsersSetForRoomsStmt *sql.Stmt
|
||||
selectKnownUsersStmt *sql.Stmt
|
||||
updateMembershipForgetRoomStmt *sql.Stmt
|
||||
|
|
@ -178,6 +185,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
|
|||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||
{&s.selectJoinedUsersSetForRoomsAndUserStmt, selectJoinedUsersSetForRoomsAndUserSQL},
|
||||
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
|
||||
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
||||
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
|
||||
|
|
@ -313,8 +321,18 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
|||
roomNIDs []types.RoomNID,
|
||||
userNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]int, error) {
|
||||
var (
|
||||
rows *sql.Rows
|
||||
err error
|
||||
)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
|
||||
if len(userNIDs) > 0 {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
|
||||
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
|
||||
} else {
|
||||
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -263,6 +263,12 @@ func (d *Database) snapshotNIDFromEventID(
|
|||
ctx context.Context, txn *sql.Tx, eventID string,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if stateNID == 0 {
|
||||
return 0, sql.ErrNoRows // effectively there's no state entry
|
||||
}
|
||||
return stateNID, err
|
||||
}
|
||||
|
||||
|
|
@ -1214,6 +1220,13 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
|
|||
stateKeyNIDs[i] = nid
|
||||
i++
|
||||
}
|
||||
// If we didn't have any userIDs to look up, get the UserIDs for the returned userNIDToCount now
|
||||
if len(userIDs) == 0 {
|
||||
nidToUserID, err = d.EventStateKeys(ctx, stateKeyNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
result := make(map[string]int, len(userNIDToCount))
|
||||
for nid, count := range userNIDToCount {
|
||||
result[nidToUserID[nid]] = count
|
||||
|
|
|
|||
|
|
@ -41,12 +41,18 @@ const membershipSchema = `
|
|||
);
|
||||
`
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid IN ($1) AND " +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
// Insert a row in to membership table so that it can be locked by the
|
||||
// SELECT FOR UPDATE
|
||||
const insertMembershipSQL = "" +
|
||||
|
|
@ -293,8 +299,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
|||
for _, v := range userNIDs {
|
||||
params = append(params, v)
|
||||
}
|
||||
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
|
||||
if len(userNIDs) > 0 {
|
||||
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
|
||||
}
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
|
|
|
|||
|
|
@ -187,7 +187,7 @@ func loadAppServices(config *AppServiceAPI, derived *Derived) error {
|
|||
}
|
||||
|
||||
// Load the config data into our struct
|
||||
if err = yaml.UnmarshalStrict(configData, &appservice); err != nil {
|
||||
if err = yaml.Unmarshal(configData, &appservice); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -315,6 +315,20 @@ func checkErrors(config *AppServiceAPI, derived *Derived) (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Check required fields
|
||||
if appservice.ID == "" {
|
||||
return ConfigErrors([]string{"Application service ID is required"})
|
||||
}
|
||||
if appservice.ASToken == "" {
|
||||
return ConfigErrors([]string{"Application service Token is required"})
|
||||
}
|
||||
if appservice.HSToken == "" {
|
||||
return ConfigErrors([]string{"Homeserver Token is required"})
|
||||
}
|
||||
if appservice.SenderLocalpart == "" {
|
||||
return ConfigErrors([]string{"Sender Localpart is required"})
|
||||
}
|
||||
|
||||
// Check if the url has trailing /'s. If so, remove them
|
||||
appservice.URL = strings.TrimRight(appservice.URL, "/")
|
||||
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms
|
|||
Type: msg.Header.Get("type"),
|
||||
}
|
||||
|
||||
timestamp, err := strconv.Atoi(msg.Header.Get("timestamp"))
|
||||
timestamp, err := strconv.ParseUint(msg.Header.Get("timestamp"), 10, 64)
|
||||
if err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("output log: message parse failure")
|
||||
|
|
|
|||
Loading…
Reference in a new issue