From 75834783055b0c70f8b411d9c3741e57461832f0 Mon Sep 17 00:00:00 2001 From: kegsay Date: Mon, 5 Dec 2022 16:54:01 +0000 Subject: [PATCH 01/15] Update contributing guidelines (#2904) --- docs/CONTRIBUTING.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 262a93a7c..21b0e7ab1 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -9,6 +9,28 @@ permalink: /development/contributing Everyone is welcome to contribute to Dendrite! We aim to make it as easy as possible to get started. + ## Contribution types + +We are a small team maintaining a large project. As a result, we cannot merge every feature, even if it +is bug-free and useful, because we then commit to maintaining it indefinitely. We will always accept: + - bug fixes + - security fixes (please responsibly disclose via security@matrix.org *before* creating pull requests) + +We will accept the following with caveats: + - documentation fixes, provided they do not add additional instructions which can end up going out-of-date, + e.g example configs, shell commands. + - performance fixes, provided they do not add significantly more maintenance burden. + - additional functionality on existing features, provided the functionality is small and maintainable. + - additional functionality that, in its absence, would impact the ecosystem e.g spam and abuse mitigations + - test-only changes, provided they help improve coverage or test tricky code. + +The following items are at risk of not being accepted: + - Configuration or CLI changes, particularly ones which increase the overall configuration surface. + +The following items are unlikely to be accepted into a main Dendrite release for now: + - New MSC implementations. + - New features which are not in the specification. + ## Sign off We require that everyone who contributes to the project signs off their contributions From ded43e0f2d07adc399da5962454d9873021b59ac Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 6 Dec 2022 13:27:33 +0100 Subject: [PATCH 02/15] Fix issue with sending presence events to invalid servers --- federationapi/consumers/roomserver.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index d16af6626..0c1080afa 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -232,7 +232,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) { - joined := make([]gomatrixserverlib.ServerName, len(addedJoined)) + joined := make([]gomatrixserverlib.ServerName, 0, len(addedJoined)) for _, added := range addedJoined { joined = append(joined, added.ServerName) } From ba2ffb7da9b86b64dc8091d90645ab80cf2831db Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 6 Dec 2022 18:16:17 +0000 Subject: [PATCH 03/15] Repeatable reads for `/sync` (#2783) This puts repeatable reads into all sync streams. Co-authored-by: kegsay --- syncapi/storage/shared/storage_consumer.go | 42 +++++++++------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 23f53d11f..f2064fb89 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -57,31 +57,23 @@ type Database struct { } func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) { - return d.NewDatabaseTransaction(ctx) - - /* - TODO: Repeatable read is probably the right thing to do here, - but it seems to cause some problems with the invite tests, so - need to investigate that further. - - txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ - // Set the isolation level so that we see a snapshot of the database. - // In PostgreSQL repeatable read transactions will see a snapshot taken - // at the first query, and since the transaction is read-only it can't - // run into any serialisation errors. - // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ - Isolation: sql.LevelRepeatableRead, - ReadOnly: true, - }) - if err != nil { - return nil, err - } - return &DatabaseTransaction{ - Database: d, - ctx: ctx, - txn: txn, - }, nil - */ + txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ + // Set the isolation level so that we see a snapshot of the database. + // In PostgreSQL repeatable read transactions will see a snapshot taken + // at the first query, and since the transaction is read-only it can't + // run into any serialisation errors. + // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }) + if err != nil { + return nil, err + } + return &DatabaseTransaction{ + Database: d, + ctx: ctx, + txn: txn, + }, nil } func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) { From 27a1dea5225c828ad787098aadafa83ed73c4940 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 8 Dec 2022 08:24:06 +0100 Subject: [PATCH 04/15] Fix issue with multiple/duplicate log entries during tests (#2906) --- internal/log.go | 5 +++++ internal/log_unix.go | 13 +++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/internal/log.go b/internal/log.go index a171555ab..d7e852c81 100644 --- a/internal/log.go +++ b/internal/log.go @@ -33,6 +33,11 @@ import ( "github.com/matrix-org/dendrite/setup/config" ) +// logrus is using a global variable when we're using `logrus.AddHook` +// this unfortunately results in us adding the same hook multiple times. +// This map ensures we only ever add one level hook. +var stdLevelLogAdded = make(map[logrus.Level]bool) + type utcFormatter struct { logrus.Formatter } diff --git a/internal/log_unix.go b/internal/log_unix.go index 75332af73..b38e7c2e8 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -22,16 +22,16 @@ import ( "log/syslog" "github.com/MFAshby/stdemuxerhook" - "github.com/matrix-org/dendrite/setup/config" "github.com/sirupsen/logrus" lSyslog "github.com/sirupsen/logrus/hooks/syslog" + + "github.com/matrix-org/dendrite/setup/config" ) // SetupHookLogging configures the logging hooks defined in the configuration. // If something fails here it means that the logging was improperly configured, // so we just exit with the error func SetupHookLogging(hooks []config.LogrusHook, componentName string) { - stdLogAdded := false for _, hook := range hooks { // Check we received a proper logging level level, err := logrus.ParseLevel(hook.Level) @@ -54,14 +54,11 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) { setupSyslogHook(hook, level, componentName) case "std": setupStdLogHook(level) - stdLogAdded = true default: logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type) } } - if !stdLogAdded { - setupStdLogHook(logrus.InfoLevel) - } + setupStdLogHook(logrus.InfoLevel) // Hooks are now configured for stdout/err, so throw away the default logger output logrus.SetOutput(io.Discard) } @@ -88,7 +85,11 @@ func checkSyslogHookParams(params map[string]interface{}) { } func setupStdLogHook(level logrus.Level) { + if stdLevelLogAdded[level] { + return + } logrus.AddHook(&logLevelHook{level, stdemuxerhook.New(logrus.StandardLogger())}) + stdLevelLogAdded[level] = true } func setupSyslogHook(hook config.LogrusHook, level logrus.Level, componentName string) { From 0351618ff4e7d569e14a165be59a1a7e9e979684 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 8 Dec 2022 08:24:24 +0100 Subject: [PATCH 05/15] Add UserAPI util tests (#2907) This adds some `userapi/util` tests. --- test/testrig/base.go | 2 +- userapi/util/notify_test.go | 119 ++++++++++++++++++++++++++++ userapi/util/phonehomestats.go | 10 +-- userapi/util/phonehomestats_test.go | 84 ++++++++++++++++++++ 4 files changed, 208 insertions(+), 7 deletions(-) create mode 100644 userapi/util/notify_test.go create mode 100644 userapi/util/phonehomestats_test.go diff --git a/test/testrig/base.go b/test/testrig/base.go index 15fb5c370..7bc26a5c5 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -108,7 +108,7 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat cfg.Global.JetStream.InMemory = true cfg.SyncAPI.Fulltext.InMemory = true cfg.FederationAPI.KeyPerspectives = nil - base := base.NewBaseDendrite(cfg, "Tests") + base := base.NewBaseDendrite(cfg, "Tests", base.DisableMetrics) js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) return base, js, jc } diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go new file mode 100644 index 000000000..f1d20259c --- /dev/null +++ b/userapi/util/notify_test.go @@ -0,0 +1,119 @@ +package util_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + userUtil "github.com/matrix-org/dendrite/userapi/util" +) + +func TestNotifyUserCountsAsync(t *testing.T) { + alice := test.NewUser(t) + aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) + if err != nil { + t.Error(err) + } + ctx := context.Background() + + // Create a test room, just used to provide events + room := test.NewRoom(t, alice) + dummyEvent := room.Events()[len(room.Events())-1] + + appID := util.RandomString(8) + pushKey := util.RandomString(8) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + receivedRequest := make(chan bool, 1) + // create a test server which responds to our /notify call + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data pushgateway.NotifyRequest + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + t.Error(err) + } + notification := data.Notification + // Validate the request + if notification.Counts == nil { + t.Fatal("no unread notification counts in request") + } + if unread := notification.Counts.Unread; unread != 1 { + t.Errorf("expected one unread notification, got %d", unread) + } + + if len(notification.Devices) == 0 { + t.Fatal("expected devices in request") + } + + // We only created one push device, so access it directly + device := notification.Devices[0] + if device.AppID != appID { + t.Errorf("unexpected app_id: %s, want %s", device.AppID, appID) + } + if device.PushKey != pushKey { + t.Errorf("unexpected push_key: %s, want %s", device.PushKey, pushKey) + } + + // Return empty result, otherwise the call is handled as failed + if _, err := w.Write([]byte("{}")); err != nil { + t.Error(err) + } + close(receivedRequest) + })) + defer srv.Close() + + // Create DB and Dendrite base + connStr, close := test.PrepareDBConnectionString(t, dbType) + defer close() + base, _, _ := testrig.Base(nil) + defer base.Close() + db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, "test", bcrypt.MinCost, 0, 0, "") + if err != nil { + t.Error(err) + } + + // Prepare pusher with our test server URL + if err := db.UpsertPusher(ctx, api.Pusher{ + Kind: api.HTTPKind, + AppID: appID, + PushKey: pushKey, + Data: map[string]interface{}{ + "url": srv.URL, + }, + }, aliceLocalpart, serverName); err != nil { + t.Error(err) + } + + // Insert a dummy event + if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ + Event: gomatrixserverlib.HeaderedToClientEvent(dummyEvent, gomatrixserverlib.FormatAll), + }); err != nil { + t.Error(err) + } + + // Notify the user about a new notification + if err := userUtil.NotifyUserCountsAsync(ctx, pushgateway.NewHTTPClient(true), aliceLocalpart, serverName, db); err != nil { + t.Error(err) + } + select { + case <-time.After(time.Second * 5): + t.Error("timed out waiting for response") + case <-receivedRequest: + } + }) + +} diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go index 6f36568c9..42c8f5d7c 100644 --- a/userapi/util/phonehomestats.go +++ b/userapi/util/phonehomestats.go @@ -97,12 +97,10 @@ func (p *phoneHomeStats) collect() { // configuration information p.stats["federation_disabled"] = p.cfg.Global.DisableFederation - p.stats["nats_embedded"] = true - p.stats["nats_in_memory"] = p.cfg.Global.JetStream.InMemory - if len(p.cfg.Global.JetStream.Addresses) > 0 { - p.stats["nats_embedded"] = false - p.stats["nats_in_memory"] = false // probably - } + natsEmbedded := len(p.cfg.Global.JetStream.Addresses) == 0 + p.stats["nats_embedded"] = natsEmbedded + p.stats["nats_in_memory"] = p.cfg.Global.JetStream.InMemory && natsEmbedded + if len(p.cfg.Logging) > 0 { p.stats["log_level"] = p.cfg.Logging[0].Level } else { diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go new file mode 100644 index 000000000..6e62210e8 --- /dev/null +++ b/userapi/util/phonehomestats_test.go @@ -0,0 +1,84 @@ +package util + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/storage" +) + +func TestCollect(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + b, _, _ := testrig.Base(nil) + connStr, closeDB := test.PrepareDBConnectionString(t, dbType) + defer closeDB() + db, err := storage.NewUserAPIDatabase(b, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, "localhost", bcrypt.MinCost, 1000, 1000, "") + if err != nil { + t.Error(err) + } + + receivedRequest := make(chan struct{}, 1) + // create a test server which responds to our call + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + t.Error(err) + } + defer r.Body.Close() + if _, err := w.Write([]byte("{}")); err != nil { + t.Error(err) + } + + // verify the received data matches our expectations + dbEngine, ok := data["database_engine"] + if !ok { + t.Errorf("missing database_engine in JSON request: %+v", data) + } + version, ok := data["version"] + if !ok { + t.Errorf("missing version in JSON request: %+v", data) + } + if version != internal.VersionString() { + t.Errorf("unexpected version: %q, expected %q", version, internal.VersionString()) + } + switch { + case dbType == test.DBTypeSQLite && dbEngine != "SQLite": + t.Errorf("unexpected database_engine: %s", dbEngine) + case dbType == test.DBTypePostgres && dbEngine != "Postgres": + t.Errorf("unexpected database_engine: %s", dbEngine) + } + close(receivedRequest) + })) + defer srv.Close() + + b.Cfg.Global.ReportStats.Endpoint = srv.URL + stats := phoneHomeStats{ + prevData: timestampToRUUsage{}, + serverName: "localhost", + startTime: time.Now(), + cfg: b.Cfg, + db: db, + isMonolith: false, + client: &http.Client{Timeout: time.Second}, + } + + stats.collect() + + select { + case <-time.After(time.Second * 5): + t.Error("timed out waiting for response") + case <-receivedRequest: + } + }) +} From c136a450d5196cf22a91419f493bb73c29481122 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 8 Dec 2022 08:25:03 +0100 Subject: [PATCH 06/15] Fix newly joined users presence (#2854) Fixes #2803 Also refactors the presence stream to not hit the database for every user, instead queries all users at once now. --- syncapi/consumers/presence.go | 12 +- syncapi/storage/interface.go | 4 +- syncapi/storage/postgres/presence_table.go | 38 +++-- syncapi/storage/shared/storage_consumer.go | 4 +- syncapi/storage/shared/storage_sync.go | 4 +- syncapi/storage/sqlite3/presence_table.go | 48 +++++-- syncapi/storage/tables/interface.go | 2 +- syncapi/storage/tables/presence_table_test.go | 136 ++++++++++++++++++ syncapi/streams/stream_presence.go | 80 +++++++---- syncapi/sync/requestpool.go | 6 +- syncapi/sync/requestpool_test.go | 4 +- 11 files changed, 263 insertions(+), 75 deletions(-) create mode 100644 syncapi/storage/tables/presence_table_test.go diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 145059c2d..6e3150c29 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -78,7 +78,7 @@ func (s *PresenceConsumer) Start() error { // Normal NATS subscription, used by Request/Reply _, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) { userID := msg.Header.Get(jetstream.UserID) - presence, err := s.db.GetPresence(context.Background(), userID) + presences, err := s.db.GetPresences(context.Background(), []string{userID}) m := &nats.Msg{ Header: nats.Header{}, } @@ -89,10 +89,12 @@ func (s *PresenceConsumer) Start() error { } return } - if presence == nil { - presence = &types.PresenceInternal{ - UserID: userID, - } + + presence := &types.PresenceInternal{ + UserID: userID, + } + if len(presences) > 0 { + presence = presences[0] } deviceRes := api.QueryDevicesResponse{} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 97c2ced49..75afbce15 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -106,7 +106,7 @@ type DatabaseTransaction interface { SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) - GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) } @@ -186,7 +186,7 @@ type Database interface { } type Presence interface { - GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) } diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index 7194afea6..a3f7c5213 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -19,10 +19,12 @@ import ( "database/sql" "time" + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const presenceSchema = ` @@ -63,9 +65,9 @@ const upsertPresenceFromSyncSQL = "" + " RETURNING id" const selectPresenceForUserSQL = "" + - "SELECT presence, status_msg, last_active_ts" + + "SELECT user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE user_id = $1 LIMIT 1" + " WHERE user_id = ANY($1)" const selectMaxPresenceSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" @@ -119,20 +121,28 @@ func (p *presenceStatements) UpsertPresence( return } -// GetPresenceForUser returns the current presence of a user. -func (p *presenceStatements) GetPresenceForUser( +// GetPresenceForUsers returns the current presence for a list of users. +// If the user doesn't have a presence status yet, it is omitted from the response. +func (p *presenceStatements) GetPresenceForUsers( ctx context.Context, txn *sql.Tx, - userID string, -) (*types.PresenceInternal, error) { - result := &types.PresenceInternal{ - UserID: userID, - } + userIDs []string, +) ([]*types.PresenceInternal, error) { + result := make([]*types.PresenceInternal, 0, len(userIDs)) stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) - err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) - if err == sql.ErrNoRows { - return nil, nil + rows, err := stmt.QueryContext(ctx, pq.Array(userIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed") + + for rows.Next() { + presence := &types.PresenceInternal{} + if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil { + return nil, err + } + presence.ClientFields.Presence = presence.Presence.String() + result = append(result, presence) } - result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index f2064fb89..df2338cf8 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -564,8 +564,8 @@ func (d *Database) UpdatePresence(ctx context.Context, userID string, presence t return pos, err } -func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, nil, userID) +func (d *Database) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUsers(ctx, nil, userIDs) } func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index c3763521c..77afa0290 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -596,8 +596,8 @@ func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx contex return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) } -func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, d.txn, userID) +func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs) } func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index b61a825df..7641de92f 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -17,12 +17,14 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const presenceSchema = ` @@ -62,9 +64,9 @@ const upsertPresenceFromSyncSQL = "" + " RETURNING id" const selectPresenceForUserSQL = "" + - "SELECT presence, status_msg, last_active_ts" + + "SELECT user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE user_id = $1 LIMIT 1" + " WHERE user_id IN ($1)" const selectMaxPresenceSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" @@ -134,20 +136,38 @@ func (p *presenceStatements) UpsertPresence( return } -// GetPresenceForUser returns the current presence of a user. -func (p *presenceStatements) GetPresenceForUser( +// GetPresenceForUsers returns the current presence for a list of users. +// If the user doesn't have a presence status yet, it is omitted from the response. +func (p *presenceStatements) GetPresenceForUsers( ctx context.Context, txn *sql.Tx, - userID string, -) (*types.PresenceInternal, error) { - result := &types.PresenceInternal{ - UserID: userID, + userIDs []string, +) ([]*types.PresenceInternal, error) { + qry := strings.Replace(selectPresenceForUserSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + prepStmt, err := p.db.Prepare(qry) + if err != nil { + return nil, err } - stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) - err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) - if err == sql.ErrNoRows { - return nil, nil + defer internal.CloseAndLogIfError(ctx, prepStmt, "GetPresenceForUsers: stmt.close() failed") + + params := make([]interface{}, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + rows, err := sqlutil.TxStmt(txn, prepStmt).QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed") + result := make([]*types.PresenceInternal, 0, len(userIDs)) + for rows.Next() { + presence := &types.PresenceInternal{} + if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil { + return nil, err + } + presence.ClientFields.Presence = presence.Presence.String() + result = append(result, presence) } - result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 2c4f04ec2..a0574b257 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -207,7 +207,7 @@ type Ignores interface { type Presence interface { UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) - GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) + GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) } diff --git a/syncapi/storage/tables/presence_table_test.go b/syncapi/storage/tables/presence_table_test.go new file mode 100644 index 000000000..dce0c695a --- /dev/null +++ b/syncapi/storage/tables/presence_table_test.go @@ -0,0 +1,136 @@ +package tables_test + +import ( + "context" + "database/sql" + "reflect" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Presence + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresPresenceTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + tab, err = sqlite3.NewSqlitePresenceTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, close +} + +func TestPresence(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + ctx := context.Background() + + statusMsg := "Hello World!" + timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + + var txn *sql.Tx + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustPresenceTable(t, dbType) + defer closeDB() + + // Insert some presences + pos, err := tab.UpsertPresence(ctx, txn, alice.ID, &statusMsg, types.PresenceOnline, timestamp, false) + if err != nil { + t.Error(err) + } + wantPos := types.StreamPosition(1) + if pos != wantPos { + t.Errorf("expected pos to be %d, got %d", wantPos, pos) + } + pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, false) + if err != nil { + t.Error(err) + } + wantPos = 2 + if pos != wantPos { + t.Errorf("expected pos to be %d, got %d", wantPos, pos) + } + + // verify the expected max presence ID + maxPos, err := tab.GetMaxPresenceID(ctx, txn) + if err != nil { + t.Error(err) + } + if maxPos != wantPos { + t.Errorf("expected max pos to be %d, got %d", wantPos, maxPos) + } + + // This should increment the position + pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, true) + if err != nil { + t.Error(err) + } + wantPos = pos + if wantPos <= maxPos { + t.Errorf("expected pos to be %d incremented, got %d", wantPos, pos) + } + + // This should return only Bobs status + presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10}) + if err != nil { + t.Error(err) + } + + if c := len(presences); c > 1 { + t.Errorf("expected only one presence, got %d", c) + } + + // Validate the response + wantPresence := &types.PresenceInternal{ + UserID: bob.ID, + Presence: types.PresenceOnline, + StreamPos: wantPos, + LastActiveTS: timestamp, + ClientFields: types.PresenceClientResponse{ + LastActiveAgo: 0, + Presence: types.PresenceOnline.String(), + StatusMsg: &statusMsg, + }, + } + if !reflect.DeepEqual(wantPresence, presences[bob.ID]) { + t.Errorf("unexpected presence result:\n%+v, want\n%+v", presences[bob.ID], wantPresence) + } + + // Try getting presences for existing and non-existing users + getUsers := []string{alice.ID, bob.ID, "@doesntexist:test"} + presencesForUsers, err := tab.GetPresenceForUsers(ctx, nil, getUsers) + if err != nil { + t.Error(err) + } + + if len(presencesForUsers) >= len(getUsers) { + t.Errorf("expected less presences, but they are the same/more as requested: %d >= %d", len(presencesForUsers), len(getUsers)) + } + }) + +} diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 030b7c5d5..445e46b3a 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -17,6 +17,7 @@ package streams import ( "context" "encoding/json" + "fmt" "sync" "github.com/matrix-org/gomatrixserverlib" @@ -70,39 +71,25 @@ func (p *PresenceStreamProvider) IncrementalSync( return from } - if len(presences) == 0 { + getPresenceForUsers, err := p.getNeededUsersFromRequest(ctx, req, presences) + if err != nil { + req.Log.WithError(err).Error("getNeededUsersFromRequest failed") + return from + } + + // Got no presence between range and no presence to get from the database + if len(getPresenceForUsers) == 0 && len(presences) == 0 { return to } - // add newly joined rooms user presences - newlyJoined := joinedRooms(req.Response, req.Device.UserID) - if len(newlyJoined) > 0 { - // TODO: Check if this is working better than before. - if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { - req.Log.WithError(err).Error("unable to refresh notifier lists") - return from - } - NewlyJoinedLoop: - for _, roomID := range newlyJoined { - roomUsers := p.notifier.JoinedUsers(roomID) - for i := range roomUsers { - // we already got a presence from this user - if _, ok := presences[roomUsers[i]]; ok { - continue - } - // Bear in mind that this might return nil, but at least populating - // a nil means that there's a map entry so we won't repeat this call. - presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i]) - if err != nil { - req.Log.WithError(err).Error("unable to query presence for user") - _ = snapshot.Rollback() - return from - } - if len(presences) > req.Filter.Presence.Limit { - break NewlyJoinedLoop - } - } - } + dbPresences, err := snapshot.GetPresences(ctx, getPresenceForUsers) + if err != nil { + req.Log.WithError(err).Error("unable to query presence for user") + _ = snapshot.Rollback() + return from + } + for _, presence := range dbPresences { + presences[presence.UserID] = presence } lastPos := from @@ -164,6 +151,39 @@ func (p *PresenceStreamProvider) IncrementalSync( return lastPos } +func (p *PresenceStreamProvider) getNeededUsersFromRequest(ctx context.Context, req *types.SyncRequest, presences map[string]*types.PresenceInternal) ([]string, error) { + getPresenceForUsers := []string{} + // Add presence for users which newly joined a room + for userID := range req.MembershipChanges { + if _, ok := presences[userID]; ok { + continue + } + getPresenceForUsers = append(getPresenceForUsers, userID) + } + + // add newly joined rooms user presences + newlyJoined := joinedRooms(req.Response, req.Device.UserID) + if len(newlyJoined) == 0 { + return getPresenceForUsers, nil + } + + // TODO: Check if this is working better than before. + if err := p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { + return getPresenceForUsers, fmt.Errorf("unable to refresh notifier lists: %w", err) + } + for _, roomID := range newlyJoined { + roomUsers := p.notifier.JoinedUsers(roomID) + for i := range roomUsers { + // we already got a presence from this user + if _, ok := presences[roomUsers[i]]; ok { + continue + } + getPresenceForUsers = append(getPresenceForUsers, roomUsers[i]) + } + } + return getPresenceForUsers, nil +} + func joinedRooms(res *types.Response, userID string) []string { var roomIDs []string for roomID, join := range res.Rooms.Join { diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 29d92b293..b086567b8 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -145,12 +145,12 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user } // ensure we also send the current status_msg to federated servers and not nil - dbPresence, err := db.GetPresence(context.Background(), userID) + dbPresence, err := db.GetPresences(context.Background(), []string{userID}) if err != nil && err != sql.ErrNoRows { return } - if dbPresence != nil { - newPresence.ClientFields = dbPresence.ClientFields + if len(dbPresence) > 0 && dbPresence[0] != nil { + newPresence.ClientFields = dbPresence[0].ClientFields } newPresence.ClientFields.Presence = presenceID.String() diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index 3e5769d8c..faa0b49c6 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -29,8 +29,8 @@ func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence typ return 0, nil } -func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return &types.PresenceInternal{}, nil +func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) { + return []*types.PresenceInternal{}, nil } func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { From 8846de7312d447cd2866236493b6ae172fbebfa9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 8 Dec 2022 10:19:55 +0000 Subject: [PATCH 07/15] Bump nokogiri from 1.13.9 to 1.13.10 in /docs (#2909) Bumps [nokogiri](https://github.com/sparklemotion/nokogiri) from 1.13.9 to 1.13.10.
Release notes

Sourced from nokogiri's releases.

1.13.10 / 2022-12-07

Security

  • [CRuby] Address CVE-2022-23476, unchecked return value from xmlTextReaderExpand. See GHSA-qv4q-mr5r-qprj for more information.

Improvements

  • [CRuby] XML::Reader#attribute_hash now returns nil on parse errors. This restores the behavior of #attributes from v1.13.7 and earlier. [#2715]

sha256 checksums:

777ce2e80f64772e91459b943e531dfef387e768f2255f9bc7a1655f254bbaa1
nokogiri-1.13.10-aarch64-linux.gem
b432ff47c51386e07f7e275374fe031c1349e37eaef2216759063bc5fa5624aa
nokogiri-1.13.10-arm64-darwin.gem
73ac581ddcb680a912e92da928ffdbac7b36afd3368418f2cee861b96e8c830b
nokogiri-1.13.10-java.gem
916aa17e624611dddbf2976ecce1b4a80633c6378f8465cff0efab022ebc2900
nokogiri-1.13.10-x64-mingw-ucrt.gem
0f85a1ad8c2b02c166a6637237133505b71a05f1bb41b91447005449769bced0
nokogiri-1.13.10-x64-mingw32.gem
91fa3a8724a1ce20fccbd718dafd9acbde099258183ac486992a61b00bb17020
nokogiri-1.13.10-x86-linux.gem
d6663f5900ccd8f72d43660d7f082565b7ffcaade0b9a59a74b3ef8791034168
nokogiri-1.13.10-x86-mingw32.gem
81755fc4b8130ef9678c76a2e5af3db7a0a6664b3cba7d9fe8ef75e7d979e91b
nokogiri-1.13.10-x86_64-darwin.gem
51d5246705dedad0a09b374d09cc193e7383a5dd32136a690a3cd56e95adf0a3
nokogiri-1.13.10-x86_64-linux.gem
d3ee00f26c151763da1691c7fc6871ddd03e532f74f85101f5acedc2d099e958
nokogiri-1.13.10.gem
Changelog

Sourced from nokogiri's changelog.

1.13.10 / 2022-12-07

Security

  • [CRuby] Address CVE-2022-23476, unchecked return value from xmlTextReaderExpand. See GHSA-qv4q-mr5r-qprj for more information.

Improvements

  • [CRuby] XML::Reader#attribute_hash now returns nil on parse errors. This restores the behavior of #attributes from v1.13.7 and earlier. [#2715]
Commits
  • 4c80121 version bump to v1.13.10
  • 85410e3 Merge pull request #2715 from sparklemotion/flavorjones-fix-reader-error-hand...
  • 9fe0761 fix(cruby): XML::Reader#attribute_hash returns nil on error
  • 3b9c736 Merge pull request #2717 from sparklemotion/flavorjones-lock-psych-to-fix-bui...
  • 2efa87b test: skip large cdata test on system libxml2
  • 3187d67 dep(dev): pin psych to v4 until v5 builds in CI
  • a16b4bf style(rubocop): disable Minitest/EmptyLineBeforeAssertionMethods
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=nokogiri&package-manager=bundler&previous-version=1.13.9&new-version=1.13.10)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) - `@dependabot use these labels` will set the current labels as the default for future PRs for this repo and language - `@dependabot use these reviewers` will set the current reviewers as the default for future PRs for this repo and language - `@dependabot use these assignees` will set the current assignees as the default for future PRs for this repo and language - `@dependabot use this milestone` will set the current milestone as the default for future PRs for this repo and language You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/matrix-org/dendrite/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/Gemfile.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index c7ba43711..509a8cbcf 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -231,9 +231,9 @@ GEM jekyll-seo-tag (~> 2.1) minitest (5.15.0) multipart-post (2.1.1) - nokogiri (1.13.9-arm64-darwin) + nokogiri (1.13.10-arm64-darwin) racc (~> 1.4) - nokogiri (1.13.9-x86_64-linux) + nokogiri (1.13.10-x86_64-linux) racc (~> 1.4) octokit (4.22.0) faraday (>= 0.9) @@ -241,7 +241,7 @@ GEM pathutil (0.16.2) forwardable-extended (~> 2.6) public_suffix (4.0.7) - racc (1.6.0) + racc (1.6.1) rb-fsevent (0.11.1) rb-inotify (0.10.1) ffi (~> 1.0) From aaf4e5c8654463cc5431d57db9163ad9ed558f53 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 9 Dec 2022 18:45:42 +0100 Subject: [PATCH 08/15] Use older sytest-dendrite image --- .github/workflows/dendrite.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 593012ef3..2c04005d2 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -331,7 +331,8 @@ jobs: postgres: postgres api: full-http container: - image: matrixdotorg/sytest-dendrite:latest + # Temporary for debugging to see if this image is working better. + image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73 volumes: - ${{ github.workspace }}:/src - /root/.cache/go-build:/github/home/.cache/go-build From 7d2344049d0780e92b071939addc41219ce94f5a Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 12 Dec 2022 08:20:59 +0100 Subject: [PATCH 09/15] Cleanup stale device lists for users we don't share a room with anymore (#2857) The stale device lists table might contain entries for users we don't share a room with anymore. This now asks the roomserver about left users and removes those entries from the table. Co-authored-by: Neil Alexander --- build/dendritejs-pinecone/main.go | 6 +- build/gobind-pinecone/monolith.go | 2 +- build/gobind-yggdrasil/monolith.go | 2 +- cmd/dendrite-demo-pinecone/main.go | 2 +- cmd/dendrite-demo-yggdrasil/main.go | 5 +- cmd/dendrite-monolith-server/main.go | 2 +- .../personalities/keyserver.go | 3 +- keyserver/internal/device_list_update.go | 29 +++++- keyserver/internal/device_list_update_test.go | 86 ++++++++++++++++- keyserver/keyserver.go | 12 ++- keyserver/keyserver_test.go | 29 ++++++ keyserver/storage/interface.go | 5 + .../storage/postgres/stale_device_lists.go | 33 +++++-- keyserver/storage/shared/storage.go | 10 ++ .../storage/sqlite3/stale_device_lists.go | 44 +++++++-- keyserver/storage/tables/interface.go | 1 + .../storage/tables/stale_device_lists_test.go | 94 ++++++++++++++++++ roomserver/api/api.go | 5 + roomserver/api/api_trace.go | 6 ++ roomserver/api/query.go | 12 +++ roomserver/internal/query/query.go | 6 ++ roomserver/inthttp/client.go | 8 ++ roomserver/inthttp/server.go | 5 + roomserver/roomserver_test.go | 77 ++++++++++++++- roomserver/storage/interface.go | 1 + .../storage/postgres/membership_table.go | 34 ++++++- roomserver/storage/shared/storage.go | 37 +++++++ roomserver/storage/shared/storage_test.go | 96 +++++++++++++++++++ .../storage/sqlite3/membership_table.go | 47 ++++++++- roomserver/storage/tables/interface.go | 1 + .../storage/tables/membership_table_test.go | 6 ++ 31 files changed, 666 insertions(+), 40 deletions(-) create mode 100644 keyserver/keyserver_test.go create mode 100644 keyserver/storage/tables/stale_device_lists_test.go create mode 100644 roomserver/storage/shared/storage_test.go diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go index e070173aa..f44a77488 100644 --- a/build/dendritejs-pinecone/main.go +++ b/build/dendritejs-pinecone/main.go @@ -180,14 +180,14 @@ func startup() { base := base.NewBaseDendrite(cfg, "Monolith") defer base.Close() // nolint: errcheck + rsAPI := roomserver.NewInternalAPI(base) + federation := conn.CreateFederationClient(base, pSessions) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - rsAPI := roomserver.NewInternalAPI(base) - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index e8ed8fe85..b8f8111d2 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -350,7 +350,7 @@ func (m *DendriteMonolith) Start() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(m.userAPI) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 9a3ac5d7b..8c2d0a006 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -165,7 +165,7 @@ func (m *DendriteMonolith) Start() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 2f647a41b..3f627b41d 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -213,7 +213,7 @@ func main() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsComponent) userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 5dd61b1b7..3ea4a08b0 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -157,11 +157,12 @@ func main() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - rsComponent := roomserver.NewInternalAPI( base, ) + + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsComponent) + rsAPI := rsComponent userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 2d2f32b00..6836b6426 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -95,7 +95,7 @@ func main() { } keyRing := fsAPI.KeyRing() - keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) keyAPI := keyImpl if base.UseHTTPAPIs { keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics) diff --git a/cmd/dendrite-polylith-multi/personalities/keyserver.go b/cmd/dendrite-polylith-multi/personalities/keyserver.go index d2924b892..ad0bd0e54 100644 --- a/cmd/dendrite-polylith-multi/personalities/keyserver.go +++ b/cmd/dendrite-polylith-multi/personalities/keyserver.go @@ -22,7 +22,8 @@ import ( func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { fsAPI := base.FederationAPIHTTPClient() - intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + rsAPI := base.RoomserverHTTPClient() + intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) intAPI.SetUserAPI(base.UserAPIClient()) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics) diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 8ff9dfc31..c7bf8da53 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -24,6 +24,8 @@ import ( "sync" "time" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -102,6 +104,7 @@ type DeviceListUpdater struct { // block on or timeout via a select. userIDToChan map[string]chan bool userIDToChanMu *sync.Mutex + rsAPI rsapi.KeyserverRoomserverAPI } // DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. @@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface { // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error } type DeviceListUpdaterAPI interface { @@ -140,7 +145,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, - thisServer gomatrixserverlib.ServerName, + rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -154,6 +159,7 @@ func NewDeviceListUpdater( workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, + rsAPI: rsAPI, } } @@ -168,7 +174,7 @@ func (u *DeviceListUpdater) Start() error { go u.worker(ch) } - staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{}) + staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) if err != nil { return err } @@ -186,6 +192,25 @@ func (u *DeviceListUpdater) Start() error { return nil } +// CleanUp removes stale device entries for users we don't share a room with anymore +func (u *DeviceListUpdater) CleanUp() error { + staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + + res := rsapi.QueryLeftUsersResponse{} + if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil { + return err + } + + if len(res.LeftUsers) == 0 { + return nil + } + logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers)) + return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers) +} + func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { u.mu.Lock() defer u.mu.Unlock() diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index a374c9516..60a2c2f30 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -30,7 +30,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage" + roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" ) var ( @@ -53,6 +58,10 @@ type mockDeviceListUpdaterDatabase struct { mu sync.Mutex // protect staleUsers } +func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error { + return nil +} + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { @@ -153,7 +162,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost") event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -225,7 +234,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -239,6 +248,7 @@ func TestUpdateNoPrevID(t *testing.T) { UserID: remoteUserID, } err := updater.Update(ctx, event) + if err != nil { t.Fatalf("Update returned an error: %s", err) } @@ -294,7 +304,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -349,3 +359,73 @@ func TestDebounce(t *testing.T) { t.Errorf("user %s is marked as stale", userID) } } + +func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { + t.Helper() + + base, _, _ := testrig.Base(nil) + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + if err != nil { + t.Fatal(err) + } + + return db, clearDB +} + +type mockKeyserverRoomserverAPI struct { + leftUsers []string +} + +func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { + res.LeftUsers = m.leftUsers + return nil +} + +func TestDeviceListUpdater_CleanUp(t *testing.T) { + processCtx := process.NewProcessContext() + + alice := test.NewUser(t) + bob := test.NewUser(t) + + // Bob is not joined to any of our rooms + rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}} + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clearDB := mustCreateKeyserverDB(t, dbType) + defer clearDB() + + // This should not get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil { + t.Error(err) + } + + // this one should get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil { + t.Error(err) + } + + updater := NewDeviceListUpdater(processCtx, db, nil, + nil, nil, + 0, rsAPI, "test") + if err := updater.CleanUp(); err != nil { + t.Error(err) + } + + // check that we still have Alice in our stale list + staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Error(err) + } + + // There should only be Alice + wantCount := 1 + if count := len(staleUsers); count != wantCount { + t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count) + } + + if staleUsers[0] != alice.ID { + t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID) + } + }) +} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 5360c06fd..275576773 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -18,6 +18,8 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/consumers" @@ -40,6 +42,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI, enableMetr // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, + rsAPI rsapi.KeyserverRoomserverAPI, ) api.KeyInternalAPI { js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) @@ -47,6 +50,7 @@ func NewInternalAPI( if err != nil { logrus.WithError(err).Panicf("failed to connect to key server database") } + keyChangeProducer := &producers.KeyChange{ Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)), JetStream: js, @@ -58,8 +62,14 @@ func NewInternalAPI( FedClient: fedClient, Producer: keyChangeProducer, } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable ap.Updater = updater + + // Remove users which we don't share a room with anymore + if err := updater.CleanUp(); err != nil { + logrus.WithError(err).Error("failed to cleanup stale device lists") + } + go func() { if err := updater.Start(); err != nil { logrus.WithError(err).Panicf("failed to start device list updater") diff --git a/keyserver/keyserver_test.go b/keyserver/keyserver_test.go new file mode 100644 index 000000000..159b280f5 --- /dev/null +++ b/keyserver/keyserver_test.go @@ -0,0 +1,29 @@ +package keyserver + +import ( + "context" + "testing" + + roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +type mockKeyserverRoomserverAPI struct { + leftUsers []string +} + +func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { + res.LeftUsers = m.leftUsers + return nil +} + +// Merely tests that we can create an internal keyserver API +func Test_NewInternalAPI(t *testing.T) { + rsAPI := &mockKeyserverRoomserverAPI{} + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, closeBase := testrig.CreateBaseDendrite(t, dbType) + defer closeBase() + _ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + }) +} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 242e16a06..c6a8f44cd 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -85,4 +85,9 @@ type Database interface { StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + + DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, + ) error } diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go index d0fe50d00..248ddfb45 100644 --- a/keyserver/storage/postgres/stale_device_lists.go +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -19,6 +19,10 @@ import ( "database/sql" "time" + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" @@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)" + type staleDeviceListsStatements struct { upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt + deleteStaleDeviceListsStmt *sql.Stmt } func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { @@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, + }.Prepare(db) } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { @@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt) + _, err := stmt.ExecContext(ctx, pq.Array(userIDs)) + return err +} + func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for rows.Next() { diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5beeed0f1..54dd6ddc9 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -249,3 +249,13 @@ func (d *Database) StoreCrossSigningSigsForTarget( return nil }) } + +// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. +func (d *Database) DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) + }) +} diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go index 1e08b266c..fd76a6e3b 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -17,8 +17,11 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" @@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)" + type staleDeviceListsStatements struct { db *sql.DB upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt + // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime } func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { @@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime + }.Prepare(db) } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { @@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + stmt, err := s.db.Prepare(qry) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed") + stmt = sqlutil.TxStmt(txn, stmt) + + params := make([]any, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for rows.Next() { diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 37a010a7c..24da1125e 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -56,6 +56,7 @@ type KeyChanges interface { type StaleDeviceLists interface { InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error } type CrossSigningKeys interface { diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/keyserver/storage/tables/stale_device_lists_test.go new file mode 100644 index 000000000..76d3baddd --- /dev/null +++ b/keyserver/storage/tables/stale_device_lists_test.go @@ -0,0 +1,94 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/config" + + "github.com/matrix-org/dendrite/keyserver/storage/postgres" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/test" +) + +func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, nil) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresStaleDeviceListsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db) + } + if err != nil { + t.Fatalf("failed to create new table: %s", err) + } + return tab, close +} + +func TestStaleDeviceLists(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := "@charlie:localhost" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateTable(t, dbType) + defer closeDB() + + if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + + // Query one server + wantStaleUsers := []string{alice.ID, bob.ID} + gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Query all servers + wantStaleUsers = []string{alice.ID, bob.ID, charlie} + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Delete stale devices + deleteUsers := []string{alice.ID, bob.ID} + if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil { + t.Fatalf("failed to delete stale device lists: %s", err) + } + + // Verify we don't get anything back after deleting + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + + if gotCount := len(gotStaleUsers); gotCount > 0 { + t.Fatalf("expected no stale users, got %d", gotCount) + } + }) +} diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 01e87ec8a..420ef278a 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -17,6 +17,7 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI + KeyserverRoomserverAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -199,3 +200,7 @@ type FederationRoomserverAPI interface { // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error } + +type KeyserverRoomserverAPI interface { + QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error +} diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 342a3904c..b23263d17 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -19,6 +19,12 @@ type RoomserverInternalAPITrace struct { Impl RoomserverInternalAPI } +func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error { + err := t.Impl.QueryLeftUsers(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryLeftUsers req=%+v res=%+v", js(req), js(res)) + return err +} + func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { t.Impl.SetFederationAPI(fsAPI, keyRing) } diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b62907f3c..76f8298ca 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -447,3 +447,15 @@ type QueryMembershipAtEventResponse struct { // do not have known state will return an empty array here. Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` } + +// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a +// a room with anymore. This is used to cleanup stale device list entries, where we would +// otherwise keep on trying to get device lists. +type QueryLeftUsersRequest struct { + StaleDeviceListUsers []string `json:"user_ids"` +} + +// QueryLeftUsersResponse is the response to QueryLeftUsersRequest. +type QueryLeftUsersResponse struct { + LeftUsers []string `json:"user_ids"` +} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index d8456fb43..69d841dda 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -805,6 +805,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS return nil } +func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error { + var err error + res.LeftUsers, err = r.DB.GetLeftUsers(ctx, req.StaleDeviceListUsers) + return err +} + func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") if err != nil { diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 1bd1b3fb7..8a2e0a03c 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -63,6 +63,7 @@ const ( RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent" + RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers" ) type httpRoomserverInternalAPI struct { @@ -553,3 +554,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, h.httpClient, ctx, request, response, ) } + +func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error { + return httputil.CallInternalRPCAPI( + "RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath, + h.httpClient, ctx, request, response, + ) +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 6e7c2d985..4d21909b7 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -203,4 +203,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe RoomserverQueryMembershipAtEventPath, httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent), ) + + internalAPIMux.Handle( + RoomserverQueryLeftMembersPath, + httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", enableMetrics, r.QueryLeftUsers), + ) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 24b5515e5..518bb3722 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -2,20 +2,27 @@ package roomserver_test import ( "context" + "net/http" "testing" + "time" + "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/gomatrixserverlib" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { + t.Helper() base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches) + db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches) if err != nil { t.Fatalf("failed to create Database: %v", err) } @@ -67,3 +74,69 @@ func Test_SharedUsers(t *testing.T) { } }) } + +func Test_QueryLeftUsers(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite and join Bob + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, _, close := mustCreateDatabase(t, dbType) + defer close() + + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // Query the left users, there should only be "@idontexist:test", + // as Alice and Bob are still joined. + res := &api.QueryLeftUsersResponse{} + leftUserID := "@idontexist:test" + getLeftUsersList := []string{alice.ID, bob.ID, leftUserID} + + testCase := func(rsAPI api.RoomserverInternalAPI) { + if err := rsAPI.QueryLeftUsers(ctx, &api.QueryLeftUsersRequest{StaleDeviceListUsers: getLeftUsersList}, res); err != nil { + t.Fatalf("unable to query left users: %v", err) + } + wantCount := 1 + if count := len(res.LeftUsers); count > wantCount { + t.Fatalf("unexpected left users count: want %d, got %d", wantCount, count) + } + if res.LeftUsers[0] != leftUserID { + t.Fatalf("unexpected left users : want %s, got %s", leftUserID, res.LeftUsers[0]) + } + } + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + roomserver.AddInternalRoutes(router, rsAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + httpAPI, err := inthttp.NewRoomserverClient(apiURL, &http.Client{Timeout: time.Second * 5}, nil) + if err != nil { + t.Fatalf("failed to create HTTP client") + } + testCase(httpAPI) + }) + t.Run("Monolith", func(t *testing.T) { + testCase(rsAPI) + // also test tracing + traceAPI := &api.RoomserverInternalAPITrace{Impl: rsAPI} + testCase(traceAPI) + }) + + }) +} diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 06db4b2d8..92bc2e66f 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -172,5 +172,6 @@ type Database interface { ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) + GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 0150534e1..d774b7892 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -21,12 +21,13 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const membershipSchema = ` @@ -157,6 +158,12 @@ const selectServerInRoomSQL = "" + " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" +const selectJoinedUsersSQL = ` +SELECT DISTINCT target_nid +FROM roomserver_membership m +WHERE membership_nid > $1 AND target_nid = ANY($2) +` + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -174,6 +181,7 @@ type membershipStatements struct { selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt deleteMembershipStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt } func CreateMembershipTable(db *sql.DB) error { @@ -209,9 +217,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL}, {&s.deleteMembershipStmt, deleteMembershipSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, }.Prepare(db) } +func (s *membershipStatements) SelectJoinedUsers( + ctx context.Context, txn *sql.Tx, + targetUserNIDs []types.EventStateKeyNID, +) ([]types.EventStateKeyNID, error) { + result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs)) + + stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt) + rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed") + var targetNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&targetNID); err != nil { + return nil, err + } + result = append(result, targetNID) + } + + return result, rows.Err() +} + func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 16898bcb1..725cc5bc7 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1365,6 +1365,43 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ return result, nil } +// GetLeftUsers calculates users we (the server) don't share a room with anymore. +func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) { + // Get the userNID for all users with a stale device list + stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs) + if err != nil { + return nil, err + } + + userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)) + userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap)) + // Create a map from userNID -> userID + for userID, nid := range stateKeyNIDMap { + userNIDs = append(userNIDs, nid) + userNIDtoUserID[nid] = userID + } + + // Get all users whose membership is still join, knock or invite. + stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs) + if err != nil { + return nil, err + } + + // Remove joined users from the "user with stale devices" list, which contains left AND joined users + for _, joinedUser := range stillJoinedUsersNIDs { + delete(userNIDtoUserID, joinedUser) + } + + // The users still in our userNIDtoUserID map are the users we don't share a room with anymore, + // and the return value we are looking for. + leftUsers := make([]string, 0, len(userNIDtoUserID)) + for _, userID := range userNIDtoUserID { + leftUsers = append(leftUsers, userID) + } + + return leftUsers, nil +} + // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID) diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go new file mode 100644 index 000000000..58724340c --- /dev/null +++ b/roomserver/storage/shared/storage_test.go @@ -0,0 +1,96 @@ +package shared_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Database, func()) { + t.Helper() + + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + base, _, _ := testrig.Base(nil) + dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)} + + db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + + var membershipTable tables.Membership + var stateKeyTable tables.EventStateKeys + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateEventStateKeysTable(db) + assert.NoError(t, err) + err = postgres.CreateMembershipTable(db) + assert.NoError(t, err) + membershipTable, err = postgres.PrepareMembershipTable(db) + assert.NoError(t, err) + stateKeyTable, err = postgres.PrepareEventStateKeysTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateEventStateKeysTable(db) + assert.NoError(t, err) + err = sqlite3.CreateMembershipTable(db) + assert.NoError(t, err) + membershipTable, err = sqlite3.PrepareMembershipTable(db) + assert.NoError(t, err) + stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db) + } + assert.NoError(t, err) + + return &shared.Database{ + DB: db, + EventStateKeysTable: stateKeyTable, + MembershipTable: membershipTable, + Writer: sqlutil.NewExclusiveWriter(), + }, func() { + err := base.Close() + assert.NoError(t, err) + clearDB() + err = db.Close() + assert.NoError(t, err) + } +} + +func Test_GetLeftUsers(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + // Create dummy entries + for _, user := range []*test.User{alice, bob, charlie} { + nid, err := db.EventStateKeysTable.InsertEventStateKeyNID(ctx, nil, user.ID) + assert.NoError(t, err) + err = db.MembershipTable.InsertMembership(ctx, nil, 1, nid, true) + assert.NoError(t, err) + // We must update the membership with a non-zero event NID or it will get filtered out in later queries + membershipNID := tables.MembershipStateLeaveOrBan + if user == alice { + membershipNID = tables.MembershipStateJoin + } + _, err = db.MembershipTable.UpdateMembership(ctx, nil, 1, nid, nid, membershipNID, 1, false) + assert.NoError(t, err) + } + + // Now try to get the left users, this should be Bob and Charlie, since they have a "leave" membership + expectedUserIDs := []string{bob.ID, charlie.ID} + leftUsers, err := db.GetLeftUsers(context.Background(), []string{alice.ID, bob.ID, charlie.ID}) + assert.NoError(t, err) + assert.ElementsMatch(t, expectedUserIDs, leftUsers) + }) +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index cd149f0ed..8a60b359f 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -21,12 +21,13 @@ import ( "fmt" "strings" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const membershipSchema = ` @@ -133,6 +134,12 @@ const selectServerInRoomSQL = "" + const deleteMembershipSQL = "" + "DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2" +const selectJoinedUsersSQL = ` +SELECT DISTINCT target_nid +FROM roomserver_membership m +WHERE membership_nid > $1 AND target_nid IN ($2) +` + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -149,6 +156,7 @@ type membershipStatements struct { selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt deleteMembershipStmt *sql.Stmt + // selectJoinedUsersStmt *sql.Stmt // Prepared at runtime } func CreateMembershipTable(db *sql.DB) error { @@ -412,3 +420,40 @@ func (s *membershipStatements) DeleteMembership( ) return err } + +func (s *membershipStatements) SelectJoinedUsers( + ctx context.Context, txn *sql.Tx, + targetUserNIDs []types.EventStateKeyNID, +) ([]types.EventStateKeyNID, error) { + result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs)) + + qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1) + + stmt, err := s.db.Prepare(qry) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed") + + params := make([]any, len(targetUserNIDs)+1) + params[0] = tables.MembershipStateLeaveOrBan + for i := range targetUserNIDs { + params[i+1] = targetUserNIDs[i] + } + + stmt = sqlutil.TxStmt(txn, stmt) + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed") + var targetNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&targetNID); err != nil { + return nil, err + } + result = append(result, targetNID) + } + + return result, rows.Err() +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 50d27c756..80fcf72dd 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -144,6 +144,7 @@ type Membership interface { SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error + SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error) } type Published interface { diff --git a/roomserver/storage/tables/membership_table_test.go b/roomserver/storage/tables/membership_table_test.go index c9541d9d2..c4524ee44 100644 --- a/roomserver/storage/tables/membership_table_test.go +++ b/roomserver/storage/tables/membership_table_test.go @@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) { knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2) assert.NoError(t, err) assert.Equal(t, 1, len(knownUsers)) + + // get users we share a room with, given their userNID + joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs) + assert.NoError(t, err) + // Only userNIDs[0] is actually joined, so we only expect this userNID + assert.Equal(t, userNIDs[:1], joinedUsers) }) } From 76db8e90defdfb9e61f6caea8a312c5d60bcc005 Mon Sep 17 00:00:00 2001 From: Kento Okamoto Date: Mon, 12 Dec 2022 08:46:37 -0800 Subject: [PATCH 10/15] Dendrite Documentation Fix (#2913) ### Pull Request Checklist I was reading through the Dendrite documentation on https://matrix-org.github.io/dendrite/development/contributing and noticed the installation link leads to a 404 error. This link works fine if it is viewed directly from [docs/CONTRIBUTING.md](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md) but this might not be very obvious to new contributors who are reading through the [contribution page](https://matrix-org.github.io/dendrite/development/contributing) directly. This PR is mainly a small re-organization of the online documentation mainly in the [Development](https://matrix-org.github.io/dendrite/development) tab along with any links throughout the doc that may be impacted by the change. This does not contain any Go unit tests as this does not actually touch core dendrite functionality. * [ ] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Kento Okamoto ` --- docs/FAQ.md | 4 ++-- docs/{ => development}/CONTRIBUTING.md | 6 +++--- docs/{ => development}/PROFILING.md | 0 docs/{ => development}/coverage.md | 0 docs/{ => development}/sytest.md | 0 docs/{ => development}/tracing/opentracing.md | 0 docs/{ => development}/tracing/setup.md | 0 7 files changed, 5 insertions(+), 5 deletions(-) rename docs/{ => development}/CONTRIBUTING.md (97%) rename docs/{ => development}/PROFILING.md (100%) rename docs/{ => development}/coverage.md (100%) rename docs/{ => development}/sytest.md (100%) rename docs/{ => development}/tracing/opentracing.md (100%) rename docs/{ => development}/tracing/setup.md (100%) diff --git a/docs/FAQ.md b/docs/FAQ.md index ca72b151d..816130515 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -91,7 +91,7 @@ Please use PostgreSQL wherever possible, especially if you are planning to run a ## Dendrite is using a lot of CPU Generally speaking, you should expect to see some CPU spikes, particularly if you are joining or participating in large rooms. However, constant/sustained high CPU usage is not expected - if you are experiencing that, please join `#dendrite-dev:matrix.org` and let us know what you were doing when the -CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](PROFILING.md) then that would +CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](development/PROFILING.md) then that would be a huge help too, as that will help us to understand where the CPU time is going. ## Dendrite is using a lot of RAM @@ -99,7 +99,7 @@ be a huge help too, as that will help us to understand where the CPU time is goi As above with CPU usage, some memory spikes are expected if Dendrite is doing particularly heavy work at a given instant. However, if it is using more RAM than you expect for a long time, that's probably not expected. Join `#dendrite-dev:matrix.org` and let us know what you were doing when the memory usage -ballooned, or file a GitHub issue if you can. If you can take a [memory profile](PROFILING.md) then that +ballooned, or file a GitHub issue if you can. If you can take a [memory profile](development/PROFILING.md) then that would be a huge help too, as that will help us to understand where the memory usage is happening. ## Dendrite is running out of PostgreSQL database connections diff --git a/docs/CONTRIBUTING.md b/docs/development/CONTRIBUTING.md similarity index 97% rename from docs/CONTRIBUTING.md rename to docs/development/CONTRIBUTING.md index 21b0e7ab1..2aec4c363 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/development/CONTRIBUTING.md @@ -9,7 +9,7 @@ permalink: /development/contributing Everyone is welcome to contribute to Dendrite! We aim to make it as easy as possible to get started. - ## Contribution types +## Contribution types We are a small team maintaining a large project. As a result, we cannot merge every feature, even if it is bug-free and useful, because we then commit to maintaining it indefinitely. We will always accept: @@ -57,7 +57,7 @@ to do so for future contributions. ## Getting up and running -See the [Installation](installation) section for information on how to build an +See the [Installation](../installation) section for information on how to build an instance of Dendrite. You will likely need this in order to test your changes. ## Code style @@ -151,7 +151,7 @@ significant amount of CPU and RAM. Once the code builds, run [Sytest](https://github.com/matrix-org/sytest) according to the guide in -[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/sytest.md#using-a-sytest-docker-image) +[docs/development/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/development/sytest.md#using-a-sytest-docker-image) so you can see whether something is being broken and whether there are newly passing tests. diff --git a/docs/PROFILING.md b/docs/development/PROFILING.md similarity index 100% rename from docs/PROFILING.md rename to docs/development/PROFILING.md diff --git a/docs/coverage.md b/docs/development/coverage.md similarity index 100% rename from docs/coverage.md rename to docs/development/coverage.md diff --git a/docs/sytest.md b/docs/development/sytest.md similarity index 100% rename from docs/sytest.md rename to docs/development/sytest.md diff --git a/docs/tracing/opentracing.md b/docs/development/tracing/opentracing.md similarity index 100% rename from docs/tracing/opentracing.md rename to docs/development/tracing/opentracing.md diff --git a/docs/tracing/setup.md b/docs/development/tracing/setup.md similarity index 100% rename from docs/tracing/setup.md rename to docs/development/tracing/setup.md From d3db542fbf5b35377586567b21bc5c28872167a1 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 10:56:20 +0100 Subject: [PATCH 11/15] Add federation peeking table tests (#2920) As the title says, adds tests for inbound/outbound peeking federation table tests. Also removes some unused code --- federationapi/queue/queue_test.go | 22 --- federationapi/storage/interface.go | 3 - .../storage/postgres/inbound_peeks_table.go | 32 ++-- .../storage/postgres/outbound_peeks_table.go | 31 ++-- .../storage/postgres/queue_edus_table.go | 21 --- .../storage/postgres/queue_pdus_table.go | 23 --- federationapi/storage/shared/storage_edus.go | 9 - federationapi/storage/shared/storage_pdus.go | 9 - .../storage/sqlite3/inbound_peeks_table.go | 32 ++-- .../storage/sqlite3/outbound_peeks_table.go | 31 ++-- .../storage/sqlite3/queue_edus_table.go | 21 --- .../storage/sqlite3/queue_pdus_table.go | 23 --- federationapi/storage/storage_test.go | 166 ++++++++++++++++++ .../tables/inbound_peeks_table_test.go | 148 ++++++++++++++++ federationapi/storage/tables/interface.go | 2 - .../tables/outbound_peeks_table_test.go | 147 ++++++++++++++++ 16 files changed, 503 insertions(+), 217 deletions(-) create mode 100644 federationapi/storage/tables/inbound_peeks_table_test.go create mode 100644 federationapi/storage/tables/outbound_peeks_table_test.go diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index b2ec4b836..c317edc21 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -221,28 +221,6 @@ func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverl return nil } -func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var count int64 - if pdus, ok := d.associatedPDUs[serverName]; ok { - count = int64(len(pdus)) - } - return count, nil -} - -func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var count int64 - if edus, ok := d.associatedEDUs[serverName]; ok { - count = int64(len(edus)) - } - return count, nil -} - func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index b15b8bfae..276cd9a50 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -45,9 +45,6 @@ type Database interface { CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error - GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) diff --git a/federationapi/storage/postgres/inbound_peeks_table.go b/federationapi/storage/postgres/inbound_peeks_table.go index df5c60761..ad2afcb15 100644 --- a/federationapi/storage/postgres/inbound_peeks_table.go +++ b/federationapi/storage/postgres/inbound_peeks_table.go @@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectInboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER by creation_ts" const renewInboundPeekSQL = "" + "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteInboundPeekSQL = "" + - "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteInboundPeeksSQL = "" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" @@ -74,25 +74,15 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er return } - if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { - return - } - if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil { - return - } - if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil { - return - } - if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertInboundPeekStmt, insertInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeeksStmt, selectInboundPeeksSQL}, + {&s.renewInboundPeekStmt, renewInboundPeekSQL}, + {&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL}, + {&s.deleteInboundPeekStmt, deleteInboundPeekSQL}, + }.Prepare(db) } func (s *inboundPeeksStatements) InsertInboundPeek( diff --git a/federationapi/storage/postgres/outbound_peeks_table.go b/federationapi/storage/postgres/outbound_peeks_table.go index c22d893f7..5df684318 100644 --- a/federationapi/storage/postgres/outbound_peeks_table.go +++ b/federationapi/storage/postgres/outbound_peeks_table.go @@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewOutboundPeekSQL = "" + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteOutboundPeeksSQL = "" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" @@ -74,25 +74,14 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err return } - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertOutboundPeekStmt, insertOutboundPeekSQL}, + {&s.selectOutboundPeekStmt, selectOutboundPeekSQL}, + {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL}, + {&s.renewOutboundPeekStmt, renewOutboundPeekSQL}, + {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL}, + {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL}, + }.Prepare(db) } func (s *outboundPeeksStatements) InsertOutboundPeek( diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go index d6507e13b..8870dc88d 100644 --- a/federationapi/storage/postgres/queue_edus_table.go +++ b/federationapi/storage/postgres/queue_edus_table.go @@ -62,10 +62,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_edus" + " WHERE json_nid = $1" -const selectQueueEDUCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_edus" + - " WHERE server_name = $1" - const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" @@ -81,7 +77,6 @@ type queueEDUsStatements struct { deleteQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt - selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt selectExpiredEDUsStmt *sql.Stmt deleteExpiredEDUsStmt *sql.Stmt @@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error { {&s.deleteQueueEDUStmt, deleteQueueEDUSQL}, {&s.selectQueueEDUStmt, selectQueueEDUSQL}, {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, - {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, @@ -186,21 +180,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( return count, err } -func (s *queueEDUsStatements) SelectQueueEDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { diff --git a/federationapi/storage/postgres/queue_pdus_table.go b/federationapi/storage/postgres/queue_pdus_table.go index 38ac5a6eb..3b0bef9af 100644 --- a/federationapi/storage/postgres/queue_pdus_table.go +++ b/federationapi/storage/postgres/queue_pdus_table.go @@ -58,10 +58,6 @@ const selectQueuePDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" -const selectQueuePDUsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_pdus" + - " WHERE server_name = $1" - const selectQueuePDUServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" @@ -71,7 +67,6 @@ type queuePDUsStatements struct { deleteQueuePDUsStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt selectQueuePDUReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt selectQueuePDUServerNamesStmt *sql.Stmt } @@ -95,9 +90,6 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil { - return - } if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil { return } @@ -146,21 +138,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) SelectQueuePDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index c796d2f8f..be8355f31 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -162,15 +162,6 @@ func (d *Database) CleanEDUs( }) } -// GetPendingEDUCount returns the number of EDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingEDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName) -} - // GetPendingServerNames returns the server names that have EDUs // waiting to be sent. func (d *Database) GetPendingEDUServerNames( diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index dc37d7507..da4cb979d 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -141,15 +141,6 @@ func (d *Database) CleanPDUs( }) } -// GetPendingPDUCount returns the number of PDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingPDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName) -} - // GetPendingServerNames returns the server names that have PDUs // waiting to be sent. func (d *Database) GetPendingPDUServerNames( diff --git a/federationapi/storage/sqlite3/inbound_peeks_table.go b/federationapi/storage/sqlite3/inbound_peeks_table.go index ad3c4a6dd..8c3567934 100644 --- a/federationapi/storage/sqlite3/inbound_peeks_table.go +++ b/federationapi/storage/sqlite3/inbound_peeks_table.go @@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectInboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewInboundPeekSQL = "" + "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteInboundPeekSQL = "" + - "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteInboundPeeksSQL = "" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" @@ -74,25 +74,15 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro return } - if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { - return - } - if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil { - return - } - if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil { - return - } - if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertInboundPeekStmt, insertInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeeksStmt, selectInboundPeeksSQL}, + {&s.renewInboundPeekStmt, renewInboundPeekSQL}, + {&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL}, + {&s.deleteInboundPeekStmt, deleteInboundPeekSQL}, + }.Prepare(db) } func (s *inboundPeeksStatements) InsertInboundPeek( diff --git a/federationapi/storage/sqlite3/outbound_peeks_table.go b/federationapi/storage/sqlite3/outbound_peeks_table.go index e29026fab..33f452b68 100644 --- a/federationapi/storage/sqlite3/outbound_peeks_table.go +++ b/federationapi/storage/sqlite3/outbound_peeks_table.go @@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewOutboundPeekSQL = "" + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteOutboundPeeksSQL = "" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" @@ -74,25 +74,14 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er return } - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertOutboundPeekStmt, insertOutboundPeekSQL}, + {&s.selectOutboundPeekStmt, selectOutboundPeekSQL}, + {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL}, + {&s.renewOutboundPeekStmt, renewOutboundPeekSQL}, + {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL}, + {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL}, + }.Prepare(db) } func (s *outboundPeeksStatements) InsertOutboundPeek( diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go index 8e7e7901f..0dc914328 100644 --- a/federationapi/storage/sqlite3/queue_edus_table.go +++ b/federationapi/storage/sqlite3/queue_edus_table.go @@ -63,10 +63,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_edus" + " WHERE json_nid = $1" -const selectQueueEDUCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_edus" + - " WHERE server_name = $1" - const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" @@ -82,7 +78,6 @@ type queueEDUsStatements struct { // deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt - selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt selectExpiredEDUsStmt *sql.Stmt deleteExpiredEDUsStmt *sql.Stmt @@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error { {&s.insertQueueEDUStmt, insertQueueEDUSQL}, {&s.selectQueueEDUStmt, selectQueueEDUSQL}, {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, - {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, @@ -198,21 +192,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( return count, err } -func (s *queueEDUsStatements) SelectQueueEDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { diff --git a/federationapi/storage/sqlite3/queue_pdus_table.go b/federationapi/storage/sqlite3/queue_pdus_table.go index e818585a5..aee8b03d6 100644 --- a/federationapi/storage/sqlite3/queue_pdus_table.go +++ b/federationapi/storage/sqlite3/queue_pdus_table.go @@ -66,10 +66,6 @@ const selectQueuePDUsReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" -const selectQueuePDUsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_pdus" + - " WHERE server_name = $1" - const selectQueuePDUsServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" @@ -79,7 +75,6 @@ type queuePDUsStatements struct { selectQueueNextTransactionIDStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt selectQueueServerNamesStmt *sql.Stmt // deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic } @@ -107,9 +102,6 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { - return - } if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil { return } @@ -179,21 +171,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) SelectQueuePDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index f7408fa9f..14efa2655 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -2,10 +2,12 @@ package storage_test import ( "context" + "reflect" "testing" "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/federationapi/storage" @@ -80,3 +82,167 @@ func TestExpireEDUs(t *testing.T) { assert.Equal(t, 2, len(data)) }) } + +func TestOutboundPeeking(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + + // Add outbound peek + if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if outboundPeek1.PeekID != peekID { + t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) + } + if outboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) + } + if outboundPeek1.ServerName != serverName { + t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) + } + if outboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) + } + // Renew the peek + if err = db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(outboundPeek1, outboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if outboundPeek1.ServerName != outboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) + } + if outboundPeek1.RoomID != outboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) + } + + // insert some peeks + peekIDs := []string{peekID} + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) + } + for i := range outboundPeeks { + if outboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i]) + } + } + }) +} + +func TestInboundPeeking(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + + // Add inbound peek + if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if inboundPeek1.PeekID != peekID { + t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID) + } + if inboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID) + } + if inboundPeek1.ServerName != serverName { + t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName) + } + if inboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval) + } + // Renew the peek + if err = db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(inboundPeek1, inboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if inboundPeek1.ServerName != inboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName) + } + if inboundPeek1.RoomID != inboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID) + } + + // insert some peeks + peekIDs := []string{peekID} + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) + } + for i := range inboundPeeks { + if inboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i]) + } + } + }) +} diff --git a/federationapi/storage/tables/inbound_peeks_table_test.go b/federationapi/storage/tables/inbound_peeks_table_test.go new file mode 100644 index 000000000..3a76a8576 --- /dev/null +++ b/federationapi/storage/tables/inbound_peeks_table_test.go @@ -0,0 +1,148 @@ +package tables_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + var tab tables.FederationInboundPeeks + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresInboundPeeksTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteInboundPeeksTable(db) + } + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + return tab, close +} + +func TestInboundPeeksTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateInboundpeeksTable(t, dbType) + defer closeDB() + + // Insert a peek + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + if err := tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + inboundPeek1, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if inboundPeek1.PeekID != peekID { + t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID) + } + if inboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID) + } + if inboundPeek1.ServerName != serverName { + t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName) + } + if inboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval) + } + + // Renew the peek + if err = tab.RenewInboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + inboundPeek2, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(inboundPeek1, inboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if inboundPeek1.ServerName != inboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName) + } + if inboundPeek1.RoomID != inboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID) + } + + // delete the peek + if err = tab.DeleteInboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil { + t.Fatal(err) + } + + // There should be no peek anymore + peek, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if peek != nil { + t.Fatalf("got a peek which should be deleted: %+v", peek) + } + + // insert some peeks + var peekIDs []string + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + inboundPeeks, err := tab.SelectInboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) + } + for i := range inboundPeeks { + if inboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("") + } + } + + // And delete them again + if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil { + t.Fatal(err) + } + + // they should be gone now + inboundPeeks, err = tab.SelectInboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) > 0 { + t.Fatal("got inbound peeks which should be deleted") + } + + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 9f4e86a6e..2b36edb46 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -28,7 +28,6 @@ type FederationQueuePDUs interface { InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) } @@ -38,7 +37,6 @@ type FederationQueueEDUs interface { DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error) DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error diff --git a/federationapi/storage/tables/outbound_peeks_table_test.go b/federationapi/storage/tables/outbound_peeks_table_test.go new file mode 100644 index 000000000..dad6b9825 --- /dev/null +++ b/federationapi/storage/tables/outbound_peeks_table_test.go @@ -0,0 +1,147 @@ +package tables_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + var tab tables.FederationOutboundPeeks + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresOutboundPeeksTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteOutboundPeeksTable(db) + } + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + return tab, close +} + +func TestOutboundPeeksTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateOutboundpeeksTable(t, dbType) + defer closeDB() + + // Insert a peek + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + outboundPeek1, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if outboundPeek1.PeekID != peekID { + t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) + } + if outboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) + } + if outboundPeek1.ServerName != serverName { + t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) + } + if outboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) + } + + // Renew the peek + if err = tab.RenewOutboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + outboundPeek2, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(outboundPeek1, outboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if outboundPeek1.ServerName != outboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) + } + if outboundPeek1.RoomID != outboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) + } + + // delete the peek + if err = tab.DeleteOutboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil { + t.Fatal(err) + } + + // There should be no peek anymore + peek, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if peek != nil { + t.Fatalf("got a peek which should be deleted: %+v", peek) + } + + // insert some peeks + var peekIDs []string + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + outboundPeeks, err := tab.SelectOutboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) + } + for i := range outboundPeeks { + if outboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("") + } + } + + // And delete them again + if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil { + t.Fatal(err) + } + + // they should be gone now + outboundPeeks, err = tab.SelectOutboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) > 0 { + t.Fatal("got outbound peeks which should be deleted") + } + }) +} From beea2432e6144a98370138f8d3f6334c19a044bb Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 11:31:54 +0100 Subject: [PATCH 12/15] Fix flakey test --- federationapi/storage/tables/inbound_peeks_table_test.go | 9 +++++---- .../storage/tables/outbound_peeks_table_test.go | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/federationapi/storage/tables/inbound_peeks_table_test.go b/federationapi/storage/tables/inbound_peeks_table_test.go index 3a76a8576..e5d898b3a 100644 --- a/federationapi/storage/tables/inbound_peeks_table_test.go +++ b/federationapi/storage/tables/inbound_peeks_table_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) { @@ -124,11 +125,11 @@ func TestInboundPeeksTable(t *testing.T) { if len(inboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) } - for i := range inboundPeeks { - if inboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("") - } + gotPeekIDs := make([]string, 0, len(inboundPeeks)) + for _, p := range inboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) // And delete them again if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil { diff --git a/federationapi/storage/tables/outbound_peeks_table_test.go b/federationapi/storage/tables/outbound_peeks_table_test.go index dad6b9825..a460af09d 100644 --- a/federationapi/storage/tables/outbound_peeks_table_test.go +++ b/federationapi/storage/tables/outbound_peeks_table_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) { @@ -124,11 +125,11 @@ func TestOutboundPeeksTable(t *testing.T) { if len(outboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) } - for i := range outboundPeeks { - if outboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("") - } + gotPeekIDs := make([]string, 0, len(outboundPeeks)) + for _, p := range outboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) // And delete them again if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil { From d1d2d16738a248846ea4367fe2b33485d56db6cd Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 11:54:03 +0100 Subject: [PATCH 13/15] Fix reset password endpoint (#2921) Fixes the admin password reset endpoint. It was using a wrong variable, so could not detect the user. Adds some more checks to validate we can actually change the password. --- clientapi/admin_test.go | 134 ++++++++++++++++++++++++++++++ clientapi/clientapi.go | 5 +- clientapi/routing/admin.go | 38 +++++++-- clientapi/routing/password.go | 3 +- clientapi/routing/register.go | 24 +----- clientapi/routing/routing.go | 14 +++- docs/administration/4_adminapi.md | 7 +- internal/httputil/httpapi.go | 6 +- internal/validate.go | 44 ++++++++++ test/user.go | 4 +- 10 files changed, 237 insertions(+), 42 deletions(-) create mode 100644 clientapi/admin_test.go create mode 100644 internal/validate.go diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go new file mode 100644 index 000000000..0d973f350 --- /dev/null +++ b/clientapi/admin_test.go @@ -0,0 +1,134 @@ +package clientapi + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/tidwall/gjson" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestAdminResetPassword(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + vhUser := &test.User{ID: "@vhuser:vh1"} + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + // add a vhost + base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, + }) + + rsAPI := roomserver.NewInternalAPI(base) + // Needed for changing the password/login + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil) + + // Create the users in the userapi and login + accessTokens := map[*test.User]string{ + aliceAdmin: "", + bob: "", + vhUser: "", + } + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + } + + testCases := []struct { + name string + requestingUser *test.User + userID string + requestOpt test.HTTPRequestOpt + wantOK bool + withHeader bool + }{ + {name: "Missing auth", requestingUser: bob, wantOK: false, userID: bob.ID}, + {name: "Bob is denied access", requestingUser: bob, wantOK: false, withHeader: true, userID: bob.ID}, + {name: "Alice is allowed access", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(8), + })}, + {name: "missing userID does not call function", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: ""}, // this 404s + {name: "rejects empty password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": "", + })}, + {name: "rejects unknown server name", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:localhost", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "rejects unknown user", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "allows changing password for different vhost", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: vhUser.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(8), + })}, + {name: "rejects existing user, missing body", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID}, + {name: "rejects invalid userID", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "!notauserid:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "rejects invalid json", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, `{invalidJSON}`)}, + {name: "rejects too weak password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(6), + })}, + {name: "rejects too long password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(513), + })}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID) + if tc.requestOpt != nil { + req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID, tc.requestOpt) + } + + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser]) + } + + rec := httptest.NewRecorder() + base.DendriteAdminMux.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 080d4d9fa..62ffa6155 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -57,10 +57,7 @@ func AddPublicRoutes( } routing.Setup( - base.PublicClientAPIMux, - base.PublicWellKnownAPIMux, - base.SynapseAdminMux, - base.DendriteAdminMux, + base, cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, syncProducer, transactionsCache, fsAPI, keyAPI, diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index be8073c33..8419622df 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -7,6 +7,7 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -98,20 +99,40 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi } func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { + if req.Body == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Unknown("Missing request body"), + } + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - serverName := cfg.Matrix.ServerName - localpart, ok := vars["localpart"] - if !ok { + var localpart string + userID := vars["userID"] + localpart, serverName, err := cfg.Matrix.SplitLocalID('@', userID) + if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting user localpart."), + JSON: jsonerror.InvalidArgumentValue(err.Error()), } } - if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil { - localpart, serverName = l, s + accAvailableResp := &userapi.QueryAccountAvailabilityResponse{} + if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ + Localpart: localpart, + ServerName: serverName, + }, accAvailableResp); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.InternalAPIError(req.Context(), err), + } + } + if accAvailableResp.Available { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.Unknown("User does not exist"), + } } request := struct { Password string `json:"password"` @@ -128,6 +149,11 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.MissingArgument("Expecting non-empty password."), } } + + if resErr := internal.ValidatePassword(request.Password); resErr != nil { + return *resErr + } + updateReq := &userapi.PerformPasswordUpdateRequest{ Localpart: localpart, ServerName: serverName, diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 9772f669a..cd88b025a 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -7,6 +7,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -81,7 +82,7 @@ func Password( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. - if resErr = validatePassword(r.NewPassword); resErr != nil { + if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil { return *resErr } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 801000f61..4abbcdf9e 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -30,6 +30,7 @@ import ( "sync" "time" + "github.com/matrix-org/dendrite/internal" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" @@ -60,8 +61,6 @@ var ( ) const ( - minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based - maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain sessionIDLength = 24 ) @@ -315,23 +314,6 @@ func validateApplicationServiceUsername(localpart string, domain gomatrixserverl return nil } -// validatePassword returns an error response if the password is invalid -func validatePassword(password string) *util.JSONResponse { - // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if len(password) > maxPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)), - } - } else if len(password) > 0 && len(password) < minPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), - } - } - return nil -} - // validateRecaptcha returns an error response if the captcha response is invalid func validateRecaptcha( cfg *config.ClientAPI, @@ -636,7 +618,7 @@ func Register( return *resErr } } - if resErr := validatePassword(r.Password); resErr != nil { + if resErr := internal.ValidatePassword(r.Password); resErr != nil { return *resErr } @@ -1138,7 +1120,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { return *resErr } - if resErr := validatePassword(ssrr.Password); resErr != nil { + if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil { return *resErr } deviceID := "shared_secret_registration" diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index a510761eb..69b46214c 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -49,7 +50,7 @@ import ( // applied: // nolint: gocyclo func Setup( - publicAPIMux, wkMux, synapseAdminRouter, dendriteAdminRouter *mux.Router, + base *base.BaseDendrite, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, @@ -63,7 +64,14 @@ func Setup( extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, natsClient *nats.Conn, ) { - prometheus.MustRegister(amtRegUsers, sendEventDuration) + publicAPIMux := base.PublicClientAPIMux + wkMux := base.PublicWellKnownAPIMux + synapseAdminRouter := base.SynapseAdminMux + dendriteAdminRouter := base.DendriteAdminMux + + if base.EnableMetrics { + prometheus.MustRegister(amtRegUsers, sendEventDuration) + } rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg) @@ -631,7 +639,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { vars := mux.Vars(req) return AuthFallback(w, req, vars["authType"], cfg) }), diff --git a/docs/administration/4_adminapi.md b/docs/administration/4_adminapi.md index 56e19a8b4..c521cbc90 100644 --- a/docs/administration/4_adminapi.md +++ b/docs/administration/4_adminapi.md @@ -44,7 +44,9 @@ This endpoint will instruct Dendrite to part the given local `userID` in the URL all rooms which they are currently joined. A JSON body will be returned containing the room IDs of all affected rooms. -## POST `/_dendrite/admin/resetPassword/{localpart}` +## POST `/_dendrite/admin/resetPassword/{userID}` + +Reset the password of a local user. Request body format: @@ -54,9 +56,6 @@ Request body format: } ``` -Reset the password of a local user. The `localpart` is the username only, i.e. if -the full user ID is `@alice:domain.com` then the local part is `alice`. - ## GET `/_dendrite/admin/fulltext/reindex` This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately. diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 127d1fac7..383913c60 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -198,7 +198,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // MakeHTMLAPI adds Span metrics to the HTML Handler function // This is used to serve HTML alongside JSON error messages -func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { +func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { span := opentracing.StartSpan(metricsName) defer span.Finish() @@ -211,6 +211,10 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) } } + if !enableMetrics { + return http.HandlerFunc(withSpan) + } + return promhttp.InstrumentHandlerCounter( promauto.NewCounterVec( prometheus.CounterOpts{ diff --git a/internal/validate.go b/internal/validate.go new file mode 100644 index 000000000..fc685ad50 --- /dev/null +++ b/internal/validate.go @@ -0,0 +1,44 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "fmt" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/util" +) + +const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based + +const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + +// ValidatePassword returns an error response if the password is invalid +func ValidatePassword(password string) *util.JSONResponse { + // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + if len(password) > maxPasswordLength { + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)), + } + } else if len(password) > 0 && len(password) < minPasswordLength { + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), + } + } + return nil +} diff --git a/test/user.go b/test/user.go index 692eae351..95a8f83e6 100644 --- a/test/user.go +++ b/test/user.go @@ -47,7 +47,7 @@ var ( type User struct { ID string - accountType api.AccountType + AccountType api.AccountType // key ID and private key of the server who has this user, if known. keyID gomatrixserverlib.KeyID privKey ed25519.PrivateKey @@ -66,7 +66,7 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve func WithAccountType(accountType api.AccountType) UserOpt { return func(u *User) { - u.accountType = accountType + u.AccountType = accountType } } From 09dff951d6be1fee1cc7c6872e98eb27e81fc778 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 13:04:32 +0100 Subject: [PATCH 14/15] More flakey tests --- federationapi/storage/storage_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 14efa2655..5b57d40d4 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -157,11 +157,11 @@ func TestOutboundPeeking(t *testing.T) { if len(outboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) } - for i := range outboundPeeks { - if outboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i]) - } + gotPeekIDs := make([]string, 0, len(outboundPeeks)) + for _, p := range outboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } @@ -239,10 +239,10 @@ func TestInboundPeeking(t *testing.T) { if len(inboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) } - for i := range inboundPeeks { - if inboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i]) - } + gotPeekIDs := make([]string, 0, len(inboundPeeks)) + for _, p := range inboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } From 5eed31fea330f5f0500384c98272b9a75a44fba4 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 13:05:59 +0100 Subject: [PATCH 15/15] Handle guest access [1/2?] (#2872) Needs https://github.com/matrix-org/sytest/pull/1315, as otherwise the membership events aren't persisted yet when hitting `/state` after kicking guest users. Makes the following tests pass: ``` Guest users denied access over federation if guest access prohibited Guest users are kicked from guest_access rooms on revocation of guest_access Guest users are kicked from guest_access rooms on revocation of guest_access over federation ``` Todo (in a follow up PR): - Restrict access to CS API Endpoints as per https://spec.matrix.org/v1.4/client-server-api/#client-behaviour-14 Co-authored-by: kegsay --- clientapi/clientapi.go | 3 +- clientapi/routing/joinroom.go | 10 +- clientapi/routing/joinroom_test.go | 158 ++++++++++++++++++++ roomserver/api/perform.go | 1 + roomserver/internal/api.go | 20 ++- roomserver/internal/input/input.go | 4 + roomserver/internal/input/input_events.go | 105 +++++++++++++ roomserver/internal/perform/perform_join.go | 23 +++ roomserver/roomserver_test.go | 141 +++++++++++++---- setup/config/config_global.go | 2 +- setup/config/config_test.go | 54 +++++++ sytest-blacklist | 5 +- sytest-whitelist | 5 +- test/room.go | 22 ++- userapi/api/api.go | 10 ++ userapi/api/api_trace.go | 6 + userapi/internal/api.go | 5 + userapi/inthttp/client.go | 12 ++ userapi/inthttp/server.go | 5 + userapi/userapi_test.go | 61 ++++++++ 20 files changed, 607 insertions(+), 45 deletions(-) create mode 100644 clientapi/routing/joinroom_test.go diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 62ffa6155..2d17e0928 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,6 +15,8 @@ package clientapi import ( + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/producers" @@ -26,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index c50e552bd..e371d9214 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias( joinReq := roomserverAPI.PerformJoinRequest{ RoomIDOrAlias: roomIDOrAlias, UserID: device.UserID, + IsGuest: device.AccountType == api.AccountTypeGuest, Content: map[string]interface{}{}, } joinRes := roomserverAPI.PerformJoinResponse{} @@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias( if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { done <- jsonerror.InternalAPIError(req.Context(), err) } else if joinRes.Error != nil { - done <- joinRes.Error.JSONResponse() + if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest { + done <- util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg), + } + } else { + done <- joinRes.Error.JSONResponse() + } } else { done <- util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go new file mode 100644 index 000000000..9e8208e6d --- /dev/null +++ b/clientapi/routing/joinroom_test.go @@ -0,0 +1,158 @@ +package routing + +import ( + "bytes" + "context" + "net/http" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestJoinRoomByIDOrAlias(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc + + // Create the users in the userapi + for _, u := range []*test.User{alice, bob, charlie} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + } + + aliceDev := &uapi.Device{UserID: alice.ID} + bobDev := &uapi.Device{UserID: bob.ID} + charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest} + + // create a room with disabled guest access and invite Bob + resp := createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + RoomAliasName: "alias", + Invite: []string{bob.ID}, + GuestCanJoin: false, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crResp, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // create a room with guest access enabled and invite Charlie + resp = createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + Invite: []string{charlie.ID}, + GuestCanJoin: true, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // Dummy request + body := &bytes.Buffer{} + req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + device *uapi.Device + roomID string + wantHTTP200 bool + }{ + { + name: "User can join successfully by alias", + device: bobDev, + roomID: crResp.RoomAlias, + wantHTTP200: true, + }, + { + name: "User can join successfully by roomID", + device: bobDev, + roomID: crResp.RoomID, + wantHTTP200: true, + }, + { + name: "join is forbidden if user is guest", + device: charlieDev, + roomID: crResp.RoomID, + }, + { + name: "room does not exist", + device: aliceDev, + roomID: "!doesnotexist:test", + }, + { + name: "user from different server", + device: &uapi.Device{UserID: "@wrong:server"}, + roomID: crResp.RoomAlias, + }, + { + name: "user doesn't exist locally", + device: &uapi.Device{UserID: "@doesnotexist:test"}, + roomID: crResp.RoomAlias, + }, + { + name: "invalid room ID", + device: aliceDev, + roomID: "invalidRoomID", + }, + { + name: "roomAlias does not exist", + device: aliceDev, + roomID: "#doesnotexist:test", + }, + { + name: "room with guest_access event", + device: charlieDev, + roomID: crRespWithGuestAccess.RoomID, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID) + if tc.wantHTTP200 && !joinResp.Is2xx() { + t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp) + } + }) + } + }) +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index e70e5ea9c..e789b9568 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -78,6 +78,7 @@ const ( type PerformJoinRequest struct { RoomIDOrAlias string `json:"room_id_or_alias"` UserID string `json:"user_id"` + IsGuest bool `json:"is_guest"` Content map[string]interface{} `json:"content"` ServerNames []gomatrixserverlib.ServerName `json:"server_names"` Unsigned map[string]interface{} `json:"unsigned"` diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1a3626609..451b37696 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -4,6 +4,10 @@ import ( "context" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/caching" @@ -19,9 +23,6 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" ) // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI @@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.fsAPI = fsAPI r.KeyRing = keyRing + identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName) + if err != nil { + logrus.Panic(err) + } + r.Inputer = &input.Inputer{ Cfg: &r.Base.Cfg.RoomServer, Base: r.Base, @@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio JetStream: r.JetStream, NATSClient: r.NATSClient, Durable: nats.Durable(r.Durable), - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, + SigningIdentity: identity, FSAPI: fsAPI, KeyRing: keyRing, ACLs: r.ServerACLs, @@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Queryer: r.Queryer, } r.Peeker = &perform.Peeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Inputer: r.Inputer, } r.Unpeeker = &perform.Unpeeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { r.Leaver.UserAPI = userAPI + r.Inputer.UserAPI = userAPI } func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index e965691c9..941311030 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -23,6 +23,8 @@ import ( "sync" "time" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -79,6 +81,7 @@ type Inputer struct { JetStream nats.JetStreamContext Durable nats.SubOpt ServerName gomatrixserverlib.ServerName + SigningIdentity *gomatrixserverlib.SigningIdentity FSAPI fedapi.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs @@ -87,6 +90,7 @@ type Inputer struct { workers sync.Map // room ID -> *worker Queryer *query.Queryer + UserAPI userapi.RoomserverUserAPI } // If a room consumer is inactive for a while then we will allow NATS diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 10b8ee27f..4179fc1ef 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -19,6 +19,7 @@ package input import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "time" @@ -31,6 +32,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + userAPI "github.com/matrix-org/dendrite/userapi/api" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" @@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent( } } + // If guest_access changed and is not can_join, kick all guest users. + if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { + if err = r.kickGuests(ctx, event, roomInfo); err != nil { + logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation") + } + } + // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) @@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState( succeeded = true return nil } + +// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited. +func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error { + membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) + if err != nil { + return err + } + + memberEvents, err := r.DB.Events(ctx, membershipNIDs) + if err != nil { + return err + } + + inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) + latestReq := &api.QueryLatestEventsAndStateRequest{ + RoomID: event.RoomID(), + } + latestRes := &api.QueryLatestEventsAndStateResponse{} + if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { + return err + } + + prevEvents := latestRes.LatestEvents + for _, memberEvent := range memberEvents { + if memberEvent.StateKey() == nil { + continue + } + + localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) + if err != nil { + continue + } + + accountRes := &userAPI.QueryAccountByLocalpartResponse{} + if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: senderDomain, + }, accountRes); err != nil { + return err + } + if accountRes.Account == nil { + continue + } + + if accountRes.Account.AccountType != userAPI.AccountTypeGuest { + continue + } + + var memberContent gomatrixserverlib.MemberContent + if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { + return err + } + memberContent.Membership = gomatrixserverlib.Leave + + stateKey := *memberEvent.StateKey() + fledglingEvent := &gomatrixserverlib.EventBuilder{ + RoomID: event.RoomID(), + Type: gomatrixserverlib.MRoomMember, + StateKey: &stateKey, + Sender: stateKey, + PrevEvents: prevEvents, + } + + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { + return err + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) + if err != nil { + return err + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes) + if err != nil { + return err + } + + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + Origin: senderDomain, + SendAsServer: string(senderDomain), + }) + prevEvents = []gomatrixserverlib.EventReference{ + event.EventReference(), + } + } + + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: inputEvents, + Asynchronous: true, // Needs to be async, as we otherwise create a deadlock + } + inputRes := &api.InputRoomEventsResponse{} + return r.InputRoomEvents(ctx, inputReq, inputRes) +} diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 4de008c66..fc7ba940c 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -16,6 +16,7 @@ package perform import ( "context" + "database/sql" "errors" "fmt" "strings" @@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID( } } + // If a guest is trying to join a room, check that the room has a m.room.guest_access event + if req.IsGuest { + var guestAccessEvent *gomatrixserverlib.HeaderedEvent + guestAccess := "forbidden" + guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "") + if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil { + logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'") + } + if guestAccessEvent != nil { + guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String() + } + + // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event + // is present on the room and has the guest_access value can_join. + if guestAccess != "can_join" { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorNotAllowed, + Msg: "Guest access is forbidden", + } + } + } + // If we should do a forced federated join then do that. var joinedVia gomatrixserverlib.ServerName if forceFederatedJoin { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 518bb3722..595ceb526 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -3,18 +3,23 @@ package roomserver_test import ( "context" "net/http" + "reflect" "testing" "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/userapi" + + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" ) @@ -29,7 +34,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s return base, db, close } -func Test_SharedUsers(t *testing.T) { +func TestUsers(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + + t.Run("shared users", func(t *testing.T) { + testSharedUsers(t, rsAPI) + }) + + t.Run("kick users", func(t *testing.T) { + usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + rsAPI.SetUserAPI(usrAPI) + testKickUsers(t, rsAPI, usrAPI) + }) + }) + +} + +func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) { alice := test.NewUser(t) bob := test.NewUser(t) room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) @@ -43,36 +69,93 @@ func Test_SharedUsers(t *testing.T) { }, test.WithStateKey(bob.ID)) ctx := context.Background() - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, _, close := mustCreateDatabase(t, dbType) - defer close() - rsAPI := roomserver.NewInternalAPI(base) - // SetFederationAPI starts the room event input consumer - rsAPI.SetFederationAPI(nil, nil) - // Create the room - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { - t.Fatalf("failed to send events: %v", err) + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Query the shared users for Alice, there should only be Bob. + // This is used by the SyncAPI keychange consumer. + res := &api.QuerySharedUsersResponse{} + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } + // Also verify that we get the expected result when specifying OtherUserIDs. + // This is used by the SyncAPI when getting device list changes. + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } +} + +func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) { + // Create users and room; Bob is going to be the guest and kicked on revocation of guest access + alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest)) + + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true)) + + // Join with the guest user + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + // Create the users in the userapi, so the RSAPI can query the account type later + for _, u := range []*test.User{alice, bob} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &userAPI.PerformAccountCreationResponse{} + if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + // Create the room in the database + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Get the membership events BEFORE revoking guest access + membershipRes := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil { + t.Errorf("failed to query membership for room: %s", err) + } + + // revoke guest access + revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need + // to loop and wait for the events to be processed by the roomserver. + for i := 0; i <= 20; i++ { + // Get the membership events AFTER revoking guest access + membershipRes2 := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil { + t.Errorf("failed to query membership for room: %s", err) } - // Query the shared users for Alice, there should only be Bob. - // This is used by the SyncAPI keychange consumer. - res := &api.QuerySharedUsersResponse{} - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) + // The membership events should NOT match, as Bob (guest user) should now be kicked from the room + if !reflect.DeepEqual(membershipRes, membershipRes2) { + return } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) - } - // Also verify that we get the expected result when specifying OtherUserIDs. - // This is used by the SyncAPI when getting device list changes. - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) - } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) - } - }) + time.Sleep(time.Millisecond * 10) + } + + t.Errorf("memberships didn't change in time") } func Test_QueryLeftUsers(t *testing.T) { diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 511951fe6..804eb1a2d 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g return id, nil } } - return nil, fmt.Errorf("no signing identity %q", serverName) + return nil, fmt.Errorf("no signing identity for %q", serverName) } func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index ee7e7389c..3408bf46d 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -16,8 +16,10 @@ package config import ( "fmt" + "reflect" "testing" + "github.com/matrix-org/gomatrixserverlib" "gopkg.in/yaml.v2" ) @@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) { } } } + +func Test_SigningIdentityFor(t *testing.T) { + tests := []struct { + name string + virtualHosts []*VirtualHost + serverName gomatrixserverlib.ServerName + want *gomatrixserverlib.SigningIdentity + wantErr bool + }{ + { + name: "no virtual hosts defined", + wantErr: true, + }, + { + name: "no identity found", + serverName: gomatrixserverlib.ServerName("doesnotexist"), + wantErr: true, + }, + { + name: "found identity", + serverName: gomatrixserverlib.ServerName("main"), + want: &gomatrixserverlib.SigningIdentity{ServerName: "main"}, + }, + { + name: "identity found on virtual hosts", + serverName: gomatrixserverlib.ServerName("vh2"), + virtualHosts: []*VirtualHost{ + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}}, + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}}, + }, + want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Global{ + VirtualHosts: tt.virtualHosts, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "main", + }, + } + got, err := c.SigningIdentityFor(tt.serverName) + if (err != nil) != tt.wantErr { + t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sytest-blacklist b/sytest-blacklist index c35b03bd7..99cfbabc8 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -48,4 +48,7 @@ If a device list update goes missing, the server resyncs on the next one Leaves are present in non-gapped incremental syncs # Below test was passing for the wrong reason, failing correctly since #2858 -New federated private chats get full presence information (SYN-115) \ No newline at end of file +New federated private chats get full presence information (SYN-115) + +# We don't have any state to calculate m.room.guest_access when accepting invites +Guest users can accept invites to private rooms over federation \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 49ffb8fe8..215889a49 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -763,4 +763,7 @@ AS and main public room lists are separate local user has tags copied to the new room remote user has tags copied to the new room /upgrade moves remote aliases to the new room -Local and remote users' homeservers remove a room from their public directory on upgrade \ No newline at end of file +Local and remote users' homeservers remove a room from their public directory on upgrade +Guest users denied access over federation if guest access prohibited +Guest users are kicked from guest_access rooms on revocation of guest_access +Guest users are kicked from guest_access rooms on revocation of guest_access over federation \ No newline at end of file diff --git a/test/room.go b/test/room.go index 4328bf84f..685876cb0 100644 --- a/test/room.go +++ b/test/room.go @@ -38,11 +38,12 @@ var ( ) type Room struct { - ID string - Version gomatrixserverlib.RoomVersion - preset Preset - visibility gomatrixserverlib.HistoryVisibility - creator *User + ID string + Version gomatrixserverlib.RoomVersion + preset Preset + guestCanJoin bool + visibility gomatrixserverlib.HistoryVisibility + creator *User authEvents gomatrixserverlib.AuthEvents currentState map[string]*gomatrixserverlib.HeaderedEvent @@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) { r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) + if r.guestCanJoin { + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{ + "guest_access": "can_join", + }, WithStateKey("")) + } } // Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe. @@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier { r.Version = ver } } + +func GuestsCanJoin(canJoin bool) roomModifier { + return func(t *testing.T, r *Room) { + r.guestCanJoin = canJoin + } +} diff --git a/userapi/api/api.go b/userapi/api/api.go index d3f5aefc8..4ea2e91c3 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -50,6 +50,7 @@ type KeyserverUserAPI interface { type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the media api @@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct { ServerName gomatrixserverlib.ServerName Medium string } + +type QueryAccountByLocalpartRequest struct { + Localpart string + ServerName gomatrixserverlib.ServerName +} + +type QueryAccountByLocalpartResponse struct { + Account *Account +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index ce661770f..d10b5767b 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex return err } +func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error { + err := t.Impl.QueryAccountByLocalpart(ctx, req, res) + util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 3f256457e..0bb480da6 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } +func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) { + res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName) + return +} + // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // creating a 'device'. func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 87ae058c2..51b0fe3ef 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -60,6 +60,7 @@ const ( QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" + QueryAccountByLocalpartPath = "/userapi/queryAccountType" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation( h.httpClient, ctx, request, response, ) } + +func (h *httpUserInternalAPI) QueryAccountByLocalpart( + ctx context.Context, + req *api.QueryAccountByLocalpartRequest, + res *api.QueryAccountByLocalpartResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath, + h.httpClient, ctx, req, res, + ) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index f0579079f..b40b507c2 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics PerformSaveThreePIDAssociationPath, httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), ) + + internalAPIMux.Handle( + QueryAccountByLocalpartPath, + httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart), + ) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 8a19af195..dada56de4 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) { }) }) } + +func TestQueryAccountByLocalpart(t *testing.T) { + alice := test.NewUser(t) + + localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() + + createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType) + if err != nil { + t.Error(err) + } + + testCases := func(t *testing.T, internalAPI api.UserInternalAPI) { + // Query existing account + queryAccResp := &api.QueryAccountByLocalpartResponse{} + if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: userServername, + }, queryAccResp); err != nil { + t.Error(err) + } + if !reflect.DeepEqual(createdAcc, queryAccResp.Account) { + t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account) + } + + // Query non-existent account, this should result in an error + err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: "doesnotexist", + ServerName: userServername, + }, queryAccResp) + + if err == nil { + t.Fatalf("expected an error, but got none: %+v", queryAccResp) + } + } + + t.Run("Monolith", func(t *testing.T) { + testCases(t, intAPI) + // also test tracing + testCases(t, &api.UserInternalAPITrace{Impl: intAPI}) + }) + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + userapi.AddInternalRoutes(router, intAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + + userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5}) + if err != nil { + t.Fatalf("failed to create HTTP client: %s", err) + } + testCases(t, userHTTPApi) + + }) + }) +}