From 0843bd776e9156b7ae62b504a5c7e8c8b26ff476 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 24 Oct 2022 07:10:50 +0200 Subject: [PATCH 01/10] Fix wrong config key --- setup/config/config_global.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 2efae0d5a..784893d24 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -170,7 +170,7 @@ type ServerNotices struct { // The displayname to be used when sending notices DisplayName string `yaml:"display_name"` // The avatar of this user - AvatarURL string `yaml:"avatar"` + AvatarURL string `yaml:"avatar_url"` // The roomname to be used when creating messages RoomName string `yaml:"room_name"` } From a553fe770575b027809fc0a0f81e6709e6d068df Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 24 Oct 2022 10:07:50 +0100 Subject: [PATCH 02/10] Fix slow querying of cross-signing signatures --- clientapi/routing/keys.go | 6 +++++- keyserver/internal/internal.go | 7 ++++--- keyserver/storage/postgres/cross_signing_sigs_table.go | 2 +- keyserver/storage/sqlite3/cross_signing_sigs_table.go | 4 ++-- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 5c3681382..0c12b1117 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -99,7 +99,11 @@ func (r *queryKeysRequest) GetTimeout() time.Duration { if r.Timeout == 0 { return 10 * time.Second } - return time.Duration(r.Timeout) * time.Millisecond + timeout := time.Duration(r.Timeout) * time.Millisecond + if timeout > time.Second*20 { + timeout = time.Second * 20 + } + return timeout } func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 49ef03054..ff0968b27 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -257,9 +257,6 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.Failures = make(map[string]interface{}) - // get cross-signing keys from the database - a.crossSigningKeysFromDatabase(ctx, req, res) - // make a map from domain to device keys domainToDeviceKeys := make(map[string]map[string][]string) domainToCrossSigningKeys := make(map[string]map[string]struct{}) @@ -336,6 +333,10 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) } + // Now that we've done the potentially expensive work of asking the federation, + // try filling the cross-signing keys from the database that we know about. + a.crossSigningKeysFromDatabase(ctx, req, res) + // Finally, append signatures that we know about // TODO: This is horrible because we need to round-trip the signature from // JSON, add the signatures and marshal it again, for some reason? diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go index 8b2a865b9..4536b7d80 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/keyserver/storage/postgres/cross_signing_sigs_table.go @@ -42,7 +42,7 @@ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_s const selectCrossSigningSigsForTargetSQL = "" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + - " WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3" + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3" const upsertCrossSigningSigsForTargetSQL = "" + "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go index ea431151e..7a153e8fb 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/keyserver/storage/sqlite3/cross_signing_sigs_table.go @@ -42,7 +42,7 @@ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_s const selectCrossSigningSigsForTargetSQL = "" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + - " WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3" + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4" const upsertCrossSigningSigsForTargetSQL = "" + "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + @@ -85,7 +85,7 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, ) (r types.CrossSigningSigMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID) + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID) if err != nil { return nil, err } From 7506e3303e78e47a7bea454de1e726c6f6640d2f Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 24 Oct 2022 17:03:04 +0200 Subject: [PATCH 03/10] Get messages from before user left the room (#2824) This is going to make `Can get rooms/{roomId}/messages for a departed room (SPEC-216)` pass, since we now only grep events from before the user left the room. --- syncapi/routing/messages.go | 33 +++++++++++++++++++++++++-------- sytest-whitelist | 4 +++- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 8f3ed3f5b..86cf8e736 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -83,18 +83,18 @@ func OnIncomingMessagesRequest( defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) // check if the user has already forgotten about this room - isForgotten, roomExists, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI) + membershipResp, err := getMembershipForUser(req.Context(), roomID, device.UserID, rsAPI) if err != nil { return jsonerror.InternalServerError() } - if !roomExists { + if !membershipResp.RoomExists { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("room does not exist"), } } - if isForgotten { + if membershipResp.IsRoomForgotten { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("user already forgot about this room"), @@ -195,6 +195,20 @@ func OnIncomingMessagesRequest( } } + // If the user already left the room, grep events from before that + if membershipResp.Membership == gomatrixserverlib.Leave { + var token types.TopologyToken + token, err = snapshot.EventPositionInTopology(req.Context(), membershipResp.EventID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + } + } + if backwardOrdering { + from = token + } + } + mReq := messagesReq{ ctx: req.Context(), db: db, @@ -283,17 +297,16 @@ func (m *messagesResp) applyLazyLoadMembers( } } -func checkIsRoomForgotten(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (forgotten bool, exists bool, err error) { +func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { req := api.QueryMembershipForUserRequest{ RoomID: roomID, UserID: userID, } - resp := api.QueryMembershipForUserResponse{} if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { - return false, false, err + return api.QueryMembershipForUserResponse{}, err } - return resp.IsRoomForgotten, resp.RoomExists, nil + return resp, nil } // retrieveEvents retrieves events from the local database for a request on @@ -313,7 +326,11 @@ func (r *messagesReq) retrieveEvents() ( } var events []*gomatrixserverlib.HeaderedEvent - util.GetLogger(r.ctx).WithField("start", start).WithField("end", end).Infof("Fetched %d events locally", len(streamEvents)) + util.GetLogger(r.ctx).WithFields(logrus.Fields{ + "start": r.from, + "end": r.to, + "backwards": r.backwardOrdering, + }).Infof("Fetched %d events locally", len(streamEvents)) // There can be two reasons for streamEvents to be empty: either we've // reached the oldest event in the room (or the most recent one, depending diff --git a/sytest-whitelist b/sytest-whitelist index 1387838f7..e92ae6495 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -752,4 +752,6 @@ When user joins a room the state is included in the next sync When user joins a room the state is included in a gapped sync Messages that notify from another user increment notification_count Messages that highlight from another user increment unread highlight count -Notifications can be viewed with GET /notifications \ No newline at end of file +Notifications can be viewed with GET /notifications +Can get rooms/{roomId}/messages for a departed room (SPEC-216) +Local device key changes appear in /keys/changes \ No newline at end of file From 313cb3fd193397536b069d819f8346d625d82af8 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 25 Oct 2022 12:39:10 +0200 Subject: [PATCH 04/10] Filter `/members`, return members at given point (#2827) Makes the tests ``` Can get rooms/{roomId}/members at a given point Can filter rooms/{roomId}/members ``` pass, by moving `/members` and `/joined_members` to the SyncAPI. --- clientapi/routing/joined_rooms.go | 52 +++++++++ clientapi/routing/routing.go | 20 ---- docs/caddy/polylith/Caddyfile | 2 +- docs/hiawatha/polylith-sample.conf | 4 +- docs/nginx/polylith-sample.conf | 4 +- {clientapi => syncapi}/routing/memberships.go | 100 ++++++++++-------- syncapi/routing/routing.go | 33 ++++++ syncapi/storage/interface.go | 5 + syncapi/storage/postgres/memberships_table.go | 35 +++++- syncapi/storage/shared/storage_consumer.go | 8 ++ syncapi/storage/sqlite3/memberships_table.go | 32 +++++- syncapi/storage/tables/interface.go | 5 + syncapi/streams/stream_pdu.go | 8 +- syncapi/types/types.go | 3 + sytest-whitelist | 4 +- 15 files changed, 243 insertions(+), 72 deletions(-) create mode 100644 clientapi/routing/joined_rooms.go rename {clientapi => syncapi}/routing/memberships.go (55%) diff --git a/clientapi/routing/joined_rooms.go b/clientapi/routing/joined_rooms.go new file mode 100644 index 000000000..4bb353ea9 --- /dev/null +++ b/clientapi/routing/joined_rooms.go @@ -0,0 +1,52 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" +) + +type getJoinedRoomsResponse struct { + JoinedRooms []string `json:"joined_rooms"` +} + +func GetJoinedRooms( + req *http.Request, + device *userapi.Device, + rsAPI api.ClientRoomserverAPI, +) util.JSONResponse { + var res api.QueryRoomsForUserResponse + err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ + UserID: device.UserID, + WantMembership: "join", + }, &res) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") + return jsonerror.InternalServerError() + } + if res.RoomIDs == nil { + res.RoomIDs = []string{} + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: getJoinedRoomsResponse{res.RoomIDs}, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 4ca8e59c5..e0e3e33d4 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -950,26 +950,6 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/members", - httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return GetMemberships(req, device, vars["roomID"], false, cfg, rsAPI) - }), - ).Methods(http.MethodGet, http.MethodOptions) - - v3mux.Handle("/rooms/{roomID}/joined_members", - httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return GetMemberships(req, device, vars["roomID"], true, cfg, rsAPI) - }), - ).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/read_markers", httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { diff --git a/docs/caddy/polylith/Caddyfile b/docs/caddy/polylith/Caddyfile index 8aeb9317f..c2d81b49b 100644 --- a/docs/caddy/polylith/Caddyfile +++ b/docs/caddy/polylith/Caddyfile @@ -74,7 +74,7 @@ matrix.example.com { # Change the end of each reverse_proxy line to the correct # address for your various services. @sync_api { - path_regexp /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|relations/.*?|event/.*?))$ + path_regexp /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ } reverse_proxy @sync_api sync_api:8073 diff --git a/docs/hiawatha/polylith-sample.conf b/docs/hiawatha/polylith-sample.conf index 0093fdcf2..eb1dd4f9a 100644 --- a/docs/hiawatha/polylith-sample.conf +++ b/docs/hiawatha/polylith-sample.conf @@ -23,8 +23,10 @@ VirtualHost { # /_matrix/client/.*/rooms/{roomId}/relations/{eventID} # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType} # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}/{eventType} + # /_matrix/client/.*/rooms/{roomId}/members + # /_matrix/client/.*/rooms/{roomId}/joined_members # to sync_api - ReverseProxy = /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|relations/.*?|event/.*?))$ http://localhost:8073 600 + ReverseProxy = /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ http://localhost:8073 600 ReverseProxy = /_matrix/client http://localhost:8071 600 ReverseProxy = /_matrix/federation http://localhost:8072 600 ReverseProxy = /_matrix/key http://localhost:8072 600 diff --git a/docs/nginx/polylith-sample.conf b/docs/nginx/polylith-sample.conf index 6e81eb5f2..0ad24509a 100644 --- a/docs/nginx/polylith-sample.conf +++ b/docs/nginx/polylith-sample.conf @@ -33,8 +33,10 @@ server { # /_matrix/client/.*/rooms/{roomId}/relations/{eventID} # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType} # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}/{eventType} + # /_matrix/client/.*/rooms/{roomId}/members + # /_matrix/client/.*/rooms/{roomId}/joined_members # to sync_api - location ~ /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|relations/.*?|event/.*?))$ { + location ~ /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ { proxy_pass http://sync_api:8073; } diff --git a/clientapi/routing/memberships.go b/syncapi/routing/memberships.go similarity index 55% rename from clientapi/routing/memberships.go rename to syncapi/routing/memberships.go index 9bdd8a4f4..b4e342251 100644 --- a/clientapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -18,22 +18,20 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type getMembershipResponse struct { Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` } -type getJoinedRoomsResponse struct { - JoinedRooms []string `json:"joined_rooms"` -} - // https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-rooms-roomid-joined-members type getJoinedMembersResponse struct { Joined map[string]joinedMember `json:"joined"` @@ -51,19 +49,22 @@ type databaseJoinedMember struct { AvatarURL string `json:"avatar_url"` } -// GetMemberships implements GET /rooms/{roomId}/members +// GetMemberships implements +// +// GET /rooms/{roomId}/members +// GET /rooms/{roomId}/joined_members func GetMemberships( - req *http.Request, device *userapi.Device, roomID string, joinedOnly bool, - _ *config.ClientAPI, - rsAPI api.ClientRoomserverAPI, + req *http.Request, device *userapi.Device, roomID string, + syncDB storage.Database, rsAPI api.SyncRoomserverAPI, + joinedOnly bool, membership, notMembership *string, at string, ) util.JSONResponse { - queryReq := api.QueryMembershipsForRoomRequest{ - JoinedOnly: joinedOnly, - RoomID: roomID, - Sender: device.UserID, + queryReq := api.QueryMembershipForUserRequest{ + RoomID: roomID, + UserID: device.UserID, } - var queryRes api.QueryMembershipsForRoomResponse - if err := rsAPI.QueryMembershipsForRoom(req.Context(), &queryReq, &queryRes); err != nil { + + var queryRes api.QueryMembershipForUserResponse + if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") return jsonerror.InternalServerError() } @@ -75,16 +76,48 @@ func GetMemberships( } } + db, err := syncDB.NewDatabaseSnapshot(req.Context()) + if err != nil { + return jsonerror.InternalServerError() + } + + atToken, err := types.NewTopologyTokenFromString(at) + if err != nil { + if queryRes.HasBeenInRoom && !queryRes.IsInRoom { + // If you have left the room then this will be the members of the room when you left. + atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID) + } else { + // If you are joined to the room then this will be the current members of the room. + atToken, err = db.MaxTopologicalPosition(req.Context(), roomID) + } + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") + return jsonerror.InternalServerError() + } + } + + eventIDs, err := db.SelectMemberships(req.Context(), roomID, atToken, membership, notMembership) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("db.SelectMemberships failed") + return jsonerror.InternalServerError() + } + + result, err := db.Events(req.Context(), eventIDs) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("db.Events failed") + return jsonerror.InternalServerError() + } + if joinedOnly { var res getJoinedMembersResponse res.Joined = make(map[string]joinedMember) - for _, ev := range queryRes.JoinEvents { + for _, ev := range result { var content databaseJoinedMember - if err := json.Unmarshal(ev.Content, &content); err != nil { + if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content") return jsonerror.InternalServerError() } - res.Joined[ev.Sender] = joinedMember(content) + res.Joined[ev.Sender()] = joinedMember(content) } return util.JSONResponse{ Code: http.StatusOK, @@ -93,29 +126,6 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{queryRes.JoinEvents}, - } -} - -func GetJoinedRooms( - req *http.Request, - device *userapi.Device, - rsAPI api.ClientRoomserverAPI, -) util.JSONResponse { - var res api.QueryRoomsForUserResponse - err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ - UserID: device.UserID, - WantMembership: "join", - }, &res) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() - } - if res.RoomIDs == nil { - res.RoomIDs = []string{} - } - return util.JSONResponse{ - Code: http.StatusOK, - JSON: getJoinedRoomsResponse{res.RoomIDs}, + JSON: getMembershipResponse{gomatrixserverlib.HeaderedToClientEvents(result, gomatrixserverlib.FormatSync)}, } } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 71fa93c1e..bc3ad2384 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -172,4 +172,37 @@ func Setup( return Search(req, device, syncDB, fts, nextBatch) }), ).Methods(http.MethodPost, http.MethodOptions) + + v3mux.Handle("/rooms/{roomID}/members", + httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + var membership, notMembership *string + if req.URL.Query().Has("membership") { + m := req.URL.Query().Get("membership") + membership = &m + } + if req.URL.Query().Has("not_membership") { + m := req.URL.Query().Get("not_membership") + notMembership = &m + } + + at := req.URL.Query().Get("at") + return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, false, membership, notMembership, at) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/rooms/{roomID}/joined_members", + httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + at := req.URL.Query().Get("at") + membership := gomatrixserverlib.Join + return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, true, &membership, nil, at) + }), + ).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 02d45f801..af4fce44e 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -178,6 +178,11 @@ type Database interface { ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) UpdateRelations(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error RedactRelations(ctx context.Context, roomID, redactedEventID string) error + SelectMemberships( + ctx context.Context, + roomID string, pos types.TopologyToken, + membership, notMembership *string, + ) (eventIDs []string, err error) } type Presence interface { diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 939d6b3f5..b555e8456 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -20,11 +20,12 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) // The memberships table is designed to track the last time that @@ -69,11 +70,20 @@ const selectHeroesSQL = "" + const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" +const selectMembersSQL = ` +SELECT event_id FROM ( + SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC +) t +WHERE ($3::text IS NULL OR t.membership = $3) + AND ($4::text IS NULL OR t.membership <> $4) +` + type membershipsStatements struct { upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt selectHeroesStmt *sql.Stmt selectMembershipForUserStmt *sql.Stmt + selectMembersStmt *sql.Stmt } func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -87,6 +97,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectHeroesStmt, selectHeroesSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, + {&s.selectMembersStmt, selectMembersSQL}, }.Prepare(db) } @@ -154,3 +165,25 @@ func (s *membershipsStatements) SelectMembershipForUser( } return membership, topologyPos, nil } + +func (s *membershipsStatements) SelectMemberships( + ctx context.Context, txn *sql.Tx, + roomID string, pos types.TopologyToken, + membership, notMembership *string, +) (eventIDs []string, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembersStmt) + rows, err := stmt.QueryContext(ctx, roomID, pos.Depth, membership, notMembership) + if err != nil { + return + } + var ( + eventID string + ) + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + return eventIDs, rows.Err() +} diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index bf12203db..23f53d11f 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -617,3 +617,11 @@ func (d *Database) RedactRelations(ctx context.Context, roomID, redactedEventID return d.Relations.DeleteRelation(ctx, txn, roomID, redactedEventID) }) } + +func (d *Database) SelectMemberships( + ctx context.Context, + roomID string, pos types.TopologyToken, + membership, notMembership *string, +) (eventIDs []string, err error) { + return d.Memberships.SelectMemberships(ctx, nil, roomID, pos, membership, notMembership) +} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 0c966fca0..7e54fac17 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -20,11 +20,12 @@ import ( "fmt" "strings" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) // The memberships table is designed to track the last time that @@ -69,12 +70,20 @@ const selectHeroesSQL = "" + const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" +const selectMembersSQL = ` +SELECT event_id FROM + ( SELECT event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))) t + WHERE ($3 IS NULL OR t.membership = $3) + AND ($4 IS NULL OR t.membership <> $4) +` + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic selectMembershipForUserStmt *sql.Stmt + selectMembersStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -89,6 +98,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, + {&s.selectMembersStmt, selectMembersSQL}, // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic }.Prepare(db) } @@ -170,3 +180,23 @@ func (s *membershipsStatements) SelectMembershipForUser( } return membership, topologyPos, nil } + +func (s *membershipsStatements) SelectMemberships( + ctx context.Context, txn *sql.Tx, + roomID string, pos types.TopologyToken, + membership, notMembership *string, +) (eventIDs []string, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembersStmt) + rows, err := stmt.QueryContext(ctx, roomID, pos.Depth, membership, notMembership) + if err != nil { + return + } + var eventID string + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + return eventIDs, rows.Err() +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index e48c050dd..2c4f04ec2 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -187,6 +187,11 @@ type Memberships interface { SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + SelectMemberships( + ctx context.Context, txn *sql.Tx, + roomID string, pos types.TopologyToken, + membership, notMembership *string, + ) (eventIDs []string, err error) } type NotificationData interface { diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 9ec2b61cd..707dbe8dc 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -473,7 +473,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( var prevBatch *types.TopologyToken if len(recentStreamEvents) > 0 { var backwardTopologyPos, backwardStreamPos types.StreamPosition - backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, recentStreamEvents[0].EventID()) + event := recentStreamEvents[0] + // If this is the beginning of the room, we can't go back further. We're going to return + // the TopologyToken from the last event instead. (Synapse returns the /sync next_Batch) + if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + event = recentStreamEvents[len(recentStreamEvents)-1] + } + backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, event.EventID()) if err != nil { return } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 57ce7b6ff..295187acc 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -234,6 +234,9 @@ func (t *TopologyToken) StreamToken() StreamingToken { } func (t TopologyToken) String() string { + if t.Depth <= 0 && t.PDUPosition <= 0 { + return "" + } return fmt.Sprintf("t%d_%d", t.Depth, t.PDUPosition) } diff --git a/sytest-whitelist b/sytest-whitelist index e92ae6495..e5e405af6 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -754,4 +754,6 @@ Messages that notify from another user increment notification_count Messages that highlight from another user increment unread highlight count Notifications can be viewed with GET /notifications Can get rooms/{roomId}/messages for a departed room (SPEC-216) -Local device key changes appear in /keys/changes \ No newline at end of file +Local device key changes appear in /keys/changes +Can get rooms/{roomId}/members at a given point +Can filter rooms/{roomId}/members \ No newline at end of file From db6a214b046c83d8cacd00608aa464bd040c4997 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 25 Oct 2022 12:28:34 +0100 Subject: [PATCH 05/10] Prettify unit test output --- .github/workflows/dendrite.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index f8019b3ea..a8271b675 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -109,6 +109,11 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} + - name: Set up gotestfmt + uses: gotesttools/gotestfmt-action@v2 + with: + # Optional: pass GITHUB_TOKEN to avoid rate limiting. + token: ${{ secrets.GITHUB_TOKEN }} - uses: actions/cache@v3 with: path: | @@ -117,7 +122,7 @@ jobs: key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go${{ matrix.go }}-test- - - run: go test ./... + - run: go test -json -v ./... 2>&1 | gotestfmt env: POSTGRES_HOST: localhost POSTGRES_USER: postgres From 8b7bf5e7d7dbb7d87848156c27666fc2353efeba Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 25 Oct 2022 15:00:52 +0200 Subject: [PATCH 06/10] Return forbidden if not a member anymore (fix #2802) --- syncapi/routing/memberships.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index b4e342251..c9acc5d2b 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -109,6 +109,12 @@ func GetMemberships( } if joinedOnly { + if !queryRes.IsInRoom { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("You aren't a member of the room and weren't previously a member of the room."), + } + } var res getJoinedMembersResponse res.Joined = make(map[string]joinedMember) for _, ev := range result { From c62ac3d6ad5c60f5f28a0f50bba50f7cbc2436ce Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 25 Oct 2022 15:15:24 +0200 Subject: [PATCH 07/10] Fix `Current state appears in timeline in private history with many messages after` (#2830) The problem was that we weren't getting enough recent events, as most of them were removed by the history visibility filter. Now we're getting all events between the given input range and re-slice the returned values after applying history visibility. --- syncapi/streams/stream_pdu.go | 18 ++++++++++-------- sytest-whitelist | 3 ++- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 707dbe8dc..90cf8ce53 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -227,14 +227,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( stateFilter *gomatrixserverlib.StateFilter, req *types.SyncRequest, ) (types.StreamPosition, error) { - if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { - // make sure we don't leak recent events after the leave event. - // TODO: History visibility makes this somewhat complex to handle correctly. For example: - // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join). - // TODO: This will fail on join -> leave -> sensitive msg -> join -> leave - // in a single /sync request - // This is all "okay" assuming history_visibility == "shared" which it is by default. - r.To = delta.MembershipPos + + originalLimit := eventFilter.Limit + if r.Backwards { + eventFilter.Limit = int(r.From - r.To) } recentStreamEvents, limited, err := snapshot.RecentEvents( ctx, delta.RoomID, r, @@ -303,6 +299,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( logrus.WithError(err).Error("unable to apply history visibility filter") } + if r.Backwards && len(events) > originalLimit { + // We're going backwards and the events are ordered chronologically, so take the last `limit` events + events = events[len(events)-originalLimit:] + limited = true + } + if len(delta.StateEvents) > 0 { updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) } diff --git a/sytest-whitelist b/sytest-whitelist index e5e405af6..60610929a 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -756,4 +756,5 @@ Notifications can be viewed with GET /notifications Can get rooms/{roomId}/messages for a departed room (SPEC-216) Local device key changes appear in /keys/changes Can get rooms/{roomId}/members at a given point -Can filter rooms/{roomId}/members \ No newline at end of file +Can filter rooms/{roomId}/members +Current state appears in timeline in private history with many messages after \ No newline at end of file From 2a4c7f45b37a9bcd1a37d42b0668e0c3dfb29762 Mon Sep 17 00:00:00 2001 From: Neboer <43609792+Neboer@users.noreply.github.com> Date: Wed, 26 Oct 2022 17:04:53 +0800 Subject: [PATCH 08/10] Add support for config "auto_join_rooms" (#2823) Add support for config "auto_join_rooms". Now new accounts can join the rooms in config file automatically. ### Pull Request Checklist * [x] I have justified why this PR doesn't need tests. * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) Signed-off-by: `Rubin Poster ` --- dendrite-sample.monolith.yaml | 8 +++++++ dendrite-sample.polylith.yaml | 8 +++++++ roomserver/api/api.go | 1 + setup/config/config_userapi.go | 4 ++++ userapi/internal/api.go | 42 ++++++++++++++++++++++++++++++++++ userapi/userapi.go | 1 + 6 files changed, 64 insertions(+) diff --git a/dendrite-sample.monolith.yaml b/dendrite-sample.monolith.yaml index eadb74a2a..5195c29bc 100644 --- a/dendrite-sample.monolith.yaml +++ b/dendrite-sample.monolith.yaml @@ -310,6 +310,14 @@ user_api: # The default lifetime is 3600000ms (60 minutes). # openid_token_lifetime_ms: 3600000 + # Users who register on this homeserver will automatically be joined to the rooms listed under "auto_join_rooms" option. + # By default, any room aliases included in this list will be created as a publicly joinable room + # when the first user registers for the homeserver. If the room already exists, + # make certain it is a publicly joinable room, i.e. the join rule of the room must be set to 'public'. + # As Spaces are just rooms under the hood, Space aliases may also be used. + auto_join_rooms: + # - "#main:matrix.org" + # Configuration for Opentracing. # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # how this works and how to set it up. diff --git a/dendrite-sample.polylith.yaml b/dendrite-sample.polylith.yaml index aa7e0cc38..bbbe16fdc 100644 --- a/dendrite-sample.polylith.yaml +++ b/dendrite-sample.polylith.yaml @@ -375,6 +375,14 @@ user_api: # The default lifetime is 3600000ms (60 minutes). # openid_token_lifetime_ms: 3600000 + # Users who register on this homeserver will automatically be joined to the rooms listed under "auto_join_rooms" option. + # By default, any room aliases included in this list will be created as a publicly joinable room + # when the first user registers for the homeserver. If the room already exists, + # make certain it is a publicly joinable room, i.e. the join rule of the room must be set to 'public'. + # As Spaces are just rooms under the hood, Space aliases may also be used. + auto_join_rooms: + # - "#main:matrix.org" + # Configuration for Opentracing. # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # how this works and how to set it up. diff --git a/roomserver/api/api.go b/roomserver/api/api.go index baf63aa31..403bbe8be 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -167,6 +167,7 @@ type UserRoomserverAPI interface { QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error + PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error } type FederationRoomserverAPI interface { diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index 97a6d738b..f8ad41d93 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -19,6 +19,10 @@ type UserAPI struct { // The Account database stores the login details and account information // for local users. It is accessed by the UserAPI. AccountDatabase DatabaseOptions `yaml:"account_database,omitempty"` + + // Users who register on this homeserver will automatically + // be joined to the rooms listed under this option. + AutoJoinRooms []string `yaml:"auto_join_rooms"` } const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 63044eedb..7b94b3da7 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -54,6 +54,7 @@ type UserInternalAPI struct { KeyAPI keyapi.UserKeyAPI RSAPI rsapi.UserRoomserverAPI PgClient pushgateway.Client + Cfg *config.UserAPI } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -130,6 +131,45 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } +func postRegisterJoinRooms(cfg *config.UserAPI, acc *api.Account, rsAPI rsapi.UserRoomserverAPI) { + // POST register behaviour: check if the user is a normal user. + // If the user is a normal user, add user to room specified in the configuration "auto_join_rooms". + if acc.AccountType != api.AccountTypeAppService && acc.AppServiceID == "" { + for room := range cfg.AutoJoinRooms { + userID := userutil.MakeUserID(acc.Localpart, cfg.Matrix.ServerName) + err := addUserToRoom(context.Background(), rsAPI, cfg.AutoJoinRooms[room], acc.Localpart, userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + "user_id": userID, + "room": cfg.AutoJoinRooms[room], + }).WithError(err).Errorf("user failed to auto-join room") + } + } + } +} + +// Add user to a room. This function currently working for auto_join_rooms config, +// which can add a newly registered user to a specified room. +func addUserToRoom( + ctx context.Context, + rsAPI rsapi.UserRoomserverAPI, + roomID string, + username string, + userID string, +) error { + addGroupContent := make(map[string]interface{}) + // This make sure the user's username can be displayed correctly. + // Because the newly-registered user doesn't have an avatar, the avatar_url is not needed. + addGroupContent["displayname"] = username + joinReq := rsapi.PerformJoinRequest{ + RoomIDOrAlias: roomID, + UserID: userID, + Content: addGroupContent, + } + joinRes := rsapi.PerformJoinResponse{} + return rsAPI.PerformJoin(ctx, &joinReq, &joinRes) +} + func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { @@ -174,6 +214,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return err } + postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) + res.AccountCreated = true res.Account = acc return nil diff --git a/userapi/userapi.go b/userapi/userapi.go index d26b4e19a..c077248e2 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -82,6 +82,7 @@ func NewInternalAPI( RSAPI: rsAPI, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, PgClient: pgClient, + Cfg: cfg, } receiptConsumer := consumers.NewOutputReceiptEventConsumer( From f6dea712d2e9c71f6ebe61f90e45a142852432e8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 26 Oct 2022 12:59:19 +0100 Subject: [PATCH 09/10] Initial support for multiple server names (#2829) This PR is the first step towards virtual hosting by laying the groundwork for multiple server names being configured. --- clientapi/auth/password.go | 2 +- clientapi/routing/admin.go | 4 +- clientapi/routing/createroom.go | 25 ++++++-- clientapi/routing/directory.go | 4 +- clientapi/routing/directory_public.go | 3 +- clientapi/routing/login.go | 6 +- clientapi/routing/membership.go | 9 +-- clientapi/routing/openid.go | 2 +- clientapi/routing/profile.go | 34 +++++++++-- clientapi/routing/redaction.go | 3 +- clientapi/routing/register.go | 2 +- clientapi/routing/sendevent.go | 5 +- clientapi/threepid/invites.go | 2 +- clientapi/userutil/userutil.go | 13 +++-- clientapi/userutil/userutil_test.go | 25 ++++++-- federationapi/federationapi.go | 4 +- federationapi/federationapi_keys_test.go | 2 +- federationapi/federationapi_test.go | 1 + federationapi/internal/keys.go | 2 +- federationapi/internal/perform.go | 2 +- federationapi/producers/syncapi.go | 5 +- federationapi/queue/queue_test.go | 2 +- federationapi/routing/routing.go | 50 ++++++++-------- federationapi/storage/postgres/storage.go | 4 +- federationapi/storage/shared/storage.go | 4 +- federationapi/storage/sqlite3/storage.go | 4 +- federationapi/storage/storage.go | 6 +- federationapi/storage/storage_test.go | 2 +- go.mod | 2 +- go.sum | 4 +- roomserver/internal/perform/perform_admin.go | 11 +++- roomserver/internal/perform/perform_invite.go | 4 +- roomserver/internal/perform/perform_join.go | 23 +++++--- roomserver/internal/perform/perform_leave.go | 6 +- roomserver/internal/perform/perform_peek.go | 6 +- roomserver/internal/perform/perform_unpeek.go | 2 +- .../internal/perform/perform_upgrade.go | 37 +++++++----- setup/config/config_global.go | 15 +++++ setup/mscs/msc2836/msc2836.go | 2 +- setup/mscs/msc2946/msc2946.go | 2 +- test/testrig/base.go | 2 + userapi/api/api.go | 31 ++++++++-- userapi/internal/api.go | 58 +++++++++++++------ userapi/internal/api_logintoken.go | 8 +-- userapi/userapi.go | 2 +- userapi/userapi_test.go | 4 +- 46 files changed, 291 insertions(+), 155 deletions(-) diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 890b18183..700a72f5d 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -74,7 +74,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, JSON: jsonerror.BadJSON("A password must be supplied."), } } - localpart, err := userutil.ParseUsernameParam(username, &t.Config.Matrix.ServerName) + localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 89c269f1a..69bca13be 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -70,7 +70,7 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.MissingArgument("User ID must belong to this server."), @@ -169,7 +169,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - if domain == cfg.Matrix.ServerName { + if cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidParam("Can not mark local device list as stale"), diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 3e837c864..eefe8e24b 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -169,9 +169,21 @@ func createRoom( asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, ) util.JSONResponse { + _, userDomain, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError() + } + if !cfg.Matrix.IsLocalServerName(userDomain) { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(fmt.Sprintf("User domain %q not configured locally", userDomain)), + } + } + // TODO (#267): Check room ID doesn't clash with an existing one, and we // probably shouldn't be using pseudo-random strings, maybe GUIDs? - roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) + roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) logger := util.GetLogger(ctx) userID := device.UserID @@ -314,7 +326,7 @@ func createRoom( var roomAlias string if r.RoomAliasName != "" { - roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, cfg.Matrix.ServerName) + roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, userDomain) // check it's free TODO: This races but is better than nothing hasAliasReq := roomserverAPI.GetRoomIDForAliasRequest{ Alias: roomAlias, @@ -436,7 +448,7 @@ func createRoom( builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} } var ev *gomatrixserverlib.Event - ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion) + ev, err = buildEvent(&builder, userDomain, &authEvents, cfg, evTime, roomVersion) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildEvent failed") return jsonerror.InternalServerError() @@ -461,7 +473,7 @@ func createRoom( inputs = append(inputs, roomserverAPI.InputRoomEvent{ Kind: roomserverAPI.KindNew, Event: event, - Origin: cfg.Matrix.ServerName, + Origin: userDomain, SendAsServer: roomserverAPI.DoNotSendToOtherServers, }) } @@ -548,7 +560,7 @@ func createRoom( Event: event, InviteRoomState: inviteStrippedState, RoomVersion: event.RoomVersion, - SendAsServer: string(cfg.Matrix.ServerName), + SendAsServer: string(userDomain), }, &inviteRes); err != nil { util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ @@ -591,6 +603,7 @@ func createRoom( // buildEvent fills out auth_events for the builder then builds the event func buildEvent( builder *gomatrixserverlib.EventBuilder, + serverName gomatrixserverlib.ServerName, provider gomatrixserverlib.AuthEventProvider, cfg *config.ClientAPI, evTime time.Time, @@ -606,7 +619,7 @@ func buildEvent( } builder.AuthEvents = refs event, err := builder.Build( - evTime, cfg.Matrix.ServerName, cfg.Matrix.KeyID, + evTime, serverName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, roomVersion, ) if err != nil { diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 836d9e152..33bc63d18 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -75,7 +75,7 @@ func DirectoryRoom( if res.RoomID == "" { // If we don't know it locally, do a federation query. // But don't send the query to ourselves. - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) if fedErr != nil { // TODO: Return 502 if the remote server errored. @@ -127,7 +127,7 @@ func SetLocalAlias( } } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Alias must be on local homeserver"), diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 8ddb3267a..4ebf2295a 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -62,8 +62,7 @@ func GetPostPublicRooms( } serverName := gomatrixserverlib.ServerName(request.Server) - - if serverName != "" && serverName != cfg.Matrix.ServerName { + if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( req.Context(), serverName, int(request.Limit), request.Since, diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 6017b5840..7f5a8c4f8 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -68,7 +68,7 @@ func Login( return *authErr } // make a device/access token - authErr2 := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + authErr2 := completeAuth(req.Context(), cfg.Matrix, userAPI, login, req.RemoteAddr, req.UserAgent()) cleanup(req.Context(), &authErr2) return authErr2 } @@ -79,7 +79,7 @@ func Login( } func completeAuth( - ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.ClientUserAPI, login *auth.Login, + ctx context.Context, cfg *config.Global, userAPI userapi.ClientUserAPI, login *auth.Login, ipAddr, userAgent string, ) util.JSONResponse { token, err := auth.GenerateAccessToken() @@ -88,7 +88,7 @@ func completeAuth( return jsonerror.InternalServerError() } - localpart, err := userutil.ParseUsernameParam(login.Username(), &serverName) + localpart, serverName, err := userutil.ParseUsernameParam(login.Username(), cfg) if err != nil { util.GetLogger(ctx).WithError(err).Error("auth.ParseUsernameParam failed") return jsonerror.InternalServerError() diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 77f627eb2..94ba17a02 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -105,12 +105,13 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic return jsonerror.InternalServerError() } + serverName := device.UserDomain() if err = roomserverAPI.SendEvents( ctx, rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, - cfg.Matrix.ServerName, - cfg.Matrix.ServerName, + serverName, + serverName, nil, false, ); err != nil { @@ -271,7 +272,7 @@ func sendInvite( Event: event, InviteRoomState: nil, // ask the roomserver to draw up invite room state for us RoomVersion: event.RoomVersion, - SendAsServer: string(cfg.Matrix.ServerName), + SendAsServer: string(device.UserDomain()), }, &inviteRes); err != nil { util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ @@ -341,7 +342,7 @@ func loadProfile( } var profile *authtypes.Profile - if serverName == cfg.Matrix.ServerName { + if cfg.Matrix.IsLocalServerName(serverName) { profile, err = appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI) } else { profile = &authtypes.Profile{} diff --git a/clientapi/routing/openid.go b/clientapi/routing/openid.go index cfb440bea..8e9be7889 100644 --- a/clientapi/routing/openid.go +++ b/clientapi/routing/openid.go @@ -63,7 +63,7 @@ func CreateOpenIDToken( JSON: openIDTokenResponse{ AccessToken: response.Token.Token, TokenType: "Bearer", - MatrixServerName: string(cfg.Matrix.ServerName), + MatrixServerName: string(device.UserDomain()), ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s }, } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index c9647eb1b..4d9e1f8a5 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -113,12 +113,19 @@ func SetAvatarURL( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() } + if !cfg.Matrix.IsLocalServerName(domain) { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"), + } + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -129,8 +136,9 @@ func SetAvatarURL( setRes := &userapi.PerformSetAvatarURLResponse{} if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ - Localpart: localpart, - AvatarURL: r.AvatarURL, + Localpart: localpart, + ServerName: domain, + AvatarURL: r.AvatarURL, }, setRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") return jsonerror.InternalServerError() @@ -204,12 +212,19 @@ func SetDisplayName( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() } + if !cfg.Matrix.IsLocalServerName(domain) { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"), + } + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -221,6 +236,7 @@ func SetDisplayName( profileRes := &userapi.PerformUpdateDisplayNameResponse{} err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ Localpart: localpart, + ServerName: domain, DisplayName: r.DisplayName, }, profileRes) if err != nil { @@ -261,6 +277,12 @@ func updateProfile( return jsonerror.InternalServerError(), err } + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError(), err + } + events, err := buildMembershipEvents( ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ) @@ -276,7 +298,7 @@ func updateProfile( return jsonerror.InternalServerError(), e } - if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError(), err } @@ -298,7 +320,7 @@ func getProfile( return nil, err } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index a0f3b1152..778a02fd4 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -131,7 +131,8 @@ func SendRedaction( JSON: jsonerror.NotFound("Room does not exist"), } } - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil { + domain := device.UserDomain() + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 0bda1e488..698d185b4 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -412,7 +412,7 @@ func UserIDIsWithinApplicationServiceNamespace( return false } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { return false } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 114e9088d..bb66cf6fc 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -94,6 +94,7 @@ func SendEvent( // create a mutex for the specific user in the specific room // this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order userID := device.UserID + domain := device.UserDomain() mutex, _ := userRoomSendMutexes.LoadOrStore(roomID+userID, &sync.Mutex{}) mutex.(*sync.Mutex).Lock() defer mutex.(*sync.Mutex).Unlock() @@ -185,8 +186,8 @@ func SendEvent( []*gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, - cfg.Matrix.ServerName, - cfg.Matrix.ServerName, + domain, + domain, txnAndSessionID, false, ); err != nil { diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 9670fecad..99fb8171d 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -215,7 +215,7 @@ func queryIDServerStoreInvite( } var profile *authtypes.Profile - if serverName == cfg.Matrix.ServerName { + if cfg.Matrix.IsLocalServerName(serverName) { res := &userapi.QueryProfileResponse{} err = userAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: device.UserID}, res) if err != nil { diff --git a/clientapi/userutil/userutil.go b/clientapi/userutil/userutil.go index 7e909ffad..9be1e9b31 100644 --- a/clientapi/userutil/userutil.go +++ b/clientapi/userutil/userutil.go @@ -17,6 +17,7 @@ import ( "fmt" "strings" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -24,23 +25,23 @@ import ( // usernameParam can either be a user ID or just the localpart/username. // If serverName is passed, it is verified against the domain obtained from usernameParam (if present) // Returns error in case of invalid usernameParam. -func ParseUsernameParam(usernameParam string, expectedServerName *gomatrixserverlib.ServerName) (string, error) { +func ParseUsernameParam(usernameParam string, cfg *config.Global) (string, gomatrixserverlib.ServerName, error) { localpart := usernameParam if strings.HasPrefix(usernameParam, "@") { lp, domain, err := gomatrixserverlib.SplitID('@', usernameParam) if err != nil { - return "", errors.New("invalid username") + return "", "", errors.New("invalid username") } - if expectedServerName != nil && domain != *expectedServerName { - return "", errors.New("user ID does not belong to this server") + if !cfg.IsLocalServerName(domain) { + return "", "", errors.New("user ID does not belong to this server") } - localpart = lp + return lp, domain, nil } - return localpart, nil + return localpart, cfg.ServerName, nil } // MakeUserID generates user ID from localpart & server name diff --git a/clientapi/userutil/userutil_test.go b/clientapi/userutil/userutil_test.go index 2628642fb..ccd6647b2 100644 --- a/clientapi/userutil/userutil_test.go +++ b/clientapi/userutil/userutil_test.go @@ -15,6 +15,7 @@ package userutil import ( "testing" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -28,7 +29,11 @@ var ( // TestGoodUserID checks that correct localpart is returned for a valid user ID. func TestGoodUserID(t *testing.T) { - lp, err := ParseUsernameParam(goodUserID, &serverName) + cfg := &config.Global{ + ServerName: serverName, + } + + lp, _, err := ParseUsernameParam(goodUserID, cfg) if err != nil { t.Error("User ID Parsing failed for ", goodUserID, " with error: ", err.Error()) @@ -41,7 +46,11 @@ func TestGoodUserID(t *testing.T) { // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. func TestWithLocalpartOnly(t *testing.T) { - lp, err := ParseUsernameParam(localpart, &serverName) + cfg := &config.Global{ + ServerName: serverName, + } + + lp, _, err := ParseUsernameParam(localpart, cfg) if err != nil { t.Error("User ID Parsing failed for ", localpart, " with error: ", err.Error()) @@ -54,7 +63,11 @@ func TestWithLocalpartOnly(t *testing.T) { // TestIncorrectDomain checks for error when there's server name mismatch. func TestIncorrectDomain(t *testing.T) { - _, err := ParseUsernameParam(goodUserID, &invalidServerName) + cfg := &config.Global{ + ServerName: invalidServerName, + } + + _, _, err := ParseUsernameParam(goodUserID, cfg) if err == nil { t.Error("Invalid Domain should return an error") @@ -63,7 +76,11 @@ func TestIncorrectDomain(t *testing.T) { // TestBadUserID checks that ParseUsernameParam fails for invalid user ID func TestBadUserID(t *testing.T) { - _, err := ParseUsernameParam(badUserID, &serverName) + cfg := &config.Global{ + ServerName: serverName, + } + + _, _, err := ParseUsernameParam(badUserID, cfg) if err == nil { t.Error("Illegal User ID should return an error") diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index f6dace702..a58cba1b1 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -69,7 +69,7 @@ func AddPublicRoutes( TopicPresenceEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), TopicDeviceListUpdate: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), TopicSigningKeyUpdate: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - ServerName: cfg.Matrix.ServerName, + Config: cfg, UserAPI: userAPI, } @@ -107,7 +107,7 @@ func NewInternalAPI( ) api.FederationInternalAPI { cfg := &base.Cfg.FederationAPI - federationDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.ServerName) + federationDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) if err != nil { logrus.WithError(err).Panic("failed to connect to federation sender db") } diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 85cc43aa5..7ccc02f76 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -87,6 +87,7 @@ func TestMain(m *testing.M) { cfg.Global.JetStream.StoragePath = config.Path(d) cfg.Global.KeyID = serverKeyID cfg.Global.KeyValidityPeriod = s.validity + cfg.FederationAPI.KeyPerspectives = nil f, err := os.CreateTemp(d, "federation_keys_test*.db") if err != nil { return -1 @@ -207,7 +208,6 @@ func TestRenewalBehaviour(t *testing.T) { // happy at this point that the key that we already have is from the past // then repeating a key fetch should cause us to try and renew the key. // If so, then the new key will end up in our cache. - serverC.renew() res, err = serverA.api.FetchKeys( diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index e923143a7..c37bc87c2 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -164,6 +164,7 @@ func TestFederationAPIJoinThenKeyUpdate(t *testing.T) { func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { base, close := testrig.CreateBaseDendrite(t, dbType) base.Cfg.FederationAPI.PreferDirectFetch = true + base.Cfg.FederationAPI.KeyPerspectives = nil defer close() jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) diff --git a/federationapi/internal/keys.go b/federationapi/internal/keys.go index 2b7a8219a..258bd88bf 100644 --- a/federationapi/internal/keys.go +++ b/federationapi/internal/keys.go @@ -99,7 +99,7 @@ func (s *FederationInternalAPI) handleLocalKeys( results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) { for req := range requests { - if req.ServerName != s.cfg.Matrix.ServerName { + if !s.cfg.Matrix.IsLocalServerName(req.ServerName) { continue } if req.KeyID == s.cfg.Matrix.KeyID { diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 28ec48d7b..1b61ec711 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -77,7 +77,7 @@ func (r *FederationInternalAPI) PerformJoin( seenSet := make(map[gomatrixserverlib.ServerName]bool) var uniqueList []gomatrixserverlib.ServerName for _, srv := range request.ServerNames { - if seenSet[srv] || srv == r.cfg.Matrix.ServerName { + if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) { continue } seenSet[srv] = true diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 659ff1bcf..7cce13a7d 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -25,6 +25,7 @@ import ( "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -39,7 +40,7 @@ type SyncAPIProducer struct { TopicDeviceListUpdate string TopicSigningKeyUpdate string JetStream nats.JetStreamContext - ServerName gomatrixserverlib.ServerName + Config *config.FederationAPI UserAPI userapi.UserInternalAPI } @@ -77,7 +78,7 @@ func (p *SyncAPIProducer) SendToDevice( // device. If the event isn't targeted locally then we can't expand the // wildcard as we don't know about the remote devices, so instead we leave it // as-is, so that the federation sender can send it on with the wildcard intact. - if domain == p.ServerName && deviceID == "*" { + if p.Config.Matrix.IsLocalServerName(domain) && deviceID == "*" { var res userapi.QueryDevicesResponse err = p.UserAPI.QueryDevices(context.TODO(), &userapi.QueryDevicesRequest{ UserID: userID, diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index a1b280103..7ef4646f7 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -47,7 +47,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase connStr, dbClose := test.PrepareDBConnectionString(t, dbType) db, err := storage.NewDatabase(b, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, b.Caches, b.Cfg.Global.ServerName) + }, b.Caches, b.Cfg.Global.IsLocalServerName) if err != nil { t.Fatalf("NewDatabase returned %s", err) } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index e25f9866e..9f16e5093 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -124,7 +124,7 @@ func Setup( mu := internal.NewMutexByRoom() v1fedmux.Handle("/send/{txnID}", MakeFedAPI( - "federation_send", cfg.Matrix.ServerName, keys, wakeup, + "federation_send", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), @@ -134,7 +134,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( - "federation_invite", cfg.Matrix.ServerName, keys, wakeup, + "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -150,7 +150,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v2fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( - "federation_invite", cfg.Matrix.ServerName, keys, wakeup, + "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -172,7 +172,7 @@ func Setup( )).Methods(http.MethodPost, http.MethodOptions) v1fedmux.Handle("/exchange_third_party_invite/{roomID}", MakeFedAPI( - "exchange_third_party_invite", cfg.Matrix.ServerName, keys, wakeup, + "exchange_third_party_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return ExchangeThirdPartyInvite( httpReq, request, vars["roomID"], rsAPI, cfg, federation, @@ -181,7 +181,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v1fedmux.Handle("/event/{eventID}", MakeFedAPI( - "federation_get_event", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_event", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetEvent( httpReq.Context(), request, rsAPI, vars["eventID"], cfg.Matrix.ServerName, @@ -190,7 +190,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/state/{roomID}", MakeFedAPI( - "federation_get_state", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_state", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -205,7 +205,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/state_ids/{roomID}", MakeFedAPI( - "federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_state_ids", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -220,7 +220,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/event_auth/{roomID}/{eventID}", MakeFedAPI( - "federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_event_auth", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -235,7 +235,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/query/directory", MakeFedAPI( - "federation_query_room_alias", cfg.Matrix.ServerName, keys, wakeup, + "federation_query_room_alias", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return RoomAliasToID( httpReq, federation, cfg, rsAPI, fsAPI, @@ -244,7 +244,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/query/profile", MakeFedAPI( - "federation_query_profile", cfg.Matrix.ServerName, keys, wakeup, + "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetProfile( httpReq, userAPI, cfg, @@ -253,7 +253,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( - "federation_user_devices", cfg.Matrix.ServerName, keys, wakeup, + "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( httpReq, keyAPI, vars["userID"], @@ -263,7 +263,7 @@ func Setup( if mscCfg.Enabled("msc2444") { v1fedmux.Handle("/peek/{roomID}/{peekID}", MakeFedAPI( - "federation_peek", cfg.Matrix.ServerName, keys, wakeup, + "federation_peek", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -294,7 +294,7 @@ func Setup( } v1fedmux.Handle("/make_join/{roomID}/{userID}", MakeFedAPI( - "federation_make_join", cfg.Matrix.ServerName, keys, wakeup, + "federation_make_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -325,7 +325,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( - "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -357,7 +357,7 @@ func Setup( )).Methods(http.MethodPut) v2fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( - "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -374,7 +374,7 @@ func Setup( )).Methods(http.MethodPut) v1fedmux.Handle("/make_leave/{roomID}/{eventID}", MakeFedAPI( - "federation_make_leave", cfg.Matrix.ServerName, keys, wakeup, + "federation_make_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -391,7 +391,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( - "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -423,7 +423,7 @@ func Setup( )).Methods(http.MethodPut) v2fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( - "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -447,7 +447,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/get_missing_events/{roomID}", MakeFedAPI( - "federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_missing_events", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -460,7 +460,7 @@ func Setup( )).Methods(http.MethodPost) v1fedmux.Handle("/backfill/{roomID}", MakeFedAPI( - "federation_backfill", cfg.Matrix.ServerName, keys, wakeup, + "federation_backfill", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -479,14 +479,14 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost) v1fedmux.Handle("/user/keys/claim", MakeFedAPI( - "federation_keys_claim", cfg.Matrix.ServerName, keys, wakeup, + "federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) v1fedmux.Handle("/user/keys/query", MakeFedAPI( - "federation_keys_query", cfg.Matrix.ServerName, keys, wakeup, + "federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) }, @@ -525,15 +525,15 @@ func ErrorIfLocalServerNotInRoom( // MakeFedAPI makes an http.Handler that checks matrix federation authentication. func MakeFedAPI( - metricsName string, - serverName gomatrixserverlib.ServerName, + metricsName string, serverName gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, keyRing gomatrixserverlib.JSONVerifier, wakeup *FederationWakeups, f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), serverName, keyRing, + req, time.Now(), serverName, isLocalServerName, keyRing, ) if fedReq == nil { return errResp diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 6e208d096..a33fa4a43 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -36,7 +36,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { var d Database var err error if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { @@ -96,7 +96,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, } d.Database = shared.Database{ DB: d.db, - ServerName: serverName, + IsLocalServerName: isLocalServerName, Cache: cache, Writer: d.writer, FederationJoinedHosts: joinedHosts, diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 6afb313a8..4fabff7d4 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -29,7 +29,7 @@ import ( type Database struct { DB *sql.DB - ServerName gomatrixserverlib.ServerName + IsLocalServerName func(gomatrixserverlib.ServerName) bool Cache caching.FederationCache Writer sqlutil.Writer FederationQueuePDUs tables.FederationQueuePDUs @@ -124,7 +124,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, } if excludeSelf { for i, server := range servers { - if server == d.ServerName { + if d.IsLocalServerName(server) { servers = append(servers[:i], servers[i+1:]...) } } diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index c89cb6bea..e86ac817b 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -35,7 +35,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { var d Database var err error if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { @@ -95,7 +95,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, } d.Database = shared.Database{ DB: d.db, - ServerName: serverName, + IsLocalServerName: isLocalServerName, Cache: cache, Writer: d.writer, FederationJoinedHosts: joinedHosts, diff --git a/federationapi/storage/storage.go b/federationapi/storage/storage.go index f246b9bc9..142e281ea 100644 --- a/federationapi/storage/storage.go +++ b/federationapi/storage/storage.go @@ -29,12 +29,12 @@ import ( ) // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, cache, serverName) + return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties, cache, serverName) + return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 6272fd2b1..f7408fa9f 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -19,7 +19,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat connStr, dbClose := test.PrepareDBConnectionString(t, dbType) db, err := storage.NewDatabase(b, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, b.Caches, b.Cfg.Global.ServerName) + }, b.Caches, func(server gomatrixserverlib.ServerName) bool { return server == "localhost" }) if err != nil { t.Fatalf("NewDatabase returned %s", err) } diff --git a/go.mod b/go.mod index 7f9bb3897..39dfb0fe1 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a + github.com/matrix-org/gomatrixserverlib v0.0.0-20221025142407-17b0be811afa github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 5cce7e0d8..5e6253860 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a h1:6rJFN5NBuzZ7h5meYkLtXKa6VFZfDc8oVXHd4SDXr5o= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221025142407-17b0be811afa h1:S98DShDv3sn7O4n4HjtJOejypseYVpv1R/XPg+cDnfI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221025142407-17b0be811afa/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index cb6b22d32..6a6d51b0a 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -117,6 +117,11 @@ func (r *Admin) PerformAdminEvacuateRoom( PrevEvents: prevEvents, } + _, senderDomain, err := gomatrixserverlib.SplitID('@', fledglingEvent.Sender) + if err != nil { + continue + } + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -146,8 +151,8 @@ func (r *Admin) PerformAdminEvacuateRoom( inputEvents = append(inputEvents, api.InputRoomEvent{ Kind: api.KindNew, Event: event, - Origin: r.Cfg.Matrix.ServerName, - SendAsServer: string(r.Cfg.Matrix.ServerName), + Origin: senderDomain, + SendAsServer: string(senderDomain), }) res.Affected = append(res.Affected, stateKey) prevEvents = []gomatrixserverlib.EventReference{ @@ -176,7 +181,7 @@ func (r *Admin) PerformAdminEvacuateUser( } return nil } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: "Can only evacuate local users using this endpoint", diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 3fbdf332e..f60247cd7 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -70,8 +70,8 @@ func (r *Inviter) PerformInvite( } return nil, nil } - isTargetLocal := domain == r.Cfg.Matrix.ServerName - isOriginLocal := senderDomain == r.Cfg.Matrix.ServerName + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) + isOriginLocal := r.Cfg.Matrix.IsLocalServerName(senderDomain) if !isOriginLocal && !isTargetLocal { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 262273ff5..9d596ab30 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -92,7 +92,7 @@ func (r *Joiner) performJoin( Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return "", "", &rsAPI.PerformError{ Code: rsAPI.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), @@ -124,7 +124,7 @@ func (r *Joiner) performJoinRoomByAlias( // Check if this alias matches our own server configuration. If it // doesn't then we'll need to try a federated join. var roomID string - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { // The alias isn't owned by us, so we will need to try joining using // a remote server. dirReq := fsAPI.PerformDirectoryLookupRequest{ @@ -172,7 +172,7 @@ func (r *Joiner) performJoinRoomByID( // The original client request ?server_name=... may include this HS so filter that out so we // don't attempt to make_join with ourselves for i := 0; i < len(req.ServerNames); i++ { - if req.ServerNames[i] == r.Cfg.Matrix.ServerName { + if r.Cfg.Matrix.IsLocalServerName(req.ServerNames[i]) { // delete this entry req.ServerNames = append(req.ServerNames[:i], req.ServerNames[i+1:]...) i-- @@ -191,12 +191,19 @@ func (r *Joiner) performJoinRoomByID( // If the server name in the room ID isn't ours then it's a // possible candidate for finding the room via federation. Add // it to the list of servers to try. - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { req.ServerNames = append(req.ServerNames, domain) } // Prepare the template for the join event. userID := req.UserID + _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorBadRequest, + Msg: fmt.Sprintf("User ID %q is invalid: %s", userID, err), + } + } eb := gomatrixserverlib.EventBuilder{ Type: gomatrixserverlib.MRoomMember, Sender: userID, @@ -247,7 +254,7 @@ func (r *Joiner) performJoinRoomByID( // If we were invited by someone from another server then we can // assume they are in the room so we can join via them. - if inviterDomain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) { req.ServerNames = append(req.ServerNames, inviterDomain) forceFederatedJoin = true memberEvent := gjson.Parse(string(inviteEvent.JSON())) @@ -300,7 +307,7 @@ func (r *Joiner) performJoinRoomByID( { Kind: rsAPI.KindNew, Event: event.Headered(buildRes.RoomVersion), - SendAsServer: string(r.Cfg.Matrix.ServerName), + SendAsServer: string(userDomain), }, }, } @@ -323,7 +330,7 @@ func (r *Joiner) performJoinRoomByID( // The room doesn't exist locally. If the room ID looks like it should // be ours then this probably means that we've nuked our database at // some point. - if domain == r.Cfg.Matrix.ServerName { + if r.Cfg.Matrix.IsLocalServerName(domain) { // If there are no more server names to try then give up here. // Otherwise we'll try a federated join as normal, since it's quite // possible that the room still exists on other servers. @@ -348,7 +355,7 @@ func (r *Joiner) performJoinRoomByID( // it will have been overwritten with a room ID by performJoinRoomByAlias. // We should now include this in the response so that the CS API can // return the right room ID. - return req.RoomIDOrAlias, r.Cfg.Matrix.ServerName, nil + return req.RoomIDOrAlias, userDomain, nil } func (r *Joiner) performFederatedJoinRoomByID( diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 85b659814..49e4b479a 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -52,7 +52,7 @@ func (r *Leaver) PerformLeave( if err != nil { return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID) } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) } logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ @@ -85,7 +85,7 @@ func (r *Leaver) performLeaveRoomByID( if serr != nil { return nil, fmt.Errorf("sender %q is invalid", senderUser) } - if senderDomain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(senderDomain) { return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) } // check that this is not a "server notice room" @@ -186,7 +186,7 @@ func (r *Leaver) performLeaveRoomByID( Kind: api.KindNew, Event: event.Headered(buildRes.RoomVersion), Origin: senderDomain, - SendAsServer: string(r.Cfg.Matrix.ServerName), + SendAsServer: string(senderDomain), }, }, } diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go index 74d87a5b4..436d137ff 100644 --- a/roomserver/internal/perform/perform_peek.go +++ b/roomserver/internal/perform/perform_peek.go @@ -72,7 +72,7 @@ func (r *Peeker) performPeek( Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), @@ -104,7 +104,7 @@ func (r *Peeker) performPeekRoomByAlias( // Check if this alias matches our own server configuration. If it // doesn't then we'll need to try a federated peek. var roomID string - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { // The alias isn't owned by us, so we will need to try peeking using // a remote server. dirReq := fsAPI.PerformDirectoryLookupRequest{ @@ -154,7 +154,7 @@ func (r *Peeker) performPeekRoomByID( // handle federated peeks // FIXME: don't create an outbound peek if we already have one going. - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { // If the server name in the room ID isn't ours then it's a // possible candidate for finding the room via federation. Add // it to the list of servers to try. diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 49e9067c9..0d97da4d6 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -67,7 +67,7 @@ func (r *Unpeeker) performUnpeek( Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index d6dc9708c..38abe323c 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -60,6 +60,13 @@ func (r *Upgrader) performRoomUpgrade( ) (string, *api.PerformError) { roomID := req.RoomID userID := req.UserID + _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return "", &api.PerformError{ + Code: api.PerformErrorNotAllowed, + Msg: "Error validating the user ID", + } + } evTime := time.Now() // Return an immediate error if the room does not exist @@ -80,7 +87,7 @@ func (r *Upgrader) performRoomUpgrade( // TODO (#267): Check room ID doesn't clash with an existing one, and we // probably shouldn't be using pseudo-random strings, maybe GUIDs? - newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), r.Cfg.Matrix.ServerName) + newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) // Get the existing room state for the old room. oldRoomReq := &api.QueryLatestEventsAndStateRequest{ @@ -107,12 +114,12 @@ func (r *Upgrader) performRoomUpgrade( } // Send the setup events to the new room - if pErr = r.sendInitialEvents(ctx, evTime, userID, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { + if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { return "", pErr } // 5. Send the tombstone event to the old room - if pErr = r.sendHeaderedEvent(ctx, tombstoneEvent, string(r.Cfg.Matrix.ServerName)); pErr != nil { + if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil { return "", pErr } @@ -122,7 +129,7 @@ func (r *Upgrader) performRoomUpgrade( } // If the old room had a canonical alias event, it should be deleted in the old room - if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, roomID); pErr != nil { + if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil { return "", pErr } @@ -132,7 +139,7 @@ func (r *Upgrader) performRoomUpgrade( } // 6. Restrict power levels in the old room - if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, roomID); pErr != nil { + if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil { return "", pErr } @@ -154,7 +161,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma return powerLevelContent, nil } -func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID, roomID string) *api.PerformError { +func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError { restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) if pErr != nil { return pErr @@ -183,7 +190,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T return resErr } } else { - if resErr = r.sendHeaderedEvent(ctx, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil { + if resErr = r.sendHeaderedEvent(ctx, userDomain, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil { return resErr } } @@ -223,7 +230,7 @@ func moveLocalAliases(ctx context.Context, return nil } -func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID, roomID string) *api.PerformError { +func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError { for _, event := range oldRoom.StateEvents { if event.Type() != gomatrixserverlib.MRoomCanonicalAlias || !event.StateKeyEquals("") { continue @@ -254,7 +261,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api return resErr } } else { - if resErr = r.sendHeaderedEvent(ctx, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil { + if resErr = r.sendHeaderedEvent(ctx, userDomain, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil { return resErr } } @@ -495,7 +502,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query return eventsToMake, nil } -func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { +func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { var err error var builtEvents []*gomatrixserverlib.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) @@ -519,7 +526,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} } var event *gomatrixserverlib.Event - event, err = r.buildEvent(&builder, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion)) + event, err = r.buildEvent(&builder, userDomain, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion)) if err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err), @@ -547,7 +554,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user inputs = append(inputs, api.InputRoomEvent{ Kind: api.KindNew, Event: event, - Origin: r.Cfg.Matrix.ServerName, + Origin: userDomain, SendAsServer: api.DoNotSendToOtherServers, }) } @@ -668,6 +675,7 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC func (r *Upgrader) sendHeaderedEvent( ctx context.Context, + serverName gomatrixserverlib.ServerName, headeredEvent *gomatrixserverlib.HeaderedEvent, sendAsServer string, ) *api.PerformError { @@ -675,7 +683,7 @@ func (r *Upgrader) sendHeaderedEvent( inputs = append(inputs, api.InputRoomEvent{ Kind: api.KindNew, Event: headeredEvent, - Origin: r.Cfg.Matrix.ServerName, + Origin: serverName, SendAsServer: sendAsServer, }) if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { @@ -689,6 +697,7 @@ func (r *Upgrader) sendHeaderedEvent( func (r *Upgrader) buildEvent( builder *gomatrixserverlib.EventBuilder, + serverName gomatrixserverlib.ServerName, provider gomatrixserverlib.AuthEventProvider, evTime time.Time, roomVersion gomatrixserverlib.RoomVersion, @@ -703,7 +712,7 @@ func (r *Upgrader) buildEvent( } builder.AuthEvents = refs event, err := builder.Build( - evTime, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID, + evTime, serverName, r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey, roomVersion, ) if err != nil { diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 784893d24..825772827 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -14,6 +14,9 @@ type Global struct { // The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'. ServerName gomatrixserverlib.ServerName `yaml:"server_name"` + // The secondary server names, used for virtual hosting. + SecondaryServerNames []gomatrixserverlib.ServerName `yaml:"-"` + // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` @@ -120,6 +123,18 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { c.Cache.Verify(configErrs, isMonolith) } +func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool { + if c.ServerName == serverName { + return true + } + for _, secondaryName := range c.SecondaryServerNames { + if secondaryName == serverName { + return true + } + } + return false +} + type OldVerifyKeys struct { // Path to the private key. PrivateKeyPath Path `yaml:"private_key"` diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 452b14580..98502f5cb 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -132,7 +132,7 @@ func Enable( base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( "msc2836_event_relationships", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), base.Cfg.Global.ServerName, keyRing, + req, time.Now(), base.Cfg.Global.ServerName, base.Cfg.Global.IsLocalServerName, keyRing, ) if fedReq == nil { return errResp diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index a92a16a27..bc9df0f96 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -64,7 +64,7 @@ func Enable( fedAPI := httputil.MakeExternalAPI( "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), base.Cfg.Global.ServerName, keyRing, + req, time.Now(), base.Cfg.Global.ServerName, base.Cfg.Global.IsLocalServerName, keyRing, ) if fedReq == nil { return errResp diff --git a/test/testrig/base.go b/test/testrig/base.go index 10cc2407b..15fb5c370 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -36,6 +36,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f Monolithic: true, }) cfg.Global.JetStream.InMemory = true + cfg.FederationAPI.KeyPerspectives = nil switch dbType { case test.DBTypePostgres: cfg.Global.Defaults(config.DefaultOpts{ // autogen a signing key @@ -106,6 +107,7 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat } cfg.Global.JetStream.InMemory = true cfg.SyncAPI.Fulltext.InMemory = true + cfg.FederationAPI.KeyPerspectives = nil base := base.NewBaseDendrite(cfg, "Tests") js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) return base, js, jc diff --git a/userapi/api/api.go b/userapi/api/api.go index eef29144a..8d7f783de 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -318,8 +318,9 @@ type QuerySearchProfilesResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformAccountCreationRequest struct { - AccountType AccountType // Required: whether this is a guest or user account - Localpart string // Required: The localpart for this account. Ignored if account type is guest. + AccountType AccountType // Required: whether this is a guest or user account + Localpart string // Required: The localpart for this account. Ignored if account type is guest. + ServerName gomatrixserverlib.ServerName // optional: if not specified, default server name used instead AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. Password string // optional: if missing then this account will be a passwordless account @@ -360,7 +361,8 @@ type PerformLastSeenUpdateResponse struct { // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string - AccessToken string // optional: if blank one will be made on your behalf + ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used + AccessToken string // optional: if blank one will be made on your behalf // optional: if nil an ID is generated for you. If set, replaces any existing device session, // which will generate a new access token and invalidate the old one. DeviceID *string @@ -384,7 +386,8 @@ type PerformDeviceCreationResponse struct { // PerformAccountDeactivationRequest is the request for PerformAccountDeactivation type PerformAccountDeactivationRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used } // PerformAccountDeactivationResponse is the response for PerformAccountDeactivation @@ -434,6 +437,18 @@ type Device struct { AccountType AccountType } +func (d *Device) UserDomain() gomatrixserverlib.ServerName { + _, domain, err := gomatrixserverlib.SplitID('@', d.UserID) + if err != nil { + // This really is catastrophic because it means that someone + // managed to forge a malformed user ID for a device during + // login. + // TODO: Is there a better way to deal with this than panic? + panic(err) + } + return domain +} + // Account represents a Matrix account on this home server. type Account struct { UserID string @@ -577,7 +592,9 @@ type Notification struct { } type PerformSetAvatarURLRequest struct { - Localpart, AvatarURL string + Localpart string + ServerName gomatrixserverlib.ServerName + AvatarURL string } type PerformSetAvatarURLResponse struct { Profile *authtypes.Profile `json:"profile"` @@ -606,7 +623,9 @@ type QueryAccountByPasswordResponse struct { } type PerformUpdateDisplayNameRequest struct { - Localpart, DisplayName string + Localpart string + ServerName gomatrixserverlib.ServerName + DisplayName string } type PerformUpdateDisplayNameResponse struct { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 7b94b3da7..9ca76965d 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -46,9 +46,9 @@ import ( type UserInternalAPI struct { DB storage.Database SyncProducer *producers.SyncAPI + Config *config.UserAPI DisableTLSValidation bool - ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService KeyAPI keyapi.UserKeyAPI @@ -62,8 +62,8 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot update account data of remote users (server name %s)", domain) } if req.DataType == "" { return fmt.Errorf("data type must not be empty") @@ -104,7 +104,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") return nil } - if domain != a.ServerName { + if !a.Config.Matrix.IsLocalServerName(domain) { return nil } @@ -171,6 +171,11 @@ func addUserToRoom( } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + // XXXX: Use the server name here acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists @@ -188,8 +193,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P res.Account = &api.Account{ AppServiceID: req.AppServiceID, Localpart: req.Localpart, - ServerName: a.ServerName, - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + ServerName: serverName, + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), AccountType: req.AccountType, } return nil @@ -235,6 +240,12 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + _ = serverName + // XXXX: Use the server name here util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, @@ -259,8 +270,8 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot PerformDeviceDeletion of remote users (server name %s)", domain) } deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { @@ -392,8 +403,8 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) } prof, err := a.DB.GetProfileByLocalpart(ctx, local) if err != nil { @@ -443,8 +454,8 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) } devs, err := a.DB.GetDevicesByLocalpart(ctx, local) if err != nil { @@ -460,8 +471,8 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query account data of remote users (server name %s)", domain) } if req.DataType != "" { var data json.RawMessage @@ -509,10 +520,13 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc } return err } - localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localPart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { return err } + if !a.Config.Matrix.IsLocalServerName(domain) { + return nil + } acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) if err != nil { return err @@ -547,7 +561,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe AccountType: api.AccountTypeAppService, } - localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) + localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) if err != nil { return nil, err } @@ -572,8 +586,16 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %q not locally configured", serverName) + } + evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { @@ -584,7 +606,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a } deviceReq := &api.PerformDeviceDeletionRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } deviceRes := &api.PerformDeviceDeletionResponse{} if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index f1bf391e4..87f25e5e2 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -31,8 +31,8 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot create a login token for a remote user (server name %s)", domain) } tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data) if err != nil { @@ -63,8 +63,8 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain) } if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { res.Data = nil diff --git a/userapi/userapi.go b/userapi/userapi.go index c077248e2..e46a8e76e 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -76,7 +76,7 @@ func NewInternalAPI( userAPI := &internal.UserInternalAPI{ DB: db, SyncProducer: syncProducer, - ServerName: cfg.Matrix.ServerName, + Config: cfg, AppServices: appServices, KeyAPI: keyAPI, RSAPI: rsAPI, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index aaa93f45b..2a43c0bd4 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -66,8 +66,8 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap } return &internal.UserInternalAPI{ - DB: accountDB, - ServerName: cfg.Matrix.ServerName, + DB: accountDB, + Config: cfg, }, accountDB, func() { close() baseclose() From 5298dd1133948606172944b7e5ec3805ccc72644 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 26 Oct 2022 14:52:33 +0100 Subject: [PATCH 10/10] Update federation API consumers --- federationapi/consumers/keychange.go | 44 ++++++++++++------------- federationapi/consumers/presence.go | 10 +++--- federationapi/consumers/receipts.go | 34 +++++++++---------- federationapi/consumers/sendtodevice.go | 36 ++++++++++---------- federationapi/consumers/typing.go | 32 +++++++++--------- 5 files changed, 78 insertions(+), 78 deletions(-) diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 67dfdc1d3..7d1ae0f81 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -35,14 +35,14 @@ import ( // KeyChangeConsumer consumes events that originate in key server. type KeyChangeConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - serverName gomatrixserverlib.ServerName - rsAPI roomserverAPI.FederationRoomserverAPI - topic string + ctx context.Context + jetstream nats.JetStreamContext + durable string + db storage.Database + queues *queue.OutgoingQueues + isLocalServerName func(gomatrixserverlib.ServerName) bool + rsAPI roomserverAPI.FederationRoomserverAPI + topic string } // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. @@ -55,14 +55,14 @@ func NewKeyChangeConsumer( rsAPI roomserverAPI.FederationRoomserverAPI, ) *KeyChangeConsumer { return &KeyChangeConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("FederationAPIKeyChangeConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), - queues: queues, - db: store, - serverName: cfg.Matrix.ServerName, - rsAPI: rsAPI, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("FederationAPIKeyChangeConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + queues: queues, + db: store, + isLocalServerName: cfg.Matrix.IsLocalServerName, + rsAPI: rsAPI, } } @@ -112,7 +112,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { logger.WithError(err).Error("Failed to extract domain from key change event") return true } - if originServerName != t.serverName { + if !t.isLocalServerName(originServerName) { return true } @@ -141,7 +141,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ Type: gomatrixserverlib.MDeviceListUpdate, - Origin: string(t.serverName), + Origin: string(originServerName), } event := gomatrixserverlib.DeviceListUpdateEvent{ UserID: m.UserID, @@ -159,7 +159,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } logger.Debugf("Sending device list update message to %q", destinations) - err = t.queues.SendEDU(edu, t.serverName, destinations) + err = t.queues.SendEDU(edu, originServerName, destinations) return err == nil } @@ -171,7 +171,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure") return true } - if host != gomatrixserverlib.ServerName(t.serverName) { + if !t.isLocalServerName(host) { // Ignore any messages that didn't originate locally, otherwise we'll // end up parroting information we received from other servers. return true @@ -203,7 +203,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ Type: types.MSigningKeyUpdate, - Origin: string(t.serverName), + Origin: string(host), } if edu.Content, err = json.Marshal(output); err != nil { sentry.CaptureException(err) @@ -212,7 +212,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { } logger.Debugf("Sending cross-signing update message to %q", destinations) - err = t.queues.SendEDU(edu, t.serverName, destinations) + err = t.queues.SendEDU(edu, host, destinations) return err == nil } diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index e76103cd3..3445d34a9 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -38,7 +38,7 @@ type OutputPresenceConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - ServerName gomatrixserverlib.ServerName + isLocalServerName func(gomatrixserverlib.ServerName) bool topic string outboundPresenceEnabled bool } @@ -56,7 +56,7 @@ func NewOutputPresenceConsumer( jetstream: js, queues: queues, db: store, - ServerName: cfg.Matrix.ServerName, + isLocalServerName: cfg.Matrix.IsLocalServerName, durable: cfg.Matrix.JetStream.Durable("FederationAPIPresenceConsumer"), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), outboundPresenceEnabled: cfg.Matrix.Presence.EnableOutbound, @@ -85,7 +85,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg log.WithError(err).WithField("user_id", userID).Error("failed to extract domain from receipt sender") return true } - if serverName != t.ServerName { + if !t.isLocalServerName(serverName) { return true } @@ -127,7 +127,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg edu := &gomatrixserverlib.EDU{ Type: gomatrixserverlib.MPresence, - Origin: string(t.ServerName), + Origin: string(serverName), } if edu.Content, err = json.Marshal(content); err != nil { log.WithError(err).Error("failed to marshal EDU JSON") @@ -135,7 +135,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg } log.Tracef("sending presence EDU to %d servers", len(joined)) - if err = t.queues.SendEDU(edu, t.ServerName, joined); err != nil { + if err = t.queues.SendEDU(edu, serverName, joined); err != nil { log.WithError(err).Error("failed to send EDU") return false } diff --git a/federationapi/consumers/receipts.go b/federationapi/consumers/receipts.go index 75827cb68..200c06e6c 100644 --- a/federationapi/consumers/receipts.go +++ b/federationapi/consumers/receipts.go @@ -34,13 +34,13 @@ import ( // OutputReceiptConsumer consumes events that originate in the clientapi. type OutputReceiptConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - ServerName gomatrixserverlib.ServerName - topic string + ctx context.Context + jetstream nats.JetStreamContext + durable string + db storage.Database + queues *queue.OutgoingQueues + isLocalServerName func(gomatrixserverlib.ServerName) bool + topic string } // NewOutputReceiptConsumer creates a new OutputReceiptConsumer. Call Start() to begin consuming typing events. @@ -52,13 +52,13 @@ func NewOutputReceiptConsumer( store storage.Database, ) *OutputReceiptConsumer { return &OutputReceiptConsumer{ - ctx: process.Context(), - jetstream: js, - queues: queues, - db: store, - ServerName: cfg.Matrix.ServerName, - durable: cfg.Matrix.JetStream.Durable("FederationAPIReceiptConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), + ctx: process.Context(), + jetstream: js, + queues: queues, + db: store, + isLocalServerName: cfg.Matrix.IsLocalServerName, + durable: cfg.Matrix.JetStream.Durable("FederationAPIReceiptConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), } } @@ -95,7 +95,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") return true } - if receiptServerName != t.ServerName { + if !t.isLocalServerName(receiptServerName) { return true } @@ -134,14 +134,14 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) edu := &gomatrixserverlib.EDU{ Type: gomatrixserverlib.MReceipt, - Origin: string(t.ServerName), + Origin: string(receiptServerName), } if edu.Content, err = json.Marshal(content); err != nil { log.WithError(err).Error("failed to marshal EDU JSON") return true } - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + if err := t.queues.SendEDU(edu, receiptServerName, names); err != nil { log.WithError(err).Error("failed to send EDU") return false } diff --git a/federationapi/consumers/sendtodevice.go b/federationapi/consumers/sendtodevice.go index 9aec22a3e..9620d1612 100644 --- a/federationapi/consumers/sendtodevice.go +++ b/federationapi/consumers/sendtodevice.go @@ -34,13 +34,13 @@ import ( // OutputSendToDeviceConsumer consumes events that originate in the clientapi. type OutputSendToDeviceConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - ServerName gomatrixserverlib.ServerName - topic string + ctx context.Context + jetstream nats.JetStreamContext + durable string + db storage.Database + queues *queue.OutgoingQueues + isLocalServerName func(gomatrixserverlib.ServerName) bool + topic string } // NewOutputSendToDeviceConsumer creates a new OutputSendToDeviceConsumer. Call Start() to begin consuming send-to-device events. @@ -52,13 +52,13 @@ func NewOutputSendToDeviceConsumer( store storage.Database, ) *OutputSendToDeviceConsumer { return &OutputSendToDeviceConsumer{ - ctx: process.Context(), - jetstream: js, - queues: queues, - db: store, - ServerName: cfg.Matrix.ServerName, - durable: cfg.Matrix.JetStream.Durable("FederationAPIESendToDeviceConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + ctx: process.Context(), + jetstream: js, + queues: queues, + db: store, + isLocalServerName: cfg.Matrix.IsLocalServerName, + durable: cfg.Matrix.JetStream.Durable("FederationAPIESendToDeviceConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), } } @@ -82,7 +82,7 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats log.WithError(err).WithField("user_id", sender).Error("Failed to extract domain from send-to-device sender") return true } - if originServerName != t.ServerName { + if !t.isLocalServerName(originServerName) { return true } // Extract the send-to-device event from msg. @@ -101,14 +101,14 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats } // The SyncAPI is already handling sendToDevice for the local server - if destServerName == t.ServerName { + if t.isLocalServerName(destServerName) { return true } // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ Type: gomatrixserverlib.MDirectToDevice, - Origin: string(t.ServerName), + Origin: string(originServerName), } tdm := gomatrixserverlib.ToDeviceMessage{ Sender: ote.Sender, @@ -127,7 +127,7 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats } log.Debugf("Sending send-to-device message into %q destination queue", destServerName) - if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { + if err := t.queues.SendEDU(edu, originServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { log.WithError(err).Error("failed to send EDU") return false } diff --git a/federationapi/consumers/typing.go b/federationapi/consumers/typing.go index 9c7379136..c66f97519 100644 --- a/federationapi/consumers/typing.go +++ b/federationapi/consumers/typing.go @@ -31,13 +31,13 @@ import ( // OutputTypingConsumer consumes events that originate in the clientapi. type OutputTypingConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - ServerName gomatrixserverlib.ServerName - topic string + ctx context.Context + jetstream nats.JetStreamContext + durable string + db storage.Database + queues *queue.OutgoingQueues + isLocalServerName func(gomatrixserverlib.ServerName) bool + topic string } // NewOutputTypingConsumer creates a new OutputTypingConsumer. Call Start() to begin consuming typing events. @@ -49,13 +49,13 @@ func NewOutputTypingConsumer( store storage.Database, ) *OutputTypingConsumer { return &OutputTypingConsumer{ - ctx: process.Context(), - jetstream: js, - queues: queues, - db: store, - ServerName: cfg.Matrix.ServerName, - durable: cfg.Matrix.JetStream.Durable("FederationAPITypingConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent), + ctx: process.Context(), + jetstream: js, + queues: queues, + db: store, + isLocalServerName: cfg.Matrix.IsLocalServerName, + durable: cfg.Matrix.JetStream.Durable("FederationAPITypingConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent), } } @@ -87,7 +87,7 @@ func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) _ = msg.Ack() return true } - if typingServerName != t.ServerName { + if !t.isLocalServerName(typingServerName) { return true } @@ -111,7 +111,7 @@ func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) log.WithError(err).Error("failed to marshal EDU JSON") return true } - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + if err := t.queues.SendEDU(edu, typingServerName, names); err != nil { log.WithError(err).Error("failed to send EDU") return false }