diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index f8019b3ea..a8271b675 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -109,6 +109,11 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} + - name: Set up gotestfmt + uses: gotesttools/gotestfmt-action@v2 + with: + # Optional: pass GITHUB_TOKEN to avoid rate limiting. + token: ${{ secrets.GITHUB_TOKEN }} - uses: actions/cache@v3 with: path: | @@ -117,7 +122,7 @@ jobs: key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go${{ matrix.go }}-test- - - run: go test ./... + - run: go test -json -v ./... 2>&1 | gotestfmt env: POSTGRES_HOST: localhost POSTGRES_USER: postgres diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 890b18183..700a72f5d 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -74,7 +74,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, JSON: jsonerror.BadJSON("A password must be supplied."), } } - localpart, err := userutil.ParseUsernameParam(username, &t.Config.Matrix.ServerName) + localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 89c269f1a..69bca13be 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -70,7 +70,7 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.MissingArgument("User ID must belong to this server."), @@ -169,7 +169,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - if domain == cfg.Matrix.ServerName { + if cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidParam("Can not mark local device list as stale"), diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 3e837c864..eefe8e24b 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -169,9 +169,21 @@ func createRoom( asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, ) util.JSONResponse { + _, userDomain, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError() + } + if !cfg.Matrix.IsLocalServerName(userDomain) { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(fmt.Sprintf("User domain %q not configured locally", userDomain)), + } + } + // TODO (#267): Check room ID doesn't clash with an existing one, and we // probably shouldn't be using pseudo-random strings, maybe GUIDs? - roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) + roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) logger := util.GetLogger(ctx) userID := device.UserID @@ -314,7 +326,7 @@ func createRoom( var roomAlias string if r.RoomAliasName != "" { - roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, cfg.Matrix.ServerName) + roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, userDomain) // check it's free TODO: This races but is better than nothing hasAliasReq := roomserverAPI.GetRoomIDForAliasRequest{ Alias: roomAlias, @@ -436,7 +448,7 @@ func createRoom( builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} } var ev *gomatrixserverlib.Event - ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion) + ev, err = buildEvent(&builder, userDomain, &authEvents, cfg, evTime, roomVersion) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildEvent failed") return jsonerror.InternalServerError() @@ -461,7 +473,7 @@ func createRoom( inputs = append(inputs, roomserverAPI.InputRoomEvent{ Kind: roomserverAPI.KindNew, Event: event, - Origin: cfg.Matrix.ServerName, + Origin: userDomain, SendAsServer: roomserverAPI.DoNotSendToOtherServers, }) } @@ -548,7 +560,7 @@ func createRoom( Event: event, InviteRoomState: inviteStrippedState, RoomVersion: event.RoomVersion, - SendAsServer: string(cfg.Matrix.ServerName), + SendAsServer: string(userDomain), }, &inviteRes); err != nil { util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ @@ -591,6 +603,7 @@ func createRoom( // buildEvent fills out auth_events for the builder then builds the event func buildEvent( builder *gomatrixserverlib.EventBuilder, + serverName gomatrixserverlib.ServerName, provider gomatrixserverlib.AuthEventProvider, cfg *config.ClientAPI, evTime time.Time, @@ -606,7 +619,7 @@ func buildEvent( } builder.AuthEvents = refs event, err := builder.Build( - evTime, cfg.Matrix.ServerName, cfg.Matrix.KeyID, + evTime, serverName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, roomVersion, ) if err != nil { diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 836d9e152..33bc63d18 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -75,7 +75,7 @@ func DirectoryRoom( if res.RoomID == "" { // If we don't know it locally, do a federation query. // But don't send the query to ourselves. - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) if fedErr != nil { // TODO: Return 502 if the remote server errored. @@ -127,7 +127,7 @@ func SetLocalAlias( } } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Alias must be on local homeserver"), diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 8ddb3267a..4ebf2295a 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -62,8 +62,7 @@ func GetPostPublicRooms( } serverName := gomatrixserverlib.ServerName(request.Server) - - if serverName != "" && serverName != cfg.Matrix.ServerName { + if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( req.Context(), serverName, int(request.Limit), request.Since, diff --git a/clientapi/routing/joined_rooms.go b/clientapi/routing/joined_rooms.go new file mode 100644 index 000000000..4bb353ea9 --- /dev/null +++ b/clientapi/routing/joined_rooms.go @@ -0,0 +1,52 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" +) + +type getJoinedRoomsResponse struct { + JoinedRooms []string `json:"joined_rooms"` +} + +func GetJoinedRooms( + req *http.Request, + device *userapi.Device, + rsAPI api.ClientRoomserverAPI, +) util.JSONResponse { + var res api.QueryRoomsForUserResponse + err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ + UserID: device.UserID, + WantMembership: "join", + }, &res) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") + return jsonerror.InternalServerError() + } + if res.RoomIDs == nil { + res.RoomIDs = []string{} + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: getJoinedRoomsResponse{res.RoomIDs}, + } +} diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 5c3681382..0c12b1117 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -99,7 +99,11 @@ func (r *queryKeysRequest) GetTimeout() time.Duration { if r.Timeout == 0 { return 10 * time.Second } - return time.Duration(r.Timeout) * time.Millisecond + timeout := time.Duration(r.Timeout) * time.Millisecond + if timeout > time.Second*20 { + timeout = time.Second * 20 + } + return timeout } func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 6017b5840..7f5a8c4f8 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -68,7 +68,7 @@ func Login( return *authErr } // make a device/access token - authErr2 := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + authErr2 := completeAuth(req.Context(), cfg.Matrix, userAPI, login, req.RemoteAddr, req.UserAgent()) cleanup(req.Context(), &authErr2) return authErr2 } @@ -79,7 +79,7 @@ func Login( } func completeAuth( - ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.ClientUserAPI, login *auth.Login, + ctx context.Context, cfg *config.Global, userAPI userapi.ClientUserAPI, login *auth.Login, ipAddr, userAgent string, ) util.JSONResponse { token, err := auth.GenerateAccessToken() @@ -88,7 +88,7 @@ func completeAuth( return jsonerror.InternalServerError() } - localpart, err := userutil.ParseUsernameParam(login.Username(), &serverName) + localpart, serverName, err := userutil.ParseUsernameParam(login.Username(), cfg) if err != nil { util.GetLogger(ctx).WithError(err).Error("auth.ParseUsernameParam failed") return jsonerror.InternalServerError() diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 77f627eb2..94ba17a02 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -105,12 +105,13 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic return jsonerror.InternalServerError() } + serverName := device.UserDomain() if err = roomserverAPI.SendEvents( ctx, rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, - cfg.Matrix.ServerName, - cfg.Matrix.ServerName, + serverName, + serverName, nil, false, ); err != nil { @@ -271,7 +272,7 @@ func sendInvite( Event: event, InviteRoomState: nil, // ask the roomserver to draw up invite room state for us RoomVersion: event.RoomVersion, - SendAsServer: string(cfg.Matrix.ServerName), + SendAsServer: string(device.UserDomain()), }, &inviteRes); err != nil { util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ @@ -341,7 +342,7 @@ func loadProfile( } var profile *authtypes.Profile - if serverName == cfg.Matrix.ServerName { + if cfg.Matrix.IsLocalServerName(serverName) { profile, err = appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI) } else { profile = &authtypes.Profile{} diff --git a/clientapi/routing/openid.go b/clientapi/routing/openid.go index cfb440bea..8e9be7889 100644 --- a/clientapi/routing/openid.go +++ b/clientapi/routing/openid.go @@ -63,7 +63,7 @@ func CreateOpenIDToken( JSON: openIDTokenResponse{ AccessToken: response.Token.Token, TokenType: "Bearer", - MatrixServerName: string(cfg.Matrix.ServerName), + MatrixServerName: string(device.UserDomain()), ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s }, } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index c9647eb1b..4d9e1f8a5 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -113,12 +113,19 @@ func SetAvatarURL( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() } + if !cfg.Matrix.IsLocalServerName(domain) { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"), + } + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -129,8 +136,9 @@ func SetAvatarURL( setRes := &userapi.PerformSetAvatarURLResponse{} if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ - Localpart: localpart, - AvatarURL: r.AvatarURL, + Localpart: localpart, + ServerName: domain, + AvatarURL: r.AvatarURL, }, setRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") return jsonerror.InternalServerError() @@ -204,12 +212,19 @@ func SetDisplayName( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() } + if !cfg.Matrix.IsLocalServerName(domain) { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"), + } + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -221,6 +236,7 @@ func SetDisplayName( profileRes := &userapi.PerformUpdateDisplayNameResponse{} err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ Localpart: localpart, + ServerName: domain, DisplayName: r.DisplayName, }, profileRes) if err != nil { @@ -261,6 +277,12 @@ func updateProfile( return jsonerror.InternalServerError(), err } + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError(), err + } + events, err := buildMembershipEvents( ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ) @@ -276,7 +298,7 @@ func updateProfile( return jsonerror.InternalServerError(), e } - if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError(), err } @@ -298,7 +320,7 @@ func getProfile( return nil, err } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index a0f3b1152..778a02fd4 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -131,7 +131,8 @@ func SendRedaction( JSON: jsonerror.NotFound("Room does not exist"), } } - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil { + domain := device.UserDomain() + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 0bda1e488..698d185b4 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -412,7 +412,7 @@ func UserIDIsWithinApplicationServiceNamespace( return false } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { return false } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 4ca8e59c5..e0e3e33d4 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -950,26 +950,6 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/members", - httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return GetMemberships(req, device, vars["roomID"], false, cfg, rsAPI) - }), - ).Methods(http.MethodGet, http.MethodOptions) - - v3mux.Handle("/rooms/{roomID}/joined_members", - httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return GetMemberships(req, device, vars["roomID"], true, cfg, rsAPI) - }), - ).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/read_markers", httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 114e9088d..bb66cf6fc 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -94,6 +94,7 @@ func SendEvent( // create a mutex for the specific user in the specific room // this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order userID := device.UserID + domain := device.UserDomain() mutex, _ := userRoomSendMutexes.LoadOrStore(roomID+userID, &sync.Mutex{}) mutex.(*sync.Mutex).Lock() defer mutex.(*sync.Mutex).Unlock() @@ -185,8 +186,8 @@ func SendEvent( []*gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, - cfg.Matrix.ServerName, - cfg.Matrix.ServerName, + domain, + domain, txnAndSessionID, false, ); err != nil { diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 9670fecad..99fb8171d 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -215,7 +215,7 @@ func queryIDServerStoreInvite( } var profile *authtypes.Profile - if serverName == cfg.Matrix.ServerName { + if cfg.Matrix.IsLocalServerName(serverName) { res := &userapi.QueryProfileResponse{} err = userAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: device.UserID}, res) if err != nil { diff --git a/clientapi/userutil/userutil.go b/clientapi/userutil/userutil.go index 7e909ffad..9be1e9b31 100644 --- a/clientapi/userutil/userutil.go +++ b/clientapi/userutil/userutil.go @@ -17,6 +17,7 @@ import ( "fmt" "strings" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -24,23 +25,23 @@ import ( // usernameParam can either be a user ID or just the localpart/username. // If serverName is passed, it is verified against the domain obtained from usernameParam (if present) // Returns error in case of invalid usernameParam. -func ParseUsernameParam(usernameParam string, expectedServerName *gomatrixserverlib.ServerName) (string, error) { +func ParseUsernameParam(usernameParam string, cfg *config.Global) (string, gomatrixserverlib.ServerName, error) { localpart := usernameParam if strings.HasPrefix(usernameParam, "@") { lp, domain, err := gomatrixserverlib.SplitID('@', usernameParam) if err != nil { - return "", errors.New("invalid username") + return "", "", errors.New("invalid username") } - if expectedServerName != nil && domain != *expectedServerName { - return "", errors.New("user ID does not belong to this server") + if !cfg.IsLocalServerName(domain) { + return "", "", errors.New("user ID does not belong to this server") } - localpart = lp + return lp, domain, nil } - return localpart, nil + return localpart, cfg.ServerName, nil } // MakeUserID generates user ID from localpart & server name diff --git a/clientapi/userutil/userutil_test.go b/clientapi/userutil/userutil_test.go index 2628642fb..ccd6647b2 100644 --- a/clientapi/userutil/userutil_test.go +++ b/clientapi/userutil/userutil_test.go @@ -15,6 +15,7 @@ package userutil import ( "testing" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -28,7 +29,11 @@ var ( // TestGoodUserID checks that correct localpart is returned for a valid user ID. func TestGoodUserID(t *testing.T) { - lp, err := ParseUsernameParam(goodUserID, &serverName) + cfg := &config.Global{ + ServerName: serverName, + } + + lp, _, err := ParseUsernameParam(goodUserID, cfg) if err != nil { t.Error("User ID Parsing failed for ", goodUserID, " with error: ", err.Error()) @@ -41,7 +46,11 @@ func TestGoodUserID(t *testing.T) { // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. func TestWithLocalpartOnly(t *testing.T) { - lp, err := ParseUsernameParam(localpart, &serverName) + cfg := &config.Global{ + ServerName: serverName, + } + + lp, _, err := ParseUsernameParam(localpart, cfg) if err != nil { t.Error("User ID Parsing failed for ", localpart, " with error: ", err.Error()) @@ -54,7 +63,11 @@ func TestWithLocalpartOnly(t *testing.T) { // TestIncorrectDomain checks for error when there's server name mismatch. func TestIncorrectDomain(t *testing.T) { - _, err := ParseUsernameParam(goodUserID, &invalidServerName) + cfg := &config.Global{ + ServerName: invalidServerName, + } + + _, _, err := ParseUsernameParam(goodUserID, cfg) if err == nil { t.Error("Invalid Domain should return an error") @@ -63,7 +76,11 @@ func TestIncorrectDomain(t *testing.T) { // TestBadUserID checks that ParseUsernameParam fails for invalid user ID func TestBadUserID(t *testing.T) { - _, err := ParseUsernameParam(badUserID, &serverName) + cfg := &config.Global{ + ServerName: serverName, + } + + _, _, err := ParseUsernameParam(badUserID, cfg) if err == nil { t.Error("Illegal User ID should return an error") diff --git a/dendrite-sample.monolith.yaml b/dendrite-sample.monolith.yaml index eadb74a2a..5195c29bc 100644 --- a/dendrite-sample.monolith.yaml +++ b/dendrite-sample.monolith.yaml @@ -310,6 +310,14 @@ user_api: # The default lifetime is 3600000ms (60 minutes). # openid_token_lifetime_ms: 3600000 + # Users who register on this homeserver will automatically be joined to the rooms listed under "auto_join_rooms" option. + # By default, any room aliases included in this list will be created as a publicly joinable room + # when the first user registers for the homeserver. If the room already exists, + # make certain it is a publicly joinable room, i.e. the join rule of the room must be set to 'public'. + # As Spaces are just rooms under the hood, Space aliases may also be used. + auto_join_rooms: + # - "#main:matrix.org" + # Configuration for Opentracing. # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # how this works and how to set it up. diff --git a/dendrite-sample.polylith.yaml b/dendrite-sample.polylith.yaml index aa7e0cc38..bbbe16fdc 100644 --- a/dendrite-sample.polylith.yaml +++ b/dendrite-sample.polylith.yaml @@ -375,6 +375,14 @@ user_api: # The default lifetime is 3600000ms (60 minutes). # openid_token_lifetime_ms: 3600000 + # Users who register on this homeserver will automatically be joined to the rooms listed under "auto_join_rooms" option. + # By default, any room aliases included in this list will be created as a publicly joinable room + # when the first user registers for the homeserver. If the room already exists, + # make certain it is a publicly joinable room, i.e. the join rule of the room must be set to 'public'. + # As Spaces are just rooms under the hood, Space aliases may also be used. + auto_join_rooms: + # - "#main:matrix.org" + # Configuration for Opentracing. # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # how this works and how to set it up. diff --git a/docs/caddy/polylith/Caddyfile b/docs/caddy/polylith/Caddyfile index 8aeb9317f..c2d81b49b 100644 --- a/docs/caddy/polylith/Caddyfile +++ b/docs/caddy/polylith/Caddyfile @@ -74,7 +74,7 @@ matrix.example.com { # Change the end of each reverse_proxy line to the correct # address for your various services. @sync_api { - path_regexp /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|relations/.*?|event/.*?))$ + path_regexp /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ } reverse_proxy @sync_api sync_api:8073 diff --git a/docs/hiawatha/polylith-sample.conf b/docs/hiawatha/polylith-sample.conf index 0093fdcf2..eb1dd4f9a 100644 --- a/docs/hiawatha/polylith-sample.conf +++ b/docs/hiawatha/polylith-sample.conf @@ -23,8 +23,10 @@ VirtualHost { # /_matrix/client/.*/rooms/{roomId}/relations/{eventID} # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType} # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}/{eventType} + # /_matrix/client/.*/rooms/{roomId}/members + # /_matrix/client/.*/rooms/{roomId}/joined_members # to sync_api - ReverseProxy = /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|relations/.*?|event/.*?))$ http://localhost:8073 600 + ReverseProxy = /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ http://localhost:8073 600 ReverseProxy = /_matrix/client http://localhost:8071 600 ReverseProxy = /_matrix/federation http://localhost:8072 600 ReverseProxy = /_matrix/key http://localhost:8072 600 diff --git a/docs/nginx/polylith-sample.conf b/docs/nginx/polylith-sample.conf index 6e81eb5f2..0ad24509a 100644 --- a/docs/nginx/polylith-sample.conf +++ b/docs/nginx/polylith-sample.conf @@ -33,8 +33,10 @@ server { # /_matrix/client/.*/rooms/{roomId}/relations/{eventID} # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType} # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}/{eventType} + # /_matrix/client/.*/rooms/{roomId}/members + # /_matrix/client/.*/rooms/{roomId}/joined_members # to sync_api - location ~ /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|relations/.*?|event/.*?))$ { + location ~ /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ { proxy_pass http://sync_api:8073; } diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 67dfdc1d3..7d1ae0f81 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -35,14 +35,14 @@ import ( // KeyChangeConsumer consumes events that originate in key server. type KeyChangeConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - serverName gomatrixserverlib.ServerName - rsAPI roomserverAPI.FederationRoomserverAPI - topic string + ctx context.Context + jetstream nats.JetStreamContext + durable string + db storage.Database + queues *queue.OutgoingQueues + isLocalServerName func(gomatrixserverlib.ServerName) bool + rsAPI roomserverAPI.FederationRoomserverAPI + topic string } // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. @@ -55,14 +55,14 @@ func NewKeyChangeConsumer( rsAPI roomserverAPI.FederationRoomserverAPI, ) *KeyChangeConsumer { return &KeyChangeConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("FederationAPIKeyChangeConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), - queues: queues, - db: store, - serverName: cfg.Matrix.ServerName, - rsAPI: rsAPI, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("FederationAPIKeyChangeConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + queues: queues, + db: store, + isLocalServerName: cfg.Matrix.IsLocalServerName, + rsAPI: rsAPI, } } @@ -112,7 +112,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { logger.WithError(err).Error("Failed to extract domain from key change event") return true } - if originServerName != t.serverName { + if !t.isLocalServerName(originServerName) { return true } @@ -141,7 +141,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ Type: gomatrixserverlib.MDeviceListUpdate, - Origin: string(t.serverName), + Origin: string(originServerName), } event := gomatrixserverlib.DeviceListUpdateEvent{ UserID: m.UserID, @@ -159,7 +159,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } logger.Debugf("Sending device list update message to %q", destinations) - err = t.queues.SendEDU(edu, t.serverName, destinations) + err = t.queues.SendEDU(edu, originServerName, destinations) return err == nil } @@ -171,7 +171,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure") return true } - if host != gomatrixserverlib.ServerName(t.serverName) { + if !t.isLocalServerName(host) { // Ignore any messages that didn't originate locally, otherwise we'll // end up parroting information we received from other servers. return true @@ -203,7 +203,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ Type: types.MSigningKeyUpdate, - Origin: string(t.serverName), + Origin: string(host), } if edu.Content, err = json.Marshal(output); err != nil { sentry.CaptureException(err) @@ -212,7 +212,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { } logger.Debugf("Sending cross-signing update message to %q", destinations) - err = t.queues.SendEDU(edu, t.serverName, destinations) + err = t.queues.SendEDU(edu, host, destinations) return err == nil } diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index e76103cd3..3445d34a9 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -38,7 +38,7 @@ type OutputPresenceConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - ServerName gomatrixserverlib.ServerName + isLocalServerName func(gomatrixserverlib.ServerName) bool topic string outboundPresenceEnabled bool } @@ -56,7 +56,7 @@ func NewOutputPresenceConsumer( jetstream: js, queues: queues, db: store, - ServerName: cfg.Matrix.ServerName, + isLocalServerName: cfg.Matrix.IsLocalServerName, durable: cfg.Matrix.JetStream.Durable("FederationAPIPresenceConsumer"), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), outboundPresenceEnabled: cfg.Matrix.Presence.EnableOutbound, @@ -85,7 +85,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg log.WithError(err).WithField("user_id", userID).Error("failed to extract domain from receipt sender") return true } - if serverName != t.ServerName { + if !t.isLocalServerName(serverName) { return true } @@ -127,7 +127,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg edu := &gomatrixserverlib.EDU{ Type: gomatrixserverlib.MPresence, - Origin: string(t.ServerName), + Origin: string(serverName), } if edu.Content, err = json.Marshal(content); err != nil { log.WithError(err).Error("failed to marshal EDU JSON") @@ -135,7 +135,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg } log.Tracef("sending presence EDU to %d servers", len(joined)) - if err = t.queues.SendEDU(edu, t.ServerName, joined); err != nil { + if err = t.queues.SendEDU(edu, serverName, joined); err != nil { log.WithError(err).Error("failed to send EDU") return false } diff --git a/federationapi/consumers/receipts.go b/federationapi/consumers/receipts.go index 75827cb68..200c06e6c 100644 --- a/federationapi/consumers/receipts.go +++ b/federationapi/consumers/receipts.go @@ -34,13 +34,13 @@ import ( // OutputReceiptConsumer consumes events that originate in the clientapi. type OutputReceiptConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - ServerName gomatrixserverlib.ServerName - topic string + ctx context.Context + jetstream nats.JetStreamContext + durable string + db storage.Database + queues *queue.OutgoingQueues + isLocalServerName func(gomatrixserverlib.ServerName) bool + topic string } // NewOutputReceiptConsumer creates a new OutputReceiptConsumer. Call Start() to begin consuming typing events. @@ -52,13 +52,13 @@ func NewOutputReceiptConsumer( store storage.Database, ) *OutputReceiptConsumer { return &OutputReceiptConsumer{ - ctx: process.Context(), - jetstream: js, - queues: queues, - db: store, - ServerName: cfg.Matrix.ServerName, - durable: cfg.Matrix.JetStream.Durable("FederationAPIReceiptConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), + ctx: process.Context(), + jetstream: js, + queues: queues, + db: store, + isLocalServerName: cfg.Matrix.IsLocalServerName, + durable: cfg.Matrix.JetStream.Durable("FederationAPIReceiptConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), } } @@ -95,7 +95,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") return true } - if receiptServerName != t.ServerName { + if !t.isLocalServerName(receiptServerName) { return true } @@ -134,14 +134,14 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) edu := &gomatrixserverlib.EDU{ Type: gomatrixserverlib.MReceipt, - Origin: string(t.ServerName), + Origin: string(receiptServerName), } if edu.Content, err = json.Marshal(content); err != nil { log.WithError(err).Error("failed to marshal EDU JSON") return true } - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + if err := t.queues.SendEDU(edu, receiptServerName, names); err != nil { log.WithError(err).Error("failed to send EDU") return false } diff --git a/federationapi/consumers/sendtodevice.go b/federationapi/consumers/sendtodevice.go index 9aec22a3e..9620d1612 100644 --- a/federationapi/consumers/sendtodevice.go +++ b/federationapi/consumers/sendtodevice.go @@ -34,13 +34,13 @@ import ( // OutputSendToDeviceConsumer consumes events that originate in the clientapi. type OutputSendToDeviceConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - ServerName gomatrixserverlib.ServerName - topic string + ctx context.Context + jetstream nats.JetStreamContext + durable string + db storage.Database + queues *queue.OutgoingQueues + isLocalServerName func(gomatrixserverlib.ServerName) bool + topic string } // NewOutputSendToDeviceConsumer creates a new OutputSendToDeviceConsumer. Call Start() to begin consuming send-to-device events. @@ -52,13 +52,13 @@ func NewOutputSendToDeviceConsumer( store storage.Database, ) *OutputSendToDeviceConsumer { return &OutputSendToDeviceConsumer{ - ctx: process.Context(), - jetstream: js, - queues: queues, - db: store, - ServerName: cfg.Matrix.ServerName, - durable: cfg.Matrix.JetStream.Durable("FederationAPIESendToDeviceConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + ctx: process.Context(), + jetstream: js, + queues: queues, + db: store, + isLocalServerName: cfg.Matrix.IsLocalServerName, + durable: cfg.Matrix.JetStream.Durable("FederationAPIESendToDeviceConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), } } @@ -82,7 +82,7 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats log.WithError(err).WithField("user_id", sender).Error("Failed to extract domain from send-to-device sender") return true } - if originServerName != t.ServerName { + if !t.isLocalServerName(originServerName) { return true } // Extract the send-to-device event from msg. @@ -101,14 +101,14 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats } // The SyncAPI is already handling sendToDevice for the local server - if destServerName == t.ServerName { + if t.isLocalServerName(destServerName) { return true } // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ Type: gomatrixserverlib.MDirectToDevice, - Origin: string(t.ServerName), + Origin: string(originServerName), } tdm := gomatrixserverlib.ToDeviceMessage{ Sender: ote.Sender, @@ -127,7 +127,7 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats } log.Debugf("Sending send-to-device message into %q destination queue", destServerName) - if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { + if err := t.queues.SendEDU(edu, originServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { log.WithError(err).Error("failed to send EDU") return false } diff --git a/federationapi/consumers/typing.go b/federationapi/consumers/typing.go index 9c7379136..c66f97519 100644 --- a/federationapi/consumers/typing.go +++ b/federationapi/consumers/typing.go @@ -31,13 +31,13 @@ import ( // OutputTypingConsumer consumes events that originate in the clientapi. type OutputTypingConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - ServerName gomatrixserverlib.ServerName - topic string + ctx context.Context + jetstream nats.JetStreamContext + durable string + db storage.Database + queues *queue.OutgoingQueues + isLocalServerName func(gomatrixserverlib.ServerName) bool + topic string } // NewOutputTypingConsumer creates a new OutputTypingConsumer. Call Start() to begin consuming typing events. @@ -49,13 +49,13 @@ func NewOutputTypingConsumer( store storage.Database, ) *OutputTypingConsumer { return &OutputTypingConsumer{ - ctx: process.Context(), - jetstream: js, - queues: queues, - db: store, - ServerName: cfg.Matrix.ServerName, - durable: cfg.Matrix.JetStream.Durable("FederationAPITypingConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent), + ctx: process.Context(), + jetstream: js, + queues: queues, + db: store, + isLocalServerName: cfg.Matrix.IsLocalServerName, + durable: cfg.Matrix.JetStream.Durable("FederationAPITypingConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent), } } @@ -87,7 +87,7 @@ func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) _ = msg.Ack() return true } - if typingServerName != t.ServerName { + if !t.isLocalServerName(typingServerName) { return true } @@ -111,7 +111,7 @@ func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) log.WithError(err).Error("failed to marshal EDU JSON") return true } - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + if err := t.queues.SendEDU(edu, typingServerName, names); err != nil { log.WithError(err).Error("failed to send EDU") return false } diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index f6dace702..a58cba1b1 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -69,7 +69,7 @@ func AddPublicRoutes( TopicPresenceEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), TopicDeviceListUpdate: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), TopicSigningKeyUpdate: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - ServerName: cfg.Matrix.ServerName, + Config: cfg, UserAPI: userAPI, } @@ -107,7 +107,7 @@ func NewInternalAPI( ) api.FederationInternalAPI { cfg := &base.Cfg.FederationAPI - federationDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.ServerName) + federationDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) if err != nil { logrus.WithError(err).Panic("failed to connect to federation sender db") } diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 85cc43aa5..7ccc02f76 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -87,6 +87,7 @@ func TestMain(m *testing.M) { cfg.Global.JetStream.StoragePath = config.Path(d) cfg.Global.KeyID = serverKeyID cfg.Global.KeyValidityPeriod = s.validity + cfg.FederationAPI.KeyPerspectives = nil f, err := os.CreateTemp(d, "federation_keys_test*.db") if err != nil { return -1 @@ -207,7 +208,6 @@ func TestRenewalBehaviour(t *testing.T) { // happy at this point that the key that we already have is from the past // then repeating a key fetch should cause us to try and renew the key. // If so, then the new key will end up in our cache. - serverC.renew() res, err = serverA.api.FetchKeys( diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index e923143a7..c37bc87c2 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -164,6 +164,7 @@ func TestFederationAPIJoinThenKeyUpdate(t *testing.T) { func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { base, close := testrig.CreateBaseDendrite(t, dbType) base.Cfg.FederationAPI.PreferDirectFetch = true + base.Cfg.FederationAPI.KeyPerspectives = nil defer close() jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) diff --git a/federationapi/internal/keys.go b/federationapi/internal/keys.go index 2b7a8219a..258bd88bf 100644 --- a/federationapi/internal/keys.go +++ b/federationapi/internal/keys.go @@ -99,7 +99,7 @@ func (s *FederationInternalAPI) handleLocalKeys( results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) { for req := range requests { - if req.ServerName != s.cfg.Matrix.ServerName { + if !s.cfg.Matrix.IsLocalServerName(req.ServerName) { continue } if req.KeyID == s.cfg.Matrix.KeyID { diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 28ec48d7b..1b61ec711 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -77,7 +77,7 @@ func (r *FederationInternalAPI) PerformJoin( seenSet := make(map[gomatrixserverlib.ServerName]bool) var uniqueList []gomatrixserverlib.ServerName for _, srv := range request.ServerNames { - if seenSet[srv] || srv == r.cfg.Matrix.ServerName { + if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) { continue } seenSet[srv] = true diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 659ff1bcf..7cce13a7d 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -25,6 +25,7 @@ import ( "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -39,7 +40,7 @@ type SyncAPIProducer struct { TopicDeviceListUpdate string TopicSigningKeyUpdate string JetStream nats.JetStreamContext - ServerName gomatrixserverlib.ServerName + Config *config.FederationAPI UserAPI userapi.UserInternalAPI } @@ -77,7 +78,7 @@ func (p *SyncAPIProducer) SendToDevice( // device. If the event isn't targeted locally then we can't expand the // wildcard as we don't know about the remote devices, so instead we leave it // as-is, so that the federation sender can send it on with the wildcard intact. - if domain == p.ServerName && deviceID == "*" { + if p.Config.Matrix.IsLocalServerName(domain) && deviceID == "*" { var res userapi.QueryDevicesResponse err = p.UserAPI.QueryDevices(context.TODO(), &userapi.QueryDevicesRequest{ UserID: userID, diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index a1b280103..7ef4646f7 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -47,7 +47,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase connStr, dbClose := test.PrepareDBConnectionString(t, dbType) db, err := storage.NewDatabase(b, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, b.Caches, b.Cfg.Global.ServerName) + }, b.Caches, b.Cfg.Global.IsLocalServerName) if err != nil { t.Fatalf("NewDatabase returned %s", err) } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index e25f9866e..9f16e5093 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -124,7 +124,7 @@ func Setup( mu := internal.NewMutexByRoom() v1fedmux.Handle("/send/{txnID}", MakeFedAPI( - "federation_send", cfg.Matrix.ServerName, keys, wakeup, + "federation_send", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), @@ -134,7 +134,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( - "federation_invite", cfg.Matrix.ServerName, keys, wakeup, + "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -150,7 +150,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v2fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( - "federation_invite", cfg.Matrix.ServerName, keys, wakeup, + "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -172,7 +172,7 @@ func Setup( )).Methods(http.MethodPost, http.MethodOptions) v1fedmux.Handle("/exchange_third_party_invite/{roomID}", MakeFedAPI( - "exchange_third_party_invite", cfg.Matrix.ServerName, keys, wakeup, + "exchange_third_party_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return ExchangeThirdPartyInvite( httpReq, request, vars["roomID"], rsAPI, cfg, federation, @@ -181,7 +181,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v1fedmux.Handle("/event/{eventID}", MakeFedAPI( - "federation_get_event", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_event", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetEvent( httpReq.Context(), request, rsAPI, vars["eventID"], cfg.Matrix.ServerName, @@ -190,7 +190,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/state/{roomID}", MakeFedAPI( - "federation_get_state", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_state", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -205,7 +205,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/state_ids/{roomID}", MakeFedAPI( - "federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_state_ids", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -220,7 +220,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/event_auth/{roomID}/{eventID}", MakeFedAPI( - "federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_event_auth", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -235,7 +235,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/query/directory", MakeFedAPI( - "federation_query_room_alias", cfg.Matrix.ServerName, keys, wakeup, + "federation_query_room_alias", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return RoomAliasToID( httpReq, federation, cfg, rsAPI, fsAPI, @@ -244,7 +244,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/query/profile", MakeFedAPI( - "federation_query_profile", cfg.Matrix.ServerName, keys, wakeup, + "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetProfile( httpReq, userAPI, cfg, @@ -253,7 +253,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( - "federation_user_devices", cfg.Matrix.ServerName, keys, wakeup, + "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( httpReq, keyAPI, vars["userID"], @@ -263,7 +263,7 @@ func Setup( if mscCfg.Enabled("msc2444") { v1fedmux.Handle("/peek/{roomID}/{peekID}", MakeFedAPI( - "federation_peek", cfg.Matrix.ServerName, keys, wakeup, + "federation_peek", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -294,7 +294,7 @@ func Setup( } v1fedmux.Handle("/make_join/{roomID}/{userID}", MakeFedAPI( - "federation_make_join", cfg.Matrix.ServerName, keys, wakeup, + "federation_make_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -325,7 +325,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( - "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -357,7 +357,7 @@ func Setup( )).Methods(http.MethodPut) v2fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( - "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -374,7 +374,7 @@ func Setup( )).Methods(http.MethodPut) v1fedmux.Handle("/make_leave/{roomID}/{eventID}", MakeFedAPI( - "federation_make_leave", cfg.Matrix.ServerName, keys, wakeup, + "federation_make_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -391,7 +391,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( - "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -423,7 +423,7 @@ func Setup( )).Methods(http.MethodPut) v2fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( - "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -447,7 +447,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/get_missing_events/{roomID}", MakeFedAPI( - "federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_missing_events", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -460,7 +460,7 @@ func Setup( )).Methods(http.MethodPost) v1fedmux.Handle("/backfill/{roomID}", MakeFedAPI( - "federation_backfill", cfg.Matrix.ServerName, keys, wakeup, + "federation_backfill", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -479,14 +479,14 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost) v1fedmux.Handle("/user/keys/claim", MakeFedAPI( - "federation_keys_claim", cfg.Matrix.ServerName, keys, wakeup, + "federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) v1fedmux.Handle("/user/keys/query", MakeFedAPI( - "federation_keys_query", cfg.Matrix.ServerName, keys, wakeup, + "federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) }, @@ -525,15 +525,15 @@ func ErrorIfLocalServerNotInRoom( // MakeFedAPI makes an http.Handler that checks matrix federation authentication. func MakeFedAPI( - metricsName string, - serverName gomatrixserverlib.ServerName, + metricsName string, serverName gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, keyRing gomatrixserverlib.JSONVerifier, wakeup *FederationWakeups, f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), serverName, keyRing, + req, time.Now(), serverName, isLocalServerName, keyRing, ) if fedReq == nil { return errResp diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 6e208d096..a33fa4a43 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -36,7 +36,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { var d Database var err error if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { @@ -96,7 +96,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, } d.Database = shared.Database{ DB: d.db, - ServerName: serverName, + IsLocalServerName: isLocalServerName, Cache: cache, Writer: d.writer, FederationJoinedHosts: joinedHosts, diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 6afb313a8..4fabff7d4 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -29,7 +29,7 @@ import ( type Database struct { DB *sql.DB - ServerName gomatrixserverlib.ServerName + IsLocalServerName func(gomatrixserverlib.ServerName) bool Cache caching.FederationCache Writer sqlutil.Writer FederationQueuePDUs tables.FederationQueuePDUs @@ -124,7 +124,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, } if excludeSelf { for i, server := range servers { - if server == d.ServerName { + if d.IsLocalServerName(server) { servers = append(servers[:i], servers[i+1:]...) } } diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index c89cb6bea..e86ac817b 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -35,7 +35,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { var d Database var err error if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { @@ -95,7 +95,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, } d.Database = shared.Database{ DB: d.db, - ServerName: serverName, + IsLocalServerName: isLocalServerName, Cache: cache, Writer: d.writer, FederationJoinedHosts: joinedHosts, diff --git a/federationapi/storage/storage.go b/federationapi/storage/storage.go index f246b9bc9..142e281ea 100644 --- a/federationapi/storage/storage.go +++ b/federationapi/storage/storage.go @@ -29,12 +29,12 @@ import ( ) // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, cache, serverName) + return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties, cache, serverName) + return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 6272fd2b1..f7408fa9f 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -19,7 +19,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat connStr, dbClose := test.PrepareDBConnectionString(t, dbType) db, err := storage.NewDatabase(b, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, b.Caches, b.Cfg.Global.ServerName) + }, b.Caches, func(server gomatrixserverlib.ServerName) bool { return server == "localhost" }) if err != nil { t.Fatalf("NewDatabase returned %s", err) } diff --git a/go.mod b/go.mod index bf001fdb9..be5099fcf 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a + github.com/matrix-org/gomatrixserverlib v0.0.0-20221025142407-17b0be811afa github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 809562f2e..c7903b0ce 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a h1:6rJFN5NBuzZ7h5meYkLtXKa6VFZfDc8oVXHd4SDXr5o= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221025142407-17b0be811afa h1:S98DShDv3sn7O4n4HjtJOejypseYVpv1R/XPg+cDnfI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221025142407-17b0be811afa/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6 h1:nAT5w41Q9uWTSnpKW55/hBwP91j2IFYPDRs0jJ8TyFI= github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 49ef03054..ff0968b27 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -257,9 +257,6 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.Failures = make(map[string]interface{}) - // get cross-signing keys from the database - a.crossSigningKeysFromDatabase(ctx, req, res) - // make a map from domain to device keys domainToDeviceKeys := make(map[string]map[string][]string) domainToCrossSigningKeys := make(map[string]map[string]struct{}) @@ -336,6 +333,10 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) } + // Now that we've done the potentially expensive work of asking the federation, + // try filling the cross-signing keys from the database that we know about. + a.crossSigningKeysFromDatabase(ctx, req, res) + // Finally, append signatures that we know about // TODO: This is horrible because we need to round-trip the signature from // JSON, add the signatures and marshal it again, for some reason? diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go index 8b2a865b9..4536b7d80 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/keyserver/storage/postgres/cross_signing_sigs_table.go @@ -42,7 +42,7 @@ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_s const selectCrossSigningSigsForTargetSQL = "" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + - " WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3" + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3" const upsertCrossSigningSigsForTargetSQL = "" + "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go index ea431151e..7a153e8fb 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/keyserver/storage/sqlite3/cross_signing_sigs_table.go @@ -42,7 +42,7 @@ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_s const selectCrossSigningSigsForTargetSQL = "" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + - " WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3" + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4" const upsertCrossSigningSigsForTargetSQL = "" + "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + @@ -85,7 +85,7 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, ) (r types.CrossSigningSigMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID) + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID) if err != nil { return nil, err } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index baf63aa31..403bbe8be 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -167,6 +167,7 @@ type UserRoomserverAPI interface { QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error + PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error } type FederationRoomserverAPI interface { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index cb6b22d32..6a6d51b0a 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -117,6 +117,11 @@ func (r *Admin) PerformAdminEvacuateRoom( PrevEvents: prevEvents, } + _, senderDomain, err := gomatrixserverlib.SplitID('@', fledglingEvent.Sender) + if err != nil { + continue + } + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -146,8 +151,8 @@ func (r *Admin) PerformAdminEvacuateRoom( inputEvents = append(inputEvents, api.InputRoomEvent{ Kind: api.KindNew, Event: event, - Origin: r.Cfg.Matrix.ServerName, - SendAsServer: string(r.Cfg.Matrix.ServerName), + Origin: senderDomain, + SendAsServer: string(senderDomain), }) res.Affected = append(res.Affected, stateKey) prevEvents = []gomatrixserverlib.EventReference{ @@ -176,7 +181,7 @@ func (r *Admin) PerformAdminEvacuateUser( } return nil } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: "Can only evacuate local users using this endpoint", diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 3fbdf332e..f60247cd7 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -70,8 +70,8 @@ func (r *Inviter) PerformInvite( } return nil, nil } - isTargetLocal := domain == r.Cfg.Matrix.ServerName - isOriginLocal := senderDomain == r.Cfg.Matrix.ServerName + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) + isOriginLocal := r.Cfg.Matrix.IsLocalServerName(senderDomain) if !isOriginLocal && !isTargetLocal { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 262273ff5..9d596ab30 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -92,7 +92,7 @@ func (r *Joiner) performJoin( Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return "", "", &rsAPI.PerformError{ Code: rsAPI.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), @@ -124,7 +124,7 @@ func (r *Joiner) performJoinRoomByAlias( // Check if this alias matches our own server configuration. If it // doesn't then we'll need to try a federated join. var roomID string - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { // The alias isn't owned by us, so we will need to try joining using // a remote server. dirReq := fsAPI.PerformDirectoryLookupRequest{ @@ -172,7 +172,7 @@ func (r *Joiner) performJoinRoomByID( // The original client request ?server_name=... may include this HS so filter that out so we // don't attempt to make_join with ourselves for i := 0; i < len(req.ServerNames); i++ { - if req.ServerNames[i] == r.Cfg.Matrix.ServerName { + if r.Cfg.Matrix.IsLocalServerName(req.ServerNames[i]) { // delete this entry req.ServerNames = append(req.ServerNames[:i], req.ServerNames[i+1:]...) i-- @@ -191,12 +191,19 @@ func (r *Joiner) performJoinRoomByID( // If the server name in the room ID isn't ours then it's a // possible candidate for finding the room via federation. Add // it to the list of servers to try. - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { req.ServerNames = append(req.ServerNames, domain) } // Prepare the template for the join event. userID := req.UserID + _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorBadRequest, + Msg: fmt.Sprintf("User ID %q is invalid: %s", userID, err), + } + } eb := gomatrixserverlib.EventBuilder{ Type: gomatrixserverlib.MRoomMember, Sender: userID, @@ -247,7 +254,7 @@ func (r *Joiner) performJoinRoomByID( // If we were invited by someone from another server then we can // assume they are in the room so we can join via them. - if inviterDomain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) { req.ServerNames = append(req.ServerNames, inviterDomain) forceFederatedJoin = true memberEvent := gjson.Parse(string(inviteEvent.JSON())) @@ -300,7 +307,7 @@ func (r *Joiner) performJoinRoomByID( { Kind: rsAPI.KindNew, Event: event.Headered(buildRes.RoomVersion), - SendAsServer: string(r.Cfg.Matrix.ServerName), + SendAsServer: string(userDomain), }, }, } @@ -323,7 +330,7 @@ func (r *Joiner) performJoinRoomByID( // The room doesn't exist locally. If the room ID looks like it should // be ours then this probably means that we've nuked our database at // some point. - if domain == r.Cfg.Matrix.ServerName { + if r.Cfg.Matrix.IsLocalServerName(domain) { // If there are no more server names to try then give up here. // Otherwise we'll try a federated join as normal, since it's quite // possible that the room still exists on other servers. @@ -348,7 +355,7 @@ func (r *Joiner) performJoinRoomByID( // it will have been overwritten with a room ID by performJoinRoomByAlias. // We should now include this in the response so that the CS API can // return the right room ID. - return req.RoomIDOrAlias, r.Cfg.Matrix.ServerName, nil + return req.RoomIDOrAlias, userDomain, nil } func (r *Joiner) performFederatedJoinRoomByID( diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 85b659814..49e4b479a 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -52,7 +52,7 @@ func (r *Leaver) PerformLeave( if err != nil { return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID) } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) } logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ @@ -85,7 +85,7 @@ func (r *Leaver) performLeaveRoomByID( if serr != nil { return nil, fmt.Errorf("sender %q is invalid", senderUser) } - if senderDomain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(senderDomain) { return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) } // check that this is not a "server notice room" @@ -186,7 +186,7 @@ func (r *Leaver) performLeaveRoomByID( Kind: api.KindNew, Event: event.Headered(buildRes.RoomVersion), Origin: senderDomain, - SendAsServer: string(r.Cfg.Matrix.ServerName), + SendAsServer: string(senderDomain), }, }, } diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go index 74d87a5b4..436d137ff 100644 --- a/roomserver/internal/perform/perform_peek.go +++ b/roomserver/internal/perform/perform_peek.go @@ -72,7 +72,7 @@ func (r *Peeker) performPeek( Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), @@ -104,7 +104,7 @@ func (r *Peeker) performPeekRoomByAlias( // Check if this alias matches our own server configuration. If it // doesn't then we'll need to try a federated peek. var roomID string - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { // The alias isn't owned by us, so we will need to try peeking using // a remote server. dirReq := fsAPI.PerformDirectoryLookupRequest{ @@ -154,7 +154,7 @@ func (r *Peeker) performPeekRoomByID( // handle federated peeks // FIXME: don't create an outbound peek if we already have one going. - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { // If the server name in the room ID isn't ours then it's a // possible candidate for finding the room via federation. Add // it to the list of servers to try. diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 49e9067c9..0d97da4d6 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -67,7 +67,7 @@ func (r *Unpeeker) performUnpeek( Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index d6dc9708c..38abe323c 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -60,6 +60,13 @@ func (r *Upgrader) performRoomUpgrade( ) (string, *api.PerformError) { roomID := req.RoomID userID := req.UserID + _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return "", &api.PerformError{ + Code: api.PerformErrorNotAllowed, + Msg: "Error validating the user ID", + } + } evTime := time.Now() // Return an immediate error if the room does not exist @@ -80,7 +87,7 @@ func (r *Upgrader) performRoomUpgrade( // TODO (#267): Check room ID doesn't clash with an existing one, and we // probably shouldn't be using pseudo-random strings, maybe GUIDs? - newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), r.Cfg.Matrix.ServerName) + newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) // Get the existing room state for the old room. oldRoomReq := &api.QueryLatestEventsAndStateRequest{ @@ -107,12 +114,12 @@ func (r *Upgrader) performRoomUpgrade( } // Send the setup events to the new room - if pErr = r.sendInitialEvents(ctx, evTime, userID, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { + if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { return "", pErr } // 5. Send the tombstone event to the old room - if pErr = r.sendHeaderedEvent(ctx, tombstoneEvent, string(r.Cfg.Matrix.ServerName)); pErr != nil { + if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil { return "", pErr } @@ -122,7 +129,7 @@ func (r *Upgrader) performRoomUpgrade( } // If the old room had a canonical alias event, it should be deleted in the old room - if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, roomID); pErr != nil { + if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil { return "", pErr } @@ -132,7 +139,7 @@ func (r *Upgrader) performRoomUpgrade( } // 6. Restrict power levels in the old room - if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, roomID); pErr != nil { + if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil { return "", pErr } @@ -154,7 +161,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma return powerLevelContent, nil } -func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID, roomID string) *api.PerformError { +func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError { restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) if pErr != nil { return pErr @@ -183,7 +190,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T return resErr } } else { - if resErr = r.sendHeaderedEvent(ctx, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil { + if resErr = r.sendHeaderedEvent(ctx, userDomain, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil { return resErr } } @@ -223,7 +230,7 @@ func moveLocalAliases(ctx context.Context, return nil } -func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID, roomID string) *api.PerformError { +func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError { for _, event := range oldRoom.StateEvents { if event.Type() != gomatrixserverlib.MRoomCanonicalAlias || !event.StateKeyEquals("") { continue @@ -254,7 +261,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api return resErr } } else { - if resErr = r.sendHeaderedEvent(ctx, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil { + if resErr = r.sendHeaderedEvent(ctx, userDomain, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil { return resErr } } @@ -495,7 +502,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query return eventsToMake, nil } -func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { +func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { var err error var builtEvents []*gomatrixserverlib.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) @@ -519,7 +526,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} } var event *gomatrixserverlib.Event - event, err = r.buildEvent(&builder, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion)) + event, err = r.buildEvent(&builder, userDomain, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion)) if err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err), @@ -547,7 +554,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user inputs = append(inputs, api.InputRoomEvent{ Kind: api.KindNew, Event: event, - Origin: r.Cfg.Matrix.ServerName, + Origin: userDomain, SendAsServer: api.DoNotSendToOtherServers, }) } @@ -668,6 +675,7 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC func (r *Upgrader) sendHeaderedEvent( ctx context.Context, + serverName gomatrixserverlib.ServerName, headeredEvent *gomatrixserverlib.HeaderedEvent, sendAsServer string, ) *api.PerformError { @@ -675,7 +683,7 @@ func (r *Upgrader) sendHeaderedEvent( inputs = append(inputs, api.InputRoomEvent{ Kind: api.KindNew, Event: headeredEvent, - Origin: r.Cfg.Matrix.ServerName, + Origin: serverName, SendAsServer: sendAsServer, }) if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { @@ -689,6 +697,7 @@ func (r *Upgrader) sendHeaderedEvent( func (r *Upgrader) buildEvent( builder *gomatrixserverlib.EventBuilder, + serverName gomatrixserverlib.ServerName, provider gomatrixserverlib.AuthEventProvider, evTime time.Time, roomVersion gomatrixserverlib.RoomVersion, @@ -703,7 +712,7 @@ func (r *Upgrader) buildEvent( } builder.AuthEvents = refs event, err := builder.Build( - evTime, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID, + evTime, serverName, r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey, roomVersion, ) if err != nil { diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 2efae0d5a..825772827 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -14,6 +14,9 @@ type Global struct { // The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'. ServerName gomatrixserverlib.ServerName `yaml:"server_name"` + // The secondary server names, used for virtual hosting. + SecondaryServerNames []gomatrixserverlib.ServerName `yaml:"-"` + // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` @@ -120,6 +123,18 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { c.Cache.Verify(configErrs, isMonolith) } +func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool { + if c.ServerName == serverName { + return true + } + for _, secondaryName := range c.SecondaryServerNames { + if secondaryName == serverName { + return true + } + } + return false +} + type OldVerifyKeys struct { // Path to the private key. PrivateKeyPath Path `yaml:"private_key"` @@ -170,7 +185,7 @@ type ServerNotices struct { // The displayname to be used when sending notices DisplayName string `yaml:"display_name"` // The avatar of this user - AvatarURL string `yaml:"avatar"` + AvatarURL string `yaml:"avatar_url"` // The roomname to be used when creating messages RoomName string `yaml:"room_name"` } diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index 97a6d738b..f8ad41d93 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -19,6 +19,10 @@ type UserAPI struct { // The Account database stores the login details and account information // for local users. It is accessed by the UserAPI. AccountDatabase DatabaseOptions `yaml:"account_database,omitempty"` + + // Users who register on this homeserver will automatically + // be joined to the rooms listed under this option. + AutoJoinRooms []string `yaml:"auto_join_rooms"` } const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 452b14580..98502f5cb 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -132,7 +132,7 @@ func Enable( base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( "msc2836_event_relationships", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), base.Cfg.Global.ServerName, keyRing, + req, time.Now(), base.Cfg.Global.ServerName, base.Cfg.Global.IsLocalServerName, keyRing, ) if fedReq == nil { return errResp diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index a92a16a27..bc9df0f96 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -64,7 +64,7 @@ func Enable( fedAPI := httputil.MakeExternalAPI( "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), base.Cfg.Global.ServerName, keyRing, + req, time.Now(), base.Cfg.Global.ServerName, base.Cfg.Global.IsLocalServerName, keyRing, ) if fedReq == nil { return errResp diff --git a/clientapi/routing/memberships.go b/syncapi/routing/memberships.go similarity index 52% rename from clientapi/routing/memberships.go rename to syncapi/routing/memberships.go index 9bdd8a4f4..c9acc5d2b 100644 --- a/clientapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -18,22 +18,20 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type getMembershipResponse struct { Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` } -type getJoinedRoomsResponse struct { - JoinedRooms []string `json:"joined_rooms"` -} - // https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-rooms-roomid-joined-members type getJoinedMembersResponse struct { Joined map[string]joinedMember `json:"joined"` @@ -51,19 +49,22 @@ type databaseJoinedMember struct { AvatarURL string `json:"avatar_url"` } -// GetMemberships implements GET /rooms/{roomId}/members +// GetMemberships implements +// +// GET /rooms/{roomId}/members +// GET /rooms/{roomId}/joined_members func GetMemberships( - req *http.Request, device *userapi.Device, roomID string, joinedOnly bool, - _ *config.ClientAPI, - rsAPI api.ClientRoomserverAPI, + req *http.Request, device *userapi.Device, roomID string, + syncDB storage.Database, rsAPI api.SyncRoomserverAPI, + joinedOnly bool, membership, notMembership *string, at string, ) util.JSONResponse { - queryReq := api.QueryMembershipsForRoomRequest{ - JoinedOnly: joinedOnly, - RoomID: roomID, - Sender: device.UserID, + queryReq := api.QueryMembershipForUserRequest{ + RoomID: roomID, + UserID: device.UserID, } - var queryRes api.QueryMembershipsForRoomResponse - if err := rsAPI.QueryMembershipsForRoom(req.Context(), &queryReq, &queryRes); err != nil { + + var queryRes api.QueryMembershipForUserResponse + if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") return jsonerror.InternalServerError() } @@ -75,16 +76,54 @@ func GetMemberships( } } + db, err := syncDB.NewDatabaseSnapshot(req.Context()) + if err != nil { + return jsonerror.InternalServerError() + } + + atToken, err := types.NewTopologyTokenFromString(at) + if err != nil { + if queryRes.HasBeenInRoom && !queryRes.IsInRoom { + // If you have left the room then this will be the members of the room when you left. + atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID) + } else { + // If you are joined to the room then this will be the current members of the room. + atToken, err = db.MaxTopologicalPosition(req.Context(), roomID) + } + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") + return jsonerror.InternalServerError() + } + } + + eventIDs, err := db.SelectMemberships(req.Context(), roomID, atToken, membership, notMembership) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("db.SelectMemberships failed") + return jsonerror.InternalServerError() + } + + result, err := db.Events(req.Context(), eventIDs) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("db.Events failed") + return jsonerror.InternalServerError() + } + if joinedOnly { + if !queryRes.IsInRoom { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("You aren't a member of the room and weren't previously a member of the room."), + } + } var res getJoinedMembersResponse res.Joined = make(map[string]joinedMember) - for _, ev := range queryRes.JoinEvents { + for _, ev := range result { var content databaseJoinedMember - if err := json.Unmarshal(ev.Content, &content); err != nil { + if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content") return jsonerror.InternalServerError() } - res.Joined[ev.Sender] = joinedMember(content) + res.Joined[ev.Sender()] = joinedMember(content) } return util.JSONResponse{ Code: http.StatusOK, @@ -93,29 +132,6 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{queryRes.JoinEvents}, - } -} - -func GetJoinedRooms( - req *http.Request, - device *userapi.Device, - rsAPI api.ClientRoomserverAPI, -) util.JSONResponse { - var res api.QueryRoomsForUserResponse - err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ - UserID: device.UserID, - WantMembership: "join", - }, &res) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() - } - if res.RoomIDs == nil { - res.RoomIDs = []string{} - } - return util.JSONResponse{ - Code: http.StatusOK, - JSON: getJoinedRoomsResponse{res.RoomIDs}, + JSON: getMembershipResponse{gomatrixserverlib.HeaderedToClientEvents(result, gomatrixserverlib.FormatSync)}, } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 8f3ed3f5b..86cf8e736 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -83,18 +83,18 @@ func OnIncomingMessagesRequest( defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) // check if the user has already forgotten about this room - isForgotten, roomExists, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI) + membershipResp, err := getMembershipForUser(req.Context(), roomID, device.UserID, rsAPI) if err != nil { return jsonerror.InternalServerError() } - if !roomExists { + if !membershipResp.RoomExists { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("room does not exist"), } } - if isForgotten { + if membershipResp.IsRoomForgotten { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("user already forgot about this room"), @@ -195,6 +195,20 @@ func OnIncomingMessagesRequest( } } + // If the user already left the room, grep events from before that + if membershipResp.Membership == gomatrixserverlib.Leave { + var token types.TopologyToken + token, err = snapshot.EventPositionInTopology(req.Context(), membershipResp.EventID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + } + } + if backwardOrdering { + from = token + } + } + mReq := messagesReq{ ctx: req.Context(), db: db, @@ -283,17 +297,16 @@ func (m *messagesResp) applyLazyLoadMembers( } } -func checkIsRoomForgotten(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (forgotten bool, exists bool, err error) { +func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { req := api.QueryMembershipForUserRequest{ RoomID: roomID, UserID: userID, } - resp := api.QueryMembershipForUserResponse{} if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { - return false, false, err + return api.QueryMembershipForUserResponse{}, err } - return resp.IsRoomForgotten, resp.RoomExists, nil + return resp, nil } // retrieveEvents retrieves events from the local database for a request on @@ -313,7 +326,11 @@ func (r *messagesReq) retrieveEvents() ( } var events []*gomatrixserverlib.HeaderedEvent - util.GetLogger(r.ctx).WithField("start", start).WithField("end", end).Infof("Fetched %d events locally", len(streamEvents)) + util.GetLogger(r.ctx).WithFields(logrus.Fields{ + "start": r.from, + "end": r.to, + "backwards": r.backwardOrdering, + }).Infof("Fetched %d events locally", len(streamEvents)) // There can be two reasons for streamEvents to be empty: either we've // reached the oldest event in the room (or the most recent one, depending diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 71fa93c1e..bc3ad2384 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -172,4 +172,37 @@ func Setup( return Search(req, device, syncDB, fts, nextBatch) }), ).Methods(http.MethodPost, http.MethodOptions) + + v3mux.Handle("/rooms/{roomID}/members", + httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + var membership, notMembership *string + if req.URL.Query().Has("membership") { + m := req.URL.Query().Get("membership") + membership = &m + } + if req.URL.Query().Has("not_membership") { + m := req.URL.Query().Get("not_membership") + notMembership = &m + } + + at := req.URL.Query().Get("at") + return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, false, membership, notMembership, at) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/rooms/{roomID}/joined_members", + httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + at := req.URL.Query().Get("at") + membership := gomatrixserverlib.Join + return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, true, &membership, nil, at) + }), + ).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 02d45f801..af4fce44e 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -178,6 +178,11 @@ type Database interface { ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) UpdateRelations(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error RedactRelations(ctx context.Context, roomID, redactedEventID string) error + SelectMemberships( + ctx context.Context, + roomID string, pos types.TopologyToken, + membership, notMembership *string, + ) (eventIDs []string, err error) } type Presence interface { diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 939d6b3f5..b555e8456 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -20,11 +20,12 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) // The memberships table is designed to track the last time that @@ -69,11 +70,20 @@ const selectHeroesSQL = "" + const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" +const selectMembersSQL = ` +SELECT event_id FROM ( + SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC +) t +WHERE ($3::text IS NULL OR t.membership = $3) + AND ($4::text IS NULL OR t.membership <> $4) +` + type membershipsStatements struct { upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt selectHeroesStmt *sql.Stmt selectMembershipForUserStmt *sql.Stmt + selectMembersStmt *sql.Stmt } func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -87,6 +97,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectHeroesStmt, selectHeroesSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, + {&s.selectMembersStmt, selectMembersSQL}, }.Prepare(db) } @@ -154,3 +165,25 @@ func (s *membershipsStatements) SelectMembershipForUser( } return membership, topologyPos, nil } + +func (s *membershipsStatements) SelectMemberships( + ctx context.Context, txn *sql.Tx, + roomID string, pos types.TopologyToken, + membership, notMembership *string, +) (eventIDs []string, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembersStmt) + rows, err := stmt.QueryContext(ctx, roomID, pos.Depth, membership, notMembership) + if err != nil { + return + } + var ( + eventID string + ) + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + return eventIDs, rows.Err() +} diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index bf12203db..23f53d11f 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -617,3 +617,11 @@ func (d *Database) RedactRelations(ctx context.Context, roomID, redactedEventID return d.Relations.DeleteRelation(ctx, txn, roomID, redactedEventID) }) } + +func (d *Database) SelectMemberships( + ctx context.Context, + roomID string, pos types.TopologyToken, + membership, notMembership *string, +) (eventIDs []string, err error) { + return d.Memberships.SelectMemberships(ctx, nil, roomID, pos, membership, notMembership) +} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 0c966fca0..7e54fac17 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -20,11 +20,12 @@ import ( "fmt" "strings" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) // The memberships table is designed to track the last time that @@ -69,12 +70,20 @@ const selectHeroesSQL = "" + const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" +const selectMembersSQL = ` +SELECT event_id FROM + ( SELECT event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))) t + WHERE ($3 IS NULL OR t.membership = $3) + AND ($4 IS NULL OR t.membership <> $4) +` + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic selectMembershipForUserStmt *sql.Stmt + selectMembersStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -89,6 +98,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, + {&s.selectMembersStmt, selectMembersSQL}, // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic }.Prepare(db) } @@ -170,3 +180,23 @@ func (s *membershipsStatements) SelectMembershipForUser( } return membership, topologyPos, nil } + +func (s *membershipsStatements) SelectMemberships( + ctx context.Context, txn *sql.Tx, + roomID string, pos types.TopologyToken, + membership, notMembership *string, +) (eventIDs []string, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembersStmt) + rows, err := stmt.QueryContext(ctx, roomID, pos.Depth, membership, notMembership) + if err != nil { + return + } + var eventID string + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + return eventIDs, rows.Err() +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index e48c050dd..2c4f04ec2 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -187,6 +187,11 @@ type Memberships interface { SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + SelectMemberships( + ctx context.Context, txn *sql.Tx, + roomID string, pos types.TopologyToken, + membership, notMembership *string, + ) (eventIDs []string, err error) } type NotificationData interface { diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 9ec2b61cd..90cf8ce53 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -227,14 +227,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( stateFilter *gomatrixserverlib.StateFilter, req *types.SyncRequest, ) (types.StreamPosition, error) { - if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { - // make sure we don't leak recent events after the leave event. - // TODO: History visibility makes this somewhat complex to handle correctly. For example: - // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join). - // TODO: This will fail on join -> leave -> sensitive msg -> join -> leave - // in a single /sync request - // This is all "okay" assuming history_visibility == "shared" which it is by default. - r.To = delta.MembershipPos + + originalLimit := eventFilter.Limit + if r.Backwards { + eventFilter.Limit = int(r.From - r.To) } recentStreamEvents, limited, err := snapshot.RecentEvents( ctx, delta.RoomID, r, @@ -303,6 +299,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( logrus.WithError(err).Error("unable to apply history visibility filter") } + if r.Backwards && len(events) > originalLimit { + // We're going backwards and the events are ordered chronologically, so take the last `limit` events + events = events[len(events)-originalLimit:] + limited = true + } + if len(delta.StateEvents) > 0 { updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) } @@ -473,7 +475,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( var prevBatch *types.TopologyToken if len(recentStreamEvents) > 0 { var backwardTopologyPos, backwardStreamPos types.StreamPosition - backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, recentStreamEvents[0].EventID()) + event := recentStreamEvents[0] + // If this is the beginning of the room, we can't go back further. We're going to return + // the TopologyToken from the last event instead. (Synapse returns the /sync next_Batch) + if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + event = recentStreamEvents[len(recentStreamEvents)-1] + } + backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, event.EventID()) if err != nil { return } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 57ce7b6ff..295187acc 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -234,6 +234,9 @@ func (t *TopologyToken) StreamToken() StreamingToken { } func (t TopologyToken) String() string { + if t.Depth <= 0 && t.PDUPosition <= 0 { + return "" + } return fmt.Sprintf("t%d_%d", t.Depth, t.PDUPosition) } diff --git a/sytest-whitelist b/sytest-whitelist index 1387838f7..60610929a 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -752,4 +752,9 @@ When user joins a room the state is included in the next sync When user joins a room the state is included in a gapped sync Messages that notify from another user increment notification_count Messages that highlight from another user increment unread highlight count -Notifications can be viewed with GET /notifications \ No newline at end of file +Notifications can be viewed with GET /notifications +Can get rooms/{roomId}/messages for a departed room (SPEC-216) +Local device key changes appear in /keys/changes +Can get rooms/{roomId}/members at a given point +Can filter rooms/{roomId}/members +Current state appears in timeline in private history with many messages after \ No newline at end of file diff --git a/test/testrig/base.go b/test/testrig/base.go index 10cc2407b..15fb5c370 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -36,6 +36,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f Monolithic: true, }) cfg.Global.JetStream.InMemory = true + cfg.FederationAPI.KeyPerspectives = nil switch dbType { case test.DBTypePostgres: cfg.Global.Defaults(config.DefaultOpts{ // autogen a signing key @@ -106,6 +107,7 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat } cfg.Global.JetStream.InMemory = true cfg.SyncAPI.Fulltext.InMemory = true + cfg.FederationAPI.KeyPerspectives = nil base := base.NewBaseDendrite(cfg, "Tests") js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) return base, js, jc diff --git a/userapi/api/api.go b/userapi/api/api.go index eef29144a..8d7f783de 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -318,8 +318,9 @@ type QuerySearchProfilesResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformAccountCreationRequest struct { - AccountType AccountType // Required: whether this is a guest or user account - Localpart string // Required: The localpart for this account. Ignored if account type is guest. + AccountType AccountType // Required: whether this is a guest or user account + Localpart string // Required: The localpart for this account. Ignored if account type is guest. + ServerName gomatrixserverlib.ServerName // optional: if not specified, default server name used instead AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. Password string // optional: if missing then this account will be a passwordless account @@ -360,7 +361,8 @@ type PerformLastSeenUpdateResponse struct { // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string - AccessToken string // optional: if blank one will be made on your behalf + ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used + AccessToken string // optional: if blank one will be made on your behalf // optional: if nil an ID is generated for you. If set, replaces any existing device session, // which will generate a new access token and invalidate the old one. DeviceID *string @@ -384,7 +386,8 @@ type PerformDeviceCreationResponse struct { // PerformAccountDeactivationRequest is the request for PerformAccountDeactivation type PerformAccountDeactivationRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used } // PerformAccountDeactivationResponse is the response for PerformAccountDeactivation @@ -434,6 +437,18 @@ type Device struct { AccountType AccountType } +func (d *Device) UserDomain() gomatrixserverlib.ServerName { + _, domain, err := gomatrixserverlib.SplitID('@', d.UserID) + if err != nil { + // This really is catastrophic because it means that someone + // managed to forge a malformed user ID for a device during + // login. + // TODO: Is there a better way to deal with this than panic? + panic(err) + } + return domain +} + // Account represents a Matrix account on this home server. type Account struct { UserID string @@ -577,7 +592,9 @@ type Notification struct { } type PerformSetAvatarURLRequest struct { - Localpart, AvatarURL string + Localpart string + ServerName gomatrixserverlib.ServerName + AvatarURL string } type PerformSetAvatarURLResponse struct { Profile *authtypes.Profile `json:"profile"` @@ -606,7 +623,9 @@ type QueryAccountByPasswordResponse struct { } type PerformUpdateDisplayNameRequest struct { - Localpart, DisplayName string + Localpart string + ServerName gomatrixserverlib.ServerName + DisplayName string } type PerformUpdateDisplayNameResponse struct { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 63044eedb..9ca76965d 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -46,14 +46,15 @@ import ( type UserInternalAPI struct { DB storage.Database SyncProducer *producers.SyncAPI + Config *config.UserAPI DisableTLSValidation bool - ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService KeyAPI keyapi.UserKeyAPI RSAPI rsapi.UserRoomserverAPI PgClient pushgateway.Client + Cfg *config.UserAPI } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -61,8 +62,8 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot update account data of remote users (server name %s)", domain) } if req.DataType == "" { return fmt.Errorf("data type must not be empty") @@ -103,7 +104,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") return nil } - if domain != a.ServerName { + if !a.Config.Matrix.IsLocalServerName(domain) { return nil } @@ -130,7 +131,51 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } +func postRegisterJoinRooms(cfg *config.UserAPI, acc *api.Account, rsAPI rsapi.UserRoomserverAPI) { + // POST register behaviour: check if the user is a normal user. + // If the user is a normal user, add user to room specified in the configuration "auto_join_rooms". + if acc.AccountType != api.AccountTypeAppService && acc.AppServiceID == "" { + for room := range cfg.AutoJoinRooms { + userID := userutil.MakeUserID(acc.Localpart, cfg.Matrix.ServerName) + err := addUserToRoom(context.Background(), rsAPI, cfg.AutoJoinRooms[room], acc.Localpart, userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + "user_id": userID, + "room": cfg.AutoJoinRooms[room], + }).WithError(err).Errorf("user failed to auto-join room") + } + } + } +} + +// Add user to a room. This function currently working for auto_join_rooms config, +// which can add a newly registered user to a specified room. +func addUserToRoom( + ctx context.Context, + rsAPI rsapi.UserRoomserverAPI, + roomID string, + username string, + userID string, +) error { + addGroupContent := make(map[string]interface{}) + // This make sure the user's username can be displayed correctly. + // Because the newly-registered user doesn't have an avatar, the avatar_url is not needed. + addGroupContent["displayname"] = username + joinReq := rsapi.PerformJoinRequest{ + RoomIDOrAlias: roomID, + UserID: userID, + Content: addGroupContent, + } + joinRes := rsapi.PerformJoinResponse{} + return rsAPI.PerformJoin(ctx, &joinReq, &joinRes) +} + func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + // XXXX: Use the server name here acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists @@ -148,8 +193,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P res.Account = &api.Account{ AppServiceID: req.AppServiceID, Localpart: req.Localpart, - ServerName: a.ServerName, - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + ServerName: serverName, + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), AccountType: req.AccountType, } return nil @@ -174,6 +219,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return err } + postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) + res.AccountCreated = true res.Account = acc return nil @@ -193,6 +240,12 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + _ = serverName + // XXXX: Use the server name here util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, @@ -217,8 +270,8 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot PerformDeviceDeletion of remote users (server name %s)", domain) } deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { @@ -350,8 +403,8 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) } prof, err := a.DB.GetProfileByLocalpart(ctx, local) if err != nil { @@ -401,8 +454,8 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) } devs, err := a.DB.GetDevicesByLocalpart(ctx, local) if err != nil { @@ -418,8 +471,8 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query account data of remote users (server name %s)", domain) } if req.DataType != "" { var data json.RawMessage @@ -467,10 +520,13 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc } return err } - localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localPart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { return err } + if !a.Config.Matrix.IsLocalServerName(domain) { + return nil + } acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) if err != nil { return err @@ -505,7 +561,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe AccountType: api.AccountTypeAppService, } - localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) + localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) if err != nil { return nil, err } @@ -530,8 +586,16 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %q not locally configured", serverName) + } + evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { @@ -542,7 +606,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a } deviceReq := &api.PerformDeviceDeletionRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } deviceRes := &api.PerformDeviceDeletionResponse{} if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index f1bf391e4..87f25e5e2 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -31,8 +31,8 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot create a login token for a remote user (server name %s)", domain) } tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data) if err != nil { @@ -63,8 +63,8 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain) } if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { res.Data = nil diff --git a/userapi/userapi.go b/userapi/userapi.go index d26b4e19a..e46a8e76e 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -76,12 +76,13 @@ func NewInternalAPI( userAPI := &internal.UserInternalAPI{ DB: db, SyncProducer: syncProducer, - ServerName: cfg.Matrix.ServerName, + Config: cfg, AppServices: appServices, KeyAPI: keyAPI, RSAPI: rsAPI, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, PgClient: pgClient, + Cfg: cfg, } receiptConsumer := consumers.NewOutputReceiptEventConsumer( diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index aaa93f45b..2a43c0bd4 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -66,8 +66,8 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap } return &internal.UserInternalAPI{ - DB: accountDB, - ServerName: cfg.Matrix.ServerName, + DB: accountDB, + Config: cfg, }, accountDB, func() { close() baseclose()