From 264165eb8c0e8ae618e53de47b814f8a9781d567 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 25 Feb 2022 17:35:10 +0000 Subject: [PATCH 01/33] Update systemd example to set `LimitNOFILE` --- docs/systemd/monolith-example.service | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/systemd/monolith-example.service b/docs/systemd/monolith-example.service index 731c6159b..237120ffb 100644 --- a/docs/systemd/monolith-example.service +++ b/docs/systemd/monolith-example.service @@ -13,6 +13,7 @@ Group=dendrite WorkingDirectory=/opt/dendrite/ ExecStart=/opt/dendrite/bin/dendrite-monolith-server Restart=always +LimitNOFILE=65535 [Install] WantedBy=multi-user.target From ac77732185a2473dd39984669ea3f454548aa7f5 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Mon, 28 Feb 2022 13:57:56 +0100 Subject: [PATCH 02/33] Add possibility to reset password using create-account (#2231) * Add possibility to reset password * Invalidate logins * Fix test --- cmd/create-account/main.go | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 3003896c8..307fa17e6 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -48,24 +48,27 @@ Example: # read password from stdin %s --config dendrite.yaml -username alice -passwordstdin < my.pass cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin + # reset password for a user, can be used with a combination above to read the password + %s --config dendrite.yaml -reset-password -username alice -password foobarbaz Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - askPass = flag.Bool("ask-pass", false, "Ask for the password to use") - isAdmin = flag.Bool("admin", false, "Create an admin account") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + askPass = flag.Bool("ask-pass", false, "Ask for the password to use") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username") ) func main() { name := os.Args[0] flag.Usage = func() { - _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name) + _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name) flag.PrintDefaults() } cfg := setup.ParseFlags(true) @@ -93,6 +96,19 @@ func main() { if *isAdmin { accType = api.AccountTypeAdmin } + + if *resetPassword { + err = accountDB.SetPassword(context.Background(), *username, pass) + if err != nil { + logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error()) + } + if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil { + logrus.Fatalf("Failed to remove all devices: %s", err.Error()) + } + logrus.Infof("Updated password for user %s and invalidated all logins\n", *username) + return + } + _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType) if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) From a23fda662607e9160230335503e912f626abf616 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 28 Feb 2022 14:51:40 +0000 Subject: [PATCH 03/33] Update `Events` call-sites which now don't return an error, update `parsedRespState` to sort (#2227) * Topologically sort with `SendEventWithState`, so that earlier events should satisfy auth for later ones * Revert "Topologically sort with `SendEventWithState`, so that earlier events should satisfy auth for later ones" This reverts commit b0cd706012b4c9b6724b11e16f19c4cb732ab286. * Update to matrix-org/gomatrixserverlib#293 * `Events` no longer returns an error, other tweaks * Make sure `Events` is sorted for `parsedRespState` too --- go.mod | 2 +- go.sum | 4 ++-- roomserver/api/wrapper.go | 8 ++------ roomserver/internal/input/input_missing.go | 22 +++++++++++++++++----- setup/mscs/msc2836/msc2836.go | 8 ++------ 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/go.mod b/go.mod index a416ec98f..78af67e6e 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438 github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index ecb0496bc..4ba0723f2 100644 --- a/go.sum +++ b/go.sum @@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 h1:+4Q/JQ3fGgA7sIHaLMlqREX8yEpsI+HlVoW9WId7SNc= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438 h1:3B0ZEJ5YVVbRKHV7WFgji5z6s262YIVRZvtDotdpbsI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 012094c62..5491d36b3 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -51,12 +51,8 @@ func SendEventWithState( state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { - outliers, err := state.Events(event.RoomVersion) - if err != nil { - return err - } - - var ires []InputRoomEvent + outliers := state.Events(event.RoomVersion) + ires := make([]InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if haveEventIDs[outlier.EventID()] { continue diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 4655e92a9..a7da9b06d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -23,6 +23,21 @@ type parsedRespState struct { StateEvents []*gomatrixserverlib.Event } +func (p *parsedRespState) Events() []*gomatrixserverlib.Event { + eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents)) + for i, event := range p.AuthEvents { + eventsByID[event.EventID()] = p.AuthEvents[i] + } + for i, event := range p.StateEvents { + eventsByID[event.EventID()] = p.StateEvents[i] + } + allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID)) + for _, event := range eventsByID { + allEvents = append(allEvents, event) + } + return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents) +} + type missingStateReq struct { origin gomatrixserverlib.ServerName db storage.Database @@ -124,11 +139,8 @@ func (t *missingStateReq) processEventWithMissingState( t.hadEventsMutex.Unlock() sendOutliers := func(resolvedState *parsedRespState) error { - outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) - if oerr != nil { - return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr) - } - var outlierRoomEvents []api.InputRoomEvent + outliers := resolvedState.Events() + outlierRoomEvents := make([]api.InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if hadEvents[outlier.EventID()] { continue diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 0af22c19a..29c781a88 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -654,11 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo AuthEvents: res.AuthChain, StateEvents: stateEvents, } - eventsInOrder, err := respState.Events(rc.roomVersion) - if err != nil { - util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse") - return - } + eventsInOrder := respState.Events(rc.roomVersion) // everything gets sent as an outlier because auth chain events may be disjoint from the DAG // as may the threaded events. var ires []roomserver.InputRoomEvent @@ -669,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo }) } // we've got the data by this point so use a background context - err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) + err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") } From 58bf91a585ec78f6ca6ff0c9ad0c10c5db9715a7 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 11:00:54 +0000 Subject: [PATCH 04/33] Check for changes in `PerformUploadDeviceKeys` (#2233) * Don't generate key change notifs if nothing changed on cross-signing upload * Check both directions of changes --- keyserver/internal/cross_signing.go | 53 ++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index bfb2037f8..5124f37e6 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -166,26 +166,53 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } // We can't have a self-signing or user-signing key without a master - // key, so make sure we have one of those. - if !hasMasterKey { - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) - if err != nil { - res.Error = &api.KeyError{ - Err: "Retrieving cross-signing keys from database failed: " + err.Error(), - } - return + // key, so make sure we have one of those. We will also only actually do + // something if any of the specified keys in the request are different + // to what we've got in the database, to avoid generating key change + // notifications unnecessarily. + existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) + if err != nil { + res.Error = &api.KeyError{ + Err: "Retrieving cross-signing keys from database failed: " + err.Error(), } - - _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster] + return } // If we still can't find a master key for the user then stop the upload. // This satisfies the "Fails to upload self-signing key without master key" test. if !hasMasterKey { - res.Error = &api.KeyError{ - Err: "No master key was found", - IsMissingParam: true, + if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey { + res.Error = &api.KeyError{ + Err: "No master key was found", + IsMissingParam: true, + } + return } + } + + // Check if anything actually changed compared to what we have in the database. + changed := false + for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{ + gomatrixserverlib.CrossSigningKeyPurposeMaster, + gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, + gomatrixserverlib.CrossSigningKeyPurposeUserSigning, + } { + old, gotOld := existingKeys[purpose] + new, gotNew := toStore[purpose] + if gotOld != gotNew { + // A new key purpose has been specified that we didn't know before, + // or one has been removed. + changed = true + break + } + if !bytes.Equal(old, new) { + // One of the existing keys for a purpose we already knew about has + // changed. + changed = true + break + } + } + if !changed { return } From 530f05885dccba91559aff09eaaa20540a08a419 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 13:01:38 +0000 Subject: [PATCH 05/33] Limit `JoinedUsersSetInRooms` to interested users (#2234) * Limit database work in `JoinedUsersSetInRooms` to changed user IDs only * Comments * Fix variadic params for SQLite, update comments --- roomserver/api/query.go | 1 + roomserver/internal/query/query.go | 2 +- roomserver/storage/interface.go | 4 ++-- .../storage/postgres/membership_table.go | 10 ++++----- roomserver/storage/shared/storage.go | 20 +++++++++++------- .../storage/sqlite3/membership_table.go | 21 ++++++++++++------- roomserver/storage/tables/interface.go | 5 ++--- syncapi/internal/keychange.go | 17 +++++++++++---- 8 files changed, 49 insertions(+), 31 deletions(-) diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 96d6711c6..6255ffa3a 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -269,6 +269,7 @@ type QueryAuthChainResponse struct { type QuerySharedUsersRequest struct { UserID string + OtherUserIDs []string ExcludeRoomIDs []string IncludeRoomIDs []string } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index c8bbe7705..20c5a2585 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -696,7 +696,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser } roomIDs = roomIDs[:j] - users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs) + users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs) if err != nil { return err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 685505d52..cd232e3e9 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -151,8 +151,8 @@ type Database interface { // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) - // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. - JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms. + JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) // GetServerInRoom returns true if we think a server is in a given room or false otherwise. diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 48c2c35cd..127178743 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" @@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectJoinedUsersSetForRooms( ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, + userNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]int, error) { - roomIDarray := make([]int64, len(roomNIDs)) - for i := range roomNIDs { - roomIDarray[i] = int64(roomNIDs[i]) - } stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) - rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 6e84b2832..6dc40816c 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1104,13 +1104,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu return result, nil } -// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. -func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { +// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms. +func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) { roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) if err != nil { return nil, err } - userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) + userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs) + if err != nil { + return nil, err + } + userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap)) + nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap)) + for id, nid := range userNIDsMap { + userNIDs = append(userNIDs, nid) + nidToUserID[nid] = id + } + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs) if err != nil { return nil, err } @@ -1120,10 +1130,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) stateKeyNIDs[i] = nid i++ } - nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs) - if err != nil { - return nil, err - } if len(nidToUserID) != len(userNIDToCount) { logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID)) } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 181b4b4c9..43567a94c 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -42,7 +42,8 @@ const membershipSchema = ` ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid IN ($1) AND target_nid IN ($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" @@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { - iRoomNIDs := make([]interface{}, len(roomNIDs)) - for i, v := range roomNIDs { - iRoomNIDs[i] = v +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) { + params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs)) + for _, v := range roomNIDs { + params = append(params, v) } - query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) + for _, v := range userNIDs { + params = append(params, v) + } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) var rows *sql.Rows var err error if txn != nil { - rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) + rows, err = txn.QueryContext(ctx, query, params...) } else { - rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) + rows, err = s.db.QueryContext(ctx, query, params...) } if err != nil { return nil, err diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index e3fed700b..04e3c96cc 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -127,9 +127,8 @@ type Membership interface { SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) - // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the - // counts of how many rooms they are joined. - SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. + SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 37a9e2d39..dc4acd8da 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -82,7 +82,16 @@ func DeviceListCatchup( util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") return to, hasNew, nil } - // QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user. + + // Work out which user IDs we care about — that includes those in the original request, + // the response from QueryKeyChanges (which includes ALL users who have changed keys) + // as well as every user who has a join or leave event in the current sync response. We + // will request information about which rooms these users are joined to, so that we can + // see if we still share any rooms with them. + joinUserIDs, leaveUserIDs := membershipEvents(res) + queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...) + queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...) + queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs) var sharedUsersMap map[string]int sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) util.GetLogger(ctx).Debugf( @@ -100,9 +109,8 @@ func DeviceListCatchup( userSet[userID] = true } } - // if the response has any join/leave events, add them now. + // Finally, add in users who have joined or left. // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. - joinUserIDs, leaveUserIDs := membershipEvents(res) for _, userID := range joinUserIDs { if !userSet[userID] { res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) @@ -213,7 +221,8 @@ func filterSharedUsers( var result []string var sharedUsersRes roomserverAPI.QuerySharedUsersResponse err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ - UserID: userID, + UserID: userID, + OtherUserIDs: usersWithChangedKeys, }, &sharedUsersRes) if err != nil { // default to all users so we do needless queries rather than miss some important device update From f1b92de0175d74be7b167c5e3f168d0048c84843 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 1 Mar 2022 13:40:07 +0000 Subject: [PATCH 06/33] MSC2946: Spaces Summary (round 2) (#2232) * Initial cut at fixing up MSC2946 to work with latest spec * bugfix: send response back correctly * Initial working version of MSC2946 * msc2946: handle suggested_only; remove custom database As the MSC doesn't require reverse lookups, we can just pull the room state and inspect via the roomserver database. To handle this, expand QueryCurrentState to support wildcards. Use all this and handle `?suggested_only`. * Sort child rooms * msc2946: Make TestClientSpacesSummary pass * msc2946: allow invited rooms to be spidered * msc2946: support basic federation requests * fix up go mod --- federationapi/api/api.go | 2 +- federationapi/internal/federationclient.go | 4 +- federationapi/inthttp/client.go | 18 +- federationapi/inthttp/server.go | 2 +- go.mod | 7 +- go.sum | 12 +- roomserver/api/query.go | 5 +- roomserver/internal/query/query.go | 25 +- roomserver/storage/interface.go | 1 + roomserver/storage/shared/storage.go | 56 +++ setup/mscs/msc2946/msc2946.go | 543 ++++++++++++--------- setup/mscs/msc2946/msc2946_test.go | 464 ------------------ setup/mscs/msc2946/storage.go | 182 ------- 13 files changed, 411 insertions(+), 910 deletions(-) delete mode 100644 setup/mscs/msc2946/msc2946_test.go delete mode 100644 setup/mscs/msc2946/storage.go diff --git a/federationapi/api/api.go b/federationapi/api/api.go index f5ee75b4b..4d6b0211c 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -21,7 +21,7 @@ type FederationClient interface { QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index b31db466c..b8bd5beda 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2946Spaces(ctx, s, roomID, r) + return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) }) if err != nil { return res, err diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index f9b2a33d2..01ca6595d 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -526,23 +526,23 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships( } type spacesReq struct { - S gomatrixserverlib.ServerName - Req gomatrixserverlib.MSC2946SpacesRequest - RoomID string - Res gomatrixserverlib.MSC2946SpacesResponse - Err *api.FederationClientError + S gomatrixserverlib.ServerName + SuggestedOnly bool + RoomID string + Res gomatrixserverlib.MSC2946SpacesResponse + Err *api.FederationClientError } func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, + ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces") defer span.Finish() request := spacesReq{ - S: dst, - Req: r, - RoomID: roomID, + S: dst, + SuggestedOnly: suggestedOnly, + RoomID: roomID, } var response spacesReq apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 8d193d9c9..ca4930f20 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req) + res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly) if err != nil { ferr, ok := err.(*api.FederationClientError) if ok { diff --git a/go.mod b/go.mod index 78af67e6e..e44e7393d 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 + github.com/google/uuid v1.2.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.3 // indirect @@ -39,7 +40,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 @@ -61,11 +62,11 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.2 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/mobile v0.0.0-20220112015953-858099ff7816 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd - golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect + golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 4ba0723f2..43b363d30 100644 --- a/go.sum +++ b/go.sum @@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438 h1:3B0ZEJ5YVVbRKHV7WFgji5z6s262YIVRZvtDotdpbsI= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 h1:CqbM5tPbF1mV2h1J6N0NSF8et6HvFlwA98TLCUcjiSI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1510,8 +1510,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5 golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= -golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1737,8 +1737,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= -golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 6255ffa3a..66e85f2f3 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -313,7 +313,10 @@ type QueryBulkStateContentResponse struct { } type QueryCurrentStateRequest struct { - RoomID string + RoomID string + AllowWildcards bool + // State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching + // state events with that event type. StateTuples []gomatrixserverlib.StateKeyTuple } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 20c5a2585..70cc5d62c 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -621,12 +621,25 @@ func (r *Queryer) QueryPublishedRooms( func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) for _, tuple := range req.StateTuples { - ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) - if err != nil { - return err - } - if ev != nil { - res.StateEvents[tuple] = ev + if tuple.StateKey == "*" && req.AllowWildcards { + events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType) + if err != nil { + return err + } + for _, e := range events { + res.StateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = e + } + } else { + ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) + if err != nil { + return err + } + if ev != nil { + res.StateEvents[tuple] = ev + } } } return nil diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index cd232e3e9..a2b22b401 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -146,6 +146,7 @@ type Database interface { // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 6dc40816c..f87782776 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -979,6 +979,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s return nil, nil } +// Same as GetStateEvent but returns all matching state events with this event type. Returns no error +// if there are no events with this event type. +func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + if roomInfo == nil { + return nil, fmt.Errorf("room %s doesn't exist", roomID) + } + // e.g invited rooms + if roomInfo.IsStub { + return nil, nil + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err == sql.ErrNoRows { + // No rooms have an event of this type, otherwise we'd have an event type NID + return nil, nil + } + if err != nil { + return nil, err + } + entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + if err != nil { + return nil, err + } + var eventNIDs []types.EventNID + for _, e := range entries { + if e.EventTypeNID == eventTypeNID { + eventNIDs = append(eventNIDs, e.EventNID) + } + } + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) + if err != nil { + eventIDs = map[types.EventNID]string{} + } + // return the events requested + eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) + if err != nil { + return nil, err + } + if len(eventPairs) == 0 { + return nil, nil + } + var result []*gomatrixserverlib.HeaderedEvent + for _, pair := range eventPairs { + ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion) + if err != nil { + return nil, err + } + result = append(result, ev.Headered(roomInfo.RoomVersion)) + } + + return result, nil +} + // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { var membershipState tables.MembershipState diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 3824c99a2..3bb56f4b3 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -18,17 +18,18 @@ package msc2946 import ( "context" "encoding/json" - "fmt" "net/http" + "net/url" + "sort" + "strconv" "strings" "sync" "time" + "github.com/google/uuid" "github.com/gorilla/mux" - chttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/httputil" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -39,15 +40,15 @@ import ( ) const ( - ConstCreateEventContentKey = "type" - ConstSpaceChildEventType = "m.space.child" - ConstSpaceParentEventType = "m.space.parent" + ConstCreateEventContentKey = "type" + ConstCreateEventContentValueSpace = "m.space" + ConstSpaceChildEventType = "m.space.child" + ConstSpaceParentEventType = "m.space.parent" ) -// Defaults sets the request defaults -func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) { - r.Limit = 2000 - r.MaxRoomsPerSpace = -1 +type MSC2946ClientResponse struct { + Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"` + NextBatch string `json:"next_batch,omitempty"` } // Enable this MSC @@ -55,26 +56,11 @@ func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, ) error { - db, err := NewDatabase(&base.Cfg.MSCs.Database) - if err != nil { - return fmt.Errorf("cannot enable MSC2946: %w", err) - } - hooks.Enable() - hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { - he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) - hookErr := db.StoreReference(context.Background(), he) - if hookErr != nil { - util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error( - "failed to StoreReference", - ) - } - }) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, base.Cfg.Global.ServerName)) + base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) + base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) - base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces", - httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)), - ).Methods(http.MethodPost, http.MethodOptions) - - base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI( + fedAPI := httputil.MakeExternalAPI( "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( req, time.Now(), base.Cfg.Global.ServerName, keyRing, @@ -88,105 +74,99 @@ func Enable( return util.ErrorResponse(err) } roomID := params["roomID"] - return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName) + return federatedSpacesHandler(req.Context(), fedReq, roomID, rsAPI, fsAPI, base.Cfg.Global.ServerName) }, - )).Methods(http.MethodPost, http.MethodOptions) + ) + base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) + base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) return nil } func federatedSpacesHandler( - ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database, + ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) - var r gomatrixserverlib.MSC2946SpacesRequest - Defaults(&r) - if err := json.Unmarshal(fedReq.Content(), &r); err != nil { + u, err := url.Parse(fedReq.RequestURI()) + if err != nil { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + Code: 400, + JSON: jsonerror.InvalidParam("bad request uri"), } } - w := walker{ - req: &r, - rootRoomID: roomID, - serverName: fedReq.Origin(), - thisServer: thisServer, - ctx: ctx, - db: db, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, - } - res := w.walk() - return util.JSONResponse{ - Code: 200, - JSON: res, + w := walker{ + rootRoomID: roomID, + serverName: fedReq.Origin(), + thisServer: thisServer, + ctx: ctx, + suggestedOnly: u.Query().Get("suggested_only") == "true", + limit: 1000, + // The main difference is that it does not recurse into spaces and does not support pagination. + // This is somewhat equivalent to a Client-Server request with a max_depth=1. + maxDepth: 1, + + rsAPI: rsAPI, + fsAPI: fsAPI, + // inline cache as we don't have pagination in federation mode + paginationCache: make(map[string]paginationInfo), } + return w.walk() } func spacesHandler( - db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, + rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) func(*http.Request, *userapi.Device) util.JSONResponse { + // declared outside the returned handler so it persists between calls + // TODO: clear based on... time? + paginationCache := make(map[string]paginationInfo) + return func(req *http.Request, device *userapi.Device) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) // Extract the room ID from the request. Sanity check request data. params, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } roomID := params["roomID"] - var r gomatrixserverlib.MSC2946SpacesRequest - Defaults(&r) - if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil { - return *resErr - } w := walker{ - req: &r, - rootRoomID: roomID, - caller: device, - thisServer: thisServer, - ctx: req.Context(), + suggestedOnly: req.URL.Query().Get("suggested_only") == "true", + limit: parseInt(req.URL.Query().Get("limit"), 1000), + maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1), + paginationToken: req.URL.Query().Get("from"), + rootRoomID: roomID, + caller: device, + thisServer: thisServer, + ctx: req.Context(), - db: db, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, - } - res := w.walk() - return util.JSONResponse{ - Code: 200, - JSON: res, + rsAPI: rsAPI, + fsAPI: fsAPI, + paginationCache: paginationCache, } + return w.walk() } } +type paginationInfo struct { + processed set + unvisited []roomVisit +} + type walker struct { - req *gomatrixserverlib.MSC2946SpacesRequest - rootRoomID string - caller *userapi.Device - serverName gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName - db Database - rsAPI roomserver.RoomserverInternalAPI - fsAPI fs.FederationInternalAPI - ctx context.Context + rootRoomID string + caller *userapi.Device + serverName gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName + rsAPI roomserver.RoomserverInternalAPI + fsAPI fs.FederationInternalAPI + ctx context.Context + suggestedOnly bool + limit int + maxDepth int + paginationToken string - // user ID|device ID|batch_num => event/room IDs sent to client - inMemoryBatchCache map[string]set - mu sync.Mutex -} - -func (w *walker) roomIsExcluded(roomID string) bool { - for _, exclRoom := range w.req.ExcludeRooms { - if exclRoom == roomID { - return true - } - } - return false + paginationCache map[string]paginationInfo + mu sync.Mutex } func (w *walker) callerID() string { @@ -196,144 +176,207 @@ func (w *walker) callerID() string { return string(w.serverName) } -func (w *walker) alreadySent(id string) bool { - w.mu.Lock() - defer w.mu.Unlock() - m, ok := w.inMemoryBatchCache[w.callerID()] - if !ok { - return false +func (w *walker) newPaginationCache() (string, paginationInfo) { + p := paginationInfo{ + processed: make(set), + unvisited: nil, } - return m[id] + tok := uuid.NewString() + return tok, p } -func (w *walker) markSent(id string) { +func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo { w.mu.Lock() defer w.mu.Unlock() - m := w.inMemoryBatchCache[w.callerID()] - if m == nil { - m = make(set) - } - m[id] = true - w.inMemoryBatchCache[w.callerID()] = m + p := w.paginationCache[paginationToken] + return &p } -func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse { - var res gomatrixserverlib.MSC2946SpacesResponse - // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms - unvisited := []string{w.rootRoomID} - processed := make(set) +func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) { + w.mu.Lock() + defer w.mu.Unlock() + w.paginationCache[paginationToken] = cache +} + +type roomVisit struct { + roomID string + depth int + vias []string // vias to query this room by +} + +func (w *walker) walk() util.JSONResponse { + if !w.authorised(w.rootRoomID) { + if w.caller != nil { + // CS API format + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("room is unknown/forbidden"), + } + } else { + // SS API format + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("room is unknown/forbidden"), + } + } + } + + var discoveredRooms []gomatrixserverlib.MSC2946Room + + var cache *paginationInfo + if w.paginationToken != "" { + cache = w.loadPaginationCache(w.paginationToken) + if cache == nil { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("invalid from"), + } + } + } else { + tok, c := w.newPaginationCache() + cache = &c + w.paginationToken = tok + // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms + c.unvisited = append(c.unvisited, roomVisit{ + roomID: w.rootRoomID, + depth: 0, + }) + } + + processed := cache.processed + unvisited := cache.unvisited + + // Depth first -> stack data structure for len(unvisited) > 0 { - roomID := unvisited[0] - unvisited = unvisited[1:] - // If this room has already been processed, skip. NB: do not remember this between calls - if processed[roomID] || roomID == "" { + if len(discoveredRooms) >= w.limit { + break + } + + // pop the stack + rv := unvisited[len(unvisited)-1] + unvisited = unvisited[:len(unvisited)-1] + // If this room has already been processed, skip. + // If this room exceeds the specified depth, skip. + if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) { continue } + // Mark this room as processed. - processed[roomID] = true + processed.set(rv.roomID) + + // if this room is not a space room, skip. + var roomType string + create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "") + if create != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + } // Collect rooms/events to send back (either locally or fetched via federation) - var discoveredRooms []gomatrixserverlib.MSC2946Room - var discoveredEvents []gomatrixserverlib.MSC2946StrippedEvent + var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent // If we know about this room and the caller is authorised (joined/world_readable) then pull // events locally - if w.roomExists(roomID) && w.authorised(roomID) { - // Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get - // all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`) - // this room. This requires servers to store reverse lookups. - events, err := w.references(roomID) + if w.roomExists(rv.roomID) && w.authorised(rv.roomID) { + // Get all `m.space.child` state events for this room + events, err := w.childReferences(rv.roomID) if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room") + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room") continue } - discoveredEvents = events + discoveredChildEvents = events - pubRoom := w.publicRoomsChunk(roomID) - roomType := "" - create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "") - if create != nil { - // escape the `.`s so gjson doesn't think it's nested - roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str - } + pubRoom := w.publicRoomsChunk(rv.roomID) - // Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`. discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ - PublicRoom: *pubRoom, - NumRefs: len(discoveredEvents), - RoomType: roomType, + PublicRoom: *pubRoom, + RoomType: roomType, + ChildrenState: events, }) } else { // attempt to query this room over federation, as either we've never heard of it before // or we've left it and hence are not authorised (but info may be exposed regardless) - fedRes, err := w.federatedRoomInfo(roomID) + fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias) if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces") + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Errorf("failed to query federated spaces") continue } if fedRes != nil { - discoveredRooms = fedRes.Rooms - discoveredEvents = fedRes.Events + discoveredChildEvents = fedRes.Room.ChildrenState + discoveredRooms = append(discoveredRooms, fedRes.Room) + if len(fedRes.Children) > 0 { + discoveredRooms = append(discoveredRooms, fedRes.Children...) + } + // mark this room as a space room as the federated server responded. + // we need to do this so we add the children of this room to the unvisited stack + // as these children may be rooms we do know about. + roomType = ConstCreateEventContentValueSpace } } - // If this room has not ever been in `rooms` (across multiple requests), send it now - for _, room := range discoveredRooms { - if !w.alreadySent(room.RoomID) && !w.roomIsExcluded(room.RoomID) { - res.Rooms = append(res.Rooms, room) - w.markSent(room.RoomID) - } + // don't walk the children + // if the parent is not a space room + if roomType != ConstCreateEventContentValueSpace { + continue } - uniqueRooms := make(set) - - // If this is the root room from the original request, insert all these events into `events` if - // they haven't been added before (across multiple requests). - if w.rootRoomID == roomID { - for _, ev := range discoveredEvents { - if !w.alreadySent(eventKey(&ev)) { - res.Events = append(res.Events, ev) - uniqueRooms[ev.RoomID] = true - uniqueRooms[spaceTargetStripped(&ev)] = true - w.markSent(eventKey(&ev)) - } - } - } else { - // Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either - // are exceeded, stop adding events. If the event has already been added, do not add it again. - numAdded := 0 - for _, ev := range discoveredEvents { - if w.req.Limit > 0 && len(res.Events) >= w.req.Limit { - break - } - if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace { - break - } - if w.alreadySent(eventKey(&ev)) { - continue - } - // Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still - // want to catch arrows which point to excluded rooms. - if w.roomIsExcluded(ev.RoomID) { - continue - } - res.Events = append(res.Events, ev) - uniqueRooms[ev.RoomID] = true - uniqueRooms[spaceTargetStripped(&ev)] = true - w.markSent(eventKey(&ev)) - // we don't distinguish between child state events and parent state events for the purposes of - // max_rooms_per_space, maybe we should? - numAdded++ - } - } - - // For each referenced room ID in the events being returned to the caller (both parent and child) + // For each referenced room ID in the child events being returned to the caller // add the room ID to the queue of unvisited rooms. Loop from the beginning. - for roomID := range uniqueRooms { - unvisited = append(unvisited, roomID) + // We need to invert the order here because the child events are lo->hi on the timestamp, + // so we need to ensure we pop in the same lo->hi order, which won't be the case if we + // insert the highest timestamp last in a stack. + for i := len(discoveredChildEvents) - 1; i >= 0; i-- { + spaceContent := struct { + Via []string `json:"via"` + }{} + ev := discoveredChildEvents[i] + _ = json.Unmarshal(ev.Content, &spaceContent) + unvisited = append(unvisited, roomVisit{ + roomID: ev.StateKey, + depth: rv.depth + 1, + vias: spaceContent.Via, + }) } } - return &res + + if len(unvisited) > 0 { + // we still have more rooms so we need to send back a pagination token, + // we probably hit a room limit + cache.processed = processed + cache.unvisited = unvisited + w.storePaginationCache(w.paginationToken, *cache) + } else { + // clear the pagination token so we don't send it back to the client + // Note we do NOT nuke the cache just in case this response is lost + // and the client retries it. + w.paginationToken = "" + } + + if w.caller != nil { + // return CS API format + return util.JSONResponse{ + Code: 200, + JSON: MSC2946ClientResponse{ + Rooms: discoveredRooms, + NextBatch: w.paginationToken, + }, + } + } + // return SS API format + // the first discovered room will be the room asked for, and subsequent ones the depth=1 children + if len(discoveredRooms) == 0 { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("room is unknown/forbidden"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: gomatrixserverlib.MSC2946SpacesResponse{ + Room: discoveredRooms[0], + Children: discoveredRooms[1:], + }, + } } func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent { @@ -366,42 +409,19 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom { // federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was // unsuccessful. -func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { +func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { // only do federated requests for client requests if w.caller == nil { return nil, nil } - // extract events which point to this room ID and extract their vias - events, err := w.db.References(w.ctx, roomID) - if err != nil { - return nil, fmt.Errorf("failed to get References events: %w", err) - } - vias := make(set) - for _, ev := range events { - if ev.StateKeyEquals(roomID) { - // event points at this room, extract vias - content := struct { - Vias []string `json:"via"` - }{} - if err = json.Unmarshal(ev.Content(), &content); err != nil { - continue // silently ignore corrupted state events - } - for _, v := range content.Vias { - vias[v] = true - } - } - } - util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias) + util.GetLogger(w.ctx).Infof("Querying %s via %+v", roomID, vias) ctx := context.Background() // query more of the spaces graph using these servers - for serverName := range vias { + for _, serverName := range vias { if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{ - Limit: w.req.Limit, - MaxRoomsPerSpace: w.req.MaxRoomsPerSpace, - }) + res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue @@ -501,7 +521,7 @@ func (w *walker) authorisedUser(roomID string) bool { hisVisEv := queryRes.StateEvents[hisVisTuple] if memberEv != nil { membership, _ := memberEv.Membership() - if membership == gomatrixserverlib.Join { + if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { return true } } @@ -514,40 +534,85 @@ func (w *walker) authorisedUser(roomID string) bool { return false } -// references returns all references pointing to or from this room. -func (w *walker) references(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { - events, err := w.db.References(w.ctx, roomID) +// references returns all child references pointing to or from this room. +func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { + createTuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomCreate, + StateKey: "", + } + var res roomserver.QueryCurrentStateResponse + err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{ + RoomID: roomID, + AllowWildcards: true, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + createTuple, { + EventType: ConstSpaceChildEventType, + StateKey: "*", + }, + }, + }, &res) if err != nil { return nil, err } - el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events)) - for _, ev := range events { + + // don't return any child refs if the room is not a space room + if res.StateEvents[createTuple] != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + if roomType != ConstCreateEventContentValueSpace { + return nil, nil + } + } + delete(res.StateEvents, createTuple) + + el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents)) + for _, ev := range res.StateEvents { + content := gjson.ParseBytes(ev.Content()) // only return events that have a `via` key as per MSC1772 // else we'll incorrectly walk redacted events (as the link // is in the state_key) - if gjson.GetBytes(ev.Content(), "via").Exists() { + if content.Get("via").Exists() { strip := stripped(ev.Event) if strip == nil { continue } + // if suggested only and this child isn't suggested, skip it. + // if suggested only = false we include everything so don't need to check the content. + if w.suggestedOnly && !content.Get("suggested").Bool() { + continue + } el = append(el, *strip) } } + // sort by origin_server_ts as per MSC2946 + sort.Slice(el, func(i, j int) bool { + return el[i].OriginServerTS < el[j].OriginServerTS + }) + return el, nil } -type set map[string]bool +type set map[string]struct{} + +func (s set) set(val string) { + s[val] = struct{}{} +} +func (s set) isSet(val string) bool { + _, ok := s[val] + return ok +} func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent { if ev.StateKey() == nil { return nil } return &gomatrixserverlib.MSC2946StrippedEvent{ - Type: ev.Type(), - StateKey: *ev.StateKey(), - Content: ev.Content(), - Sender: ev.Sender(), - RoomID: ev.RoomID(), + Type: ev.Type(), + StateKey: *ev.StateKey(), + Content: ev.Content(), + Sender: ev.Sender(), + RoomID: ev.RoomID(), + OriginServerTS: ev.OriginServerTS(), } } @@ -567,3 +632,11 @@ func spaceTargetStripped(event *gomatrixserverlib.MSC2946StrippedEvent) string { } return "" } + +func parseInt(intstr string, defaultVal int) int { + i, err := strconv.ParseInt(intstr, 10, 32) + if err != nil { + return defaultVal + } + return int(i) +} diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go deleted file mode 100644 index e8066c34d..000000000 --- a/setup/mscs/msc2946/msc2946_test.go +++ /dev/null @@ -1,464 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msc2946_test - -import ( - "bytes" - "context" - "crypto/ed25519" - "encoding/json" - "io/ioutil" - "net/http" - "net/url" - "testing" - "time" - - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal/hooks" - "github.com/matrix-org/dendrite/internal/httputil" - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/mscs/msc2946" - userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" -) - -var ( - client = &http.Client{ - Timeout: 10 * time.Second, - } - roomVer = gomatrixserverlib.RoomVersionV6 -) - -// Basic sanity check of MSC2946 logic. Tests a single room with a few state events -// and a bit of recursion to subspaces. Makes a graph like: -// Root -// ____|_____ -// | | | -// R1 R2 S1 -// |_________ -// | | | -// R3 R4 S2 -// | <-- this link is just a parent, not a child -// R5 -// -// Alice is not joined to R4, but R4 is "world_readable". -func TestMSC2946(t *testing.T) { - alice := "@alice:localhost" - // give access token to alice - nopUserAPI := &testUserAPI{ - accessTokens: make(map[string]userapi.Device), - } - nopUserAPI.accessTokens["alice"] = userapi.Device{ - AccessToken: "alice", - DisplayName: "Alice", - UserID: alice, - } - rootSpace := "!rootspace:localhost" - subSpaceS1 := "!subspaceS1:localhost" - subSpaceS2 := "!subspaceS2:localhost" - room1 := "!room1:localhost" - room2 := "!room2:localhost" - room3 := "!room3:localhost" - room4 := "!room4:localhost" - empty := "" - room5 := "!room5:localhost" - allRooms := []string{ - rootSpace, subSpaceS1, subSpaceS2, - room1, room2, room3, room4, room5, - } - rootToR1 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room1, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - rootToR2 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - rootToS1 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &subSpaceS1, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToR3 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room3, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToR4 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room4, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToS2 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &subSpaceS2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - // This is a parent link only - s2ToR5 := mustCreateEvent(t, fledglingEvent{ - RoomID: room5, - Sender: alice, - Type: msc2946.ConstSpaceParentEventType, - StateKey: &subSpaceS2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - // history visibility for R4 - r4HisVis := mustCreateEvent(t, fledglingEvent{ - RoomID: room4, - Sender: "@someone:localhost", - Type: gomatrixserverlib.MRoomHistoryVisibility, - StateKey: &empty, - Content: map[string]interface{}{ - "history_visibility": "world_readable", - }, - }) - var joinEvents []*gomatrixserverlib.HeaderedEvent - for _, roomID := range allRooms { - if roomID == room4 { - continue // not joined to that room - } - joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{ - RoomID: roomID, - Sender: alice, - StateKey: &alice, - Type: gomatrixserverlib.MRoomMember, - Content: map[string]interface{}{ - "membership": "join", - }, - })) - } - roomNameTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.name", - StateKey: "", - } - hisVisTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.history_visibility", - StateKey: "", - } - nopRsAPI := &testRoomserverAPI{ - joinEvents: joinEvents, - events: map[string]*gomatrixserverlib.HeaderedEvent{ - rootToR1.EventID(): rootToR1, - rootToR2.EventID(): rootToR2, - rootToS1.EventID(): rootToS1, - s1ToR3.EventID(): s1ToR3, - s1ToR4.EventID(): s1ToR4, - s1ToS2.EventID(): s1ToS2, - s2ToR5.EventID(): s2ToR5, - r4HisVis.EventID(): r4HisVis, - }, - pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{ - rootSpace: { - roomNameTuple: "Root", - hisVisTuple: "shared", - }, - subSpaceS1: { - roomNameTuple: "Sub-Space 1", - hisVisTuple: "joined", - }, - subSpaceS2: { - roomNameTuple: "Sub-Space 2", - hisVisTuple: "shared", - }, - room1: { - hisVisTuple: "joined", - }, - room2: { - hisVisTuple: "joined", - }, - room3: { - hisVisTuple: "joined", - }, - room4: { - hisVisTuple: "world_readable", - }, - room5: { - hisVisTuple: "joined", - }, - }, - } - allEvents := []*gomatrixserverlib.HeaderedEvent{ - rootToR1, rootToR2, rootToS1, - s1ToR3, s1ToR4, s1ToS2, - s2ToR5, r4HisVis, - } - allEvents = append(allEvents, joinEvents...) - router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents) - cancel := runServer(t, router) - defer cancel() - - t.Run("returns no events for unknown rooms", func(t *testing.T) { - res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{})) - if len(res.Events) > 0 { - t.Errorf("got %d events, want 0", len(res.Events)) - } - if len(res.Rooms) > 0 { - t.Errorf("got %d rooms, want 0", len(res.Rooms)) - } - }) - t.Run("returns the entire graph", func(t *testing.T) { - res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) - if len(res.Events) != 7 { - t.Errorf("got %d events, want 7", len(res.Events)) - } - if len(res.Rooms) != len(allRooms) { - t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)) - } - }) - t.Run("can update the graph", func(t *testing.T) { - // remove R3 from the graph - rmS1ToR3 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room3, - Content: map[string]interface{}{}, // redacted - }) - nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3 - hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3) - - res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) - if len(res.Events) != 6 { // one less since we don't return redacted events - t.Errorf("got %d events, want 6", len(res.Events)) - } - if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3 - t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1) - } - }) -} - -func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest { - t.Helper() - b, err := json.Marshal(jsonBody) - if err != nil { - t.Fatalf("Failed to marshal request: %s", err) - } - var r gomatrixserverlib.MSC2946SpacesRequest - if err := json.Unmarshal(b, &r); err != nil { - t.Fatalf("Failed to unmarshal request: %s", err) - } - return &r -} - -func runServer(t *testing.T, router *mux.Router) func() { - t.Helper() - externalServ := &http.Server{ - Addr: string(":8010"), - WriteTimeout: 60 * time.Second, - Handler: router, - } - go func() { - externalServ.ListenAndServe() - }() - // wait to listen on the port - time.Sleep(500 * time.Millisecond) - return func() { - externalServ.Shutdown(context.TODO()) - } -} - -func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse { - t.Helper() - var r gomatrixserverlib.MSC2946SpacesRequest - msc2946.Defaults(&r) - data, err := json.Marshal(req) - if err != nil { - t.Fatalf("failed to marshal request: %s", err) - } - httpReq, err := http.NewRequest( - "POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces", - bytes.NewBuffer(data), - ) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - if err != nil { - t.Fatalf("failed to prepare request: %s", err) - } - res, err := client.Do(httpReq) - if err != nil { - t.Fatalf("failed to do request: %s", err) - } - if res.StatusCode != expectCode { - body, _ := ioutil.ReadAll(res.Body) - t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body)) - } - if res.StatusCode == 200 { - var result gomatrixserverlib.MSC2946SpacesResponse - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("response 200 OK but failed to read response body: %s", err) - } - t.Logf("Body: %s", string(body)) - if err := json.Unmarshal(body, &result); err != nil { - t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body)) - } - return &result - } - return nil -} - -type testUserAPI struct { - userapi.UserInternalAPITrace - accessTokens map[string]userapi.Device -} - -func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { - dev, ok := u.accessTokens[req.AccessToken] - if !ok { - res.Err = "unknown token" - return nil - } - res.Device = &dev - return nil -} - -type testRoomserverAPI struct { - // use a trace API as it implements method stubs so we don't need to have them here. - // We'll override the functions we care about. - roomserver.RoomserverInternalAPITrace - joinEvents []*gomatrixserverlib.HeaderedEvent - events map[string]*gomatrixserverlib.HeaderedEvent - pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string -} - -func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error { - res.IsInRoom = true - res.RoomExists = true - return nil -} - -func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error { - res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) - for _, roomID := range req.RoomIDs { - pubRoomData, ok := r.pubRoomState[roomID] - if ok { - res.Rooms[roomID] = pubRoomData - } - } - return nil -} - -func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error { - res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) - checkEvent := func(he *gomatrixserverlib.HeaderedEvent) { - if he.RoomID() != req.RoomID { - return - } - if he.StateKey() == nil { - return - } - tuple := gomatrixserverlib.StateKeyTuple{ - EventType: he.Type(), - StateKey: *he.StateKey(), - } - for _, t := range req.StateTuples { - if t == tuple { - res.StateEvents[t] = he - } - } - } - for _, he := range r.joinEvents { - checkEvent(he) - } - for _, he := range r.events { - checkEvent(he) - } - return nil -} - -func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router { - t.Helper() - cfg := &config.Dendrite{} - cfg.Defaults(true) - cfg.Global.ServerName = "localhost" - cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db" - cfg.MSCs.MSCs = []string{"msc2946"} - base := &base.BaseDendrite{ - Cfg: cfg, - PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(), - PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(), - } - - err := msc2946.Enable(base, rsAPI, userAPI, nil, nil) - if err != nil { - t.Fatalf("failed to enable MSC2946: %s", err) - } - for _, ev := range events { - hooks.Run(hooks.KindNewEventPersisted, ev) - } - return base.PublicClientAPIMux -} - -type fledglingEvent struct { - Type string - StateKey *string - Content interface{} - Sender string - RoomID string -} - -func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { - t.Helper() - seed := make([]byte, ed25519.SeedSize) // zero seed - key := ed25519.NewKeyFromSeed(seed) - eb := gomatrixserverlib.EventBuilder{ - Sender: ev.Sender, - Depth: 999, - Type: ev.Type, - StateKey: ev.StateKey, - RoomID: ev.RoomID, - } - err := eb.SetContent(ev.Content) - if err != nil { - t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) - } - // make sure the origin_server_ts changes so we can test recency - time.Sleep(1 * time.Millisecond) - signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer) - if err != nil { - t.Fatalf("mustCreateEvent: failed to sign event: %s", err) - } - h := signedEvent.Headered(roomVer) - return h -} diff --git a/setup/mscs/msc2946/storage.go b/setup/mscs/msc2946/storage.go deleted file mode 100644 index 20db18594..000000000 --- a/setup/mscs/msc2946/storage.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msc2946 - -import ( - "context" - "database/sql" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" -) - -var ( - relTypes = map[string]int{ - ConstSpaceChildEventType: 1, - ConstSpaceParentEventType: 2, - } -) - -type Database interface { - // StoreReference persists a child or parent space mapping. - StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error - // References returns all events which have the given roomID as a parent or child space. - References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) -} - -type DB struct { - db *sql.DB - writer sqlutil.Writer - insertEdgeStmt *sql.Stmt - selectEdgesStmt *sql.Stmt -} - -// NewDatabase loads the database for msc2836 -func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - if dbOpts.ConnectionString.IsPostgres() { - return newPostgresDatabase(dbOpts) - } - return newSQLiteDatabase(dbOpts) -} - -func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewDummyWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { - return nil, err - } - _, err = d.db.Exec(` - CREATE TABLE IF NOT EXISTS msc2946_edges ( - room_version TEXT NOT NULL, - -- the room ID of the event, the source of the arrow - source_room_id TEXT NOT NULL, - -- the target room ID, the arrow destination - dest_room_id TEXT NOT NULL, - -- the kind of relation, either child or parent (1,2) - rel_type SMALLINT NOT NULL, - event_json TEXT NOT NULL, - CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type) - ); - `) - if err != nil { - return nil, err - } - if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5 - `); err != nil { - return nil, err - } - if d.selectEdgesStmt, err = d.db.Prepare(` - SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 - `); err != nil { - return nil, err - } - return &d, err -} - -func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewExclusiveWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { - return nil, err - } - _, err = d.db.Exec(` - CREATE TABLE IF NOT EXISTS msc2946_edges ( - room_version TEXT NOT NULL, - -- the room ID of the event, the source of the arrow - source_room_id TEXT NOT NULL, - -- the target room ID, the arrow destination - dest_room_id TEXT NOT NULL, - -- the kind of relation, either child or parent (1,2) - rel_type SMALLINT NOT NULL, - event_json TEXT NOT NULL, - UNIQUE (source_room_id, dest_room_id, rel_type) - ); - `) - if err != nil { - return nil, err - } - if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5 - `); err != nil { - return nil, err - } - if d.selectEdgesStmt, err = d.db.Prepare(` - SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 - `); err != nil { - return nil, err - } - return &d, err -} - -func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error { - target := SpaceTarget(he) - if target == "" { - return nil // malformed event - } - relType := relTypes[he.Type()] - _, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON()) - return err -} - -func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) { - rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "failed to close References") - refs := make([]*gomatrixserverlib.HeaderedEvent, 0) - for rows.Next() { - var roomVer string - var jsonBytes []byte - if err := rows.Scan(&roomVer, &jsonBytes); err != nil { - return nil, err - } - ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer)) - if err != nil { - return nil, err - } - he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer)) - refs = append(refs, he) - } - return refs, nil -} - -// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent -// depending on the event type. -func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string { - if he.StateKey() == nil { - return "" // no-op - } - switch he.Type() { - case ConstSpaceParentEventType: - return *he.StateKey() - case ConstSpaceChildEventType: - return *he.StateKey() - } - return "" -} From 1a79060b46d3b9c6dc0083fd7b20a2471f795a90 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 1 Mar 2022 14:16:47 +0000 Subject: [PATCH 07/33] Bump GMSL version --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index e44e7393d..ddf3f2175 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index 43b363d30..74a181528 100644 --- a/go.sum +++ b/go.sum @@ -985,6 +985,8 @@ github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 h1:CqbM5tPbF1mV2h1J6N0NSF8et6HvFlwA98TLCUcjiSI= github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= From 18e3c40da43017a9e1abe775a406844aca5e1643 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 1 Mar 2022 14:22:59 +0000 Subject: [PATCH 08/33] Always send [] from federated rooms, not null --- setup/mscs/msc2946/msc2946.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 3bb56f4b3..bf26e9910 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -426,6 +426,18 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserve util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue } + // ensure nil slices are empty as we send this to the client sometimes + if res.Room.ChildrenState == nil { + res.Room.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{} + } + for i := 0; i < len(res.Children); i++ { + child := res.Children[i] + if child.ChildrenState == nil { + child.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{} + } + res.Children[i] = child + } + return &res, nil } return nil, nil From 471fda810a24f8945f5da3a76d2730a1ffe80166 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 14:39:06 +0000 Subject: [PATCH 09/33] Remove unnecessary error line (#2237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously this error line would print because we were pulling out all user memberships, but now this is no longer necessary — an event state key that we don't know will no longer get passed to `SelectJoinedUsersSetForRooms` at all. --- roomserver/storage/shared/storage.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index f87782776..dd49f35ca 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -13,7 +13,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) @@ -1186,9 +1185,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ stateKeyNIDs[i] = nid i++ } - if len(nidToUserID) != len(userNIDToCount) { - logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID)) - } result := make(map[string]int, len(userNIDToCount)) for nid, count := range userNIDToCount { result[nidToUserID[nid]] = count From af610df85a6a2145d5ddcd6419bb8085ae224207 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Tue, 1 Mar 2022 15:39:56 +0100 Subject: [PATCH 10/33] Return state on calls to /message and lazy load members (#2218) Co-authored-by: Neil Alexander --- syncapi/routing/context.go | 4 +-- syncapi/routing/context_test.go | 6 ++-- syncapi/routing/messages.go | 61 +++++++++++++++++++-------------- sytest-whitelist | 1 + 4 files changed, 42 insertions(+), 30 deletions(-) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index ef7efa2b0..591139717 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -44,7 +44,7 @@ func Context( syncDB storage.Database, roomID, eventID string, ) util.JSONResponse { - filter, err := parseContextParams(req) + filter, err := parseRoomEventFilter(req) if err != nil { errMsg := "" switch err.(type) { @@ -164,7 +164,7 @@ func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter return newState } -func parseContextParams(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { +func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { // Default room filter filter := &gomatrixserverlib.RoomEventFilter{Limit: 10} diff --git a/syncapi/routing/context_test.go b/syncapi/routing/context_test.go index 1b430d83a..e79a5d5f1 100644 --- a/syncapi/routing/context_test.go +++ b/syncapi/routing/context_test.go @@ -55,13 +55,13 @@ func Test_parseContextParams(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotFilter, err := parseContextParams(tt.req) + gotFilter, err := parseRoomEventFilter(tt.req) if (err != nil) != tt.wantErr { - t.Errorf("parseContextParams() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("parseRoomEventFilter() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(gotFilter, tt.wantFilter) { - t.Errorf("parseContextParams() gotFilter = %v, want %v", gotFilter, tt.wantFilter) + t.Errorf("parseRoomEventFilter() gotFilter = %v, want %v", gotFilter, tt.wantFilter) } }) } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 9bb8c6d24..7cd54eef0 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -19,7 +19,6 @@ import ( "fmt" "net/http" "sort" - "strconv" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" @@ -45,8 +44,8 @@ type messagesReq struct { fromStream *types.StreamingToken device *userapi.Device wasToProvided bool - limit int backwardOrdering bool + filter *gomatrixserverlib.RoomEventFilter } type messagesResp struct { @@ -54,10 +53,9 @@ type messagesResp struct { StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token End string `json:"end"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` + State []gomatrixserverlib.ClientEvent `json:"state"` } -const defaultMessagesLimit = 10 - // OnIncomingMessagesRequest implements the /messages endpoint from the // client-server API. // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages @@ -83,6 +81,14 @@ func OnIncomingMessagesRequest( } } + filter, err := parseRoomEventFilter(req) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unable to parse filter"), + } + } + // Extract parameters from the request's URL. // Pagination tokens. var fromStream *types.StreamingToken @@ -143,18 +149,6 @@ func OnIncomingMessagesRequest( wasToProvided = false } - // Maximum number of events to return; defaults to 10. - limit := defaultMessagesLimit - if len(req.URL.Query().Get("limit")) > 0 { - limit, err = strconv.Atoi(req.URL.Query().Get("limit")) - - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()), - } - } - } // TODO: Implement filtering (#587) // Check the room ID's format. @@ -176,7 +170,7 @@ func OnIncomingMessagesRequest( to: &to, fromStream: fromStream, wasToProvided: wasToProvided, - limit: limit, + filter: filter, backwardOrdering: backwardOrdering, device: device, } @@ -187,10 +181,27 @@ func OnIncomingMessagesRequest( return jsonerror.InternalServerError() } + // at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set + state := []gomatrixserverlib.ClientEvent{} + if filter.LazyLoadMembers { + memberShipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent) + for _, evt := range clientEvents { + memberShip, err := db.GetStateEvent(req.Context(), roomID, gomatrixserverlib.MRoomMember, evt.Sender) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to get membership event for user") + continue + } + memberShipToUser[evt.Sender] = memberShip + } + for _, evt := range memberShipToUser { + state = append(state, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll)) + } + } + util.GetLogger(req.Context()).WithFields(logrus.Fields{ "from": from.String(), "to": to.String(), - "limit": limit, + "limit": filter.Limit, "backwards": backwardOrdering, "return_start": start.String(), "return_end": end.String(), @@ -200,6 +211,7 @@ func OnIncomingMessagesRequest( Chunk: clientEvents, Start: start.String(), End: end.String(), + State: state, } if emptyFromSupplied { res.StartStream = fromStream.String() @@ -234,19 +246,18 @@ func (r *messagesReq) retrieveEvents() ( clientEvents []gomatrixserverlib.ClientEvent, start, end types.TopologyToken, err error, ) { - eventFilter := gomatrixserverlib.DefaultRoomEventFilter() - eventFilter.Limit = r.limit + eventFilter := r.filter // Retrieve the events from the local database. var streamEvents []types.StreamEvent if r.fromStream != nil { toStream := r.to.StreamToken() streamEvents, err = r.db.GetEventsInStreamingRange( - r.ctx, r.fromStream, &toStream, r.roomID, &eventFilter, r.backwardOrdering, + r.ctx, r.fromStream, &toStream, r.roomID, eventFilter, r.backwardOrdering, ) } else { streamEvents, err = r.db.GetEventsInTopologicalRange( - r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, + r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering, ) } if err != nil { @@ -434,7 +445,7 @@ func (r *messagesReq) handleEmptyEventsSlice() ( // Check if we have backward extremities for this room. if len(backwardExtremities) > 0 { // If so, retrieve as much events as needed through backfilling. - events, err = r.backfill(r.roomID, backwardExtremities, r.limit) + events, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit) if err != nil { return } @@ -456,7 +467,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent events []*gomatrixserverlib.HeaderedEvent, err error, ) { // Check if we have enough events. - isSetLargeEnough := len(streamEvents) >= r.limit + isSetLargeEnough := len(streamEvents) >= r.filter.Limit if !isSetLargeEnough { // it might be fine we don't have up to 'limit' events, let's find out if r.backwardOrdering { @@ -483,7 +494,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering { var pdus []*gomatrixserverlib.HeaderedEvent // Only ask the remote server for enough events to reach the limit. - pdus, err = r.backfill(r.roomID, backwardExtremities, r.limit-len(streamEvents)) + pdus, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit-len(streamEvents)) if err != nil { return } diff --git a/sytest-whitelist b/sytest-whitelist index 12522cfb3..2dd56d260 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -597,6 +597,7 @@ Device list doesn't change if remote server is down /context/ on non world readable room does not work /context/ returns correct number of events /context/ with lazy_load_members filter works +GET /rooms/:room_id/messages lazy loads members correctly Can query remote device keys using POST after notification Device deletion propagates over federation Get left notifs in sync and /keys/changes when other user leaves From 8dfc958dddc5f7125633f9b5cf55371ab35ceb6c Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 1 Mar 2022 14:40:47 +0000 Subject: [PATCH 11/33] Also don't send null back when the target room isn't a space room --- setup/mscs/msc2946/msc2946.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index bf26e9910..f53e7b249 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -572,7 +572,7 @@ func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946Stri // escape the `.`s so gjson doesn't think it's nested roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str if roomType != ConstCreateEventContentValueSpace { - return nil, nil + return []gomatrixserverlib.MSC2946StrippedEvent{}, nil } } delete(res.StateEvents, createTuple) From ae840590b643bd74786f3e8a869a8ea9e769995b Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 1 Mar 2022 16:03:54 +0000 Subject: [PATCH 12/33] Make complement go fast (#2240) --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 124940f71..4a1720295 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -63,7 +63,7 @@ jobs: # Run Complement - run: | set -o pipefail && - go test -v -p 1 -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt + go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: From 352e63915f110cbe4907349a7e59f43f179657e6 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 1 Mar 2022 16:32:48 +0000 Subject: [PATCH 13/33] msc2946: add federation cache (#2238) --- internal/caching/cache_space_rooms.go | 28 +++++++++++++++++++++++++ internal/caching/caches.go | 1 + internal/caching/impl_inmemorylru.go | 12 ++++++++++- setup/mscs/msc2946/msc2946.go | 30 ++++++++++++++++----------- setup/mscs/mscs.go | 2 +- 5 files changed, 59 insertions(+), 14 deletions(-) create mode 100644 internal/caching/cache_space_rooms.go diff --git a/internal/caching/cache_space_rooms.go b/internal/caching/cache_space_rooms.go new file mode 100644 index 000000000..71b81abcf --- /dev/null +++ b/internal/caching/cache_space_rooms.go @@ -0,0 +1,28 @@ +package caching + +import "github.com/matrix-org/gomatrixserverlib" + +const ( + SpaceSummaryRoomsCacheName = "space_summary_rooms" + SpaceSummaryRoomsCacheMaxEntries = 100 + SpaceSummaryRoomsCacheMutable = true +) + +type SpaceSummaryRoomsCache interface { + GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) + StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) +} + +func (c Caches) GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) { + val, found := c.SpaceSummaryRooms.Get(roomID) + if found && val != nil { + if resp, ok := val.(gomatrixserverlib.MSC2946SpacesResponse); ok { + return resp, true + } + } + return r, false +} + +func (c Caches) StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) { + c.SpaceSummaryRooms.Set(roomID, r) +} diff --git a/internal/caching/caches.go b/internal/caching/caches.go index e1642a663..39926d735 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -10,6 +10,7 @@ type Caches struct { RoomServerRoomIDs Cache // RoomServerNIDsCache RoomInfos Cache // RoomInfoCache FederationEvents Cache // FederationEventsCache + SpaceSummaryRooms Cache // SpaceSummaryRoomsCache } // Cache is the interface that an implementation must satisfy. diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index ccb92852b..7b21f1362 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -55,9 +55,18 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { if err != nil { return nil, err } + spaceRooms, err := NewInMemoryLRUCachePartition( + SpaceSummaryRoomsCacheName, + SpaceSummaryRoomsCacheMutable, + SpaceSummaryRoomsCacheMaxEntries, + enablePrometheus, + ) + if err != nil { + return nil, err + } go cacheCleaner( roomVersions, serverKeys, roomServerRoomIDs, - roomInfos, federationEvents, + roomInfos, federationEvents, spaceRooms, ) return &Caches{ RoomVersions: roomVersions, @@ -65,6 +74,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomServerRoomIDs: roomServerRoomIDs, RoomInfos: roomInfos, FederationEvents: federationEvents, + SpaceSummaryRooms: spaceRooms, }, nil } diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index f53e7b249..7ab50c32e 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -30,6 +30,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -54,9 +55,9 @@ type MSC2946ClientResponse struct { // Enable this MSC func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, - fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, + fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache, ) error { - clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, base.Cfg.Global.ServerName)) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName)) base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) @@ -74,7 +75,7 @@ func Enable( return util.ErrorResponse(err) } roomID := params["roomID"] - return federatedSpacesHandler(req.Context(), fedReq, roomID, rsAPI, fsAPI, base.Cfg.Global.ServerName) + return federatedSpacesHandler(req.Context(), fedReq, roomID, cache, rsAPI, fsAPI, base.Cfg.Global.ServerName) }, ) base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) @@ -84,6 +85,7 @@ func Enable( func federatedSpacesHandler( ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, + cache caching.SpaceSummaryRoomsCache, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) util.JSONResponse { @@ -100,6 +102,7 @@ func federatedSpacesHandler( serverName: fedReq.Origin(), thisServer: thisServer, ctx: ctx, + cache: cache, suggestedOnly: u.Query().Get("suggested_only") == "true", limit: 1000, // The main difference is that it does not recurse into spaces and does not support pagination. @@ -115,7 +118,9 @@ func federatedSpacesHandler( } func spacesHandler( - rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, + rsAPI roomserver.RoomserverInternalAPI, + fsAPI fs.FederationInternalAPI, + cache caching.SpaceSummaryRoomsCache, thisServer gomatrixserverlib.ServerName, ) func(*http.Request, *userapi.Device) util.JSONResponse { // declared outside the returned handler so it persists between calls @@ -138,6 +143,7 @@ func spacesHandler( caller: device, thisServer: thisServer, ctx: req.Context(), + cache: cache, rsAPI: rsAPI, fsAPI: fsAPI, @@ -160,6 +166,7 @@ type walker struct { rsAPI roomserver.RoomserverInternalAPI fsAPI fs.FederationInternalAPI ctx context.Context + cache caching.SpaceSummaryRoomsCache suggestedOnly bool limit int maxDepth int @@ -169,13 +176,6 @@ type walker struct { mu sync.Mutex } -func (w *walker) callerID() string { - if w.caller != nil { - return w.caller.UserID + "|" + w.caller.ID - } - return string(w.serverName) -} - func (w *walker) newPaginationCache() (string, paginationInfo) { p := paginationInfo{ processed: make(set), @@ -414,7 +414,12 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserve if w.caller == nil { return nil, nil } - util.GetLogger(w.ctx).Infof("Querying %s via %+v", roomID, vias) + resp, ok := w.cache.GetSpaceSummary(roomID) + if ok { + util.GetLogger(w.ctx).Debugf("Returning cached response for %s", roomID) + return &resp, nil + } + util.GetLogger(w.ctx).Debugf("Querying %s via %+v", roomID, vias) ctx := context.Background() // query more of the spaces graph using these servers for _, serverName := range vias { @@ -437,6 +442,7 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserve } res.Children[i] = child } + w.cache.StoreSpaceSummary(roomID, res) return &res, nil } diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go index 10df6a15d..35b7bba3b 100644 --- a/setup/mscs/mscs.go +++ b/setup/mscs/mscs.go @@ -42,7 +42,7 @@ func EnableMSC(base *base.BaseDendrite, monolith *setup.Monolith, msc string) er case "msc2836": return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) case "msc2946": - return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing) + return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing, base.Caches) case "msc2444": // enabled inside federationapi case "msc2753": // enabled inside clientapi default: From cda2452ba00afffa9a73870ca09047ce52dd28c7 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Tue, 1 Mar 2022 17:39:57 +0100 Subject: [PATCH 14/33] Only allow device deletion from session UIA was initiated from (#2235) * Only allow device deletion if the session matches * Make the challenge response available to other packages * Remove userID, as it's not in the spec * Remove tests * Add passing test & remove obsolete config * Rename field, add comment Co-authored-by: Neil Alexander --- clientapi/auth/user_interactive.go | 24 ++++++++++++---------- clientapi/routing/device.go | 33 ++++++++++++++++++++++++++++++ clientapi/routing/register.go | 26 ++++++++++++++++++++--- clientapi/routing/register_test.go | 10 +++++++++ dendrite-config.yaml | 5 ----- sytest-whitelist | 2 ++ 6 files changed, 81 insertions(+), 19 deletions(-) diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 9cab7956c..4db75809f 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -144,21 +144,23 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) { delete(u.Sessions, sessionID) } +type Challenge struct { + Completed []string `json:"completed"` + Flows []userInteractiveFlow `json:"flows"` + Session string `json:"session"` + // TODO: Return any additional `params` + Params map[string]interface{} `json:"params"` +} + // Challenge returns an HTTP 401 with the supported flows for authenticating func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse { return &util.JSONResponse{ Code: 401, - JSON: struct { - Completed []string `json:"completed"` - Flows []userInteractiveFlow `json:"flows"` - Session string `json:"session"` - // TODO: Return any additional `params` - Params map[string]interface{} `json:"params"` - }{ - u.Completed, - u.Flows, - sessionID, - make(map[string]interface{}), + JSON: Challenge{ + Completed: u.Completed, + Flows: u.Flows, + Session: sessionID, + Params: make(map[string]interface{}), }, } } diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 9f54a625a..161bc2731 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/tidwall/gjson" ) // https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices @@ -163,6 +164,15 @@ func DeleteDeviceById( req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { + var ( + deleteOK bool + sessionID string + ) + defer func() { + if deleteOK { + sessions.deleteSession(sessionID) + } + }() ctx := req.Context() defer req.Body.Close() // nolint:errcheck bodyBytes, err := ioutil.ReadAll(req.Body) @@ -172,8 +182,29 @@ func DeleteDeviceById( JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), } } + + // check that we know this session, and it matches with the device to delete + s := gjson.GetBytes(bodyBytes, "auth.session").Str + if dev, ok := sessions.getDeviceToDelete(s); ok { + if dev != deviceID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("session & device mismatch"), + } + } + } + + if s != "" { + sessionID = s + } + login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device) if errRes != nil { + switch data := errRes.JSON.(type) { + case auth.Challenge: + sessions.addDeviceToDelete(data.Session, deviceID) + default: + } return *errRes } @@ -201,6 +232,8 @@ func DeleteDeviceById( return jsonerror.InternalServerError() } + deleteOK = true + return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 10cfa4325..c97876594 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -76,6 +76,10 @@ type sessionsDict struct { sessions map[string][]authtypes.LoginType params map[string]registerRequest timer map[string]*time.Timer + // deleteSessionToDeviceID protects requests to DELETE /devices/{deviceID} from being abused. + // If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2, + // the delete request will fail for device2 since the UIA was initiated by trying to delete device1. + deleteSessionToDeviceID map[string]string } // defaultTimeout is the timeout used to clean up sessions @@ -115,6 +119,7 @@ func (d *sessionsDict) deleteSession(sessionID string) { defer d.Unlock() delete(d.params, sessionID) delete(d.sessions, sessionID) + delete(d.deleteSessionToDeviceID, sessionID) // stop the timer, e.g. because the registration was completed if t, ok := d.timer[sessionID]; ok { if !t.Stop() { @@ -129,9 +134,10 @@ func (d *sessionsDict) deleteSession(sessionID string) { func newSessionsDict() *sessionsDict { return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), - params: make(map[string]registerRequest), - timer: make(map[string]*time.Timer), + sessions: make(map[string][]authtypes.LoginType), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + deleteSessionToDeviceID: make(map[string]string), } } @@ -165,6 +171,20 @@ func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtype d.sessions[sessionID] = append(sessions.sessions[sessionID], stage) } +func (d *sessionsDict) addDeviceToDelete(sessionID, deviceID string) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + d.deleteSessionToDeviceID[sessionID] = deviceID +} + +func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) { + d.RLock() + defer d.RUnlock() + deviceID, ok := d.deleteSessionToDeviceID[sessionID] + return deviceID, ok +} + var ( sessions = newSessionsDict() validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index c6b7e61cf..520f5ddeb 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -242,6 +242,7 @@ func TestSessionCleanUp(t *testing.T) { s.addParams(dummySession, registerRequest{Username: "Testing"}) s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha) s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy) + s.addDeviceToDelete(dummySession, "dummyDevice") s.getCompletedStages(dummySession) // reset the timer with a lower timeout s.startTimer(time.Millisecond, dummySession) @@ -249,5 +250,14 @@ func TestSessionCleanUp(t *testing.T) { if data, ok := s.getParams(dummySession); ok { t.Errorf("expected session to be deleted: %+v", data) } + if _, ok := s.timer[dummySession]; ok { + t.Error("expected timer to be delete") + } + if _, ok := s.sessions[dummySession]; ok { + t.Error("expected session to be delete") + } + if _, ok := s.getDeviceToDelete(dummySession); ok { + t.Error("expected session to device to be delete") + } }) } \ No newline at end of file diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 6d086ed77..533b5c952 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -345,11 +345,6 @@ user_api: max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 - device_database: - connection_string: file:userapi_devices.db - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 # The length of time that a token issued for a relying party from # /_matrix/client/r0/user/{userId}/openid/request_token endpoint # is considered to be valid in milliseconds. diff --git a/sytest-whitelist b/sytest-whitelist index 2dd56d260..0cb80b7b6 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -605,4 +605,6 @@ Remote banned user is kicked and may not rejoin until unbanned registration remembers parameters registration accepts non-ascii passwords registration with inhibit_login inhibits login +The operation must be consistent through an interactive authentication session +Multiple calls to /sync should not cause 500 errors From 726529fe996519c93f4f329c03a968a432b0bb0e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 16:59:11 +0000 Subject: [PATCH 15/33] Hopefully fix read receipts (#2241) --- syncapi/storage/postgres/receipt_table.go | 2 +- syncapi/storage/sqlite3/receipt_table.go | 2 +- syncapi/streams/stream_receipt.go | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index f93081e1a..474d0c020 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -96,7 +96,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room } func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) { - lastPos := streamPos + var lastPos types.StreamPosition rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) if err != nil { return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index 6b39ee879..9111a39f6 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -101,7 +101,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) { selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) - lastPos := streamPos + var lastPos types.StreamPosition params := make([]interface{}, len(roomIDs)+1) params[0] = streamPos for k, v := range roomIDs { diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index cccadb525..35ffd3a1e 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -63,7 +63,6 @@ func (p *ReceiptStreamProvider) IncrementalSync( if existing, ok := req.Response.Rooms.Join[roomID]; ok { jr = existing } - var ok bool ev := gomatrixserverlib.ClientEvent{ Type: gomatrixserverlib.MReceipt, @@ -71,8 +70,8 @@ func (p *ReceiptStreamProvider) IncrementalSync( } content := make(map[string]eduAPI.ReceiptMRead) for _, receipt := range receipts { - var read eduAPI.ReceiptMRead - if read, ok = content[receipt.EventID]; !ok { + read, ok := content[receipt.EventID] + if !ok { read = eduAPI.ReceiptMRead{ User: make(map[string]eduAPI.ReceiptTS), } From bb2380c254b65a6586137e9e0ab9e08354aa19f4 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 16:59:52 +0000 Subject: [PATCH 16/33] Allow specifying max age for caches (#2239) * Allow specifying max age for caches * Evict cache entry if it's found to be stale when we call Get * Fix bugs --- internal/caching/cache_federationevents.go | 1 + internal/caching/cache_roominfo.go | 3 ++ internal/caching/cache_roomservernids.go | 1 + internal/caching/cache_roomversions.go | 1 + internal/caching/cache_serverkeys.go | 1 + internal/caching/caches.go | 4 +++ internal/caching/impl_inmemorylru.go | 42 +++++++++++++++++++--- 7 files changed, 48 insertions(+), 5 deletions(-) diff --git a/internal/caching/cache_federationevents.go b/internal/caching/cache_federationevents.go index d10b333a6..b79cc809f 100644 --- a/internal/caching/cache_federationevents.go +++ b/internal/caching/cache_federationevents.go @@ -10,6 +10,7 @@ const ( FederationEventCacheName = "federation_event" FederationEventCacheMaxEntries = 256 FederationEventCacheMutable = true // to allow use of Unset only + FederationEventCacheMaxAge = CacheNoMaxAge ) // FederationCache contains the subset of functions needed for diff --git a/internal/caching/cache_roominfo.go b/internal/caching/cache_roominfo.go index f32d6ba9b..60d221285 100644 --- a/internal/caching/cache_roominfo.go +++ b/internal/caching/cache_roominfo.go @@ -1,6 +1,8 @@ package caching import ( + "time" + "github.com/matrix-org/dendrite/roomserver/types" ) @@ -16,6 +18,7 @@ const ( RoomInfoCacheName = "roominfo" RoomInfoCacheMaxEntries = 1024 RoomInfoCacheMutable = true + RoomInfoCacheMaxAge = time.Minute * 5 ) // RoomInfosCache contains the subset of functions needed for diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go index 6d413093f..1918a2f1e 100644 --- a/internal/caching/cache_roomservernids.go +++ b/internal/caching/cache_roomservernids.go @@ -10,6 +10,7 @@ const ( RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMutable = false + RoomServerRoomIDsCacheMaxAge = CacheNoMaxAge ) type RoomServerCaches interface { diff --git a/internal/caching/cache_roomversions.go b/internal/caching/cache_roomversions.go index 0b46d3d4b..92d2eab08 100644 --- a/internal/caching/cache_roomversions.go +++ b/internal/caching/cache_roomversions.go @@ -6,6 +6,7 @@ const ( RoomVersionCacheName = "room_versions" RoomVersionCacheMaxEntries = 1024 RoomVersionCacheMutable = false + RoomVersionCacheMaxAge = CacheNoMaxAge ) // RoomVersionsCache contains the subset of functions needed for diff --git a/internal/caching/cache_serverkeys.go b/internal/caching/cache_serverkeys.go index 4697fb4d2..4eb10fe6f 100644 --- a/internal/caching/cache_serverkeys.go +++ b/internal/caching/cache_serverkeys.go @@ -10,6 +10,7 @@ const ( ServerKeyCacheName = "server_key" ServerKeyCacheMaxEntries = 4096 ServerKeyCacheMutable = true + ServerKeyCacheMaxAge = CacheNoMaxAge ) // ServerKeyCache contains the subset of functions needed for diff --git a/internal/caching/caches.go b/internal/caching/caches.go index 39926d735..722405de6 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -1,5 +1,7 @@ package caching +import "time" + // Caches contains a set of references to caches. They may be // different implementations as long as they satisfy the Cache // interface. @@ -19,3 +21,5 @@ type Cache interface { Set(key string, value interface{}) Unset(key string) } + +const CacheNoMaxAge = time.Duration(0) diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index 7b21f1362..c9c8fd08d 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -14,6 +14,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomVersionCacheName, RoomVersionCacheMutable, RoomVersionCacheMaxEntries, + RoomVersionCacheMaxAge, enablePrometheus, ) if err != nil { @@ -23,6 +24,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { ServerKeyCacheName, ServerKeyCacheMutable, ServerKeyCacheMaxEntries, + ServerKeyCacheMaxAge, enablePrometheus, ) if err != nil { @@ -32,6 +34,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomServerRoomIDsCacheName, RoomServerRoomIDsCacheMutable, RoomServerRoomIDsCacheMaxEntries, + RoomServerRoomIDsCacheMaxAge, enablePrometheus, ) if err != nil { @@ -41,6 +44,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomInfoCacheName, RoomInfoCacheMutable, RoomInfoCacheMaxEntries, + RoomInfoCacheMaxAge, enablePrometheus, ) if err != nil { @@ -50,6 +54,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { FederationEventCacheName, FederationEventCacheMutable, FederationEventCacheMaxEntries, + FederationEventCacheMaxAge, enablePrometheus, ) if err != nil { @@ -96,15 +101,22 @@ type InMemoryLRUCachePartition struct { name string mutable bool maxEntries int + maxAge time.Duration lru *lru.Cache } -func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, enablePrometheus bool) (*InMemoryLRUCachePartition, error) { +type inMemoryLRUCacheEntry struct { + value interface{} + created time.Time +} + +func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, maxAge time.Duration, enablePrometheus bool) (*InMemoryLRUCachePartition, error) { var err error cache := InMemoryLRUCachePartition{ name: name, mutable: mutable, maxEntries: maxEntries, + maxAge: maxAge, } cache.lru, err = lru.New(maxEntries) if err != nil { @@ -124,11 +136,16 @@ func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, ena func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) { if !c.mutable { - if peek, ok := c.lru.Peek(key); ok && peek != value { - panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key)) + if peek, ok := c.lru.Peek(key); ok { + if entry, ok := peek.(*inMemoryLRUCacheEntry); ok && entry.value != value { + panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key)) + } } } - c.lru.Add(key, value) + c.lru.Add(key, &inMemoryLRUCacheEntry{ + value: value, + created: time.Now(), + }) } func (c *InMemoryLRUCachePartition) Unset(key string) { @@ -139,5 +156,20 @@ func (c *InMemoryLRUCachePartition) Unset(key string) { } func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) { - return c.lru.Get(key) + v, ok := c.lru.Get(key) + if !ok { + return nil, false + } + entry, ok := v.(*inMemoryLRUCacheEntry) + switch { + case ok && c.maxAge == CacheNoMaxAge: + return entry.value, ok // There's no maximum age policy + case ok && time.Since(entry.created) < c.maxAge: + return entry.value, ok // The value for the key isn't stale + default: + // Either the key was found and it was stale, or the key + // wasn't found at all + c.lru.Remove(key) + return nil, false + } } From 8e82739d77ea472932fd2ff7800ce744ace53bfa Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 17:01:08 +0000 Subject: [PATCH 17/33] Set max age of 5 minutes for spaces summary cache --- internal/caching/cache_space_rooms.go | 7 ++++++- internal/caching/impl_inmemorylru.go | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/internal/caching/cache_space_rooms.go b/internal/caching/cache_space_rooms.go index 71b81abcf..6d56cce5f 100644 --- a/internal/caching/cache_space_rooms.go +++ b/internal/caching/cache_space_rooms.go @@ -1,11 +1,16 @@ package caching -import "github.com/matrix-org/gomatrixserverlib" +import ( + "time" + + "github.com/matrix-org/gomatrixserverlib" +) const ( SpaceSummaryRoomsCacheName = "space_summary_rooms" SpaceSummaryRoomsCacheMaxEntries = 100 SpaceSummaryRoomsCacheMutable = true + SpaceSummaryRoomsCacheMaxAge = time.Minute * 5 ) type SpaceSummaryRoomsCache interface { diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index c9c8fd08d..94fdd1a9b 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -64,6 +64,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { SpaceSummaryRoomsCacheName, SpaceSummaryRoomsCacheMutable, SpaceSummaryRoomsCacheMaxEntries, + SpaceSummaryRoomsCacheMaxAge, enablePrometheus, ) if err != nil { From 23f028cf6e492f9e176b788797f2da2d274c1705 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 1 Mar 2022 17:18:06 +0000 Subject: [PATCH 18/33] Add unit test for device list update debouncing (#2220) * Add unit test for device list update debouncing * bugfix: actually return stale device lists in the test... Co-authored-by: Neil Alexander --- keyserver/internal/device_list_update_test.go | 89 ++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 164be6bec..e8d1bfe8b 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -48,13 +48,19 @@ type mockDeviceListUpdaterDatabase struct { staleUsers map[string]bool prevIDsExist func(string, []int) bool storedKeys []api.DeviceMessage + mu sync.Mutex // protect staleUsers } // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + d.mu.Lock() + defer d.mu.Unlock() var result []string - for userID := range d.staleUsers { + for userID, isStale := range d.staleUsers { + if !isStale { + continue + } _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return nil, err @@ -75,6 +81,8 @@ func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, do // MarkDeviceListStale sets the stale bit for this user to isStale. func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + d.mu.Lock() + defer d.mu.Unlock() d.staleUsers[userID] = isStale return nil } @@ -247,3 +255,82 @@ func TestUpdateNoPrevID(t *testing.T) { } } + +// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the +// update is still ongoing. +func TestDebounce(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int) bool { + return true + }, + } + ap := &mockDeviceListUpdaterAPI{} + producer := &mockKeyChangeProducer{} + fedCh := make(chan *http.Response, 1) + srv := gomatrixserverlib.ServerName("example.com") + userID := "@alice:example.com" + keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` + incomingFedReq := make(chan struct{}) + fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { + if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) { + return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) + } + close(incomingFedReq) + return <-fedCh, nil + }) + updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1) + if err := updater.Start(); err != nil { + t.Fatalf("failed to start updater: %s", err) + } + + // hit this 5 times + var wg sync.WaitGroup + wg.Add(5) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil { + t.Errorf("ManualUpdate: %s", err) + } + }() + } + + // wait until the updater hits federation + select { + case <-incomingFedReq: + case <-time.After(time.Second): + t.Fatalf("timed out waiting for updater to hit federation") + } + + // user should be marked as stale + if !db.staleUsers[userID] { + t.Errorf("user %s not marked as stale", userID) + } + // now send the response over federation + fedCh <- &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(` + { + "user_id": "` + userID + `", + "stream_id": 5, + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": ` + keyJSON + `, + "device_display_name": "Mobile Phone" + } + ] + } + `)), + } + close(fedCh) + // wait until all 5 ManualUpdates return. If we hit federation again we won't send a response + // and should panic with read on a closed channel + wg.Wait() + + // user is no longer stale now + if db.staleUsers[userID] { + t.Errorf("user %s is marked as stale", userID) + } +} From 849e40d45653d3d42bbddf2d58d6d5f488cebdba Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 17:25:26 +0000 Subject: [PATCH 19/33] Use correct stream provider in `Latest` for `ReceiptPosition` --- syncapi/streams/streams.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index 6b02c75ea..c71095af6 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -69,7 +69,7 @@ func (s *Streams) Latest(ctx context.Context) types.StreamingToken { return types.StreamingToken{ PDUPosition: s.PDUStreamProvider.LatestPosition(ctx), TypingPosition: s.TypingStreamProvider.LatestPosition(ctx), - ReceiptPosition: s.PDUStreamProvider.LatestPosition(ctx), + ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx), InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx), From 00b3545b14ca5e6987be93cabafc51e771f8bcc7 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 2 Mar 2022 10:36:20 +0000 Subject: [PATCH 20/33] Update NATS Server to v2.7.3 --- go.mod | 5 ++--- go.sum | 16 ++++++---------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/go.mod b/go.mod index ddf3f2175..dbcae5d54 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/matrix-org/dendrite -replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad +replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740 replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c @@ -24,7 +24,6 @@ require ( github.com/h2non/filetype v1.1.3 // indirect github.com/hashicorp/golang-lru v0.5.4 github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect - github.com/klauspost/compress v1.14.2 // indirect github.com/lib/pq v1.10.4 github.com/libp2p/go-libp2p v0.13.0 github.com/libp2p/go-libp2p-circuit v0.4.0 @@ -45,7 +44,7 @@ require ( github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 github.com/morikuni/aec v1.0.0 // indirect - github.com/nats-io/nats-server/v2 v2.3.2 + github.com/nats-io/nats-server/v2 v2.7.3 github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index 74a181528..25697f59c 100644 --- a/go.sum +++ b/go.sum @@ -480,7 +480,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4= @@ -712,9 +711,8 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw= -github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.14.4 h1:eijASRJcobkVtSt81Olfh7JX43osYLwy5krOJo6YEu4= +github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -983,8 +981,6 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 h1:CqbM5tPbF1mV2h1J6N0NSF8et6HvFlwA98TLCUcjiSI= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE= github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= @@ -1031,8 +1027,8 @@ github.com/miekg/dns v1.1.31/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7 github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 h1:lYpkrQH5ajf0OXOcUbGjvZxxijuBwbbmlSxLiuofa+g= github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1/go.mod h1:pD8RvIylQ358TN4wwqatJ8rNavkEINozVn9DtGI3dfQ= -github.com/minio/highwayhash v1.0.1 h1:dZ6IIu8Z14VlC0VpfKofAhCy74wu/Qb5gcn52yWoz/0= -github.com/minio/highwayhash v1.0.1/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= +github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= +github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= github.com/minio/sha256-simd v0.0.0-20190131020904-2d45a736cd16/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U= github.com/minio/sha256-simd v0.0.0-20190328051042-05b4dd3047e5/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U= github.com/minio/sha256-simd v0.1.0/go.mod h1:2FMWW+8GMoPweT6+pI63m9YE3Lmw4J71hV56Chs1E/U= @@ -1134,8 +1130,8 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad h1:Z2nWMQsXWWqzj89nW6OaLJSdkFknqhaR5whEOz4++Y8= -github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad/go.mod h1:tckmrt0M6bVaDT3kmh9UrIq/CBOBBse+TpXQi5ldaa8= +github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740 h1:RJrc+z35RHZlrjR6UBt9UmVRAlFh4SgYyEA0YpQdPHM= +github.com/neilalexander/nats-server/v2 v2.7.4-0.20220302103432-6b04b9f12740/go.mod h1:eJUrA5gm0ch6sJTEv85xmXIgQWsB0OyjkTsKXvlHbYc= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= From a4c918ee17d342e704aee738e4f0d5712b98eaa6 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 2 Mar 2022 10:49:29 +0000 Subject: [PATCH 21/33] Fix data race in unit tests --- keyserver/internal/device_list_update_test.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index e8d1bfe8b..59a66ec8e 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -87,6 +87,12 @@ func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, return nil } +func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool { + d.mu.Lock() + defer d.mu.Unlock() + return d.staleUsers[userID] +} + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // for this (user, device). Does not modify the stream ID for keys. func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error { @@ -169,7 +175,7 @@ func TestUpdateHavePrevID(t *testing.T) { if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) } - if db.staleUsers[event.UserID] { + if db.isStale(event.UserID) { t.Errorf("%s incorrectly marked as stale", event.UserID) } } @@ -243,7 +249,7 @@ func TestUpdateNoPrevID(t *testing.T) { }, } // Now we should have a fresh list and the keys and emitted something - if db.staleUsers[event.UserID] { + if db.isStale(event.UserID) { t.Errorf("%s still marked as stale", event.UserID) } if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { @@ -304,7 +310,7 @@ func TestDebounce(t *testing.T) { } // user should be marked as stale - if !db.staleUsers[userID] { + if !db.isStale(userID) { t.Errorf("user %s not marked as stale", userID) } // now send the response over federation @@ -330,7 +336,7 @@ func TestDebounce(t *testing.T) { wg.Wait() // user is no longer stale now - if db.staleUsers[userID] { + if db.isStale(userID) { t.Errorf("user %s is marked as stale", userID) } } From 8996cc805913e9927f57cd3c4b8247a09a47472b Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 2 Mar 2022 11:35:35 +0000 Subject: [PATCH 22/33] Media endpoints on `/v3` (#2242) * Media endpoints on `/v3` * Keep `/v1` too? --- mediaapi/routing/routing.go | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 44f9a9d69..fc2136bb6 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -53,8 +53,7 @@ func Setup( ) { rateLimits := httputil.NewRateLimits(rateLimit) - r0mux := publicAPIMux.PathPrefix("/r0").Subrouter() - v1mux := publicAPIMux.PathPrefix("/v1").Subrouter() + v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter() activeThumbnailGeneration := &types.ActiveThumbnailGeneration{ PathToResult: map[string]*types.ThumbnailGenerationResult{}, @@ -80,21 +79,18 @@ func Setup( } }) - r0mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions) - v1mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/upload", uploadHandler).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions) activeRemoteRequests := &types.ActiveRemoteRequests{ MXCToResult: map[string]*types.RemoteRequestResult{}, } downloadHandler := makeDownloadAPI("download", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration) - r0mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) - v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed - v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed + v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/thumbnail/{serverName}/{mediaId}", + v3mux.Handle("/thumbnail/{serverName}/{mediaId}", makeDownloadAPI("thumbnail", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration), ).Methods(http.MethodGet, http.MethodOptions) } From e46a61c49e5a0116b5df3df34a95cf5b8c811a8d Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 2 Mar 2022 11:38:13 +0000 Subject: [PATCH 23/33] Skip flakey test for now --- keyserver/internal/device_list_update_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 59a66ec8e..ff939355f 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -265,6 +265,7 @@ func TestUpdateNoPrevID(t *testing.T) { // Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the // update is still ongoing. func TestDebounce(t *testing.T) { + t.Skipf("panic on closed channel on GHA") db := &mockDeviceListUpdaterDatabase{ staleUsers: make(map[string]bool), prevIDsExist: func(string, []int) bool { From 111f01ddc81d775dfdaab6e6a3a6afa6fa5608ea Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 3 Mar 2022 09:18:40 +0000 Subject: [PATCH 24/33] Update `sytest-whitelist` for changes in matrix-org/sytest#1200 --- sytest-whitelist | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sytest-whitelist b/sytest-whitelist index 0cb80b7b6..3e38176f4 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -78,7 +78,7 @@ Room creation reports m.room.member to myself Outbound federation rejects send_join responses with no m.room.create event Outbound federation rejects m.room.create events with an unknown room version Invited user can see room metadata -# Blacklisted because these tests call /r0/events which we don't implement +# Blacklisted because these tests call /v3/events which we don't implement # New room members see their own join event # Existing members see new members' join events setting 'm.room.power_levels' respects room powerlevel @@ -257,7 +257,7 @@ Guest non-joined users cannot send messages to guest_access rooms if not joined Real non-joined users cannot room initalSync for non-world_readable rooms Push rules come down in an initial /sync Regular users can add and delete aliases in the default room configuration -GET /r0/capabilities is not public +GET /v3/capabilities is not public GET /joined_rooms lists newly-created room /joined_rooms returns only joined rooms Message history can be paginated over federation @@ -366,8 +366,8 @@ Outbound federation will ignore a missing event with bad JSON for room version 6 Server correctly handles transactions that break edu limits Server rejects invalid JSON in a version 6 room Can download without a file name over federation -POST /media/r0/upload can create an upload -GET /media/r0/download can fetch the value again +POST /media/v3/upload can create an upload +GET /media/v3/download can fetch the value again Remote users can join room by alias Alias creators can delete alias with no ops Alias creators can delete canonical alias with no ops From f05ce478f05dcaf650fbae68a39aaf5d9880a580 Mon Sep 17 00:00:00 2001 From: Dan Date: Thu, 3 Mar 2022 13:40:53 +0200 Subject: [PATCH 25/33] Implement Push Notifications (#1842) * Add Pushserver component with Pushers API Co-authored-by: Tommie Gannert Co-authored-by: Dan Peleg * Wire Pushserver component Co-authored-by: Neil Alexander * Add PushGatewayClient. The full event format is required for Sytest. * Add a pushrules module. * Change user API account creation to use the new pushrules module's defaults. Introduces "scope" as required by client API, and some small field tweaks to make some 61push Sytests pass. * Add push rules query/put API in Pushserver. This manipulates account data over User API, and fires sync messages for changes. Those sync messages should, according to an existing TODO in clientapi, be moved to userapi. Forks clientapi/producers/syncapi.go to pushserver/ for later extension. * Add clientapi routes for push rules to Pushserver. A cleanup would be to move more of the name-splitting logic into pushrules.go, to depollute routing.go. * Output rooms.join.unread_notifications in /sync. This is the read-side. Pushserver will be the write-side. * Implement pushserver/storage for notifications. * Use PushGatewayClient and the pushrules module in Pushserver's room consumer. * Use one goroutine per user to avoid locking up the entire server for one bad push gateway. * Split pushing by format. * Send one device per push. Sytest does not support coalescing multiple devices into one push. Matches Synapse. Either we change Sytest, or remove the group-by-url-and-format logic. * Write OutputNotificationData from push server. Sync API is already the consumer. * Implement read receipt consumers in Pushserver. Supports m.read and m.fully_read receipts. * Add clientapi route for /unstable/notifications. * Rename to UpsertPusher for clarity and handle pusher update * Fix linter errors * Ignore body.Close() error check * Fix push server internal http wiring * Add 40 newly passing 61push tests to whitelist * Add next 12 newly passing 61push tests to whitelist * Send notification data before notifying users in EDU server consumer * NATS JetStream * Goodbye sarama * Fix `NewStreamTokenFromString` * Consume on the correct topic for the roomserver * Don't panic, NAK instead * Move push notifications into the User API * Don't set null values since that apparently causes Element upsetti * Also set omitempty on conditions * Fix bug so that we don't override the push rules unnecessarily * Tweak defaults * Update defaults * More tweaks * Move `/notifications` onto `r0`/`v3` mux * User API will consume events and read/fully read markers from the sync API with stream positions, instead of consuming directly Co-authored-by: Piotr Kozimor Co-authored-by: Tommie Gannert Co-authored-by: Neil Alexander --- .gitignore | 6 +- build/docker/config/dendrite.yaml | 11 + build/gobind-pinecone/monolith.go | 2 +- build/gobind-yggdrasil/monolith.go | 2 +- clientapi/clientapi.go | 3 +- clientapi/producers/syncapi.go | 7 +- clientapi/routing/account_data.go | 12 +- clientapi/routing/notification.go | 63 ++ clientapi/routing/password.go | 15 + clientapi/routing/pusher.go | 114 ++++ clientapi/routing/pushrules.go | 386 ++++++++++++ clientapi/routing/room_tagging.go | 4 +- clientapi/routing/routing.go | 167 ++++- cmd/dendrite-demo-libp2p/main.go | 6 +- cmd/dendrite-demo-pinecone/main.go | 2 +- cmd/dendrite-demo-yggdrasil/main.go | 5 +- cmd/dendrite-monolith-server/main.go | 3 +- .../personalities/userapi.go | 6 +- cmd/dendritejs-pinecone/main.go | 6 +- cmd/dendritejs/main.go | 3 + dendrite-config.yaml | 86 +-- go.mod | 1 + internal/eventutil/types.go | 24 +- internal/pushgateway/client.go | 66 ++ internal/pushgateway/pushgateway.go | 62 ++ internal/pushrules/action.go | 102 +++ internal/pushrules/action_test.go | 39 ++ internal/pushrules/condition.go | 49 ++ internal/pushrules/default.go | 23 + internal/pushrules/default_content.go | 33 + internal/pushrules/default_override.go | 165 +++++ internal/pushrules/default_underride.go | 119 ++++ internal/pushrules/evaluate.go | 165 +++++ internal/pushrules/evaluate_test.go | 189 ++++++ internal/pushrules/pushrules.go | 71 +++ internal/pushrules/util.go | 125 ++++ internal/pushrules/util_test.go | 169 +++++ internal/pushrules/validate.go | 85 +++ internal/pushrules/validate_test.go | 163 +++++ internal/sqlutil/sql.go | 1 + q.sqlite | Bin 0 -> 12288 bytes run-sytest.sh | 63 ++ setup/base/base.go | 6 + setup/config/config_test.go | 5 + setup/config/config_userapi.go | 3 + setup/jetstream/streams.go | 18 + setup/monolith.go | 4 +- syncapi/consumers/clientapi.go | 77 ++- syncapi/consumers/eduserver_receipts.go | 71 ++- syncapi/consumers/roomserver.go | 10 + syncapi/consumers/userapi.go | 110 ++++ syncapi/notifier/notifier.go | 11 + syncapi/notifier/notifier_test.go | 2 +- syncapi/producers/userapi_readupdate.go | 62 ++ syncapi/producers/userapi_streamevent.go | 60 ++ syncapi/storage/interface.go | 8 + .../postgres/notification_data_table.go | 108 ++++ syncapi/storage/postgres/syncserver.go | 5 + syncapi/storage/shared/syncserver.go | 21 + .../sqlite3/notification_data_table.go | 108 ++++ .../sqlite3/output_room_events_table.go | 6 + syncapi/storage/sqlite3/syncserver.go | 5 + syncapi/storage/tables/interface.go | 7 + syncapi/streams/stream_notificationdata.go | 55 ++ syncapi/streams/streams.go | 34 +- syncapi/sync/requestpool.go | 9 +- syncapi/syncapi.go | 24 +- syncapi/types/types.go | 61 +- syncapi/types/types_test.go | 8 +- sytest-blacklist | 6 + sytest-whitelist | 79 ++- userapi/api/api.go | 83 +++ userapi/api/api_trace.go | 30 + userapi/consumers/syncapi_readupdate.go | 136 ++++ userapi/consumers/syncapi_streamevent.go | 588 ++++++++++++++++++ userapi/internal/api.go | 171 ++++- userapi/inthttp/client.go | 61 ++ userapi/inthttp/server.go | 82 +++ userapi/producers/syncapi.go | 104 ++++ userapi/storage/interface.go | 13 + .../storage/postgres/notifications_table.go | 219 +++++++ userapi/storage/postgres/pusher_table.go | 157 +++++ userapi/storage/postgres/storage.go | 10 + userapi/storage/shared/storage.go | 109 +++- .../storage/sqlite3/notifications_table.go | 219 +++++++ userapi/storage/sqlite3/pusher_table.go | 157 +++++ userapi/storage/sqlite3/storage.go | 10 + userapi/storage/tables/interface.go | 40 ++ userapi/userapi.go | 57 +- userapi/userapi_test.go | 6 +- userapi/util/devices.go | 100 +++ userapi/util/notify.go | 76 +++ 92 files changed, 5840 insertions(+), 194 deletions(-) create mode 100644 clientapi/routing/notification.go create mode 100644 clientapi/routing/pusher.go create mode 100644 clientapi/routing/pushrules.go create mode 100644 internal/pushgateway/client.go create mode 100644 internal/pushgateway/pushgateway.go create mode 100644 internal/pushrules/action.go create mode 100644 internal/pushrules/action_test.go create mode 100644 internal/pushrules/condition.go create mode 100644 internal/pushrules/default.go create mode 100644 internal/pushrules/default_content.go create mode 100644 internal/pushrules/default_override.go create mode 100644 internal/pushrules/default_underride.go create mode 100644 internal/pushrules/evaluate.go create mode 100644 internal/pushrules/evaluate_test.go create mode 100644 internal/pushrules/pushrules.go create mode 100644 internal/pushrules/util.go create mode 100644 internal/pushrules/util_test.go create mode 100644 internal/pushrules/validate.go create mode 100644 internal/pushrules/validate_test.go create mode 100644 q.sqlite create mode 100755 run-sytest.sh create mode 100644 syncapi/consumers/userapi.go create mode 100644 syncapi/producers/userapi_readupdate.go create mode 100644 syncapi/producers/userapi_streamevent.go create mode 100644 syncapi/storage/postgres/notification_data_table.go create mode 100644 syncapi/storage/sqlite3/notification_data_table.go create mode 100644 syncapi/streams/stream_notificationdata.go create mode 100644 userapi/consumers/syncapi_readupdate.go create mode 100644 userapi/consumers/syncapi_streamevent.go create mode 100644 userapi/producers/syncapi.go create mode 100644 userapi/storage/postgres/notifications_table.go create mode 100644 userapi/storage/postgres/pusher_table.go create mode 100644 userapi/storage/sqlite3/notifications_table.go create mode 100644 userapi/storage/sqlite3/pusher_table.go create mode 100644 userapi/util/devices.go create mode 100644 userapi/util/notify.go diff --git a/.gitignore b/.gitignore index 092f4501c..2a8c2cf55 100644 --- a/.gitignore +++ b/.gitignore @@ -54,7 +54,7 @@ dendrite.yaml *.db # Log files -*.log* +*.log* # Generated code cmd/dendrite-demo-yggdrasil/embed/fs*.go @@ -62,5 +62,7 @@ cmd/dendrite-demo-yggdrasil/embed/fs*.go # Test dependencies test/wasm/node_modules -media_store/ +# Ignore complement folder when running locally +complement/ +media_store/ diff --git a/build/docker/config/dendrite.yaml b/build/docker/config/dendrite.yaml index 6d5ebc9fd..ebae50132 100644 --- a/build/docker/config/dendrite.yaml +++ b/build/docker/config/dendrite.yaml @@ -318,6 +318,17 @@ user_api: max_idle_conns: 2 conn_max_lifetime: -1 +# Configuration for the Push Server API. +push_server: + internal_api: + listen: http://localhost:7782 + connect: http://localhost:7782 + database: + connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_pushserver?sslmode=disable + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + # Configuration for Opentracing. # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # how this works and how to set it up. diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index aa8cc6e6e..5ab90adaf 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -312,7 +312,7 @@ func (m *DendriteMonolith) Start() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - m.userAPI = userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + m.userAPI = userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(m.userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 8b9c88f2a..3329485aa 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -116,7 +116,7 @@ func (m *DendriteMonolith) Start() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index a65f3b70d..918476674 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -59,6 +59,7 @@ func AddPublicRoutes( routing.Setup( router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI, accountsDB, userAPI, federation, - syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg, + syncProducer, transactionsCache, fsAPI, keyAPI, + extRoomsProvider, mscCfg, ) } diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index 9b1d6b1a2..9ab90391d 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -30,7 +30,7 @@ type SyncAPIProducer struct { } // SendData sends account data to the sync API server -func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error { +func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON) error { m := &nats.Msg{ Subject: p.Topic, Header: nats.Header{}, @@ -38,8 +38,9 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string m.Header.Set(jetstream.UserID, userID) data := eventutil.AccountData{ - RoomID: roomID, - Type: dataType, + RoomID: roomID, + Type: dataType, + ReadMarker: readMarker, } var err error m.Data, err = json.Marshal(data) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 03025f1da..d8e982690 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/eventutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" @@ -127,7 +128,7 @@ func SaveAccountData( } // TODO: user API should do this since it's account data - if err := syncProducer.SendData(userID, roomID, dataType); err != nil { + if err := syncProducer.SendData(userID, roomID, dataType, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") return jsonerror.InternalServerError() } @@ -138,11 +139,6 @@ func SaveAccountData( } } -type readMarkerJSON struct { - FullyRead string `json:"m.fully_read"` - Read string `json:"m.read"` -} - type fullyReadEvent struct { EventID string `json:"event_id"` } @@ -159,7 +155,7 @@ func SaveReadMarker( return *resErr } - var r readMarkerJSON + var r eventutil.ReadMarkerJSON resErr = httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { return *resErr @@ -189,7 +185,7 @@ func SaveReadMarker( return util.ErrorResponse(err) } - if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read"); err != nil { + if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r); err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go new file mode 100644 index 000000000..ee715d323 --- /dev/null +++ b/clientapi/routing/notification.go @@ -0,0 +1,63 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + "strconv" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// GetNotifications handles /_matrix/client/r0/notifications +func GetNotifications( + req *http.Request, device *userapi.Device, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + var limit int64 + if limitStr := req.URL.Query().Get("limit"); limitStr != "" { + var err error + limit, err = strconv.ParseInt(limitStr, 10, 64) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed") + return jsonerror.InternalServerError() + } + } + + var queryRes userapi.QueryNotificationsResponse + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") + return jsonerror.InternalServerError() + } + err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ + Localpart: localpart, + From: req.URL.Query().Get("from"), + Limit: int(limit), + Only: req.URL.Query().Get("only"), + }, &queryRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed") + return jsonerror.InternalServerError() + } + util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications)) + return util.JSONResponse{ + Code: http.StatusOK, + JSON: queryRes, + } +} diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index acac60fa5..c63412d08 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -12,6 +12,7 @@ import ( userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) type newPasswordRequest struct { @@ -37,6 +38,11 @@ func Password( var r newPasswordRequest r.LogoutDevices = true + logrus.WithFields(logrus.Fields{ + "sessionId": device.SessionID, + "userId": device.UserID, + }).Debug("Changing password") + // Unmarshal the request. resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { @@ -116,6 +122,15 @@ func Password( util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") return jsonerror.InternalServerError() } + + pushersReq := &api.PerformPusherDeletionRequest{ + Localpart: localpart, + SessionID: device.SessionID, + } + if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") + return jsonerror.InternalServerError() + } } // Return a success code. diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go new file mode 100644 index 000000000..9d6bef8bd --- /dev/null +++ b/clientapi/routing/pusher.go @@ -0,0 +1,114 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + "net/url" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// GetPushers handles /_matrix/client/r0/pushers +func GetPushers( + req *http.Request, device *userapi.Device, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + var queryRes userapi.QueryPushersResponse + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") + return jsonerror.InternalServerError() + } + err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ + Localpart: localpart, + }, &queryRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") + return jsonerror.InternalServerError() + } + for i := range queryRes.Pushers { + queryRes.Pushers[i].SessionID = 0 + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: queryRes, + } +} + +// SetPusher handles /_matrix/client/r0/pushers/set +// This endpoint allows the creation, modification and deletion of pushers for this user ID. +// The behaviour of this endpoint varies depending on the values in the JSON body. +func SetPusher( + req *http.Request, device *userapi.Device, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") + return jsonerror.InternalServerError() + } + body := userapi.PerformPusherSetRequest{} + if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil { + return *resErr + } + if len(body.AppID) > 64 { + return invalidParam("length of app_id must be no more than 64 characters") + } + if len(body.PushKey) > 512 { + return invalidParam("length of pushkey must be no more than 512 bytes") + } + uInt := body.Data["url"] + if uInt != nil { + u, ok := uInt.(string) + if !ok { + return invalidParam("url must be string") + } + if u != "" { + var pushUrl *url.URL + pushUrl, err = url.Parse(u) + if err != nil { + return invalidParam("malformed url passed") + } + if pushUrl.Scheme != "https" { + return invalidParam("only https scheme is allowed") + } + } + + } + body.Localpart = localpart + body.SessionID = device.SessionID + err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed") + return jsonerror.InternalServerError() + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func invalidParam(msg string) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidParam(msg), + } +} diff --git a/clientapi/routing/pushrules.go b/clientapi/routing/pushrules.go new file mode 100644 index 000000000..81a33b25a --- /dev/null +++ b/clientapi/routing/pushrules.go @@ -0,0 +1,386 @@ +package routing + +import ( + "context" + "encoding/json" + "io" + "net/http" + "reflect" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/pushrules" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse { + if eerr, ok := err.(*jsonerror.MatrixError); ok { + var status int + switch eerr.ErrCode { + case "M_INVALID_ARGUMENT_VALUE": + status = http.StatusBadRequest + case "M_NOT_FOUND": + status = http.StatusNotFound + default: + status = http.StatusInternalServerError + } + return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err) + } + util.GetLogger(ctx).WithError(err).Errorf(msg, args...) + return jsonerror.InternalServerError() +} + +func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRulesJSON failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: ruleSets, + } +} + +func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRulesJSON failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: ruleSet, + } +} + +func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: *rulesPtr, + } +} + +func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: (*rulesPtr)[i], + } +} + +func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + var newRule pushrules.Rule + if err := json.NewDecoder(body).Decode(&newRule); err != nil { + return errorResponse(ctx, err, "JSON Decode failed") + } + newRule.RuleID = ruleID + + errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule) + if len(errs) > 0 { + return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs) + } + + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i >= 0 && afterRuleID == "" && beforeRuleID == "" { + // Modify rule at the same index. + + // TODO: The spec does not say what to do in this case, but + // this feels reasonable. + *((*rulesPtr)[i]) = newRule + util.GetLogger(ctx).Infof("Modified existing push rule at %d", i) + } else { + if i >= 0 { + // Delete old rule. + *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) + util.GetLogger(ctx).Infof("Deleted old push rule at %d", i) + } else { + // SPEC: When creating push rules, they MUST be enabled by default. + // + // TODO: it's unclear if we must reject disabled rules, or force + // the value to true. Sytests fail if we don't force it. + newRule.Enabled = true + } + + // Add new rule. + i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID) + if err != nil { + return errorResponse(ctx, err, "findPushRuleInsertionIndex failed") + } + + *rulesPtr = append((*rulesPtr)[:i], append([]*pushrules.Rule{&newRule}, (*rulesPtr)[i:]...)...) + util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i) + } + + if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + return errorResponse(ctx, err, "putPushRules failed") + } + + return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} +} + +func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + + *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) + + if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + return errorResponse(ctx, err, "putPushRules failed") + } + + return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} +} + +func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + attrGet, err := pushRuleAttrGetter(attr) + if err != nil { + return errorResponse(ctx, err, "pushRuleAttrGetter failed") + } + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: map[string]interface{}{ + attr: attrGet((*rulesPtr)[i]), + }, + } +} + +func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { + var newPartialRule pushrules.Rule + if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(err.Error()), + } + } + if newPartialRule.Actions == nil { + // This ensures json.Marshal encodes the empty list as [] rather than null. + newPartialRule.Actions = []*pushrules.Action{} + } + + attrGet, err := pushRuleAttrGetter(attr) + if err != nil { + return errorResponse(ctx, err, "pushRuleAttrGetter failed") + } + attrSet, err := pushRuleAttrSetter(attr) + if err != nil { + return errorResponse(ctx, err, "pushRuleAttrSetter failed") + } + + ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + if err != nil { + return errorResponse(ctx, err, "queryPushRules failed") + } + ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) + if ruleSet == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + } + rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) + if rulesPtr == nil { + return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + } + i := pushRuleIndexByID(*rulesPtr, ruleID) + if i < 0 { + return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + } + + if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) { + attrSet((*rulesPtr)[i], &newPartialRule) + + if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + return errorResponse(ctx, err, "putPushRules failed") + } + } + + return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} +} + +func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInternalAPI) (*pushrules.AccountRuleSets, error) { + var res userapi.QueryPushRulesResponse + if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed") + return nil, err + } + return res.RuleSets, nil +} + +func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.UserInternalAPI) error { + req := userapi.PerformPushRulesPutRequest{ + UserID: userID, + RuleSets: ruleSets, + } + var res struct{} + if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed") + return err + } + return nil +} + +func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet { + switch scope { + case pushrules.GlobalScope: + return &ruleSets.Global + default: + return nil + } +} + +func pushRuleSetKindPointer(ruleSet *pushrules.RuleSet, kind pushrules.Kind) *[]*pushrules.Rule { + switch kind { + case pushrules.OverrideKind: + return &ruleSet.Override + case pushrules.ContentKind: + return &ruleSet.Content + case pushrules.RoomKind: + return &ruleSet.Room + case pushrules.SenderKind: + return &ruleSet.Sender + case pushrules.UnderrideKind: + return &ruleSet.Underride + default: + return nil + } +} + +func pushRuleIndexByID(rules []*pushrules.Rule, id string) int { + for i, rule := range rules { + if rule.RuleID == id { + return i + } + } + return -1 +} + +func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) { + switch attr { + case "actions": + return func(rule *pushrules.Rule) interface{} { return rule.Actions }, nil + case "enabled": + return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil + default: + return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute") + } +} + +func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) { + switch attr { + case "actions": + return func(dest, src *pushrules.Rule) { dest.Actions = src.Actions }, nil + case "enabled": + return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil + default: + return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute") + } +} + +func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID string) (int, error) { + var i int + + if afterID != "" { + for ; i < len(rules); i++ { + if rules[i].RuleID == afterID { + break + } + } + if i == len(rules) { + return 0, jsonerror.NotFound("after: rule ID not found") + } + if rules[i].Default { + return 0, jsonerror.NotFound("after: rule ID must not be a default rule") + } + // We stopped on the "after" match to differentiate + // not-found from is-last-entry. Now we move to the earliest + // insertion point. + i++ + } + + if beforeID != "" { + for ; i < len(rules); i++ { + if rules[i].RuleID == beforeID { + break + } + } + if i == len(rules) { + return 0, jsonerror.NotFound("before: rule ID not found") + } + if rules[i].Default { + return 0, jsonerror.NotFound("before: rule ID must not be a default rule") + } + } + + // UNSPEC: The spec does not say what to do if no after/before is + // given. Sytest fails if it doesn't go first. + return i, nil +} diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index c683cc949..83294b180 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -98,7 +98,7 @@ func PutTag( return jsonerror.InternalServerError() } - if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + if err = syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil { logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") } @@ -151,7 +151,7 @@ func DeleteTag( } // TODO: user API should do this since it's account data - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + if err := syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil { logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d75f58b81..d22fbd809 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -16,7 +16,6 @@ package routing import ( "context" - "encoding/json" "net/http" "strings" @@ -561,25 +560,142 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - v3mux.Handle("/pushrules/", - httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { - // TODO: Implement push rules API - res := json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`) + // Push rules + + v3mux.Handle("/pushrules", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ - Code: http.StatusOK, - JSON: &res, + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("missing trailing slash"), } }), ).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/pushrules/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetAllPushRules(req.Context(), device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"), + } + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRulesByScope(req.Context(), vars["scope"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"), + } + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope:[^/]+/?}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"), + } + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/{kind}/", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRulesByKind(req.Context(), vars["scope"], vars["kind"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"), + } + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"), + } + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.Limit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + query := req.URL.Query() + return PutPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], query.Get("after"), query.Get("before"), req.Body, device, userAPI) + }), + ).Methods(http.MethodPut) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return DeletePushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI) + }), + ).Methods(http.MethodDelete) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", + httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return PutPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], req.Body, device, userAPI) + }), + ).Methods(http.MethodPut) + // Element user settings v3mux.Handle("/profile/{userID}", @@ -885,6 +1001,27 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/notifications", + httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetNotifications(req, device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushers", + httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetPushers(req, device, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/pushers/set", + httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.Limit(req); r != nil { + return *r + } + return SetPusher(req, device, userAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // Stub implementations for sytest v3mux.Handle("/events", httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 78536901c..8ce641914 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -144,12 +144,14 @@ func main() { accountDB := base.Base.CreateAccountsDB() federation := createFederationClient(base) keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) - keyAPI.SetUserAPI(userAPI) rsAPI := roomserver.NewInternalAPI( &base.Base, ) + + userAPI := userapi.NewInternalAPI(&base.Base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.Base.PushGatewayHTTPClient()) + keyAPI.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI( &base.Base, cache.New(), userAPI, ) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 5810a7f18..45f186985 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -187,7 +187,7 @@ func main() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index d16f0e9e5..b7e30ba2e 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -111,14 +111,15 @@ func main() { keyRing := serverKeyAPI.KeyRing() keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) - keyAPI.SetUserAPI(userAPI) rsComponent := roomserver.NewInternalAPI( base, ) rsAPI := rsComponent + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + keyAPI.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI( base, cache.New(), userAPI, ) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index bb2685208..3b952504b 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -106,7 +106,8 @@ func main() { keyAPI = base.KeyServerHTTPClient() } - userImpl := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + pgClient := base.PushGatewayHTTPClient() + userImpl := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) userAPI := userImpl if base.UseHTTPAPIs { userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) diff --git a/cmd/dendrite-polylith-multi/personalities/userapi.go b/cmd/dendrite-polylith-multi/personalities/userapi.go index f147cda14..f1fa379c7 100644 --- a/cmd/dendrite-polylith-multi/personalities/userapi.go +++ b/cmd/dendrite-polylith-multi/personalities/userapi.go @@ -23,7 +23,11 @@ import ( func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { accountDB := base.CreateAccountsDB() - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) + userAPI := userapi.NewInternalAPI( + base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, + base.KeyServerHTTPClient(), base.RoomserverHTTPClient(), + base.PushGatewayHTTPClient(), + ) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index 664f644f3..407081f59 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -184,13 +184,15 @@ func startup() { accountDB := base.CreateAccountsDB() federation := conn.CreateFederationClient(base, pSessions) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) - keyAPI.SetUserAPI(userAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() rsAPI := roomserver.NewInternalAPI(base) + + userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + keyAPI.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 0ea41b4c4..37cbb12dd 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -212,6 +212,8 @@ func main() { rsAPI.SetFederationAPI(fedSenderAPI, keyRing) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) + psAPI := pushserver.NewInternalAPI(base) + monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, @@ -225,6 +227,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, KeyAPI: keyAPI, + PushserverAPI: psAPI, //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: p2pPublicRoomProvider, } diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 533b5c952..0236851c4 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -6,7 +6,7 @@ # # At a minimum, to get started, you will need to update the settings in the # "global" section for your deployment, and you will need to check that the -# database "connection_string" line in each component section is correct. +# database "connection_string" line in each component section is correct. # # Each component with a "database" section can accept the following formats # for "connection_string": @@ -21,13 +21,13 @@ # small number of users and likely will perform worse still with a higher volume # of users. # -# The "max_open_conns" and "max_idle_conns" settings configure the maximum +# The "max_open_conns" and "max_idle_conns" settings configure the maximum # number of open/idle database connections. The value 0 will use the database # engine default, and a negative value will use unlimited connections. The # "conn_max_lifetime" option controls the maximum length of time a database # connection can be idle in seconds - a negative value is unlimited. -# The version of the configuration file. +# The version of the configuration file. version: 2 # Global Matrix configuration. This configuration applies to all components. @@ -61,8 +61,8 @@ global: # Lists of domains that the server will trust as identity servers to verify third # party identifiers such as phone numbers and email addresses. trusted_third_party_id_servers: - - matrix.org - - vector.im + - matrix.org + - vector.im # Disables federation. Dendrite will not be able to make any outbound HTTP requests # to other servers and the federation API will not be exposed. @@ -87,14 +87,14 @@ global: # in monolith mode. It is required to specify the address of at least one # NATS Server node if running in polylith mode. addresses: - # - localhost:4222 + # - localhost:4222 # Keep all NATS streams in memory, rather than persisting it to the storage # path below. This option is present primarily for integration testing and # should not be used on a real world Dendrite deployment. in_memory: false - # Persistent directory to store JetStream streams in. This directory + # Persistent directory to store JetStream streams in. This directory # should be preserved across Dendrite restarts. storage_path: ./ @@ -126,7 +126,7 @@ global: # Configuration for the Appservice API. app_service_api: internal_api: - listen: http://localhost:7777 # Only used in polylith deployments + listen: http://localhost:7777 # Only used in polylith deployments connect: http://localhost:7777 # Only used in polylith deployments database: connection_string: file:appservice.db @@ -145,7 +145,7 @@ app_service_api: # Configuration for the Client API. client_api: internal_api: - listen: http://localhost:7771 # Only used in polylith deployments + listen: http://localhost:7771 # Only used in polylith deployments connect: http://localhost:7771 # Only used in polylith deployments external_api: listen: http://[::]:8071 @@ -165,13 +165,13 @@ client_api: # Whether to require reCAPTCHA for registration. enable_registration_captcha: false - # Settings for ReCAPTCHA. + # Settings for ReCAPTCHA. recaptcha_public_key: "" recaptcha_private_key: "" recaptcha_bypass_secret: "" recaptcha_siteverify_api: "" - # TURN server information that this homeserver should send to clients. + # TURN server information that this homeserver should send to clients. turn: turn_user_lifetime: "" turn_uris: [] @@ -180,7 +180,7 @@ client_api: turn_password: "" # Settings for rate-limited endpoints. Rate limiting will kick in after the - # threshold number of "slots" have been taken by requests from a specific + # threshold number of "slots" have been taken by requests from a specific # host. Each "slot" will be released after the cooloff time in milliseconds. rate_limiting: enabled: true @@ -190,13 +190,13 @@ client_api: # Configuration for the EDU server. edu_server: internal_api: - listen: http://localhost:7778 # Only used in polylith deployments + listen: http://localhost:7778 # Only used in polylith deployments connect: http://localhost:7778 # Only used in polylith deployments # Configuration for the Federation API. federation_api: internal_api: - listen: http://localhost:7772 # Only used in polylith deployments + listen: http://localhost:7772 # Only used in polylith deployments connect: http://localhost:7772 # Only used in polylith deployments external_api: listen: http://[::]:8072 @@ -224,12 +224,12 @@ federation_api: # be required to satisfy key requests for servers that are no longer online when # joining some rooms. key_perspectives: - - server_name: matrix.org - keys: - - key_id: ed25519:auto - public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw - - key_id: ed25519:a_RXGa - public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ + - server_name: matrix.org + keys: + - key_id: ed25519:auto + public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw + - key_id: ed25519:a_RXGa + public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ # This option will control whether Dendrite will prefer to look up keys directly # or whether it should try perspective servers first, using direct fetches as a @@ -239,7 +239,7 @@ federation_api: # Configuration for the Key Server (for end-to-end encryption). key_server: internal_api: - listen: http://localhost:7779 # Only used in polylith deployments + listen: http://localhost:7779 # Only used in polylith deployments connect: http://localhost:7779 # Only used in polylith deployments database: connection_string: file:keyserver.db @@ -250,7 +250,7 @@ key_server: # Configuration for the Media API. media_api: internal_api: - listen: http://localhost:7774 # Only used in polylith deployments + listen: http://localhost:7774 # Only used in polylith deployments connect: http://localhost:7774 # Only used in polylith deployments external_api: listen: http://[::]:8074 @@ -276,15 +276,15 @@ media_api: # A list of thumbnail sizes to be generated for media content. thumbnail_sizes: - - width: 32 - height: 32 - method: crop - - width: 96 - height: 96 - method: crop - - width: 640 - height: 480 - method: scale + - width: 32 + height: 32 + method: crop + - width: 96 + height: 96 + method: crop + - width: 640 + height: 480 + method: scale # Configuration for experimental MSC's mscs: @@ -302,7 +302,7 @@ mscs: # Configuration for the Room Server. room_server: internal_api: - listen: http://localhost:7770 # Only used in polylith deployments + listen: http://localhost:7770 # Only used in polylith deployments connect: http://localhost:7770 # Only used in polylith deployments database: connection_string: file:roomserver.db @@ -313,7 +313,7 @@ room_server: # Configuration for the Sync API. sync_api: internal_api: - listen: http://localhost:7773 # Only used in polylith deployments + listen: http://localhost:7773 # Only used in polylith deployments connect: http://localhost:7773 # Only used in polylith deployments external_api: listen: http://[::]:8073 @@ -338,16 +338,16 @@ user_api: # This value can be low if performing tests or on embedded Dendrite instances (e.g WASM builds) # bcrypt_cost: 10 internal_api: - listen: http://localhost:7781 # Only used in polylith deployments + listen: http://localhost:7781 # Only used in polylith deployments connect: http://localhost:7781 # Only used in polylith deployments account_database: connection_string: file:userapi_accounts.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 - # The length of time that a token issued for a relying party from + # The length of time that a token issued for a relying party from # /_matrix/client/r0/user/{userId}/openid/request_token endpoint - # is considered to be valid in milliseconds. + # is considered to be valid in milliseconds. # The default lifetime is 3600000ms (60 minutes). # openid_token_lifetime_ms: 3600000 @@ -369,10 +369,10 @@ tracing: # Logging configuration logging: -- type: std - level: info -- type: file - # The logging level, must be one of debug, info, warn, error, fatal, panic. - level: info - params: - path: ./logs + - type: std + level: info + - type: file + # The logging level, must be one of debug, info, warn, error, fatal, panic. + level: info + params: + path: ./logs diff --git a/go.mod b/go.mod index dbcae5d54..525950daa 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 + github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.2.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 diff --git a/internal/eventutil/types.go b/internal/eventutil/types.go index 6d119ce6d..17861d6c5 100644 --- a/internal/eventutil/types.go +++ b/internal/eventutil/types.go @@ -26,8 +26,30 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID") // AccountData represents account data sent from the client API server to the // sync API server type AccountData struct { + RoomID string `json:"room_id"` + Type string `json:"type"` + ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional +} + +type ReadMarkerJSON struct { + FullyRead string `json:"m.fully_read"` + Read string `json:"m.read"` +} + +// NotificationData contains statistics about notifications, sent from +// the Push Server to the Sync API server. +type NotificationData struct { + // RoomID identifies the scope of the statistics, together with + // MXID (which is encoded in the Kafka key). RoomID string `json:"room_id"` - Type string `json:"type"` + + // HighlightCount is the number of unread notifications with the + // highlight tweak. + UnreadHighlightCount int `json:"unread_highlight_count"` + + // UnreadNotificationCount is the total number of unread + // notifications. + UnreadNotificationCount int `json:"unread_notification_count"` } // ProfileResponse is a struct containing all known user profile data diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go new file mode 100644 index 000000000..49907cee8 --- /dev/null +++ b/internal/pushgateway/client.go @@ -0,0 +1,66 @@ +package pushgateway + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/opentracing/opentracing-go" +) + +type httpClient struct { + hc *http.Client +} + +// NewHTTPClient creates a new Push Gateway client. +func NewHTTPClient(disableTLSValidation bool) Client { + hc := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: disableTLSValidation, + }, + }, + } + return &httpClient{hc: hc} +} + +func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "Notify") + defer span.Finish() + + body, err := json.Marshal(req) + if err != nil { + return err + } + hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + hreq.Header.Set("Content-Type", "application/json") + + hresp, err := h.hc.Do(hreq) + if err != nil { + return err + } + + //nolint:errcheck + defer hresp.Body.Close() + + if hresp.StatusCode == http.StatusOK { + return json.NewDecoder(hresp.Body).Decode(resp) + } + + var errorBody struct { + Message string `json:"message"` + } + if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil { + return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message) + } + return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url) +} diff --git a/internal/pushgateway/pushgateway.go b/internal/pushgateway/pushgateway.go new file mode 100644 index 000000000..88c326eb2 --- /dev/null +++ b/internal/pushgateway/pushgateway.go @@ -0,0 +1,62 @@ +package pushgateway + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/gomatrixserverlib" +) + +// A Client is how interactions with a Push Gateway is done. +type Client interface { + // Notify sends a notification to the gateway at the given URL. + Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error +} + +type NotifyRequest struct { + Notification Notification `json:"notification"` // Required +} + +type NotifyResponse struct { + // Rejected is the list of device push keys that were rejected + // during the push. The caller should remove the push keys so they + // are not used again. + Rejected []string `json:"rejected"` // Required +} + +type Notification struct { + Content json.RawMessage `json:"content,omitempty"` + Counts *Counts `json:"counts,omitempty"` + Devices []*Device `json:"devices"` // Required + EventID string `json:"event_id,omitempty"` + ID string `json:"id,omitempty"` // Deprecated name for EventID. + Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest. + Prio Prio `json:"prio,omitempty"` + RoomAlias string `json:"room_alias,omitempty"` + RoomID string `json:"room_id,omitempty"` + RoomName string `json:"room_name,omitempty"` + Sender string `json:"sender,omitempty"` + SenderDisplayName string `json:"sender_display_name,omitempty"` + Type string `json:"type,omitempty"` + UserIsTarget bool `json:"user_is_target,omitempty"` +} + +type Counts struct { + MissedCalls int `json:"missed_calls,omitempty"` + Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it. +} + +type Device struct { + AppID string `json:"app_id"` // Required + Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys. + PushKey string `json:"pushkey"` // Required + PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` + Tweaks map[string]interface{} `json:"tweaks,omitempty"` +} + +type Prio string + +const ( + HighPrio Prio = "high" + LowPrio Prio = "low" +) diff --git a/internal/pushrules/action.go b/internal/pushrules/action.go new file mode 100644 index 000000000..c7b8cec83 --- /dev/null +++ b/internal/pushrules/action.go @@ -0,0 +1,102 @@ +package pushrules + +import ( + "bytes" + "encoding/json" + "fmt" +) + +// An Action is (part of) an outcome of a rule. There are +// (unofficially) terminal actions, and modifier actions. +type Action struct { + // Kind is the type of action. Has custom encoding in JSON. + Kind ActionKind `json:"-"` + + // Tweak is the property to tweak. Has custom encoding in JSON. + Tweak TweakKey `json:"-"` + + // Value is some value interpreted according to Kind and Tweak. + Value interface{} `json:"value"` +} + +func (a *Action) MarshalJSON() ([]byte, error) { + if a.Tweak == UnknownTweak && a.Value == nil { + return json.Marshal(a.Kind) + } + + if a.Kind != SetTweakAction { + return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind) + } + + m := map[string]interface{}{ + string(a.Kind): a.Tweak, + } + if a.Value != nil { + m["value"] = a.Value + } + + return json.Marshal(m) +} + +func (a *Action) UnmarshalJSON(bs []byte) error { + if bytes.HasPrefix(bs, []byte("\"")) { + return json.Unmarshal(bs, &a.Kind) + } + + var raw struct { + SetTweak TweakKey `json:"set_tweak"` + Value interface{} `json:"value"` + } + if err := json.Unmarshal(bs, &raw); err != nil { + return err + } + if raw.SetTweak == UnknownTweak { + return fmt.Errorf("got unknown action JSON: %s", string(bs)) + } + a.Kind = SetTweakAction + a.Tweak = raw.SetTweak + if raw.Value != nil { + a.Value = raw.Value + } + + return nil +} + +// ActionKind is the primary discriminator for actions. +type ActionKind string + +const ( + UnknownAction ActionKind = "" + + // NotifyAction indicates the clients should show a notification. + NotifyAction ActionKind = "notify" + + // DontNotifyAction indicates the clients should not show a notification. + DontNotifyAction ActionKind = "dont_notify" + + // CoalesceAction tells the clients to show a notification, and + // tells both servers and clients that multiple events can be + // coalesced into a single notification. The behaviour is + // implementation-specific. + CoalesceAction ActionKind = "coalesce" + + // SetTweakAction uses the Tweak and Value fields to add a + // tweak. Multiple SetTweakAction can be provided in a rule, + // combined with NotifyAction or CoalesceAction. + SetTweakAction ActionKind = "set_tweak" +) + +// A TweakKey describes a property to be modified/tweaked for events +// that match the rule. +type TweakKey string + +const ( + UnknownTweak TweakKey = "" + + // SoundTweak describes which sound to play. Using "default" means + // "enable sound". + SoundTweak TweakKey = "sound" + + // HighlightTweak asks the clients to highlight the conversation. + HighlightTweak TweakKey = "highlight" +) diff --git a/internal/pushrules/action_test.go b/internal/pushrules/action_test.go new file mode 100644 index 000000000..72db9c998 --- /dev/null +++ b/internal/pushrules/action_test.go @@ -0,0 +1,39 @@ +package pushrules + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestActionJSON(t *testing.T) { + tsts := []struct { + Want Action + }{ + {Action{Kind: NotifyAction}}, + {Action{Kind: DontNotifyAction}}, + {Action{Kind: CoalesceAction}}, + {Action{Kind: SetTweakAction}}, + + {Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, + {Action{Kind: SetTweakAction, Tweak: HighlightTweak}}, + {Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}}, + } + for _, tst := range tsts { + t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) { + bs, err := json.Marshal(&tst.Want) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + var got Action + if err := json.Unmarshal(bs, &got); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("+got -want:\n%s", diff) + } + }) + } +} diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go new file mode 100644 index 000000000..2d9773c0f --- /dev/null +++ b/internal/pushrules/condition.go @@ -0,0 +1,49 @@ +package pushrules + +// A Condition dictates extra conditions for a matching rules. See +// ConditionKind. +type Condition struct { + // Kind is the primary discriminator for the condition + // type. Required. + Kind ConditionKind `json:"kind"` + + // Key indicates the dot-separated path of Event fields to + // match. Required for EventMatchCondition and + // SenderNotificationPermissionCondition. + Key string `json:"key,omitempty"` + + // Pattern indicates the value pattern that must match. Required + // for EventMatchCondition. + Pattern string `json:"pattern,omitempty"` + + // Is indicates the condition that must be fulfilled. Required for + // RoomMemberCountCondition. + Is string `json:"is,omitempty"` +} + +// ConditionKind represents a kind of condition. +// +// SPEC: Unrecognised conditions MUST NOT match any events, +// effectively making the push rule disabled. +type ConditionKind string + +const ( + UnknownCondition ConditionKind = "" + + // EventMatchCondition indicates the condition looks for a key + // path and matches a pattern. How paths that don't reference a + // simple value match against rules is implementation-specific. + EventMatchCondition ConditionKind = "event_match" + + // ContainsDisplayNameCondition indicates the current user's + // display name must be found in the content body. + ContainsDisplayNameCondition ConditionKind = "contains_display_name" + + // RoomMemberCountCondition matches a simple arithmetic comparison + // against the total number of members in a room. + RoomMemberCountCondition ConditionKind = "room_member_count" + + // SenderNotificationPermissionCondition compares power level for + // the sender in the event's room. + SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission" +) diff --git a/internal/pushrules/default.go b/internal/pushrules/default.go new file mode 100644 index 000000000..996985514 --- /dev/null +++ b/internal/pushrules/default.go @@ -0,0 +1,23 @@ +package pushrules + +import ( + "github.com/matrix-org/gomatrixserverlib" +) + +// DefaultAccountRuleSets is the complete set of default push rules +// for an account. +func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets { + return &AccountRuleSets{ + Global: *DefaultGlobalRuleSet(localpart, serverName), + } +} + +// DefaultGlobalRuleSet returns the default ruleset for a given (fully +// qualified) MXID. +func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet { + return &RuleSet{ + Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)), + Content: defaultContentRules(localpart), + Underride: defaultUnderrideRules, + } +} diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go new file mode 100644 index 000000000..158afd18b --- /dev/null +++ b/internal/pushrules/default_content.go @@ -0,0 +1,33 @@ +package pushrules + +func defaultContentRules(localpart string) []*Rule { + return []*Rule{ + mRuleContainsUserNameDefinition(localpart), + } +} + +const ( + MRuleContainsUserName = ".m.rule.contains_user_name" +) + +func mRuleContainsUserNameDefinition(localpart string) *Rule { + return &Rule{ + RuleID: MRuleContainsUserName, + Default: true, + Enabled: true, + Pattern: localpart, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: true, + }, + }, + } +} diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go new file mode 100644 index 000000000..6f66fd66a --- /dev/null +++ b/internal/pushrules/default_override.go @@ -0,0 +1,165 @@ +package pushrules + +func defaultOverrideRules(userID string) []*Rule { + return []*Rule{ + &mRuleMasterDefinition, + &mRuleSuppressNoticesDefinition, + mRuleInviteForMeDefinition(userID), + &mRuleMemberEventDefinition, + &mRuleContainsDisplayNameDefinition, + &mRuleTombstoneDefinition, + &mRuleRoomNotifDefinition, + } +} + +const ( + MRuleMaster = ".m.rule.master" + MRuleSuppressNotices = ".m.rule.suppress_notices" + MRuleInviteForMe = ".m.rule.invite_for_me" + MRuleMemberEvent = ".m.rule.member_event" + MRuleContainsDisplayName = ".m.rule.contains_display_name" + MRuleTombstone = ".m.rule.tombstone" + MRuleRoomNotif = ".m.rule.roomnotif" +) + +var ( + mRuleMasterDefinition = Rule{ + RuleID: MRuleMaster, + Default: true, + Enabled: false, + Conditions: []*Condition{}, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleSuppressNoticesDefinition = Rule{ + RuleID: MRuleSuppressNotices, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "content.msgtype", + Pattern: "m.notice", + }, + }, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleMemberEventDefinition = Rule{ + RuleID: MRuleMemberEvent, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.member", + }, + }, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleContainsDisplayNameDefinition = Rule{ + RuleID: MRuleContainsDisplayName, + Default: true, + Enabled: true, + Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}}, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: true, + }, + }, + } + mRuleTombstoneDefinition = Rule{ + RuleID: MRuleTombstone, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.tombstone", + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: "", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleRoomNotifDefinition = Rule{ + RuleID: MRuleRoomNotif, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "content.body", + Pattern: "@room", + }, + { + Kind: SenderNotificationPermissionCondition, + Key: "room", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } +) + +func mRuleInviteForMeDefinition(userID string) *Rule { + return &Rule{ + RuleID: MRuleInviteForMe, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.member", + }, + { + Kind: EventMatchCondition, + Key: "content.membership", + Pattern: "invite", + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: userID, + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } +} diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go new file mode 100644 index 000000000..de72bd526 --- /dev/null +++ b/internal/pushrules/default_underride.go @@ -0,0 +1,119 @@ +package pushrules + +const ( + MRuleCall = ".m.rule.call" + MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one" + MRuleRoomOneToOne = ".m.rule.room_one_to_one" + MRuleMessage = ".m.rule.message" + MRuleEncrypted = ".m.rule.encrypted" +) + +var defaultUnderrideRules = []*Rule{ + &mRuleCallDefinition, + &mRuleEncryptedRoomOneToOneDefinition, + &mRuleRoomOneToOneDefinition, + &mRuleMessageDefinition, + &mRuleEncryptedDefinition, +} + +var ( + mRuleCallDefinition = Rule{ + RuleID: MRuleCall, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.call.invite", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "ring", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleEncryptedRoomOneToOneDefinition = Rule{ + RuleID: MRuleEncryptedRoomOneToOne, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: RoomMemberCountCondition, + Is: "2", + }, + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.encrypted", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleRoomOneToOneDefinition = Rule{ + RuleID: MRuleRoomOneToOne, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: RoomMemberCountCondition, + Is: "2", + }, + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.message", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleMessageDefinition = Rule{ + RuleID: MRuleMessage, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.message", + }, + }, + Actions: []*Action{{Kind: NotifyAction}}, + } + mRuleEncryptedDefinition = Rule{ + RuleID: MRuleEncrypted, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.encrypted", + }, + }, + Actions: []*Action{{Kind: NotifyAction}}, + } +) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go new file mode 100644 index 000000000..df22cb042 --- /dev/null +++ b/internal/pushrules/evaluate.go @@ -0,0 +1,165 @@ +package pushrules + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/gomatrixserverlib" +) + +// A RuleSetEvaluator encapsulates context to evaluate an event +// against a rule set. +type RuleSetEvaluator struct { + ec EvaluationContext + ruleSet []kindAndRules +} + +// An EvaluationContext gives a RuleSetEvaluator access to the +// environment, for rules that require that. +type EvaluationContext interface { + // UserDisplayName returns the current user's display name. + UserDisplayName() string + + // RoomMemberCount returns the number of members in the room of + // the current event. + RoomMemberCount() (int, error) + + // HasPowerLevel returns whether the user has at least the given + // power in the room of the current event. + HasPowerLevel(userID, levelKey string) (bool, error) +} + +// A kindAndRules is just here to simplify iteration of the (ordered) +// kinds of rules. +type kindAndRules struct { + Kind Kind + Rules []*Rule +} + +// NewRuleSetEvaluator creates a new evaluator for the given rule set. +func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator { + return &RuleSetEvaluator{ + ec: ec, + ruleSet: []kindAndRules{ + {OverrideKind, ruleSet.Override}, + {ContentKind, ruleSet.Content}, + {RoomKind, ruleSet.Room}, + {SenderKind, ruleSet.Sender}, + {UnderrideKind, ruleSet.Underride}, + }, + } +} + +// MatchEvent returns the first matching rule. Returns nil if there +// was no match rule. +func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) { + // TODO: server-default rules have lower priority than user rules, + // but they are stored together with the user rules. It's a bit + // unclear what the specification (11.14.1.4 Predefined rules) + // means the ordering should be. + // + // The most reasonable interpretation is that default overrides + // still have lower priority than user content rules, so we + // iterate twice. + for _, rsat := range rse.ruleSet { + for _, defRules := range []bool{false, true} { + for _, rule := range rsat.Rules { + if rule.Default != defRules { + continue + } + ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec) + if err != nil { + return nil, err + } + if ok { + return rule, nil + } + } + } + } + + // No matching rule. + return nil, nil +} + +func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { + if !rule.Enabled { + return false, nil + } + + switch kind { + case OverrideKind, UnderrideKind: + for _, cond := range rule.Conditions { + ok, err := conditionMatches(cond, event, ec) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + } + return true, nil + + case ContentKind: + // TODO: "These configure behaviour for (unencrypted) messages + // that match certain patterns." - Does that mean "content.body"? + return patternMatches("content.body", rule.Pattern, event) + + case RoomKind: + return rule.RuleID == event.RoomID(), nil + + case SenderKind: + return rule.RuleID == event.Sender(), nil + + default: + return false, nil + } +} + +func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { + switch cond.Kind { + case EventMatchCondition: + return patternMatches(cond.Key, cond.Pattern, event) + + case ContainsDisplayNameCondition: + return patternMatches("content.body", ec.UserDisplayName(), event) + + case RoomMemberCountCondition: + cmp, err := parseRoomMemberCountCondition(cond.Is) + if err != nil { + return false, fmt.Errorf("parsing room_member_count condition: %w", err) + } + n, err := ec.RoomMemberCount() + if err != nil { + return false, fmt.Errorf("RoomMemberCount failed: %w", err) + } + return cmp(n), nil + + case SenderNotificationPermissionCondition: + return ec.HasPowerLevel(event.Sender(), cond.Key) + + default: + return false, nil + } +} + +func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) { + re, err := globToRegexp(pattern) + if err != nil { + return false, err + } + + var eventMap map[string]interface{} + if err = json.Unmarshal(event.JSON(), &eventMap); err != nil { + return false, fmt.Errorf("parsing event: %w", err) + } + v, err := lookupMapPath(strings.Split(key, "."), eventMap) + if err != nil { + // An unknown path is a benign error that shouldn't stop rule + // processing. It's just a non-match. + return false, nil + } + + return re.MatchString(fmt.Sprint(v)), nil +} diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go new file mode 100644 index 000000000..50e703365 --- /dev/null +++ b/internal/pushrules/evaluate_test.go @@ -0,0 +1,189 @@ +package pushrules + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/matrix-org/gomatrixserverlib" +) + +func TestRuleSetEvaluatorMatchEvent(t *testing.T) { + ev := mustEventFromJSON(t, `{}`) + defaultEnabled := &Rule{ + RuleID: ".default.enabled", + Default: true, + Enabled: true, + } + userEnabled := &Rule{ + RuleID: ".user.enabled", + Default: false, + Enabled: true, + } + userEnabled2 := &Rule{ + RuleID: ".user.enabled.2", + Default: false, + Enabled: true, + } + tsts := []struct { + Name string + RuleSet RuleSet + Want *Rule + }{ + {"empty", RuleSet{}, nil}, + {"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled}, + {"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled}, + {"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled}, + {"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled}, + {"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled}, + {"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled}, + {"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + rse := NewRuleSetEvaluator(nil, &tst.RuleSet) + got, err := rse.MatchEvent(ev) + if err != nil { + t.Fatalf("MatchEvent failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("MatchEvent rule: +got -want:\n%s", diff) + } + }) + } +} + +func TestRuleMatches(t *testing.T) { + emptyRule := Rule{Enabled: true} + tsts := []struct { + Name string + Kind Kind + Rule Rule + EventJSON string + Want bool + }{ + {"emptyOverride", OverrideKind, emptyRule, `{}`, true}, + {"emptyContent", ContentKind, emptyRule, `{}`, false}, + {"emptyRoom", RoomKind, emptyRule, `{}`, true}, + {"emptySender", SenderKind, emptyRule, `{}`, true}, + {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, + + {"disabled", OverrideKind, Rule{}, `{}`, false}, + + {"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true}, + {"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, + + {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, + {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, + + {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true}, + {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false}, + + {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, + {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, + + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil) + if err != nil { + t.Fatalf("ruleMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("ruleMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +func TestConditionMatches(t *testing.T) { + tsts := []struct { + Name string + Cond Condition + EventJSON string + Want bool + }{ + {"empty", Condition{}, `{}`, false}, + {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, + + {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true}, + + {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, + {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, + + {"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false}, + {"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true}, + {"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false}, + {"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true}, + {"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false}, + {"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true}, + {"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false}, + {"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true}, + {"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false}, + {"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true}, + + {"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true}, + {"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{}) + if err != nil { + t.Fatalf("conditionMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("conditionMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +type fakeEvaluationContext struct{} + +func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" } +func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil } +func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) { + return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil +} + +func TestPatternMatches(t *testing.T) { + tsts := []struct { + Name string + Key string + Pattern string + EventJSON string + Want bool + }{ + {"empty", "", "", `{}`, false}, + + // Note that an empty pattern contains no wildcard characters, + // which implicitly means "*". + {"patternEmpty", "content", "", `{"content":{}}`, true}, + + {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true}, + {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true}, + {"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true}, + {"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true}, + {"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON)) + if err != nil { + t.Fatalf("patternMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("patternMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event { + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7) + if err != nil { + t.Fatal(err) + } + return ev +} diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go new file mode 100644 index 000000000..bbed1f95f --- /dev/null +++ b/internal/pushrules/pushrules.go @@ -0,0 +1,71 @@ +package pushrules + +// An AccountRuleSets carries the rule sets associated with an +// account. +type AccountRuleSets struct { + Global RuleSet `json:"global"` // Required +} + +// A RuleSet contains all the various push rules for an +// account. Listed in decreasing order of priority. +type RuleSet struct { + Override []*Rule `json:"override,omitempty"` + Content []*Rule `json:"content,omitempty"` + Room []*Rule `json:"room,omitempty"` + Sender []*Rule `json:"sender,omitempty"` + Underride []*Rule `json:"underride,omitempty"` +} + +// A Rule contains matchers, conditions and final actions. While +// evaluating, at most one rule is considered matching. +// +// Kind and scope are part of the push rules request/responses, but +// not of the core data model. +type Rule struct { + // RuleID is either a free identifier, or the sender's MXID for + // SenderKind. Required. + RuleID string `json:"rule_id"` + + // Default indicates whether this is a server-defined default, or + // a user-provided rule. Required. + // + // The server-default rules have the lowest priority. + Default bool `json:"default"` + + // Enabled allows the user to disable rules while keeping them + // around. Required. + Enabled bool `json:"enabled"` + + // Actions describe the desired outcome, should the rule + // match. Required. + Actions []*Action `json:"actions"` + + // Conditions provide the rule's conditions for OverrideKind and + // UnderrideKind. Not allowed for other kinds. + Conditions []*Condition `json:"conditions"` + + // Pattern is the body pattern to match for ContentKind. Required + // for that kind. The interpretation is the same as that of + // Condition.Pattern. + Pattern string `json:"pattern"` +} + +// Scope only has one valid value. See also AccountRuleSets. +type Scope string + +const ( + UnknownScope Scope = "" + GlobalScope Scope = "global" +) + +// Kind is the type of push rule. See also RuleSet. +type Kind string + +const ( + UnknownKind Kind = "" + OverrideKind Kind = "override" + ContentKind Kind = "content" + RoomKind Kind = "room" + SenderKind Kind = "sender" + UnderrideKind Kind = "underride" +) diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go new file mode 100644 index 000000000..027d35ef6 --- /dev/null +++ b/internal/pushrules/util.go @@ -0,0 +1,125 @@ +package pushrules + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +// ActionsToTweaks converts a list of actions into a primary action +// kind and a tweaks map. Returns a nil map if it would have been +// empty. +func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) { + var kind ActionKind + tweaks := map[string]interface{}{} + + for _, a := range as { + if a.Kind == SetTweakAction { + tweaks[string(a.Tweak)] = a.Value + continue + } + if kind != UnknownAction { + return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) + } + kind = a.Kind + } + + if len(tweaks) == 0 { + tweaks = nil + } + + return kind, tweaks, nil +} + +// BoolTweakOr returns the named tweak as a boolean, and returns `def` +// on failure. +func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool { + v, ok := tweaks[string(key)] + if !ok { + return def + } + b, ok := v.(bool) + if !ok { + return def + } + return b +} + +// globToRegexp converts a Matrix glob-style pattern to a Regular expression. +func globToRegexp(pattern string) (*regexp.Regexp, error) { + // TODO: It's unclear which glob characters are supported. The only + // place this is discussed is for the unrelated "m.policy.rule.*" + // events. Assuming, the same: /[*?]/ + if !strings.ContainsAny(pattern, "*?") { + pattern = "*" + pattern + "*" + } + + // The defined syntax doesn't allow escaping the glob wildcard + // characters, which makes this a straight-forward + // replace-after-quote. + pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta) + pattern = strings.Replace(pattern, "*", ".*", -1) + pattern = strings.Replace(pattern, "?", ".", -1) + return regexp.Compile("^(" + pattern + ")$") +} + +// globNonMetaRegexp are the characters that are not considered glob +// meta-characters (i.e. may need escaping). +var globNonMetaRegexp = regexp.MustCompile("[^*?]+") + +// lookupMapPath traverses a hierarchical map structure, like the one +// produced by json.Unmarshal, to return the leaf value. Traversing +// arrays/slices is not supported, only objects/maps. +func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) { + if len(path) == 0 { + return nil, fmt.Errorf("empty path") + } + + var v interface{} = m + for i, key := range path { + m, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v) + } + + v, ok = m[key] + if !ok { + return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], ".")) + } + } + + return v, nil +} + +// parseRoomMemberCountCondition parses a string like "2", "==2", "<2" +// into a function that checks if the argument to it fulfils the +// condition. +func parseRoomMemberCountCondition(s string) (func(int) bool, error) { + var b int + var cmp = func(a int) bool { return a == b } + switch { + case strings.HasPrefix(s, "<="): + cmp = func(a int) bool { return a <= b } + s = s[2:] + case strings.HasPrefix(s, ">="): + cmp = func(a int) bool { return a >= b } + s = s[2:] + case strings.HasPrefix(s, "<"): + cmp = func(a int) bool { return a < b } + s = s[1:] + case strings.HasPrefix(s, ">"): + cmp = func(a int) bool { return a > b } + s = s[1:] + case strings.HasPrefix(s, "=="): + // Same cmp as the default. + s = s[2:] + } + + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, err + } + b = int(v) + return cmp, nil +} diff --git a/internal/pushrules/util_test.go b/internal/pushrules/util_test.go new file mode 100644 index 000000000..a951c55a2 --- /dev/null +++ b/internal/pushrules/util_test.go @@ -0,0 +1,169 @@ +package pushrules + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestActionsToTweaks(t *testing.T) { + tsts := []struct { + Name string + Input []*Action + WantKind ActionKind + WantTweaks map[string]interface{} + }{ + {"empty", nil, UnknownAction, nil}, + {"zero", []*Action{{}}, UnknownAction, nil}, + {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil}, + {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}}, + {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}}, + { + "all", + []*Action{ + {Kind: CoalesceAction}, + {Kind: SetTweakAction, Tweak: HighlightTweak}, + {Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}, + }, + CoalesceAction, + map[string]interface{}{"highlight": nil, "sound": "default"}, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + gotKind, gotTweaks, err := ActionsToTweaks(tst.Input) + if err != nil { + t.Fatalf("ActionsToTweaks failed: %v", err) + } + if gotKind != tst.WantKind { + t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind) + } + if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" { + t.Errorf("tweaks: +got -want:\n%s", diff) + } + }) + } +} + +func TestBoolTweakOr(t *testing.T) { + tsts := []struct { + Name string + Input map[string]interface{} + Def bool + Want bool + }{ + {"nil", nil, false, false}, + {"nilValue", map[string]interface{}{"highlight": nil}, true, true}, + {"false", map[string]interface{}{"highlight": false}, true, false}, + {"true", map[string]interface{}{"highlight": true}, false, true}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def) + if got != tst.Want { + t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want) + } + }) + } +} + +func TestGlobToRegexp(t *testing.T) { + tsts := []struct { + Input string + Want string + }{ + {"", "^(.*.*)$"}, + {"a", "^(.*a.*)$"}, + {"a.b", "^(.*a\\.b.*)$"}, + {"a?b", "^(a.b)$"}, + {"a*b*", "^(a.*b.*)$"}, + {"a*b?", "^(a.*b.)$"}, + } + for _, tst := range tsts { + t.Run(tst.Want, func(t *testing.T) { + got, err := globToRegexp(tst.Input) + if err != nil { + t.Fatalf("globToRegexp failed: %v", err) + } + if got.String() != tst.Want { + t.Errorf("got %v, want %v", got.String(), tst.Want) + } + }) + } +} + +func TestLookupMapPath(t *testing.T) { + tsts := []struct { + Path []string + Root map[string]interface{} + Want interface{} + }{ + {[]string{"a"}, map[string]interface{}{"a": "b"}, "b"}, + {[]string{"a"}, map[string]interface{}{"a": 42}, 42}, + {[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"}, + } + for _, tst := range tsts { + t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) { + got, err := lookupMapPath(tst.Path, tst.Root) + if err != nil { + t.Fatalf("lookupMapPath failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("+got -want:\n%s", diff) + } + }) + } +} + +func TestLookupMapPathInvalid(t *testing.T) { + tsts := []struct { + Path []string + Root map[string]interface{} + }{ + {nil, nil}, + {[]string{"a"}, nil}, + {[]string{"a", "b"}, map[string]interface{}{"a": "c"}}, + } + for _, tst := range tsts { + t.Run(fmt.Sprint(tst.Path), func(t *testing.T) { + got, err := lookupMapPath(tst.Path, tst.Root) + if err == nil { + t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got) + } + }) + } +} + +func TestParseRoomMemberCountCondition(t *testing.T) { + tsts := []struct { + Input string + WantTrue []int + WantFalse []int + }{ + {"1", []int{1}, []int{0, 2}}, + {"==1", []int{1}, []int{0, 2}}, + {"<1", []int{0}, []int{1, 2}}, + {"<=1", []int{0, 1}, []int{2}}, + {">1", []int{2}, []int{0, 1}}, + {">=42", []int{42, 43}, []int{41}}, + } + for _, tst := range tsts { + t.Run(tst.Input, func(t *testing.T) { + got, err := parseRoomMemberCountCondition(tst.Input) + if err != nil { + t.Fatalf("parseRoomMemberCountCondition failed: %v", err) + } + for _, v := range tst.WantTrue { + if !got(v) { + t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v) + } + } + for _, v := range tst.WantFalse { + if got(v) { + t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v) + } + } + }) + } +} diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go new file mode 100644 index 000000000..5d260f0b9 --- /dev/null +++ b/internal/pushrules/validate.go @@ -0,0 +1,85 @@ +package pushrules + +import ( + "fmt" + "regexp" +) + +// ValidateRule checks the rule for errors. These follow from Sytests +// and the specification. +func ValidateRule(kind Kind, rule *Rule) []error { + var errs []error + + if !validRuleIDRE.MatchString(rule.RuleID) { + errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID)) + } + + if len(rule.Actions) == 0 { + errs = append(errs, fmt.Errorf("missing actions")) + } + for _, action := range rule.Actions { + errs = append(errs, validateAction(action)...) + } + + for _, cond := range rule.Conditions { + errs = append(errs, validateCondition(cond)...) + } + + switch kind { + case OverrideKind, UnderrideKind: + // The empty list is allowed, but for JSON-encoding reasons, + // it must not be nil. + if rule.Conditions == nil { + errs = append(errs, fmt.Errorf("missing rule conditions")) + } + + case ContentKind: + if rule.Pattern == "" { + errs = append(errs, fmt.Errorf("missing content rule pattern")) + } + + case RoomKind, SenderKind: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind)) + } + + return errs +} + +// validRuleIDRE is a regexp for valid IDs. +// +// TODO: the specification doesn't seem to say what the rule ID syntax +// is. A Sytest fails if it contains a backslash. +var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`) + +// validateAction returns issues with an Action. +func validateAction(action *Action) []error { + var errs []error + + switch action.Kind { + case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind)) + } + + return errs +} + +// validateCondition returns issues with a Condition. +func validateCondition(cond *Condition) []error { + var errs []error + + switch cond.Kind { + case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind)) + } + + return errs +} diff --git a/internal/pushrules/validate_test.go b/internal/pushrules/validate_test.go new file mode 100644 index 000000000..b276eb551 --- /dev/null +++ b/internal/pushrules/validate_test.go @@ -0,0 +1,163 @@ +package pushrules + +import ( + "strings" + "testing" +) + +func TestValidateRuleNegatives(t *testing.T) { + tsts := []struct { + Name string + Kind Kind + Rule Rule + WantErrString string + }{ + {"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"}, + {"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"}, + {"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"}, + {"noActions", OverrideKind, Rule{}, "missing actions"}, + {"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"}, + {"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"}, + {"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"}, + {"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"}, + {"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := ValidateRule(tst.Kind, &tst.Rule) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateRulePositives(t *testing.T) { + tsts := []struct { + Name string + Kind Kind + Rule Rule + WantNoErrString string + }{ + {"invalidKind", OverrideKind, Rule{}, "invalid rule kind"}, + {"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"}, + {"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"}, + {"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"}, + {"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"}, + {"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"}, + {"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"}, + {"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := ValidateRule(tst.Kind, &tst.Rule) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} + +func TestValidateActionNegatives(t *testing.T) { + tsts := []struct { + Name string + Action Action + WantErrString string + }{ + {"emptyKind", Action{}, "invalid rule action kind"}, + {"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateAction(&tst.Action) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateActionPositives(t *testing.T) { + tsts := []struct { + Name string + Action Action + WantNoErrString string + }{ + {"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateAction(&tst.Action) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} + +func TestValidateConditionNegatives(t *testing.T) { + tsts := []struct { + Name string + Condition Condition + WantErrString string + }{ + {"emptyKind", Condition{}, "invalid rule condition kind"}, + {"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateCondition(&tst.Condition) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateConditionPositives(t *testing.T) { + tsts := []struct { + Name string + Condition Condition + WantNoErrString string + }{ + {"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateCondition(&tst.Condition) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 8d0d2dfa5..19483b268 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -163,6 +163,7 @@ type StatementList []struct { func (s StatementList) Prepare(db *sql.DB) (err error) { for _, statement := range s { if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { + err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL) return } } diff --git a/q.sqlite b/q.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..b7d6268e2f2f204e4718d45be1e3ef2bf9e96ae9 GIT binary patch literal 12288 zcmeI#F-rq66u|N16cweUTi55NAcBayRc^8f)$YfIMhRZ;?3sghbm#~0!}(ntO$wd6 zAO8pMg-6Knewp6ubmD`Px29c`L2lJhX|3)>lu|n8LCjsG{&3gCGxhsItLvwqf%#LJ z*(kFyzxy6=j{pJ)Ab"$tmpdir/buildlog" || status=$? + if (( status != 0 )); then + # Docker is very verbose, and we don't really care about + # building SyTest. So we accumulate and only output on + # failure. + cat "$tmpdir/buildlog" >&2 + return $status + fi + fi + + runargs+=( -v "$PWD/../sytest:/sytest:ro" ) + fi + if [ -n "$SYTEST_POSTGRES" ]; then + runargs+=( -e POSTGRES=1 ) + fi + + local sytestout=$PWD/../sytestout + mkdir -p "$sytestout"/{logs,cache/go-build,cache/go-pkg} + docker run \ + --rm \ + --name "sytest-dendrite-${LOGNAME}" \ + -e LOGS_USER=$(id -u) \ + -e LOGS_GROUP=$(id -g) \ + -v "$PWD:/src/:ro" \ + -v "$sytestout/logs:/logs/" \ + -v "$sytestout/cache/go-build:/root/.cache/go-build" \ + -v "$sytestout/cache/go-pkg:/gopath/pkg" \ + "${runargs[@]}" \ + matrixdotorg/sytest-dendrite:latest "$@" +} + +main "$@" diff --git a/setup/base/base.go b/setup/base/base.go index e39977541..ef3b2be29 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -30,6 +30,7 @@ import ( sentryhttp "github.com/getsentry/sentry-go/http" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus/promhttp" "go.uber.org/atomic" @@ -271,6 +272,11 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { return f } +// PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways. +func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client { + return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation) +} + // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. func (b *BaseDendrite) CreateAccountsDB() userdb.Database { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 8f7611f0a..6467b7c82 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -205,6 +205,11 @@ user_api: max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 + pusher_database: + connection_string: file:pushserver.db + max_open_conns: 100 + max_idle_conns: 2 + conn_max_lifetime: -1 tracing: enabled: false jaeger: diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index 1cb5eba18..570dc6030 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -13,6 +13,9 @@ type UserAPI struct { // The length of time an OpenID token is condidered valid in milliseconds OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"` + // Disable TLS validation on HTTPS calls to push gatways. NOT RECOMMENDED! + PushGatewayDisableTLSValidation bool `yaml:"push_gateway_disable_tls_validation"` + // The Account database stores the login details and account information // for local users. It is accessed by the UserAPI. AccountDatabase DatabaseOptions `yaml:"account_database"` diff --git a/setup/jetstream/streams.go b/setup/jetstream/streams.go index 5810a2a91..3f07488f9 100644 --- a/setup/jetstream/streams.go +++ b/setup/jetstream/streams.go @@ -18,7 +18,10 @@ var ( OutputKeyChangeEvent = "OutputKeyChangeEvent" OutputTypingEvent = "OutputTypingEvent" OutputClientData = "OutputClientData" + OutputNotificationData = "OutputNotificationData" OutputReceiptEvent = "OutputReceiptEvent" + OutputStreamEvent = "OutputStreamEvent" + OutputReadUpdate = "OutputReadUpdate" ) var streams = []*nats.StreamConfig{ @@ -58,4 +61,19 @@ var streams = []*nats.StreamConfig{ Retention: nats.InterestPolicy, Storage: nats.FileStorage, }, + { + Name: OutputNotificationData, + Retention: nats.InterestPolicy, + Storage: nats.FileStorage, + }, + { + Name: OutputStreamEvent, + Retention: nats.InterestPolicy, + Storage: nats.FileStorage, + }, + { + Name: OutputReadUpdate, + Retention: nats.InterestPolicy, + Storage: nats.FileStorage, + }, } diff --git a/setup/monolith.go b/setup/monolith.go index 61125e4a9..7dbd2eeaa 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -60,8 +60,8 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB, m.FedClient, m.RoomserverAPI, m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), - m.FederationAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, - &m.Config.MSCs, + m.FederationAPI, m.UserAPI, m.KeyAPI, + m.ExtPublicRoomsProvider, &m.Config.MSCs, ) federationapi.AddPublicRoutes( ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient, diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index c3650085f..f01afce6d 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -17,6 +17,7 @@ package consumers import ( "context" "encoding/json" + "fmt" "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/internal/eventutil" @@ -24,21 +25,26 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) // OutputClientDataConsumer consumes events that originated in the client API server. type OutputClientDataConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream types.StreamProvider - notifier *notifier.Notifier + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + stream types.StreamProvider + notifier *notifier.Notifier + serverName gomatrixserverlib.ServerName + producer *producers.UserAPIReadProducer } // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. @@ -49,15 +55,18 @@ func NewOutputClientDataConsumer( store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, + producer *producers.UserAPIReadProducer, ) *OutputClientDataConsumer { return &OutputClientDataConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), - durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"), - db: store, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), + durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"), + db: store, + notifier: notifier, + stream: stream, + serverName: cfg.Matrix.ServerName, + producer: producer, } } @@ -100,8 +109,48 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) }).Panicf("could not save account data") } + if err = s.sendReadUpdate(ctx, userID, output); err != nil { + log.WithError(err).WithFields(logrus.Fields{ + "user_id": userID, + "room_id": output.RoomID, + }).Errorf("Failed to generate read update") + sentry.CaptureException(err) + return false + } + s.stream.Advance(streamPos) s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) return true } + +func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error { + if output.Type != "m.fully_read" || output.ReadMarker == nil { + return nil + } + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + if serverName != s.serverName { + return nil + } + var readPos types.StreamPosition + var fullyReadPos types.StreamPosition + if output.ReadMarker.Read != "" { + if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil { + return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) + } + } + if output.ReadMarker.FullyRead != "" { + if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil { + return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err) + } + } + if readPos > 0 || fullyReadPos > 0 { + if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil { + return fmt.Errorf("s.producer.SendReadUpdate: %w", err) + } + } + return nil +} diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go index 392840ece..881583449 100644 --- a/syncapi/consumers/eduserver_receipts.go +++ b/syncapi/consumers/eduserver_receipts.go @@ -17,6 +17,7 @@ package consumers import ( "context" "encoding/json" + "fmt" "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/eduserver/api" @@ -24,21 +25,26 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) // OutputReceiptEventConsumer consumes events that originated in the EDU server. type OutputReceiptEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream types.StreamProvider - notifier *notifier.Notifier + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + stream types.StreamProvider + notifier *notifier.Notifier + serverName gomatrixserverlib.ServerName + producer *producers.UserAPIReadProducer } // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. @@ -50,15 +56,18 @@ func NewOutputReceiptEventConsumer( store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, + producer *producers.UserAPIReadProducer, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent), - durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"), - db: store, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent), + durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"), + db: store, + notifier: notifier, + stream: stream, + serverName: cfg.Matrix.ServerName, + producer: producer, } } @@ -92,8 +101,42 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms return true } + if err = s.sendReadUpdate(ctx, output); err != nil { + log.WithError(err).WithFields(logrus.Fields{ + "user_id": output.UserID, + "room_id": output.RoomID, + }).Errorf("Failed to generate read update") + sentry.CaptureException(err) + return false + } + s.stream.Advance(streamPos) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) return true } + +func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output api.OutputReceiptEvent) error { + if output.Type != "m.read" { + return nil + } + _, serverName, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + if serverName != s.serverName { + return nil + } + var readPos types.StreamPosition + if output.EventID != "" { + if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil { + return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) + } + } + if readPos > 0 { + if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil { + return fmt.Errorf("s.producer.SendReadUpdate: %w", err) + } + } + return nil +} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 15485bb35..159657f9f 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -45,6 +46,7 @@ type OutputRoomEventConsumer struct { pduStream types.StreamProvider inviteStream types.StreamProvider notifier *notifier.Notifier + producer *producers.UserAPIStreamEventProducer } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -57,6 +59,7 @@ func NewOutputRoomEventConsumer( pduStream types.StreamProvider, inviteStream types.StreamProvider, rsAPI api.RoomserverInternalAPI, + producer *producers.UserAPIStreamEventProducer, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), @@ -69,6 +72,7 @@ func NewOutputRoomEventConsumer( pduStream: pduStream, inviteStream: inviteStream, rsAPI: rsAPI, + producer: producer, } } @@ -194,6 +198,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( return nil } + if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil { + log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID()) + sentry.CaptureException(err) + return err + } + if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) sentry.CaptureException(err) diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go new file mode 100644 index 000000000..a3b2dd53d --- /dev/null +++ b/syncapi/consumers/userapi.go @@ -0,0 +1,110 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +// OutputNotificationDataConsumer consumes events that originated in +// the Push server. +type OutputNotificationDataConsumer struct { + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + notifier *notifier.Notifier + stream types.StreamProvider +} + +// NewOutputNotificationDataConsumer creates a new consumer. Call +// Start() to begin consuming. +func NewOutputNotificationDataConsumer( + process *process.ProcessContext, + cfg *config.SyncAPI, + js nats.JetStreamContext, + store storage.Database, + notifier *notifier.Notifier, + stream types.StreamProvider, +) *OutputNotificationDataConsumer { + s := &OutputNotificationDataConsumer{ + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Durable("SyncAPINotificationDataConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData), + db: store, + notifier: notifier, + stream: stream, + } + return s +} + +// Start starts consumption. +func (s *OutputNotificationDataConsumer) Start() error { + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) +} + +// onMessage is called when the Sync server receives a new event from +// the push server. It is not safe for this function to be called from +// multiple goroutines, or else the sync stream position may race and +// be incorrectly calculated. +func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + userID := string(msg.Header.Get(jetstream.UserID)) + + // Parse out the event JSON + var data eventutil.NotificationData + if err := json.Unmarshal(msg.Data, &data); err != nil { + sentry.CaptureException(err) + log.WithField("user_id", userID).WithError(err).Error("user API consumer: message parse failure") + return true + } + + streamPos, err := s.db.UpsertRoomUnreadNotificationCounts(ctx, userID, data.RoomID, data.UnreadNotificationCount, data.UnreadHighlightCount) + if err != nil { + sentry.CaptureException(err) + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": data.RoomID, + }).WithError(err).Error("Could not save notification counts") + return false + } + + s.stream.Advance(streamPos) + s.notifier.OnNewNotificationData(userID, types.StreamingToken{NotificationDataPosition: streamPos}) + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": data.RoomID, + "streamPos": streamPos, + }).Trace("Received notification data from user API") + + return true +} diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index d853cc0e4..6a641e6f8 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -217,6 +217,17 @@ func (n *Notifier) OnNewInvite( n.wakeupUsers([]string{wakeUserID}, nil, n.currPos) } +func (n *Notifier) OnNewNotificationData( + userID string, + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + n.wakeupUsers([]string{userID}, nil, n.currPos) +} + // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go index c6d3df7ee..60403d5d5 100644 --- a/syncapi/notifier/notifier_test.go +++ b/syncapi/notifier/notifier_test.go @@ -219,7 +219,7 @@ func TestEDUWakeup(t *testing.T) { go func() { pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) if err != nil { - t.Errorf("TestNewInviteEventForUser error: %w", err) + t.Errorf("TestNewInviteEventForUser error: %v", err) } mustEqualPositions(t, pos, syncPositionNewEDU) wg.Done() diff --git a/syncapi/producers/userapi_readupdate.go b/syncapi/producers/userapi_readupdate.go new file mode 100644 index 000000000..d56cab776 --- /dev/null +++ b/syncapi/producers/userapi_readupdate.go @@ -0,0 +1,62 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package producers + +import ( + "encoding/json" + + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +// UserAPIProducer produces events for the user API server to consume +type UserAPIReadProducer struct { + Topic string + JetStream nats.JetStreamContext +} + +// SendData sends account data to the user API server +func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error { + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + m.Header.Set(jetstream.RoomID, roomID) + + data := types.ReadUpdate{ + UserID: userID, + RoomID: roomID, + Read: readPos, + FullyRead: fullyReadPos, + } + var err error + m.Data, err = json.Marshal(data) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": roomID, + "read_pos": readPos, + "fully_read_pos": fullyReadPos, + }).Tracef("Producing to topic '%s'", p.Topic) + + _, err = p.JetStream.PublishMsg(m) + return err +} diff --git a/syncapi/producers/userapi_streamevent.go b/syncapi/producers/userapi_streamevent.go new file mode 100644 index 000000000..2bbd19c0b --- /dev/null +++ b/syncapi/producers/userapi_streamevent.go @@ -0,0 +1,60 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package producers + +import ( + "encoding/json" + + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +// UserAPIProducer produces events for the user API server to consume +type UserAPIStreamEventProducer struct { + Topic string + JetStream nats.JetStreamContext +} + +// SendData sends account data to the user API server +func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error { + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.RoomID, roomID) + + data := types.StreamedEvent{ + Event: event, + StreamPosition: pos, + } + var err error + m.Data, err = json.Marshal(data) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "room_id": roomID, + "event_id": event.EventID(), + "event_type": event.Type(), + "stream_pos": pos, + }).Tracef("Producing to topic '%s'", p.Topic) + + _, err = p.JetStream.PublishMsg(m) + return err +} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 126bc8658..e44766338 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -18,6 +18,7 @@ import ( "context" eduAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" @@ -31,6 +32,7 @@ type Database interface { MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) + MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) @@ -138,6 +140,12 @@ type Database interface { // GetRoomReceipts gets all receipts for a given roomID GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) + // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. + UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) + + // GetUserUnreadNotificationCounts returns statistics per room a user is interested in. + GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) + SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go new file mode 100644 index 000000000..f3fc4451f --- /dev/null +++ b/syncapi/storage/postgres/notification_data_table.go @@ -0,0 +1,108 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, error) { + _, err := db.Exec(notificationDataSchema) + if err != nil { + return nil, err + } + r := ¬ificationDataStatements{} + return r, sqlutil.StatementList{ + {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, + {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, + {&r.selectMaxID, selectMaxNotificationIDSQL}, + }.Prepare(db) +} + +type notificationDataStatements struct { + upsertRoomUnreadCounts *sql.Stmt + selectUserUnreadCounts *sql.Stmt + selectMaxID *sql.Stmt +} + +const notificationDataSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_notification_data ( + id BIGSERIAL PRIMARY KEY, + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notification_count BIGINT NOT NULL DEFAULT 0, + highlight_count BIGINT NOT NULL DEFAULT 0, + CONSTRAINT syncapi_notification_data_unique UNIQUE (user_id, room_id) +);` + +const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data + (user_id, room_id, notification_count, highlight_count) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, room_id) + DO UPDATE SET notification_count = $3, highlight_count = $4 + RETURNING id` + +const selectUserUnreadNotificationCountsSQL = `SELECT + id, room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE + user_id = $1 AND + id BETWEEN $2 + 1 AND $3` + +const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` + +func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) + return +} + +func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { + rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + + roomCounts := map[string]*eventutil.NotificationData{} + for rows.Next() { + var id types.StreamPosition + var roomID string + var notificationCount, highlightCount int + + if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + return nil, err + } + + roomCounts[roomID] = &eventutil.NotificationData{ + RoomID: roomID, + UnreadNotificationCount: notificationCount, + UnreadHighlightCount: highlightCount, + } + } + return roomCounts, rows.Err() +} + +func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) { + var id int64 + err := r.selectMaxID.QueryRowContext(ctx).Scan(&id) + return id, err +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 6f4e7749d..60fe5b54d 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -90,6 +90,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e if err != nil { return nil, err } + notificationData, err := NewPostgresNotificationDataTable(d.db) + if err != nil { + return nil, err + } m := sqlutil.NewMigrations() deltas.LoadFixSequences(m) deltas.LoadRemoveSendToDeviceSentColumn(m) @@ -110,6 +114,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e SendToDevice: sendToDevice, Receipts: receipts, Memberships: memberships, + NotificationData: notificationData, } return &d, nil } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 819851b33..87d7c6df7 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -48,6 +48,7 @@ type Database struct { Filter tables.Filter Receipts tables.Receipts Memberships tables.Memberships + NotificationData tables.NotificationData } func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) { @@ -102,6 +103,14 @@ func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.S return types.StreamPosition(id), nil } +func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) { + id, err := d.NotificationData.SelectMaxID(ctx) + if err != nil { + return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err) + } + return types.StreamPosition(id), nil +} + func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs) } @@ -956,6 +965,18 @@ func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, stream return receipts, err } +func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, userID, roomID, notificationCount, highlightCount) + return err + }) + return +} + +func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { + return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to) +} + func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) } diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go new file mode 100644 index 000000000..4b3f074db --- /dev/null +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -0,0 +1,108 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) { + _, err := db.Exec(notificationDataSchema) + if err != nil { + return nil, err + } + r := ¬ificationDataStatements{} + return r, sqlutil.StatementList{ + {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, + {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, + {&r.selectMaxID, selectMaxNotificationIDSQL}, + }.Prepare(db) +} + +type notificationDataStatements struct { + upsertRoomUnreadCounts *sql.Stmt + selectUserUnreadCounts *sql.Stmt + selectMaxID *sql.Stmt +} + +const notificationDataSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_notification_data ( + id INTEGER PRIMARY KEY, + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notification_count BIGINT NOT NULL DEFAULT 0, + highlight_count BIGINT NOT NULL DEFAULT 0, + CONSTRAINT syncapi_notifications_unique UNIQUE (user_id, room_id) +);` + +const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data + (user_id, room_id, notification_count, highlight_count) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, room_id) + DO UPDATE SET notification_count = $3, highlight_count = $4 + RETURNING id` + +const selectUserUnreadNotificationCountsSQL = `SELECT + id, room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE + user_id = $1 AND + id BETWEEN $2 + 1 AND $3` + +const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` + +func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) + return +} + +func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { + rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + + roomCounts := map[string]*eventutil.NotificationData{} + for rows.Next() { + var id types.StreamPosition + var roomID string + var notificationCount, highlightCount int + + if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + return nil, err + } + + roomCounts[roomID] = &eventutil.NotificationData{ + RoomID: roomID, + UnreadNotificationCount: notificationCount, + UnreadHighlightCount: highlightCount, + } + } + return roomCounts, rows.Err() +} + +func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) { + var id int64 + err := r.selectMaxID.QueryRowContext(ctx).Scan(&id) + return id, err +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 581ee6928..1b256f91a 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -62,16 +62,19 @@ const selectEventsSQL = "" + const selectRecentEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectRecentEventsForSyncSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectEarlyEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectMaxEventIDSQL = "" + @@ -85,6 +88,7 @@ const selectStateInRangeSQL = "" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2)" + " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const deleteEventsForRoomSQL = "" + @@ -95,10 +99,12 @@ const selectContextEventSQL = "" + const selectContextBeforeEventSQL = "" + "SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectContextAfterEventSQL = "" + "SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters type outputRoomEventsStatements struct { diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 706d43f81..f5ae9fdd7 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -100,6 +100,10 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er if err != nil { return err } + notificationData, err := NewSqliteNotificationDataTable(d.db) + if err != nil { + return err + } m := sqlutil.NewMigrations() deltas.LoadFixSequences(m) deltas.LoadRemoveSendToDeviceSentColumn(m) @@ -120,6 +124,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er SendToDevice: sendToDevice, Receipts: receipts, Memberships: memberships, + NotificationData: notificationData, } return nil } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 1d807ee6b..1ebb42651 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -19,6 +19,7 @@ import ( "database/sql" eduAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -171,3 +172,9 @@ type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) } + +type NotificationData interface { + UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) + SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) + SelectMaxID(ctx context.Context) (int64, error) +} diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go new file mode 100644 index 000000000..8ba9e07ca --- /dev/null +++ b/syncapi/streams/stream_notificationdata.go @@ -0,0 +1,55 @@ +package streams + +import ( + "context" + + "github.com/matrix-org/dendrite/syncapi/types" +) + +type NotificationDataStreamProvider struct { + StreamProvider +} + +func (p *NotificationDataStreamProvider) Setup() { + p.StreamProvider.Setup() + + id, err := p.DB.MaxStreamPositionForNotificationData(context.Background()) + if err != nil { + panic(err) + } + p.latest = id +} + +func (p *NotificationDataStreamProvider) CompleteSync( + ctx context.Context, + req *types.SyncRequest, +) types.StreamPosition { + return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) +} + +func (p *NotificationDataStreamProvider) IncrementalSync( + ctx context.Context, + req *types.SyncRequest, + from, to types.StreamPosition, +) types.StreamPosition { + // We want counts for all possible rooms, so always start from zero. + countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) + if err != nil { + req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") + return from + } + + // We're merely decorating existing rooms. Note that the Join map + // values are not pointers. + for roomID, jr := range req.Response.Rooms.Join { + counts := countsByRoom[roomID] + if counts == nil { + continue + } + + jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount + jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount + req.Response.Rooms.Join[roomID] = jr + } + return to +} diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index c71095af6..17951acb4 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -12,13 +12,14 @@ import ( ) type Streams struct { - PDUStreamProvider types.StreamProvider - TypingStreamProvider types.StreamProvider - ReceiptStreamProvider types.StreamProvider - InviteStreamProvider types.StreamProvider - SendToDeviceStreamProvider types.StreamProvider - AccountDataStreamProvider types.StreamProvider - DeviceListStreamProvider types.StreamProvider + PDUStreamProvider types.StreamProvider + TypingStreamProvider types.StreamProvider + ReceiptStreamProvider types.StreamProvider + InviteStreamProvider types.StreamProvider + SendToDeviceStreamProvider types.StreamProvider + AccountDataStreamProvider types.StreamProvider + DeviceListStreamProvider types.StreamProvider + NotificationDataStreamProvider types.StreamProvider } func NewSyncStreamProviders( @@ -47,6 +48,9 @@ func NewSyncStreamProviders( StreamProvider: StreamProvider{DB: d}, userAPI: userAPI, }, + NotificationDataStreamProvider: &NotificationDataStreamProvider{ + StreamProvider: StreamProvider{DB: d}, + }, DeviceListStreamProvider: &DeviceListStreamProvider{ StreamProvider: StreamProvider{DB: d}, rsAPI: rsAPI, @@ -60,6 +64,7 @@ func NewSyncStreamProviders( streams.InviteStreamProvider.Setup() streams.SendToDeviceStreamProvider.Setup() streams.AccountDataStreamProvider.Setup() + streams.NotificationDataStreamProvider.Setup() streams.DeviceListStreamProvider.Setup() return streams @@ -67,12 +72,13 @@ func NewSyncStreamProviders( func (s *Streams) Latest(ctx context.Context) types.StreamingToken { return types.StreamingToken{ - PDUPosition: s.PDUStreamProvider.LatestPosition(ctx), - TypingPosition: s.TypingStreamProvider.LatestPosition(ctx), - ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx), - InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), - SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), - AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx), - DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx), + PDUPosition: s.PDUStreamProvider.LatestPosition(ctx), + TypingPosition: s.TypingStreamProvider.LatestPosition(ctx), + ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx), + InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), + SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), + AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx), + NotificationDataPosition: s.NotificationDataStreamProvider.LatestPosition(ctx), + DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx), } } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index ca35951a0..2c9920d18 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -189,7 +189,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. currentPos.ApplyUpdates(userStreamListener.GetSyncPosition()) } } else { - syncReq.Log.Debugln("Responding to sync immediately") + syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately") } if syncReq.Since.IsEmpty() { @@ -213,6 +213,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync( syncReq.Context, syncReq, ), + NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync( syncReq.Context, syncReq, ), @@ -244,6 +247,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. syncReq.Context, syncReq, syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, ), + NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition, + ), DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync( syncReq.Context, syncReq, syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 72462459c..cb9890ff7 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/consumers" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" @@ -64,6 +65,18 @@ func AddPublicRoutes( requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier) + userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{ + JetStream: js, + Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent), + } + + userAPIReadUpdateProducer := &producers.UserAPIReadProducer{ + JetStream: js, + Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate), + } + + _ = userAPIReadUpdateProducer + keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), js, keyAPI, rsAPI, syncDB, notifier, @@ -75,7 +88,7 @@ func AddPublicRoutes( roomConsumer := consumers.NewOutputRoomEventConsumer( process, cfg, js, syncDB, notifier, streams.PDUStreamProvider, - streams.InviteStreamProvider, rsAPI, + streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") @@ -83,11 +96,19 @@ func AddPublicRoutes( clientConsumer := consumers.NewOutputClientDataConsumer( process, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, + userAPIReadUpdateProducer, ) if err = clientConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start client data consumer") } + notificationConsumer := consumers.NewOutputNotificationDataConsumer( + process, cfg, js, syncDB, notifier, streams.NotificationDataStreamProvider, + ) + if err = notificationConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start notification data consumer") + } + typingConsumer := consumers.NewOutputTypingEventConsumer( process, cfg, js, syncDB, eduCache, notifier, streams.TypingStreamProvider, ) @@ -104,6 +125,7 @@ func AddPublicRoutes( receiptConsumer := consumers.NewOutputReceiptEventConsumer( process, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, + userAPIReadUpdateProducer, ) if err = receiptConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start receipts consumer") diff --git a/syncapi/types/types.go b/syncapi/types/types.go index c2e8ed01c..4150e6c98 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -95,13 +95,14 @@ const ( ) type StreamingToken struct { - PDUPosition StreamPosition - TypingPosition StreamPosition - ReceiptPosition StreamPosition - SendToDevicePosition StreamPosition - InvitePosition StreamPosition - AccountDataPosition StreamPosition - DeviceListPosition StreamPosition + PDUPosition StreamPosition + TypingPosition StreamPosition + ReceiptPosition StreamPosition + SendToDevicePosition StreamPosition + InvitePosition StreamPosition + AccountDataPosition StreamPosition + DeviceListPosition StreamPosition + NotificationDataPosition StreamPosition } // This will be used as a fallback by json.Marshal. @@ -117,10 +118,11 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) { func (t StreamingToken) String() string { posStr := fmt.Sprintf( - "s%d_%d_%d_%d_%d_%d_%d", + "s%d_%d_%d_%d_%d_%d_%d_%d", t.PDUPosition, t.TypingPosition, t.ReceiptPosition, t.SendToDevicePosition, - t.InvitePosition, t.AccountDataPosition, t.DeviceListPosition, + t.InvitePosition, t.AccountDataPosition, + t.DeviceListPosition, t.NotificationDataPosition, ) return posStr } @@ -142,12 +144,14 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool { return true case t.DeviceListPosition > other.DeviceListPosition: return true + case t.NotificationDataPosition > other.NotificationDataPosition: + return true } return false } func (t *StreamingToken) IsEmpty() bool { - return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition == 0 + return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition+t.NotificationDataPosition == 0 } // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. @@ -185,6 +189,9 @@ func (t *StreamingToken) ApplyUpdates(other StreamingToken) { if other.DeviceListPosition > t.DeviceListPosition { t.DeviceListPosition = other.DeviceListPosition } + if other.NotificationDataPosition > t.NotificationDataPosition { + t.NotificationDataPosition = other.NotificationDataPosition + } } type TopologyToken struct { @@ -277,7 +284,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { // s478_0_0_0_0_13.dl-0-2 but we have now removed partitioned stream positions tok = strings.Split(tok, ".")[0] parts := strings.Split(tok[1:], "_") - var positions [7]StreamPosition + var positions [8]StreamPosition for i, p := range parts { if i >= len(positions) { break @@ -291,13 +298,14 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { positions[i] = StreamPosition(pos) } token = StreamingToken{ - PDUPosition: positions[0], - TypingPosition: positions[1], - ReceiptPosition: positions[2], - SendToDevicePosition: positions[3], - InvitePosition: positions[4], - AccountDataPosition: positions[5], - DeviceListPosition: positions[6], + PDUPosition: positions[0], + TypingPosition: positions[1], + ReceiptPosition: positions[2], + SendToDevicePosition: positions[3], + InvitePosition: positions[4], + AccountDataPosition: positions[5], + DeviceListPosition: positions[6], + NotificationDataPosition: positions[7], } return token, nil } @@ -383,6 +391,10 @@ type JoinResponse struct { AccountData struct { Events []gomatrixserverlib.ClientEvent `json:"events"` } `json:"account_data"` + UnreadNotifications struct { + HighlightCount int `json:"highlight_count"` + NotificationCount int `json:"notification_count"` + } `json:"unread_notifications"` } // NewJoinResponse creates an empty response with initialised arrays. @@ -462,3 +474,16 @@ type Peek struct { New bool Deleted bool } + +type ReadUpdate struct { + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + Read StreamPosition `json:"read,omitempty"` + FullyRead StreamPosition `json:"fully_read,omitempty"` +} + +// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. +type StreamedEvent struct { + Event *gomatrixserverlib.HeaderedEvent `json:"event"` + StreamPosition StreamPosition `json:"stream_position"` +} diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index cda178b37..ff78bfb9d 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -9,10 +9,10 @@ import ( func TestSyncTokens(t *testing.T) { shouldPass := map[string]string{ - "s4_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0}.String(), - "s3_1_0_0_0_0_2": StreamingToken{3, 1, 0, 0, 0, 0, 2}.String(), - "s3_1_2_3_5_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0}.String(), - "t3_1": TopologyToken{3, 1}.String(), + "s4_0_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0}.String(), + "s3_1_0_0_0_0_2_0": StreamingToken{3, 1, 0, 0, 0, 0, 2, 0}.String(), + "s3_1_2_3_5_0_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0, 0}.String(), + "t3_1": TopologyToken{3, 1}.String(), } for a, b := range shouldPass { diff --git a/sytest-blacklist b/sytest-blacklist index e8617dcdf..7f518b21a 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -30,3 +30,9 @@ Local device key changes appear in /keys/changes Remove group category Remove group role +# Flakey +AS-ghosted users can use rooms themselves + +# Flakey, need additional investigation +Messages that notify from another user increment notification_count +Messages that highlight from another user increment unread highlight count diff --git a/sytest-whitelist b/sytest-whitelist index 3e38176f4..602f86465 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -339,17 +339,17 @@ Existing members see new members' join events Inbound federation can receive events Inbound federation can receive redacted events Can logout current device -Can send a message directly to a device using PUT /sendToDevice -Can recv a device message using /sync -Can recv device messages until they are acknowledged -Device messages with the same txn_id are deduplicated -Device messages wake up /sync -Can recv device messages over federation -Device messages over federation wake up /sync -Can send messages with a wildcard device id -Can send messages with a wildcard device id to two devices -Wildcard device messages wake up /sync -Wildcard device messages over federation wake up /sync +Can send a message directly to a device using PUT /sendToDevice +Can recv a device message using /sync +Can recv device messages until they are acknowledged +Device messages with the same txn_id are deduplicated +Device messages wake up /sync +Can recv device messages over federation +Device messages over federation wake up /sync +Can send messages with a wildcard device id +Can send messages with a wildcard device id to two devices +Wildcard device messages wake up /sync +Wildcard device messages over federation wake up /sync Can send a to-device message to two users which both receive it using /sync User can create and send/receive messages in a room with version 6 local user can join room with version 6 @@ -477,7 +477,7 @@ Federation key API can act as a notary server via a GET request Inbound /make_join rejects attempts to join rooms where all users have left Inbound federation rejects invites which include invalid JSON for room version 6 Inbound federation rejects invite rejections which include invalid JSON for room version 6 -GET /capabilities is present and well formed for registered user +GET /capabilities is present and well formed for registered user m.room.history_visibility == "joined" allows/forbids appropriately for Guest users m.room.history_visibility == "joined" allows/forbids appropriately for Real users POST rejects invalid utf-8 in JSON @@ -588,6 +588,59 @@ User can invite remote user to room with version 9 Remote user can backfill in a room with version 9 Can reject invites over federation for rooms with version 9 Can receive redactions from regular users over federation in room version 9 +Pushers created with a different access token are deleted on password change +Pushers created with a the same access token are not deleted on password change +Can fetch a user's pushers +Can add global push rule for room +Can add global push rule for sender +Can add global push rule for content +Can add global push rule for override +Can add global push rule for underride +Can add global push rule for content +New rules appear before old rules by default +Can add global push rule before an existing rule +Can add global push rule after an existing rule +Can delete a push rule +Can disable a push rule +Adding the same push rule twice is idempotent +Can change the actions of default rules +Can change the actions of a user specified rule +Adding a push rule wakes up an incremental /sync +Disabling a push rule wakes up an incremental /sync +Enabling a push rule wakes up an incremental /sync +Setting actions for a push rule wakes up an incremental /sync +Can enable/disable default rules +Trying to add push rule with missing template fails with 400 +Trying to add push rule with missing rule_id fails with 400 +Trying to add push rule with empty rule_id fails with 400 +Trying to add push rule with invalid template fails with 400 +Trying to add push rule with rule_id with slashes fails with 400 +Trying to add push rule with override rule without conditions fails with 400 +Trying to add push rule with underride rule without conditions fails with 400 +Trying to add push rule with condition without kind fails with 400 +Trying to add push rule with content rule without pattern fails with 400 +Trying to add push rule with no actions fails with 400 +Trying to add push rule with invalid action fails with 400 +Trying to add push rule with invalid attr fails with 400 +Trying to add push rule with invalid value for enabled fails with 400 +Trying to get push rules with no trailing slash fails with 400 +Trying to get push rules with scope without trailing slash fails with 400 +Trying to get push rules with template without tailing slash fails with 400 +Trying to get push rules with unknown scope fails with 400 +Trying to get push rules with unknown template fails with 400 +Trying to get push rules with unknown attribute fails with 400 +Getting push rules doesn't corrupt the cache SYN-390 +Test that a message is pushed +Invites are pushed +Rooms with names are correctly named in pushes +Rooms with canonical alias are correctly named in pushed +Rooms with many users are correctly pushed +Don't get pushed for rooms you've muted +Rejected events are not pushed +Test that rejected pushers are removed. +Notifications can be viewed with GET /notifications +Trying to add push rule with no scope fails with 400 +Trying to add push rule with invalid scope fails with 400 Forward extremities remain so even after the next events are populated as outliers If a device list update goes missing, the server resyncs on the next one uploading self-signing key notifies over federation @@ -607,4 +660,4 @@ registration accepts non-ascii passwords registration with inhibit_login inhibits login The operation must be consistent through an interactive authentication session Multiple calls to /sync should not cause 500 errors - +/context/ with lazy_load_members filter works diff --git a/userapi/api/api.go b/userapi/api/api.go index 2be662e55..e9cdbe01c 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/pushrules" ) // UserInternalAPI is the internal API for information about users and devices. @@ -28,6 +29,7 @@ type UserInternalAPI interface { LoginTokenInternalAPI InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error + PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error @@ -37,6 +39,10 @@ type UserInternalAPI interface { PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error + PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error + PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error + PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error @@ -45,6 +51,9 @@ type UserInternalAPI interface { QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error + QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error + QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error + QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error } type PerformKeyBackupRequest struct { @@ -424,3 +433,77 @@ const ( // AccountTypeAppService indicates this is an appservice account AccountTypeAppService AccountType = 4 ) + +type QueryPushersRequest struct { + Localpart string +} + +type QueryPushersResponse struct { + Pushers []Pusher `json:"pushers"` +} + +type PerformPusherSetRequest struct { + Pusher // Anonymous field because that's how clientapi unmarshals it. + Localpart string + Append bool `json:"append"` +} + +type PerformPusherDeletionRequest struct { + Localpart string + SessionID int64 +} + +// Pusher represents a push notification subscriber +type Pusher struct { + SessionID int64 `json:"session_id,omitempty"` + PushKey string `json:"pushkey"` + PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` + Kind PusherKind `json:"kind"` + AppID string `json:"app_id"` + AppDisplayName string `json:"app_display_name"` + DeviceDisplayName string `json:"device_display_name"` + ProfileTag string `json:"profile_tag"` + Language string `json:"lang"` + Data map[string]interface{} `json:"data"` +} + +type PusherKind string + +const ( + EmailKind PusherKind = "email" + HTTPKind PusherKind = "http" +) + +type PerformPushRulesPutRequest struct { + UserID string `json:"user_id"` + RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` +} + +type QueryPushRulesRequest struct { + UserID string `json:"user_id"` +} + +type QueryPushRulesResponse struct { + RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` +} + +type QueryNotificationsRequest struct { + Localpart string `json:"localpart"` // Required. + From string `json:"from,omitempty"` + Limit int `json:"limit,omitempty"` + Only string `json:"only,omitempty"` +} + +type QueryNotificationsResponse struct { + NextToken string `json:"next_token"` + Notifications []*Notification `json:"notifications"` // Required. +} + +type Notification struct { + Actions []*pushrules.Action `json:"actions"` // Required. + Event gomatrixserverlib.ClientEvent `json:"event"` // Required. + ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional. + Read bool `json:"read"` // Required. + RoomID string `json:"room_id"` // Required. + TS gomatrixserverlib.Timestamp `json:"ts"` // Required. +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index aa069f40b..9334f4455 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -79,6 +79,21 @@ func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *Perfor util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res)) return err } +func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error { + err := t.Impl.PerformPusherSet(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error { + err := t.Impl.PerformPusherDeletion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error { + err := t.Impl.PerformPushRulesPut(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res)) + return err +} func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) { t.Impl.QueryKeyBackup(ctx, req, res) util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) @@ -118,6 +133,21 @@ func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryO util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res)) return err } +func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error { + err := t.Impl.QueryPushers(ctx, req, res) + util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error { + err := t.Impl.QueryPushRules(ctx, req, res) + util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error { + err := t.Impl.QueryNotifications(ctx, req, res) + util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res)) + return err +} func js(thing interface{}) string { b, err := json.Marshal(thing) diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go new file mode 100644 index 000000000..2e58020b4 --- /dev/null +++ b/userapi/consumers/syncapi_readupdate.go @@ -0,0 +1,136 @@ +package consumers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/util" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type OutputReadUpdateConsumer struct { + ctx context.Context + cfg *config.UserAPI + jetstream nats.JetStreamContext + durable string + db storage.Database + pgClient pushgateway.Client + ServerName gomatrixserverlib.ServerName + topic string + userAPI uapi.UserInternalAPI + syncProducer *producers.SyncAPI +} + +func NewOutputReadUpdateConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + pgClient pushgateway.Client, + userAPI uapi.UserInternalAPI, + syncProducer *producers.SyncAPI, +) *OutputReadUpdateConsumer { + return &OutputReadUpdateConsumer{ + ctx: process.Context(), + cfg: cfg, + jetstream: js, + db: store, + ServerName: cfg.Matrix.ServerName, + durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate), + pgClient: pgClient, + userAPI: userAPI, + syncProducer: syncProducer, + } +} + +func (s *OutputReadUpdateConsumer) Start() error { + if err := jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { + return err + } + return nil +} + +func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var read types.ReadUpdate + if err := json.Unmarshal(msg.Data, &read); err != nil { + log.WithError(err).Error("userapi clientapi consumer: message parse failure") + return true + } + if read.FullyRead == 0 && read.Read == 0 { + return true + } + + userID := string(msg.Header.Get(jetstream.UserID)) + roomID := string(msg.Header.Get(jetstream.RoomID)) + + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + log.WithError(err).Error("userapi clientapi consumer: SplitID failure") + return true + } + if domain != s.ServerName { + log.Error("userapi clientapi consumer: not a local user") + return true + } + + log := log.WithFields(log.Fields{ + "room_id": roomID, + "user_id": userID, + }) + log.Tracef("Received read update from sync API: %#v", read) + + if read.Read > 0 { + updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true) + if err != nil { + log.WithError(err).Error("userapi EDU consumer") + return false + } + + if updated { + if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil { + log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed") + return false + } + if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") + return false + } + } + } + + if read.FullyRead > 0 { + deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead)) + if err != nil { + log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed") + return false + } + + if deleted { + if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed") + return false + } + + if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil { + log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed") + return false + } + } + } + + return true +} diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go new file mode 100644 index 000000000..110813274 --- /dev/null +++ b/userapi/consumers/syncapi_streamevent.go @@ -0,0 +1,588 @@ +package consumers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/internal/pushrules" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/util" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type OutputStreamEventConsumer struct { + ctx context.Context + cfg *config.UserAPI + userAPI api.UserInternalAPI + rsAPI rsapi.RoomserverInternalAPI + jetstream nats.JetStreamContext + durable string + db storage.Database + topic string + pgClient pushgateway.Client + syncProducer *producers.SyncAPI +} + +func NewOutputStreamEventConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + pgClient pushgateway.Client, + userAPI api.UserInternalAPI, + rsAPI rsapi.RoomserverInternalAPI, + syncProducer *producers.SyncAPI, +) *OutputStreamEventConsumer { + return &OutputStreamEventConsumer{ + ctx: process.Context(), + cfg: cfg, + jetstream: js, + db: store, + durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent), + pgClient: pgClient, + userAPI: userAPI, + rsAPI: rsAPI, + syncProducer: syncProducer, + } +} + +func (s *OutputStreamEventConsumer) Start() error { + if err := jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { + return err + } + return nil +} + +func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output types.StreamedEvent + output.Event = &gomatrixserverlib.HeaderedEvent{} + if err := json.Unmarshal(msg.Data, &output); err != nil { + log.WithError(err).Errorf("userapi consumer: message parse failure") + return true + } + if output.Event.Event == nil { + log.Errorf("userapi consumer: expected event") + return true + } + + log.WithFields(log.Fields{ + "event_id": output.Event.EventID(), + "event_type": output.Event.Type(), + "stream_pos": output.StreamPosition, + }).Tracef("Received message from sync API: %#v", output) + + if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { + log.WithFields(log.Fields{ + "event_id": output.Event.EventID(), + }).WithError(err).Errorf("userapi consumer: process room event failure") + } + + return true +} + +func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { + members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) + if err != nil { + return fmt.Errorf("s.localRoomMembers: %w", err) + } + + if event.Type() == gomatrixserverlib.MRoomMember { + cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll) + var member *localMembership + member, err = newLocalMembership(&cevent) + if err != nil { + return fmt.Errorf("newLocalMembership: %w", err) + } + if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName { + // localRoomMembers only adds joined members. An invite + // should also be pushed to the target user. + members = append(members, member) + } + } + + // TODO: run in parallel with localRoomMembers. + roomName, err := s.roomName(ctx, event) + if err != nil { + return fmt.Errorf("s.roomName: %w", err) + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "num_members": len(members), + "room_size": roomSize, + }).Tracef("Notifying members") + + // Notification.UserIsTarget is a per-member field, so we + // cannot group all users in a single request. + // + // TODO: does it have to be set? It's not required, and + // removing it means we can send all notifications to + // e.g. Element's Push gateway in one go. + for _, mem := range members { + if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { + log.WithFields(log.Fields{ + "localpart": mem.Localpart, + }).WithError(err).Debugf("Unable to push to local user") + continue + } + } + + return nil +} + +type localMembership struct { + gomatrixserverlib.MemberContent + UserID string + Localpart string + Domain gomatrixserverlib.ServerName +} + +func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) { + if event.StateKey == nil { + return nil, fmt.Errorf("missing state_key") + } + + var member localMembership + if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil { + return nil, err + } + + localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey) + if err != nil { + return nil, err + } + + member.UserID = *event.StateKey + member.Localpart = localpart + member.Domain = domain + return &member, nil +} + +// localRoomMembers fetches the current local members of a room, and +// the total number of members. +func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { + req := &rsapi.QueryMembershipsForRoomRequest{ + RoomID: roomID, + JoinedOnly: true, + } + var res rsapi.QueryMembershipsForRoomResponse + + // XXX: This could potentially race if the state for the event is not known yet + // e.g. the event came over federation but we do not have the full state persisted. + if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil { + return nil, 0, err + } + + var members []*localMembership + var ntotal int + for _, event := range res.JoinEvents { + member, err := newLocalMembership(&event) + if err != nil { + log.WithError(err).Errorf("Parsing MemberContent") + continue + } + if member.Membership != gomatrixserverlib.Join { + continue + } + if member.Domain != s.cfg.Matrix.ServerName { + continue + } + + ntotal++ + members = append(members, member) + } + + return members, ntotal, nil +} + +// roomName returns the name in the event (if type==m.room.name), or +// looks it up in roomserver. If there is no name, +// m.room.canonical_alias is consulted. Returns an empty string if the +// room has no name. +func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { + if event.Type() == gomatrixserverlib.MRoomName { + name, err := unmarshalRoomName(event) + if err != nil { + return "", err + } + + if name != "" { + return name, nil + } + } + + req := &rsapi.QueryCurrentStateRequest{ + RoomID: event.RoomID(), + StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple}, + } + var res rsapi.QueryCurrentStateResponse + + if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil { + return "", nil + } + + if eventS := res.StateEvents[roomNameTuple]; eventS != nil { + return unmarshalRoomName(eventS) + } + + if event.Type() == gomatrixserverlib.MRoomCanonicalAlias { + alias, err := unmarshalCanonicalAlias(event) + if err != nil { + return "", err + } + + if alias != "" { + return alias, nil + } + } + + if event = res.StateEvents[canonicalAliasTuple]; event != nil { + return unmarshalCanonicalAlias(event) + } + + return "", nil +} + +var ( + canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias} + roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName} +) + +func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) { + var nc eventutil.NameContent + if err := json.Unmarshal(event.Content(), &nc); err != nil { + return "", fmt.Errorf("unmarshaling NameContent: %w", err) + } + + return nc.Name, nil +} + +func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) { + var cac eventutil.CanonicalAliasContent + if err := json.Unmarshal(event.Content(), &cac); err != nil { + return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err) + } + + return cac.Alias, nil +} + +// notifyLocal finds the right push actions for a local user, given an event. +func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { + actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) + if err != nil { + return err + } + a, tweaks, err := pushrules.ActionsToTweaks(actions) + if err != nil { + return err + } + // TODO: support coalescing. + if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + }).Tracef("Push rule evaluation rejected the event") + return nil + } + + devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) + if err != nil { + return err + } + + n := &api.Notification{ + Actions: actions, + // UNSPEC: the spec doesn't say this is a ClientEvent, but the + // fields seem to match. room_id should be missing, which + // matches the behaviour of FormatSync. + Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync), + // TODO: this is per-device, but it's not part of the primary + // key. So inserting one notification per profile tag doesn't + // make sense. What is this supposed to be? Sytests require it + // to "work", but they only use a single device. + ProfileTag: profileTag, + RoomID: event.RoomID(), + TS: gomatrixserverlib.AsTimestamp(time.Now()), + } + if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { + return err + } + + if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { + return err + } + + // We do this after InsertNotification. Thus, this should always return >=1. + userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + "num_urls": len(devicesByURLAndFormat), + "num_unread": userNumUnreadNotifs, + }).Tracef("Notifying single member") + + // Push gateways are out of our control, and we cannot risk + // looking up the server on a misbehaving push gateway. Each user + // receives a goroutine now that all internal API calls have been + // made. + // + // TODO: think about bounding this to one per user, and what + // ordering guarantees we must provide. + go func() { + // This background processing cannot be tied to a request. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var rejected []*pushgateway.Device + for url, fmts := range devicesByURLAndFormat { + for format, devices := range fmts { + // TODO: support "email". + if !strings.HasPrefix(url, "http") { + continue + } + + // UNSPEC: the specification suggests there can be + // more than one device per request. There is at least + // one Sytest that expects one HTTP request per + // device, rather than per URL. For now, we must + // notify each one separately. + for _, dev := range devices { + rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs)) + if err != nil { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "localpart": mem.Localpart, + }).WithError(err).Errorf("Unable to notify HTTP pusher") + continue + } + rejected = append(rejected, rej...) + } + } + } + + if len(rejected) > 0 { + s.deleteRejectedPushers(ctx, rejected, mem.Localpart) + } + }() + + return nil +} + +// evaluatePushRules fetches and evaluates the push rules of a local +// user. Returns actions (including dont_notify). +func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { + if event.Sender() == mem.UserID { + // SPEC: Homeservers MUST NOT notify the Push Gateway for + // events that the user has sent themselves. + return nil, nil + } + + var res api.QueryPushRulesResponse + if err := s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil { + return nil, err + } + + ec := &ruleSetEvalContext{ + ctx: ctx, + rsAPI: s.rsAPI, + mem: mem, + roomID: event.RoomID(), + roomSize: roomSize, + } + eval := pushrules.NewRuleSetEvaluator(ec, &res.RuleSets.Global) + rule, err := eval.MatchEvent(event.Event) + if err != nil { + return nil, err + } + if rule == nil { + // SPEC: If no rules match an event, the homeserver MUST NOT + // notify the Push Gateway for that event. + return nil, err + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + "rule_id": rule.RuleID, + }).Tracef("Matched a push rule") + + return rule.Actions, nil +} + +type ruleSetEvalContext struct { + ctx context.Context + rsAPI rsapi.RoomserverInternalAPI + mem *localMembership + roomID string + roomSize int +} + +func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName } + +func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } + +func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { + req := &rsapi.QueryLatestEventsAndStateRequest{ + RoomID: rse.roomID, + StateToFetch: []gomatrixserverlib.StateKeyTuple{ + {EventType: gomatrixserverlib.MRoomPowerLevels}, + }, + } + var res rsapi.QueryLatestEventsAndStateResponse + if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil { + return false, err + } + for _, ev := range res.StateEvents { + if ev.Type() != gomatrixserverlib.MRoomPowerLevels { + continue + } + + plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event) + if err != nil { + return false, err + } + return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil + } + return true, nil +} + +// localPushDevices pushes to the configured devices of a local +// user. The map keys are [url][format]. +func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { + pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) + if err != nil { + return nil, "", err + } + + var profileTag string + devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices)) + for _, pusherDevice := range pusherDevices { + if profileTag == "" { + profileTag = pusherDevice.Pusher.ProfileTag + } + + url := pusherDevice.URL + if devicesByURL[url] == nil { + devicesByURL[url] = make(map[string][]*pushgateway.Device, 2) + } + devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device) + } + + return devicesByURL, profileTag, nil +} + +// notifyHTTP performs a notificatation to a Push Gateway. +func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { + logger := log.WithFields(log.Fields{ + "event_id": event.EventID(), + "url": url, + "localpart": localpart, + "num_devices": len(devices), + }) + + var req pushgateway.NotifyRequest + switch format { + case "event_id_only": + req = pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Counts: &pushgateway.Counts{}, + Devices: devices, + EventID: event.EventID(), + RoomID: event.RoomID(), + }, + } + + default: + req = pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Content: event.Content(), + Counts: &pushgateway.Counts{ + Unread: userNumUnreadNotifs, + }, + Devices: devices, + EventID: event.EventID(), + ID: event.EventID(), + RoomID: event.RoomID(), + RoomName: roomName, + Sender: event.Sender(), + Type: event.Type(), + }, + } + if mem, err := event.Membership(); err == nil { + req.Notification.Membership = mem + } + if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) { + req.Notification.UserIsTarget = true + } + } + + logger.Debugf("Notifying push gateway %s", url) + var res pushgateway.NotifyResponse + if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { + logger.WithError(err).Errorf("Failed to notify push gateway %s", url) + return nil, err + } + logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") + + if len(res.Rejected) == 0 { + return nil, nil + } + + devMap := make(map[string]*pushgateway.Device, len(devices)) + for _, d := range devices { + devMap[d.PushKey] = d + } + rejected := make([]*pushgateway.Device, 0, len(res.Rejected)) + for _, pushKey := range res.Rejected { + d := devMap[pushKey] + if d != nil { + rejected = append(rejected, d) + } + } + + return rejected, nil +} + +// deleteRejectedPushers deletes the pushers associated with the given devices. +func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": devices[0].AppID, + "num_devices": len(devices), + }).Warnf("Deleting pushers rejected by the HTTP push gateway") + + for _, d := range devices { + if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { + log.WithFields(log.Fields{ + "localpart": localpart, + }).WithError(err).Errorf("Unable to delete rejected pusher") + } + } +} diff --git a/userapi/internal/api.go b/userapi/internal/api.go index f54cc6137..7a42fc605 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -20,6 +20,8 @@ import ( "encoding/json" "errors" "fmt" + "strconv" + "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -27,16 +29,22 @@ import ( "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) type UserInternalAPI struct { - DB storage.Database - ServerName gomatrixserverlib.ServerName + DB storage.Database + SyncProducer *producers.SyncAPI + + DisableTLSValidation bool + ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService KeyAPI keyapi.KeyInternalAPI @@ -595,3 +603,162 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB } res.Keys = result } + +func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { + if req.Limit == 0 || req.Limit > 1000 { + req.Limit = 1000 + } + + var fromID int64 + var err error + if req.From != "" { + fromID, err = strconv.ParseInt(req.From, 10, 64) + if err != nil { + return fmt.Errorf("QueryNotifications: parsing 'from': %w", err) + } + } + var filter tables.NotificationFilter = tables.AllNotifications + if req.Only == "highlight" { + filter = tables.HighlightNotifications + } + notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter) + if err != nil { + return err + } + if notifs == nil { + // This ensures empty is JSON-encoded as [] instead of null. + notifs = []*api.Notification{} + } + res.Notifications = notifs + if lastID >= 0 { + res.NextToken = strconv.FormatInt(lastID+1, 10) + } + return nil +} + +func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "localpart": req.Localpart, + "pushkey": req.Pusher.PushKey, + "display_name": req.Pusher.AppDisplayName, + }).Info("PerformPusherCreation") + if !req.Append { + err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey) + if err != nil { + return err + } + } + if req.Pusher.Kind == "" { + return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) + } + if req.Pusher.PushKeyTS == 0 { + req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now()) + } + return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) +} + +func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { + pushers, err := a.DB.GetPushers(ctx, req.Localpart) + if err != nil { + return err + } + for i := range pushers { + logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) + if pushers[i].SessionID != req.SessionID { + err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart) + if err != nil { + return err + } + } + } + return nil +} + +func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { + var err error + res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart) + return err +} + +func (a *UserInternalAPI) PerformPushRulesPut( + ctx context.Context, + req *api.PerformPushRulesPutRequest, + _ *struct{}, +) error { + bs, err := json.Marshal(&req.RuleSets) + if err != nil { + return err + } + userReq := api.InputAccountDataRequest{ + UserID: req.UserID, + DataType: pushRulesAccountDataType, + AccountData: json.RawMessage(bs), + } + var userRes api.InputAccountDataResponse // empty + if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { + return err + } + + if err := a.SyncProducer.SendAccountData(req.UserID, "" /* roomID */, pushRulesAccountDataType); err != nil { + util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed") + } + + return nil +} + +func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { + userReq := api.QueryAccountDataRequest{ + UserID: req.UserID, + DataType: pushRulesAccountDataType, + } + var userRes api.QueryAccountDataResponse + if err := a.QueryAccountData(ctx, &userReq, &userRes); err != nil { + return err + } + bs, ok := userRes.GlobalAccountData[pushRulesAccountDataType] + if ok { + // Legacy Dendrite users will have completely empty push rules, so we should + // detect that situation and set some defaults. + var rules struct { + G struct { + Content []json.RawMessage `json:"content"` + Override []json.RawMessage `json:"override"` + Room []json.RawMessage `json:"room"` + Sender []json.RawMessage `json:"sender"` + Underride []json.RawMessage `json:"underride"` + } `json:"global"` + } + if err := json.Unmarshal([]byte(bs), &rules); err == nil { + count := len(rules.G.Content) + len(rules.G.Override) + + len(rules.G.Room) + len(rules.G.Sender) + len(rules.G.Underride) + ok = count > 0 + } + } + if !ok { + // If we didn't find any default push rules then we should just generate some + // fresh ones. + localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) + } + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, a.ServerName) + prbs, err := json.Marshal(pushRuleSets) + if err != nil { + return fmt.Errorf("failed to marshal default push rules: %w", err) + } + if err := a.DB.SaveAccountData(ctx, localpart, "", pushRulesAccountDataType, json.RawMessage(prbs)); err != nil { + return fmt.Errorf("failed to save default push rules: %w", err) + } + res.RuleSets = pushRuleSets + return nil + } + var data pushrules.AccountRuleSets + if err := json.Unmarshal([]byte(bs), &data); err != nil { + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal of push rules failed") + return err + } + res.RuleSets = &data + return nil +} + +const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 1599d4639..8ec649ad0 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -37,6 +37,9 @@ const ( PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" PerformKeyBackupPath = "/userapi/performKeyBackup" + PerformPusherSetPath = "/pushserver/performPusherSet" + PerformPusherDeletionPath = "/pushserver/performPusherDeletion" + PerformPushRulesPutPath = "/pushserver/performPushRulesPut" QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryProfilePath = "/userapi/queryProfile" @@ -46,6 +49,9 @@ const ( QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QuerySearchProfilesPath = "/userapi/querySearchProfiles" QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" + QueryPushersPath = "/pushserver/queryPushers" + QueryPushRulesPath = "/pushserver/queryPushRules" + QueryNotificationsPath = "/pushserver/queryNotifications" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -249,3 +255,58 @@ func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.Query res.Error = err.Error() } } + +func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications") + defer span.Finish() + + return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res) +} + +func (h *httpUserInternalAPI) PerformPusherSet( + ctx context.Context, + request *api.PerformPusherSetRequest, + response *struct{}, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet") + defer span.Finish() + + apiURL := h.apiURL + PerformPusherSetPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformPusherDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers") + defer span.Finish() + + apiURL := h.apiURL + QueryPushersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) PerformPushRulesPut( + ctx context.Context, + request *api.PerformPushRulesPutRequest, + response *struct{}, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut") + defer span.Finish() + + apiURL := h.apiURL + PerformPushRulesPutPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules") + defer span.Finish() + + apiURL := h.apiURL + QueryPushRulesPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index d00ee042c..526f99575 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -265,4 +265,86 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryNotificationsPath, + httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse { + var request api.QueryNotificationsRequest + var response api.QueryNotificationsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryNotifications(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(PerformPusherSetPath, + httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse { + request := api.PerformPusherSetRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformPusherDeletionPath, + httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformPusherDeletionRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(QueryPushersPath, + httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse { + request := api.QueryPushersRequest{} + response := api.QueryPushersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryPushers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(PerformPushRulesPutPath, + httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse { + request := api.PerformPushRulesPutRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(QueryPushRulesPath, + httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse { + request := api.QueryPushRulesRequest{} + response := api.QueryPushRulesResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryPushRules(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go new file mode 100644 index 000000000..4a206f333 --- /dev/null +++ b/userapi/producers/syncapi.go @@ -0,0 +1,104 @@ +package producers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type JetStreamPublisher interface { + PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) +} + +// SyncAPI produces messages for the Sync API server to consume. +type SyncAPI struct { + db storage.Database + producer JetStreamPublisher + clientDataTopic string + notificationDataTopic string +} + +func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { + return &SyncAPI{ + db: db, + producer: js, + clientDataTopic: clientDataTopic, + notificationDataTopic: notificationDataTopic, + } +} + +// SendAccountData sends account data to the Sync API server. +func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string) error { + m := &nats.Msg{ + Subject: p.clientDataTopic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + + var err error + m.Data, err = json.Marshal(eventutil.AccountData{ + RoomID: roomID, + Type: dataType, + }) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": roomID, + "data_type": dataType, + }).Tracef("Producing to topic '%s'", p.clientDataTopic) + + _, err = p.producer.PublishMsg(m) + return err +} + +// GetAndSendNotificationData reads the database and sends data about unread +// notifications to the Sync API server. +func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error { + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + + ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID) + if err != nil { + return err + } + + return p.sendNotificationData(userID, &eventutil.NotificationData{ + RoomID: roomID, + UnreadHighlightCount: int(nhighlight), + UnreadNotificationCount: int(ntotal), + }) +} + +// sendNotificationData sends data about unread notifications to the Sync API server. +func (p *SyncAPI) sendNotificationData(userID string, data *eventutil.NotificationData) error { + m := &nats.Msg{ + Subject: p.notificationDataTopic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + + var err error + m.Data, err = json.Marshal(data) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": data.RoomID, + }).Tracef("Producing to topic '%s'", p.clientDataTopic) + + _, err = p.producer.PublishMsg(m) + return err +} diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index a131dac47..6d22fea9d 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) type Database interface { @@ -89,6 +90,18 @@ type Database interface { // GetLoginTokenDataByToken returns the data associated with the given token. // May return sql.ErrNoRows. GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) + + InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) + + UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error + GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) + RemovePusher(ctx context.Context, appid, pushkey, localpart string) error + RemovePushers(ctx context.Context, appid, pushkey string) error } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go new file mode 100644 index 000000000..7bcc0f9cd --- /dev/null +++ b/userapi/storage/postgres/notifications_table.go @@ -0,0 +1,219 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt +} + +const notificationSchema = ` +CREATE TABLE IF NOT EXISTS userapi_notifications ( + id BIGSERIAL PRIMARY KEY, + localpart TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + stream_pos BIGINT NOT NULL, + ts_ms BIGINT NOT NULL, + highlight BOOLEAN NOT NULL, + notification_json TEXT NOT NULL, + read BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +` + +const insertNotificationSQL = "" + + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + +const deleteNotificationsUpToSQL = "" + + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + +const updateNotificationReadSQL = "" + + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + +const selectNotificationSQL = "" + + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $4" + +const selectNotificationCountSQL = "" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read" + +const selectRoomNotificationCountsSQL = "" + + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + + "WHERE localpart = $1 AND room_id = $2 AND NOT read" + +func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) { + s := ¬ificationsStatements{} + _, err := db.Exec(notificationSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + }.Prepare(db) +} + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + return err +} + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go new file mode 100644 index 000000000..670dc916f --- /dev/null +++ b/userapi/storage/postgres/pusher_table.go @@ -0,0 +1,157 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers +const pushersSchema = ` +CREATE TABLE IF NOT EXISTS userapi_pushers ( + id BIGSERIAL PRIMARY KEY, + -- The Matrix user ID localpart for this pusher + localpart TEXT NOT NULL, + session_id BIGINT DEFAULT NULL, + profile_tag TEXT, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + pushkey_ts_ms BIGINT NOT NULL DEFAULT 0, + lang TEXT NOT NULL, + data TEXT NOT NULL +); + +-- For faster deleting by app_id, pushkey pair. +CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); + +-- For faster retrieving by localpart. +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); + +-- Pushkey must be unique for a given user and app. +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +` + +const insertPusherSQL = "" + + "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + +const selectPushersSQL = "" + + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + +const deletePusherSQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + +const deletePushersByAppIdAndPushKeySQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" + +func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) { + s := &pushersStatements{} + _, err := db.Exec(pushersSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertPusherStmt, insertPusherSQL}, + {&s.selectPushersStmt, selectPushersSQL}, + {&s.deletePusherStmt, deletePusherSQL}, + {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL}, + }.Prepare(db) +} + +type pushersStatements struct { + insertPusherStmt *sql.Stmt + selectPushersStmt *sql.Stmt + deletePusherStmt *sql.Stmt + deletePushersByAppIdAndPushKeyStmt *sql.Stmt +} + +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) InsertPusher( + ctx context.Context, txn *sql.Tx, session_id int64, + pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + logrus.Debugf("Created pusher %d", session_id) + return err +} + +func (s *pushersStatements) SelectPushers( + ctx context.Context, txn *sql.Tx, localpart string, +) ([]api.Pusher, error) { + pushers := []api.Pusher{} + rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) + + if err != nil { + return pushers, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed") + + for rows.Next() { + var pusher api.Pusher + var data []byte + err = rows.Scan( + &pusher.SessionID, + &pusher.PushKey, + &pusher.PushKeyTS, + &pusher.Kind, + &pusher.AppID, + &pusher.AppDisplayName, + &pusher.DeviceDisplayName, + &pusher.ProfileTag, + &pusher.Language, + &data) + if err != nil { + return pushers, err + } + err := json.Unmarshal(data, &pusher.Data) + if err != nil { + return pushers, err + } + pushers = append(pushers, pusher) + } + + logrus.Debugf("Database returned %d pushers", len(pushers)) + return pushers, rows.Err() +} + +// deletePusher removes a single pusher by pushkey and user localpart. +func (s *pushersStatements) DeletePusher( + ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + return err +} + +func (s *pushersStatements) DeletePushers( + ctx context.Context, txn *sql.Tx, appid, pushkey string, +) error { + _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey) + return err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index ac5c59b81..c74a999f4 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -85,6 +85,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err) } + pusherTable, err := NewPostgresPusherTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresPusherTable: %w", err) + } + notificationsTable, err := NewPostgresNotificationTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -95,6 +103,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver OpenIDTokens: openIDTable, Profiles: profilesTable, ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, ServerName: serverName, DB: db, Writer: sqlutil.NewDummyWriter(), diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 5f1f95005..a58974b41 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -29,6 +29,7 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -47,6 +48,8 @@ type Database struct { KeyBackupVersions tables.KeyBackupVersionTable Devices tables.DevicesTable LoginTokens tables.LoginTokenTable + Notifications tables.NotificationTable + Pushers tables.PusherTable LoginTokenLifetime time.Duration ServerName gomatrixserverlib.ServerName BcryptCost int @@ -160,15 +163,12 @@ func (d *Database) createAccount( if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { return nil, err } - if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + prbs, err := json.Marshal(pushRuleSets) + if err != nil { + return nil, err + } + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { return nil, err } return account, nil @@ -670,3 +670,94 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { return d.LoginTokens.SelectLoginToken(ctx, token) } + +func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) + }) +} + +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) + return err + }) + return +} + +func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) + return err + }) + return +} + +func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter) +} + +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { + return d.Notifications.SelectCount(ctx, nil, localpart, filter) +} + +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { + return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID) +} + +func (d *Database) UpsertPusher( + ctx context.Context, p api.Pusher, localpart string, +) error { + data, err := json.Marshal(p.Data) + if err != nil { + return err + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Pushers.InsertPusher( + ctx, txn, + p.SessionID, + p.PushKey, + p.PushKeyTS, + p.Kind, + p.AppID, + p.AppDisplayName, + p.DeviceDisplayName, + p.ProfileTag, + p.Language, + string(data), + localpart) + }) +} + +// GetPushers returns the pushers matching the given localpart. +func (d *Database) GetPushers( + ctx context.Context, localpart string, +) ([]api.Pusher, error) { + return d.Pushers.SelectPushers(ctx, nil, localpart) +} + +// RemovePusher deletes one pusher +// Invoked when `append` is true and `kind` is null in +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set +func (d *Database) RemovePusher( + ctx context.Context, appid, pushkey, localpart string, +) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) + if err == sql.ErrNoRows { + return nil + } + return err + }) +} + +// RemovePushers deletes all pushers that match given App Id and Push Key pair. +// Invoked when `append` parameter is false in +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set +func (d *Database) RemovePushers( + ctx context.Context, appid, pushkey string, +) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Pushers.DeletePushers(ctx, txn, appid, pushkey) + }) +} diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go new file mode 100644 index 000000000..fcfb1aadc --- /dev/null +++ b/userapi/storage/sqlite3/notifications_table.go @@ -0,0 +1,219 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt +} + +const notificationSchema = ` +CREATE TABLE IF NOT EXISTS userapi_notifications ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + localpart TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + stream_pos BIGINT NOT NULL, + ts_ms BIGINT NOT NULL, + highlight BOOLEAN NOT NULL, + notification_json TEXT NOT NULL, + read BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +` + +const insertNotificationSQL = "" + + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + +const deleteNotificationsUpToSQL = "" + + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + +const updateNotificationReadSQL = "" + + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + +const selectNotificationSQL = "" + + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $4" + +const selectNotificationCountSQL = "" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read" + +const selectRoomNotificationCountsSQL = "" + + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + + "WHERE localpart = $1 AND room_id = $2 AND NOT read" + +func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) { + s := ¬ificationsStatements{} + _, err := db.Exec(notificationSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + }.Prepare(db) +} + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + return err +} + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go new file mode 100644 index 000000000..e718792e1 --- /dev/null +++ b/userapi/storage/sqlite3/pusher_table.go @@ -0,0 +1,157 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers +const pushersSchema = ` +CREATE TABLE IF NOT EXISTS userapi_pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The Matrix user ID localpart for this pusher + localpart TEXT NOT NULL, + session_id BIGINT DEFAULT NULL, + profile_tag TEXT, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + pushkey_ts_ms BIGINT NOT NULL DEFAULT 0, + lang TEXT NOT NULL, + data TEXT NOT NULL +); + +-- For faster deleting by app_id, pushkey pair. +CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); + +-- For faster retrieving by localpart. +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); + +-- Pushkey must be unique for a given user and app. +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +` + +const insertPusherSQL = "" + + "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + +const selectPushersSQL = "" + + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + +const deletePusherSQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + +const deletePushersByAppIdAndPushKeySQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" + +func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) { + s := &pushersStatements{} + _, err := db.Exec(pushersSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertPusherStmt, insertPusherSQL}, + {&s.selectPushersStmt, selectPushersSQL}, + {&s.deletePusherStmt, deletePusherSQL}, + {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL}, + }.Prepare(db) +} + +type pushersStatements struct { + insertPusherStmt *sql.Stmt + selectPushersStmt *sql.Stmt + deletePusherStmt *sql.Stmt + deletePushersByAppIdAndPushKeyStmt *sql.Stmt +} + +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) InsertPusher( + ctx context.Context, txn *sql.Tx, session_id int64, + pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, +) error { + _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + logrus.Debugf("Created pusher %d", session_id) + return err +} + +func (s *pushersStatements) SelectPushers( + ctx context.Context, txn *sql.Tx, localpart string, +) ([]api.Pusher, error) { + pushers := []api.Pusher{} + rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) + + if err != nil { + return pushers, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed") + + for rows.Next() { + var pusher api.Pusher + var data []byte + err = rows.Scan( + &pusher.SessionID, + &pusher.PushKey, + &pusher.PushKeyTS, + &pusher.Kind, + &pusher.AppID, + &pusher.AppDisplayName, + &pusher.DeviceDisplayName, + &pusher.ProfileTag, + &pusher.Language, + &data) + if err != nil { + return pushers, err + } + err := json.Unmarshal(data, &pusher.Data) + if err != nil { + return pushers, err + } + pushers = append(pushers, pusher) + } + + logrus.Debugf("Database returned %d pushers", len(pushers)) + return pushers, rows.Err() +} + +// deletePusher removes a single pusher by pushkey and user localpart. +func (s *pushersStatements) DeletePusher( + ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, +) error { + _, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart) + return err +} + +func (s *pushersStatements) DeletePushers( + ctx context.Context, txn *sql.Tx, appid, pushkey string, +) error { + _, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey) + return err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 98c244977..b5bb96c42 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err) } + pusherTable, err := NewSQLitePusherTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresPusherTable: %w", err) + } + notificationsTable, err := NewSQLiteNotificationTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -96,6 +104,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver OpenIDTokens: openIDTable, Profiles: profilesTable, ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, ServerName: serverName, DB: db, Writer: sqlutil.NewExclusiveWriter(), diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 12939ced5..815e51193 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) type AccountDataTable interface { @@ -93,3 +94,42 @@ type ThreePIDTable interface { InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) } + +type PusherTable interface { + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error + SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) + DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error + DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error +} + +type NotificationTable interface { + Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) + Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) +} + +type NotificationFilter uint32 + +const ( + // HighlightNotifications returns notifications that had a + // "highlight" tweak assigned to them from evaluating push rules. + HighlightNotifications NotificationFilter = 1 << iota + + // NonHighlightNotifications returns notifications that don't + // match HighlightNotifications. + NonHighlightNotifications + + // NoNotifications is a filter to exclude all types of + // notifications. It's useful as a zero value, but isn't likely to + // be used in a call to Notifications.Select*. + NoNotifications NotificationFilter = 0 + + // AllNotifications is a filter to include all types of + // notifications in Notifications.Select*. Note that PostgreSQL + // balks if this doesn't fit in INTEGER, even though we use + // uint32. + AllNotifications NotificationFilter = (1 << 31) - 1 +) diff --git a/userapi/userapi.go b/userapi/userapi.go index 4a5793abb..2382e9512 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -18,11 +18,17 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/pushgateway" keyapi "github.com/matrix-org/dendrite/keyserver/api" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/consumers" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/sirupsen/logrus" ) @@ -36,26 +42,49 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + base *base.BaseDendrite, db storage.Database, cfg *config.UserAPI, + appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client, ) api.UserInternalAPI { db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } - return newInternalAPI(db, cfg, appServices, keyAPI) -} + js := jetstream.Prepare(&cfg.Matrix.JetStream) -func newInternalAPI( - db storage.Database, - cfg *config.UserAPI, - appServices []config.ApplicationService, - keyAPI keyapi.KeyInternalAPI, -) api.UserInternalAPI { - return &internal.UserInternalAPI{ - DB: db, - ServerName: cfg.Matrix.ServerName, - AppServices: appServices, - KeyAPI: keyAPI, + syncProducer := producers.NewSyncAPI( + db, js, + // TODO: user API should handle syncs for account data. Right now, + // it's handled by clientapi, and hence uses its topic. When user + // API handles it for all account data, we can remove it from + // here. + cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), + cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData), + ) + + userAPI := &internal.UserInternalAPI{ + DB: db, + SyncProducer: syncProducer, + ServerName: cfg.Matrix.ServerName, + AppServices: appServices, + KeyAPI: keyAPI, + DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, } + + readConsumer := consumers.NewOutputReadUpdateConsumer( + base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer, + ) + if err := readConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API read update consumer") + } + + eventConsumer := consumers.NewOutputStreamEventConsumer( + base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer, + ) + if err := eventConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API streamed event consumer") + } + + return userAPI } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 4214c07f7..25319c4bf 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage" ) @@ -62,7 +63,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s }, } - return newInternalAPI(accountDB, cfg, nil, nil), accountDB + return &internal.UserInternalAPI{ + DB: accountDB, + ServerName: cfg.Matrix.ServerName, + }, accountDB } func TestQueryProfile(t *testing.T) { diff --git a/userapi/util/devices.go b/userapi/util/devices.go new file mode 100644 index 000000000..cbf3bd28f --- /dev/null +++ b/userapi/util/devices.go @@ -0,0 +1,100 @@ +package util + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + log "github.com/sirupsen/logrus" +) + +type PusherDevice struct { + Device pushgateway.Device + Pusher *api.Pusher + URL string + Format string +} + +// GetPushDevices pushes to the configured devices of a local user. +func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { + pushers, err := db.GetPushers(ctx, localpart) + if err != nil { + return nil, err + } + + devices := make([]*PusherDevice, 0, len(pushers)) + for _, pusher := range pushers { + var url, format string + data := pusher.Data + switch pusher.Kind { + case api.EmailKind: + url = "mailto:" + + case api.HTTPKind: + // TODO: The spec says only event_id_only is supported, + // but Sytests assume "" means "full notification". + fmtIface := pusher.Data["format"] + var ok bool + format, ok = fmtIface.(string) + if ok && format != "event_id_only" { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + }).Errorf("Only data.format event_id_only or empty is supported") + continue + } + + urlIface := pusher.Data["url"] + url, ok = urlIface.(string) + if !ok { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + }).Errorf("No data.url configured for HTTP Pusher") + continue + } + data = mapWithout(data, "url") + + default: + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + "kind": pusher.Kind, + }).Errorf("Unhandled pusher kind") + continue + } + + devices = append(devices, &PusherDevice{ + Device: pushgateway.Device{ + AppID: pusher.AppID, + Data: data, + PushKey: pusher.PushKey, + PushKeyTS: pusher.PushKeyTS, + Tweaks: tweaks, + }, + Pusher: &pusher, + URL: url, + Format: format, + }) + } + + return devices, nil +} + +// mapWithout returns a shallow copy of the map, without the given +// key. Returns nil if the resulting map is empty. +func mapWithout(m map[string]interface{}, key string) map[string]interface{} { + ret := make(map[string]interface{}, len(m)) + for k, v := range m { + // The specification says we do not send "url". + if k == key { + continue + } + ret[k] = v + } + if len(ret) == 0 { + return nil + } + return ret +} diff --git a/userapi/util/notify.go b/userapi/util/notify.go new file mode 100644 index 000000000..ff206bd3c --- /dev/null +++ b/userapi/util/notify.go @@ -0,0 +1,76 @@ +package util + +import ( + "context" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + log "github.com/sirupsen/logrus" +) + +// NotifyUserCountsAsync sends notifications to a local user's +// notification destinations. Database lookups run synchronously, but +// a single goroutine is started when talking to the Push +// gateways. There is no way to know when the background goroutine has +// finished. +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error { + pusherDevices, err := GetPushDevices(ctx, localpart, nil, db) + if err != nil { + return err + } + + if len(pusherDevices) == 0 { + return nil + } + + userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": pusherDevices[0].Device.AppID, + "pushkey": pusherDevices[0].Device.PushKey, + }).Tracef("Notifying HTTP push gateway about notification counts") + + // TODO: think about bounding this to one per user, and what + // ordering guarantees we must provide. + go func() { + // This background processing cannot be tied to a request. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // TODO: we could batch all devices with the same URL, but + // Sytest requires consumers/roomserver.go to do it + // one-by-one, so we do the same here. + for _, pusherDevice := range pusherDevices { + // TODO: support "email". + if !strings.HasPrefix(pusherDevice.URL, "http") { + continue + } + + req := pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Counts: &pushgateway.Counts{ + Unread: int(userNumUnreadNotifs), + }, + Devices: []*pushgateway.Device{&pusherDevice.Device}, + }, + } + if err := pgClient.Notify(ctx, pusherDevice.URL, &req, &pushgateway.NotifyResponse{}); err != nil { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": pusherDevice.Device.AppID, + "pushkey": pusherDevice.Device.PushKey, + }).WithError(err).Error("HTTP push gateway request failed") + return + } + } + }() + + return nil +} From bcc27e9e1884add0ba272eaf2bb4c45a03f76e85 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 3 Mar 2022 12:01:19 +0000 Subject: [PATCH 26/33] Only store notifications for users with pushers, de-parallelise `TestSessionCleanUp` for now --- clientapi/routing/register_test.go | 8 ++++---- userapi/consumers/syncapi_streamevent.go | 3 +++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 520f5ddeb..46bae6979 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -214,7 +214,7 @@ func TestSessionCleanUp(t *testing.T) { s := newSessionsDict() t.Run("session is cleaned up after a while", func(t *testing.T) { - t.Parallel() + // t.Parallel() dummySession := "helloWorld" // manually added, as s.addParams() would start the timer with the default timeout s.params[dummySession] = registerRequest{Username: "Testing"} @@ -226,7 +226,7 @@ func TestSessionCleanUp(t *testing.T) { }) t.Run("session is deleted, once the registration completed", func(t *testing.T) { - t.Parallel() + // t.Parallel() dummySession := "helloWorld2" s.startTimer(time.Minute, dummySession) s.deleteSession(dummySession) @@ -236,7 +236,7 @@ func TestSessionCleanUp(t *testing.T) { }) t.Run("session timer is restarted after second call", func(t *testing.T) { - t.Parallel() + // t.Parallel() dummySession := "helloWorld3" // the following will start a timer with the default timeout of 5min s.addParams(dummySession, registerRequest{Username: "Testing"}) @@ -260,4 +260,4 @@ func TestSessionCleanUp(t *testing.T) { t.Error("expected session to device to be delete") } }) -} \ No newline at end of file +} diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go index 110813274..d86078cbf 100644 --- a/userapi/consumers/syncapi_streamevent.go +++ b/userapi/consumers/syncapi_streamevent.go @@ -139,6 +139,9 @@ func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *g // removing it means we can send all notifications to // e.g. Element's Push gateway in one go. for _, mem := range members { + if p, err := s.db.GetPushers(ctx, mem.Localpart); err != nil || len(p) == 0 { + continue + } if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { log.WithFields(log.Fields{ "localpart": mem.Localpart, From 6ed8cf0e07322fd78c41b22fc94c40dfa9c6d47a Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 3 Mar 2022 12:09:16 +0000 Subject: [PATCH 27/33] Handle `ErrNoRows` when sending read updates --- syncapi/consumers/clientapi.go | 5 +++-- syncapi/consumers/eduserver_receipts.go | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index f01afce6d..fcb7b5b1c 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -16,6 +16,7 @@ package consumers import ( "context" + "database/sql" "encoding/json" "fmt" @@ -138,12 +139,12 @@ func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID st var readPos types.StreamPosition var fullyReadPos types.StreamPosition if output.ReadMarker.Read != "" { - if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil { + if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows { return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) } } if output.ReadMarker.FullyRead != "" { - if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil { + if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows { return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err) } } diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go index 881583449..4e4c61c67 100644 --- a/syncapi/consumers/eduserver_receipts.go +++ b/syncapi/consumers/eduserver_receipts.go @@ -16,6 +16,7 @@ package consumers import ( "context" + "database/sql" "encoding/json" "fmt" @@ -129,7 +130,7 @@ func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output } var readPos types.StreamPosition if output.EventID != "" { - if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil { + if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows { return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) } } From 43ab0288f4975453a5532c310c7ed774f5bb398d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 3 Mar 2022 12:37:12 +0000 Subject: [PATCH 28/33] Give more time to `TestSessionCleanUp` tests --- clientapi/routing/register_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 46bae6979..0507116f1 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -219,7 +219,7 @@ func TestSessionCleanUp(t *testing.T) { // manually added, as s.addParams() would start the timer with the default timeout s.params[dummySession] = registerRequest{Username: "Testing"} s.startTimer(time.Millisecond, dummySession) - time.Sleep(time.Millisecond * 2) + time.Sleep(time.Millisecond * 50) if data, ok := s.getParams(dummySession); ok { t.Errorf("expected session to be deleted: %+v", data) } @@ -246,7 +246,7 @@ func TestSessionCleanUp(t *testing.T) { s.getCompletedStages(dummySession) // reset the timer with a lower timeout s.startTimer(time.Millisecond, dummySession) - time.Sleep(time.Millisecond * 2) + time.Sleep(time.Millisecond * 50) if data, ok := s.getParams(dummySession); ok { t.Errorf("expected session to be deleted: %+v", data) } From b6b2455ecd11538445aa938870db3bc7201fdde1 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 3 Mar 2022 14:01:14 +0000 Subject: [PATCH 29/33] Test `/context/ with lazy_load_members filter works` should be OK now --- sytest-blacklist | 1 - 1 file changed, 1 deletion(-) diff --git a/sytest-blacklist b/sytest-blacklist index 7f518b21a..978dcde84 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -24,7 +24,6 @@ Local device key changes get to remote servers with correct prev_id # Flakey Local device key changes appear in /keys/changes -/context/ with lazy_load_members filter works # we don't support groups Remove group category From c44029f269d0b17102dbbf31e50669c7c93fe50e Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Thu, 3 Mar 2022 17:04:18 +0100 Subject: [PATCH 30/33] Don't open two connections for the userapi --- userapi/userapi.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/userapi/userapi.go b/userapi/userapi.go index 2382e9512..8dbc095fb 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -15,8 +15,6 @@ package userapi import ( - "time" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/pushgateway" keyapi "github.com/matrix-org/dendrite/keyserver/api" @@ -46,11 +44,6 @@ func NewInternalAPI( appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client, ) api.UserInternalAPI { - db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to device db") - } - js := jetstream.Prepare(&cfg.Matrix.JetStream) syncProducer := producers.NewSyncAPI( From 5592322e13d0bf741130425079e37979f7637564 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 3 Mar 2022 16:45:06 +0000 Subject: [PATCH 31/33] Clean old notifications regularly (#2244) * Clean old notifications regularly We'll keep highlights for a month and non-highlights for a day, to stop the `userapi_notifications` table from growing indefinitely. We'll also allow storing events even if no pushers are present, because apparently Element Web expects to work that way. * Fix the milliseconds * Use process context * Update sytest lists * Fix build issue --- sytest-blacklist | 1 + sytest-whitelist | 1 - userapi/consumers/syncapi_streamevent.go | 3 -- userapi/storage/interface.go | 1 + .../storage/postgres/notifications_table.go | 28 +++++++++++++++---- userapi/storage/shared/storage.go | 4 +++ .../storage/sqlite3/notifications_table.go | 28 +++++++++++++++---- userapi/storage/tables/interface.go | 1 + userapi/userapi.go | 12 ++++++++ 9 files changed, 63 insertions(+), 16 deletions(-) diff --git a/sytest-blacklist b/sytest-blacklist index 978dcde84..becc500ec 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -35,3 +35,4 @@ AS-ghosted users can use rooms themselves # Flakey, need additional investigation Messages that notify from another user increment notification_count Messages that highlight from another user increment unread highlight count +Notifications can be viewed with GET /notifications \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 602f86465..63d779bf1 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -638,7 +638,6 @@ Rooms with many users are correctly pushed Don't get pushed for rooms you've muted Rejected events are not pushed Test that rejected pushers are removed. -Notifications can be viewed with GET /notifications Trying to add push rule with no scope fails with 400 Trying to add push rule with invalid scope fails with 400 Forward extremities remain so even after the next events are populated as outliers diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go index d86078cbf..110813274 100644 --- a/userapi/consumers/syncapi_streamevent.go +++ b/userapi/consumers/syncapi_streamevent.go @@ -139,9 +139,6 @@ func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *g // removing it means we can send all notifications to // e.g. Element's Push gateway in one go. for _, mem := range members { - if p, err := s.db.GetPushers(ctx, mem.Localpart); err != nil || len(p) == 0 { - continue - } if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { log.WithFields(log.Fields{ "localpart": mem.Localpart, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 6d22fea9d..777067109 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -97,6 +97,7 @@ type Database interface { GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) + DeleteOldNotifications(ctx context.Context) error UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index 7bcc0f9cd..a27c1125e 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "encoding/json" + "time" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -28,12 +29,13 @@ import ( ) type notificationsStatements struct { - insertStmt *sql.Stmt - deleteUpToStmt *sql.Stmt - updateReadStmt *sql.Stmt - selectStmt *sql.Stmt - selectCountStmt *sql.Stmt - selectRoomCountsStmt *sql.Stmt + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt + cleanNotificationsStmt *sql.Stmt } const notificationSchema = ` @@ -77,6 +79,10 @@ const selectRoomNotificationCountsSQL = "" + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + "WHERE localpart = $1 AND room_id = $2 AND NOT read" +const cleanNotificationsSQL = "" + + "DELETE FROM userapi_notifications WHERE" + + " (highlight = FALSE AND ts_ms < $1) OR (highlight = TRUE AND ts_ms < $2)" + func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) { s := ¬ificationsStatements{} _, err := db.Exec(notificationSchema) @@ -90,9 +96,19 @@ func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) {&s.selectStmt, selectNotificationSQL}, {&s.selectCountStmt, selectNotificationCountSQL}, {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + {&s.cleanNotificationsStmt, cleanNotificationsSQL}, }.Prepare(db) } +func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error { + _, err := sqlutil.TxStmt(txn, s.cleanNotificationsStmt).ExecContext( + ctx, + time.Now().AddDate(0, 0, -1).UnixNano()/int64(time.Millisecond), // keep non-highlights for a day + time.Now().AddDate(0, -1, 0).UnixNano()/int64(time.Millisecond), // keep highlights for a month + ) + return err +} + // Insert inserts a notification into the database. func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index a58974b41..febf03221 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -705,6 +705,10 @@ func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roo return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID) } +func (d *Database) DeleteOldNotifications(ctx context.Context) error { + return d.Notifications.Clean(ctx, nil) +} + func (d *Database) UpsertPusher( ctx context.Context, p api.Pusher, localpart string, ) error { diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index fcfb1aadc..df8260251 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "encoding/json" + "time" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -28,12 +29,13 @@ import ( ) type notificationsStatements struct { - insertStmt *sql.Stmt - deleteUpToStmt *sql.Stmt - updateReadStmt *sql.Stmt - selectStmt *sql.Stmt - selectCountStmt *sql.Stmt - selectRoomCountsStmt *sql.Stmt + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt + cleanNotificationsStmt *sql.Stmt } const notificationSchema = ` @@ -77,6 +79,10 @@ const selectRoomNotificationCountsSQL = "" + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + "WHERE localpart = $1 AND room_id = $2 AND NOT read" +const cleanNotificationsSQL = "" + + "DELETE FROM userapi_notifications WHERE" + + " (highlight = FALSE AND ts_ms < $1) OR (highlight = TRUE AND ts_ms < $2)" + func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) { s := ¬ificationsStatements{} _, err := db.Exec(notificationSchema) @@ -90,9 +96,19 @@ func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) { {&s.selectStmt, selectNotificationSQL}, {&s.selectCountStmt, selectNotificationCountSQL}, {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + {&s.cleanNotificationsStmt, cleanNotificationsSQL}, }.Prepare(db) } +func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error { + _, err := sqlutil.TxStmt(txn, s.cleanNotificationsStmt).ExecContext( + ctx, + time.Now().AddDate(0, 0, -1).UnixNano()/int64(time.Millisecond), // keep non-highlights for a day + time.Now().AddDate(0, -1, 0).UnixNano()/int64(time.Millisecond), // keep highlights for a month + ) + return err +} + // Insert inserts a notification into the database. func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 815e51193..99c907b85 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -103,6 +103,7 @@ type PusherTable interface { } type NotificationTable interface { + Clean(ctx context.Context, txn *sql.Tx) error Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) diff --git a/userapi/userapi.go b/userapi/userapi.go index 8dbc095fb..251a4edad 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -15,6 +15,8 @@ package userapi import ( + "time" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/pushgateway" keyapi "github.com/matrix-org/dendrite/keyserver/api" @@ -79,5 +81,15 @@ func NewInternalAPI( logrus.WithError(err).Panic("failed to start user API streamed event consumer") } + var cleanOldNotifs func() + cleanOldNotifs = func() { + logrus.Infof("Cleaning old notifications") + if err := db.DeleteOldNotifications(base.Context()); err != nil { + logrus.WithError(err).Error("Failed to clean old notifications") + } + time.AfterFunc(time.Hour, cleanOldNotifs) + } + time.AfterFunc(time.Minute, cleanOldNotifs) + return userAPI } From 72022a6ecf80177a28883d7016790ea41646d396 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 3 Mar 2022 17:58:24 +0000 Subject: [PATCH 32/33] Return 404 if event given to `/context` was not found (#2245) --- syncapi/routing/context.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 591139717..d07fc0c64 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -17,6 +17,7 @@ package routing import ( "database/sql" "encoding/json" + "fmt" "net/http" "strconv" @@ -102,6 +103,12 @@ func Context( id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) if err != nil { + if err == sql.ErrNoRows { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound(fmt.Sprintf("Event %s not found", eventID)), + } + } logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event") return jsonerror.InternalServerError() } From f75169c3535488ac584c4e270bc9e6b0cb69e539 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 3 Mar 2022 18:24:14 +0000 Subject: [PATCH 33/33] Send profile updates asynchronously (#2246) --- clientapi/routing/profile.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 717cbda75..4aab03714 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -286,7 +286,7 @@ func SetDisplayName( return jsonerror.InternalServerError() } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() }