Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/fts2

This commit is contained in:
Till Faelligen 2022-07-06 07:47:24 +02:00
commit e883d688b2
14 changed files with 312 additions and 187 deletions

View file

@ -17,6 +17,7 @@ package producers
import (
"context"
"encoding/json"
"fmt"
"strconv"
"time"
@ -83,7 +84,7 @@ func (p *SyncAPIProducer) SendReceipt(
m.Header.Set(jetstream.RoomID, roomID)
m.Header.Set(jetstream.EventID, eventID)
m.Header.Set("type", receiptType)
m.Header.Set("timestamp", strconv.Itoa(int(timestamp)))
m.Header.Set("timestamp", fmt.Sprintf("%d", timestamp))
log.WithFields(log.Fields{}).Tracef("Producing to topic '%s'", p.TopicReceiptEvent)
_, err := p.JetStream.PublishMsg(m, nats.Context(ctx))

View file

@ -90,7 +90,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msg *nats.Msg) bo
return true
}
timestamp, err := strconv.Atoi(msg.Header.Get("timestamp"))
timestamp, err := strconv.ParseUint(msg.Header.Get("timestamp"), 10, 64)
if err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("EDU output log: message parse failure")

View file

@ -53,7 +53,7 @@ func (p *SyncAPIProducer) SendReceipt(
m.Header.Set(jetstream.RoomID, roomID)
m.Header.Set(jetstream.EventID, eventID)
m.Header.Set("type", receiptType)
m.Header.Set("timestamp", strconv.Itoa(int(timestamp)))
m.Header.Set("timestamp", fmt.Sprintf("%d", timestamp))
log.WithFields(log.Fields{}).Tracef("Producing to topic '%s'", p.TopicReceiptEvent)
_, err := p.JetStream.PublishMsg(m, nats.Context(ctx))

View file

@ -1,36 +1,26 @@
package storage
package storage_test
import (
"context"
"fmt"
"io/ioutil"
"log"
"os"
"reflect"
"testing"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
)
var ctx = context.Background()
func MustCreateDatabase(t *testing.T) (Database, func()) {
tmpfile, err := ioutil.TempFile("", "keyserver_storage_test")
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database)
if err != nil {
log.Fatal(err)
}
t.Logf("Database %s", tmpfile.Name())
db, err := NewDatabase(nil, &config.DatabaseOptions{
ConnectionString: config.DataSource(fmt.Sprintf("file://%s", tmpfile.Name())),
})
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
return db, func() {
os.Remove(tmpfile.Name())
t.Fatalf("failed to create new database: %v", err)
}
return db, close
}
func MustNotError(t *testing.T, err error) {
@ -42,151 +32,159 @@ func MustNotError(t *testing.T, err error) {
}
func TestKeyChanges(t *testing.T) {
db, clean := MustCreateDatabase(t)
defer clean()
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
MustNotError(t, err)
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeIDC {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
}
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean()
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
MustNotError(t, err)
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeIDC {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
}
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
func TestKeyChangesNoDupes(t *testing.T) {
db, clean := MustCreateDatabase(t)
defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
if deviceChangeIDA == deviceChangeIDB {
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
}
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeID {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
}
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
if deviceChangeIDA == deviceChangeIDB {
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
}
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeID {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
}
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
func TestKeyChangesUpperLimit(t *testing.T) {
db, clean := MustCreateDatabase(t)
defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
MustNotError(t, err)
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeIDB {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
}
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
MustNotError(t, err)
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeIDB {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
}
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
var err error
db, clean := MustCreateDatabase(t)
defer clean()
alice := "@alice:TestDeviceKeysStreamIDGeneration"
bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean()
alice := "@alice:TestDeviceKeysStreamIDGeneration"
bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1
},
// StreamID: 1
},
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: bob,
KeyJSON: []byte(`{"key":"v1"}`),
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: bob,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1 as this is a different user
},
// StreamID: 1 as this is a different user
},
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "another_device",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "another_device",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 2 as this is a 2nd device key
},
// StreamID: 2 as this is a 2nd device key
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
}
if msgs[1].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
}
if msgs[2].StreamID != 2 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
}
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),
},
// StreamID: 3
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 3 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
}
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
wantStreamIDs := map[string]int64{
"AAA": 3,
"another_device": 2,
}
if len(msgs) != len(wantStreamIDs) {
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
}
for _, m := range msgs {
if m.StreamID != wantStreamIDs[m.DeviceID] {
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
}
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
}
if msgs[1].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
}
if msgs[2].StreamID != 2 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
}
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),
},
// StreamID: 3
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 3 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
}
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
wantStreamIDs := map[string]int64{
"AAA": 3,
"another_device": 2,
}
if len(msgs) != len(wantStreamIDs) {
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
}
for _, m := range msgs {
if m.StreamID != wantStreamIDs[m.DeviceID] {
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
}
}
})
}

View file

@ -14,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/producers"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
@ -39,6 +40,7 @@ type RoomserverInternalAPI struct {
*perform.Upgrader
*perform.Admin
ProcessContext *process.ProcessContext
Base *base.BaseDendrite
DB storage.Database
Cfg *config.RoomServer
Cache caching.RoomServerCaches
@ -56,33 +58,38 @@ type RoomserverInternalAPI struct {
}
func NewRoomserverAPI(
processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database,
js nats.JetStreamContext, nc *nats.Conn, inputRoomEventTopic string,
caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName,
base *base.BaseDendrite, roomserverDB storage.Database,
js nats.JetStreamContext, nc *nats.Conn,
) *RoomserverInternalAPI {
var perspectiveServerNames []gomatrixserverlib.ServerName
for _, kp := range base.Cfg.FederationAPI.KeyPerspectives {
perspectiveServerNames = append(perspectiveServerNames, kp.ServerName)
}
serverACLs := acls.NewServerACLs(roomserverDB)
producer := &producers.RoomEventProducer{
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent)),
Topic: string(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)),
JetStream: js,
ACLs: serverACLs,
}
a := &RoomserverInternalAPI{
ProcessContext: processCtx,
ProcessContext: base.ProcessContext,
DB: roomserverDB,
Cfg: cfg,
Cache: caches,
ServerName: cfg.Matrix.ServerName,
Base: base,
Cfg: &base.Cfg.RoomServer,
Cache: base.Caches,
ServerName: base.Cfg.Global.ServerName,
PerspectiveServerNames: perspectiveServerNames,
InputRoomEventTopic: inputRoomEventTopic,
InputRoomEventTopic: base.Cfg.Global.JetStream.Prefixed(jetstream.InputRoomEvent),
OutputProducer: producer,
JetStream: js,
NATSClient: nc,
Durable: cfg.Matrix.JetStream.Durable("RoomserverInputConsumer"),
Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"),
ServerACLs: serverACLs,
Queryer: &query.Queryer{
DB: roomserverDB,
Cache: caches,
ServerName: cfg.Matrix.ServerName,
Cache: base.Caches,
ServerName: base.Cfg.Global.ServerName,
ServerACLs: serverACLs,
},
// perform-er structs get initialised when we have a federation sender to use
@ -98,8 +105,9 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.KeyRing = keyRing
r.Inputer = &input.Inputer{
Cfg: r.Cfg,
ProcessContext: r.ProcessContext,
Cfg: &r.Base.Cfg.RoomServer,
Base: r.Base,
ProcessContext: r.Base.ProcessContext,
DB: r.DB,
InputRoomEventTopic: r.InputRoomEventTopic,
OutputProducer: r.OutputProducer,

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/producers"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
@ -69,6 +70,7 @@ import (
// or C.
type Inputer struct {
Cfg *config.RoomServer
Base *base.BaseDendrite
ProcessContext *process.ProcessContext
DB storage.Database
NATSClient *nats.Conn
@ -160,7 +162,9 @@ func (r *Inputer) startWorkerForRoom(roomID string) {
// will look to see if we have a worker for that room which has its
// own consumer. If we don't, we'll start one.
func (r *Inputer) Start() error {
prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
if r.Base.EnableMetrics {
prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
}
_, err := r.JetStream.Subscribe(
"", // This is blank because we specified it in BindStream.
func(m *nats.Msg) {

View file

@ -17,13 +17,10 @@ package roomserver
import (
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/roomserver/internal"
"github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/sirupsen/logrus"
)
@ -40,11 +37,6 @@ func NewInternalAPI(
) api.RoomserverInternalAPI {
cfg := &base.Cfg.RoomServer
var perspectiveServerNames []gomatrixserverlib.ServerName
for _, kp := range base.Cfg.FederationAPI.KeyPerspectives {
perspectiveServerNames = append(perspectiveServerNames, kp.ServerName)
}
roomserverDB, err := storage.Open(base, &cfg.Database, base.Caches)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to room server db")
@ -53,8 +45,6 @@ func NewInternalAPI(
js, nc := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
return internal.NewRoomserverAPI(
base.ProcessContext, cfg, roomserverDB, js, nc,
cfg.Matrix.JetStream.Prefixed(jetstream.InputRoomEvent),
base.Caches, perspectiveServerNames,
base, roomserverDB, js, nc,
)
}

View file

@ -0,0 +1,69 @@
package roomserver_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches)
if err != nil {
t.Fatalf("failed to create Database: %v", err)
}
return base, db, close
}
func Test_SharedUsers(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
// Invite and join Bob
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, _, close := mustCreateDatabase(t, dbType)
defer close()
rsAPI := roomserver.NewInternalAPI(base)
// SetFederationAPI starts the room event input consumer
rsAPI.SetFederationAPI(nil, nil)
// Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// Query the shared users for Alice, there should only be Bob.
// This is used by the SyncAPI keychange consumer.
res := &api.QuerySharedUsersResponse{}
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
// Also verify that we get the expected result when specifying OtherUserIDs.
// This is used by the SyncAPI when getting device list changes.
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
})
}

View file

@ -110,7 +110,7 @@ func (v *StateResolution) LoadStateAtEvent(
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err)
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
}
if snapshotNID == 0 {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)

View file

@ -65,12 +65,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
);
`
var selectJoinedUsersSetForRoomsSQL = "" +
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
@ -153,6 +159,7 @@ type membershipStatements struct {
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
selectRoomsWithMembershipStmt *sql.Stmt
selectJoinedUsersSetForRoomsAndUserStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
updateMembershipForgetRoomStmt *sql.Stmt
@ -178,6 +185,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
{&s.selectJoinedUsersSetForRoomsAndUserStmt, selectJoinedUsersSetForRoomsAndUserSQL},
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
@ -313,8 +321,18 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) {
var (
rows *sql.Rows
err error
)
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
if len(userNIDs) > 0 {
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
} else {
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs))
}
if err != nil {
return nil, err
}

View file

@ -263,6 +263,12 @@ func (d *Database) snapshotNIDFromEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
if err != nil {
return 0, err
}
if stateNID == 0 {
return 0, sql.ErrNoRows // effectively there's no state entry
}
return stateNID, err
}
@ -1214,6 +1220,13 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
stateKeyNIDs[i] = nid
i++
}
// If we didn't have any userIDs to look up, get the UserIDs for the returned userNIDToCount now
if len(userIDs) == 0 {
nidToUserID, err = d.EventStateKeys(ctx, stateKeyNIDs)
if err != nil {
return nil, err
}
}
result := make(map[string]int, len(userNIDToCount))
for nid, count := range userNIDToCount {
result[nidToUserID[nid]] = count

View file

@ -41,12 +41,18 @@ const membershipSchema = `
);
`
var selectJoinedUsersSetForRoomsSQL = "" +
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND " +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
@ -293,8 +299,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
for _, v := range userNIDs {
params = append(params, v)
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
if len(userNIDs) > 0 {
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
}
var rows *sql.Rows
var err error
if txn != nil {

View file

@ -187,7 +187,7 @@ func loadAppServices(config *AppServiceAPI, derived *Derived) error {
}
// Load the config data into our struct
if err = yaml.UnmarshalStrict(configData, &appservice); err != nil {
if err = yaml.Unmarshal(configData, &appservice); err != nil {
return err
}
@ -315,6 +315,20 @@ func checkErrors(config *AppServiceAPI, derived *Derived) (err error) {
}
}
// Check required fields
if appservice.ID == "" {
return ConfigErrors([]string{"Application service ID is required"})
}
if appservice.ASToken == "" {
return ConfigErrors([]string{"Application service Token is required"})
}
if appservice.HSToken == "" {
return ConfigErrors([]string{"Homeserver Token is required"})
}
if appservice.SenderLocalpart == "" {
return ConfigErrors([]string{"Sender Localpart is required"})
}
// Check if the url has trailing /'s. If so, remove them
appservice.URL = strings.TrimRight(appservice.URL, "/")

View file

@ -87,7 +87,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms
Type: msg.Header.Get("type"),
}
timestamp, err := strconv.Atoi(msg.Header.Get("timestamp"))
timestamp, err := strconv.ParseUint(msg.Header.Get("timestamp"), 10, 64)
if err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("output log: message parse failure")