From 09dff951d6be1fee1cc7c6872e98eb27e81fc778 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 13:04:32 +0100 Subject: [PATCH 1/6] More flakey tests --- federationapi/storage/storage_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 14efa2655..5b57d40d4 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -157,11 +157,11 @@ func TestOutboundPeeking(t *testing.T) { if len(outboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) } - for i := range outboundPeeks { - if outboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i]) - } + gotPeekIDs := make([]string, 0, len(outboundPeeks)) + for _, p := range outboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } @@ -239,10 +239,10 @@ func TestInboundPeeking(t *testing.T) { if len(inboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) } - for i := range inboundPeeks { - if inboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i]) - } + gotPeekIDs := make([]string, 0, len(inboundPeeks)) + for _, p := range inboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } From 5eed31fea330f5f0500384c98272b9a75a44fba4 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 13:05:59 +0100 Subject: [PATCH 2/6] Handle guest access [1/2?] (#2872) Needs https://github.com/matrix-org/sytest/pull/1315, as otherwise the membership events aren't persisted yet when hitting `/state` after kicking guest users. Makes the following tests pass: ``` Guest users denied access over federation if guest access prohibited Guest users are kicked from guest_access rooms on revocation of guest_access Guest users are kicked from guest_access rooms on revocation of guest_access over federation ``` Todo (in a follow up PR): - Restrict access to CS API Endpoints as per https://spec.matrix.org/v1.4/client-server-api/#client-behaviour-14 Co-authored-by: kegsay --- clientapi/clientapi.go | 3 +- clientapi/routing/joinroom.go | 10 +- clientapi/routing/joinroom_test.go | 158 ++++++++++++++++++++ roomserver/api/perform.go | 1 + roomserver/internal/api.go | 20 ++- roomserver/internal/input/input.go | 4 + roomserver/internal/input/input_events.go | 105 +++++++++++++ roomserver/internal/perform/perform_join.go | 23 +++ roomserver/roomserver_test.go | 141 +++++++++++++---- setup/config/config_global.go | 2 +- setup/config/config_test.go | 54 +++++++ sytest-blacklist | 5 +- sytest-whitelist | 5 +- test/room.go | 22 ++- userapi/api/api.go | 10 ++ userapi/api/api_trace.go | 6 + userapi/internal/api.go | 5 + userapi/inthttp/client.go | 12 ++ userapi/inthttp/server.go | 5 + userapi/userapi_test.go | 61 ++++++++ 20 files changed, 607 insertions(+), 45 deletions(-) create mode 100644 clientapi/routing/joinroom_test.go diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 62ffa6155..2d17e0928 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,6 +15,8 @@ package clientapi import ( + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/producers" @@ -26,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index c50e552bd..e371d9214 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias( joinReq := roomserverAPI.PerformJoinRequest{ RoomIDOrAlias: roomIDOrAlias, UserID: device.UserID, + IsGuest: device.AccountType == api.AccountTypeGuest, Content: map[string]interface{}{}, } joinRes := roomserverAPI.PerformJoinResponse{} @@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias( if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { done <- jsonerror.InternalAPIError(req.Context(), err) } else if joinRes.Error != nil { - done <- joinRes.Error.JSONResponse() + if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest { + done <- util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg), + } + } else { + done <- joinRes.Error.JSONResponse() + } } else { done <- util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go new file mode 100644 index 000000000..9e8208e6d --- /dev/null +++ b/clientapi/routing/joinroom_test.go @@ -0,0 +1,158 @@ +package routing + +import ( + "bytes" + "context" + "net/http" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestJoinRoomByIDOrAlias(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc + + // Create the users in the userapi + for _, u := range []*test.User{alice, bob, charlie} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + } + + aliceDev := &uapi.Device{UserID: alice.ID} + bobDev := &uapi.Device{UserID: bob.ID} + charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest} + + // create a room with disabled guest access and invite Bob + resp := createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + RoomAliasName: "alias", + Invite: []string{bob.ID}, + GuestCanJoin: false, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crResp, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // create a room with guest access enabled and invite Charlie + resp = createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + Invite: []string{charlie.ID}, + GuestCanJoin: true, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // Dummy request + body := &bytes.Buffer{} + req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + device *uapi.Device + roomID string + wantHTTP200 bool + }{ + { + name: "User can join successfully by alias", + device: bobDev, + roomID: crResp.RoomAlias, + wantHTTP200: true, + }, + { + name: "User can join successfully by roomID", + device: bobDev, + roomID: crResp.RoomID, + wantHTTP200: true, + }, + { + name: "join is forbidden if user is guest", + device: charlieDev, + roomID: crResp.RoomID, + }, + { + name: "room does not exist", + device: aliceDev, + roomID: "!doesnotexist:test", + }, + { + name: "user from different server", + device: &uapi.Device{UserID: "@wrong:server"}, + roomID: crResp.RoomAlias, + }, + { + name: "user doesn't exist locally", + device: &uapi.Device{UserID: "@doesnotexist:test"}, + roomID: crResp.RoomAlias, + }, + { + name: "invalid room ID", + device: aliceDev, + roomID: "invalidRoomID", + }, + { + name: "roomAlias does not exist", + device: aliceDev, + roomID: "#doesnotexist:test", + }, + { + name: "room with guest_access event", + device: charlieDev, + roomID: crRespWithGuestAccess.RoomID, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID) + if tc.wantHTTP200 && !joinResp.Is2xx() { + t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp) + } + }) + } + }) +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index e70e5ea9c..e789b9568 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -78,6 +78,7 @@ const ( type PerformJoinRequest struct { RoomIDOrAlias string `json:"room_id_or_alias"` UserID string `json:"user_id"` + IsGuest bool `json:"is_guest"` Content map[string]interface{} `json:"content"` ServerNames []gomatrixserverlib.ServerName `json:"server_names"` Unsigned map[string]interface{} `json:"unsigned"` diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1a3626609..451b37696 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -4,6 +4,10 @@ import ( "context" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/caching" @@ -19,9 +23,6 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" ) // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI @@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.fsAPI = fsAPI r.KeyRing = keyRing + identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName) + if err != nil { + logrus.Panic(err) + } + r.Inputer = &input.Inputer{ Cfg: &r.Base.Cfg.RoomServer, Base: r.Base, @@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio JetStream: r.JetStream, NATSClient: r.NATSClient, Durable: nats.Durable(r.Durable), - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, + SigningIdentity: identity, FSAPI: fsAPI, KeyRing: keyRing, ACLs: r.ServerACLs, @@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Queryer: r.Queryer, } r.Peeker = &perform.Peeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Inputer: r.Inputer, } r.Unpeeker = &perform.Unpeeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { r.Leaver.UserAPI = userAPI + r.Inputer.UserAPI = userAPI } func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index e965691c9..941311030 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -23,6 +23,8 @@ import ( "sync" "time" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -79,6 +81,7 @@ type Inputer struct { JetStream nats.JetStreamContext Durable nats.SubOpt ServerName gomatrixserverlib.ServerName + SigningIdentity *gomatrixserverlib.SigningIdentity FSAPI fedapi.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs @@ -87,6 +90,7 @@ type Inputer struct { workers sync.Map // room ID -> *worker Queryer *query.Queryer + UserAPI userapi.RoomserverUserAPI } // If a room consumer is inactive for a while then we will allow NATS diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 10b8ee27f..4179fc1ef 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -19,6 +19,7 @@ package input import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "time" @@ -31,6 +32,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + userAPI "github.com/matrix-org/dendrite/userapi/api" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" @@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent( } } + // If guest_access changed and is not can_join, kick all guest users. + if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { + if err = r.kickGuests(ctx, event, roomInfo); err != nil { + logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation") + } + } + // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) @@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState( succeeded = true return nil } + +// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited. +func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error { + membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) + if err != nil { + return err + } + + memberEvents, err := r.DB.Events(ctx, membershipNIDs) + if err != nil { + return err + } + + inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) + latestReq := &api.QueryLatestEventsAndStateRequest{ + RoomID: event.RoomID(), + } + latestRes := &api.QueryLatestEventsAndStateResponse{} + if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { + return err + } + + prevEvents := latestRes.LatestEvents + for _, memberEvent := range memberEvents { + if memberEvent.StateKey() == nil { + continue + } + + localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) + if err != nil { + continue + } + + accountRes := &userAPI.QueryAccountByLocalpartResponse{} + if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: senderDomain, + }, accountRes); err != nil { + return err + } + if accountRes.Account == nil { + continue + } + + if accountRes.Account.AccountType != userAPI.AccountTypeGuest { + continue + } + + var memberContent gomatrixserverlib.MemberContent + if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { + return err + } + memberContent.Membership = gomatrixserverlib.Leave + + stateKey := *memberEvent.StateKey() + fledglingEvent := &gomatrixserverlib.EventBuilder{ + RoomID: event.RoomID(), + Type: gomatrixserverlib.MRoomMember, + StateKey: &stateKey, + Sender: stateKey, + PrevEvents: prevEvents, + } + + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { + return err + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) + if err != nil { + return err + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes) + if err != nil { + return err + } + + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + Origin: senderDomain, + SendAsServer: string(senderDomain), + }) + prevEvents = []gomatrixserverlib.EventReference{ + event.EventReference(), + } + } + + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: inputEvents, + Asynchronous: true, // Needs to be async, as we otherwise create a deadlock + } + inputRes := &api.InputRoomEventsResponse{} + return r.InputRoomEvents(ctx, inputReq, inputRes) +} diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 4de008c66..fc7ba940c 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -16,6 +16,7 @@ package perform import ( "context" + "database/sql" "errors" "fmt" "strings" @@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID( } } + // If a guest is trying to join a room, check that the room has a m.room.guest_access event + if req.IsGuest { + var guestAccessEvent *gomatrixserverlib.HeaderedEvent + guestAccess := "forbidden" + guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "") + if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil { + logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'") + } + if guestAccessEvent != nil { + guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String() + } + + // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event + // is present on the room and has the guest_access value can_join. + if guestAccess != "can_join" { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorNotAllowed, + Msg: "Guest access is forbidden", + } + } + } + // If we should do a forced federated join then do that. var joinedVia gomatrixserverlib.ServerName if forceFederatedJoin { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 518bb3722..595ceb526 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -3,18 +3,23 @@ package roomserver_test import ( "context" "net/http" + "reflect" "testing" "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/userapi" + + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" ) @@ -29,7 +34,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s return base, db, close } -func Test_SharedUsers(t *testing.T) { +func TestUsers(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + + t.Run("shared users", func(t *testing.T) { + testSharedUsers(t, rsAPI) + }) + + t.Run("kick users", func(t *testing.T) { + usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + rsAPI.SetUserAPI(usrAPI) + testKickUsers(t, rsAPI, usrAPI) + }) + }) + +} + +func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) { alice := test.NewUser(t) bob := test.NewUser(t) room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) @@ -43,36 +69,93 @@ func Test_SharedUsers(t *testing.T) { }, test.WithStateKey(bob.ID)) ctx := context.Background() - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, _, close := mustCreateDatabase(t, dbType) - defer close() - rsAPI := roomserver.NewInternalAPI(base) - // SetFederationAPI starts the room event input consumer - rsAPI.SetFederationAPI(nil, nil) - // Create the room - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { - t.Fatalf("failed to send events: %v", err) + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Query the shared users for Alice, there should only be Bob. + // This is used by the SyncAPI keychange consumer. + res := &api.QuerySharedUsersResponse{} + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } + // Also verify that we get the expected result when specifying OtherUserIDs. + // This is used by the SyncAPI when getting device list changes. + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } +} + +func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) { + // Create users and room; Bob is going to be the guest and kicked on revocation of guest access + alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest)) + + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true)) + + // Join with the guest user + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + // Create the users in the userapi, so the RSAPI can query the account type later + for _, u := range []*test.User{alice, bob} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &userAPI.PerformAccountCreationResponse{} + if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + // Create the room in the database + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Get the membership events BEFORE revoking guest access + membershipRes := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil { + t.Errorf("failed to query membership for room: %s", err) + } + + // revoke guest access + revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need + // to loop and wait for the events to be processed by the roomserver. + for i := 0; i <= 20; i++ { + // Get the membership events AFTER revoking guest access + membershipRes2 := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil { + t.Errorf("failed to query membership for room: %s", err) } - // Query the shared users for Alice, there should only be Bob. - // This is used by the SyncAPI keychange consumer. - res := &api.QuerySharedUsersResponse{} - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) + // The membership events should NOT match, as Bob (guest user) should now be kicked from the room + if !reflect.DeepEqual(membershipRes, membershipRes2) { + return } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) - } - // Also verify that we get the expected result when specifying OtherUserIDs. - // This is used by the SyncAPI when getting device list changes. - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) - } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) - } - }) + time.Sleep(time.Millisecond * 10) + } + + t.Errorf("memberships didn't change in time") } func Test_QueryLeftUsers(t *testing.T) { diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 511951fe6..804eb1a2d 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g return id, nil } } - return nil, fmt.Errorf("no signing identity %q", serverName) + return nil, fmt.Errorf("no signing identity for %q", serverName) } func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index ee7e7389c..3408bf46d 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -16,8 +16,10 @@ package config import ( "fmt" + "reflect" "testing" + "github.com/matrix-org/gomatrixserverlib" "gopkg.in/yaml.v2" ) @@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) { } } } + +func Test_SigningIdentityFor(t *testing.T) { + tests := []struct { + name string + virtualHosts []*VirtualHost + serverName gomatrixserverlib.ServerName + want *gomatrixserverlib.SigningIdentity + wantErr bool + }{ + { + name: "no virtual hosts defined", + wantErr: true, + }, + { + name: "no identity found", + serverName: gomatrixserverlib.ServerName("doesnotexist"), + wantErr: true, + }, + { + name: "found identity", + serverName: gomatrixserverlib.ServerName("main"), + want: &gomatrixserverlib.SigningIdentity{ServerName: "main"}, + }, + { + name: "identity found on virtual hosts", + serverName: gomatrixserverlib.ServerName("vh2"), + virtualHosts: []*VirtualHost{ + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}}, + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}}, + }, + want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Global{ + VirtualHosts: tt.virtualHosts, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "main", + }, + } + got, err := c.SigningIdentityFor(tt.serverName) + if (err != nil) != tt.wantErr { + t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sytest-blacklist b/sytest-blacklist index c35b03bd7..99cfbabc8 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -48,4 +48,7 @@ If a device list update goes missing, the server resyncs on the next one Leaves are present in non-gapped incremental syncs # Below test was passing for the wrong reason, failing correctly since #2858 -New federated private chats get full presence information (SYN-115) \ No newline at end of file +New federated private chats get full presence information (SYN-115) + +# We don't have any state to calculate m.room.guest_access when accepting invites +Guest users can accept invites to private rooms over federation \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 49ffb8fe8..215889a49 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -763,4 +763,7 @@ AS and main public room lists are separate local user has tags copied to the new room remote user has tags copied to the new room /upgrade moves remote aliases to the new room -Local and remote users' homeservers remove a room from their public directory on upgrade \ No newline at end of file +Local and remote users' homeservers remove a room from their public directory on upgrade +Guest users denied access over federation if guest access prohibited +Guest users are kicked from guest_access rooms on revocation of guest_access +Guest users are kicked from guest_access rooms on revocation of guest_access over federation \ No newline at end of file diff --git a/test/room.go b/test/room.go index 4328bf84f..685876cb0 100644 --- a/test/room.go +++ b/test/room.go @@ -38,11 +38,12 @@ var ( ) type Room struct { - ID string - Version gomatrixserverlib.RoomVersion - preset Preset - visibility gomatrixserverlib.HistoryVisibility - creator *User + ID string + Version gomatrixserverlib.RoomVersion + preset Preset + guestCanJoin bool + visibility gomatrixserverlib.HistoryVisibility + creator *User authEvents gomatrixserverlib.AuthEvents currentState map[string]*gomatrixserverlib.HeaderedEvent @@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) { r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) + if r.guestCanJoin { + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{ + "guest_access": "can_join", + }, WithStateKey("")) + } } // Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe. @@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier { r.Version = ver } } + +func GuestsCanJoin(canJoin bool) roomModifier { + return func(t *testing.T, r *Room) { + r.guestCanJoin = canJoin + } +} diff --git a/userapi/api/api.go b/userapi/api/api.go index d3f5aefc8..4ea2e91c3 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -50,6 +50,7 @@ type KeyserverUserAPI interface { type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the media api @@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct { ServerName gomatrixserverlib.ServerName Medium string } + +type QueryAccountByLocalpartRequest struct { + Localpart string + ServerName gomatrixserverlib.ServerName +} + +type QueryAccountByLocalpartResponse struct { + Account *Account +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index ce661770f..d10b5767b 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex return err } +func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error { + err := t.Impl.QueryAccountByLocalpart(ctx, req, res) + util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 3f256457e..0bb480da6 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } +func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) { + res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName) + return +} + // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // creating a 'device'. func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 87ae058c2..51b0fe3ef 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -60,6 +60,7 @@ const ( QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" + QueryAccountByLocalpartPath = "/userapi/queryAccountType" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation( h.httpClient, ctx, request, response, ) } + +func (h *httpUserInternalAPI) QueryAccountByLocalpart( + ctx context.Context, + req *api.QueryAccountByLocalpartRequest, + res *api.QueryAccountByLocalpartResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath, + h.httpClient, ctx, req, res, + ) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index f0579079f..b40b507c2 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics PerformSaveThreePIDAssociationPath, httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), ) + + internalAPIMux.Handle( + QueryAccountByLocalpartPath, + httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart), + ) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 8a19af195..dada56de4 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) { }) }) } + +func TestQueryAccountByLocalpart(t *testing.T) { + alice := test.NewUser(t) + + localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() + + createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType) + if err != nil { + t.Error(err) + } + + testCases := func(t *testing.T, internalAPI api.UserInternalAPI) { + // Query existing account + queryAccResp := &api.QueryAccountByLocalpartResponse{} + if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: userServername, + }, queryAccResp); err != nil { + t.Error(err) + } + if !reflect.DeepEqual(createdAcc, queryAccResp.Account) { + t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account) + } + + // Query non-existent account, this should result in an error + err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: "doesnotexist", + ServerName: userServername, + }, queryAccResp) + + if err == nil { + t.Fatalf("expected an error, but got none: %+v", queryAccResp) + } + } + + t.Run("Monolith", func(t *testing.T) { + testCases(t, intAPI) + // also test tracing + testCases(t, &api.UserInternalAPITrace{Impl: intAPI}) + }) + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + userapi.AddInternalRoutes(router, intAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + + userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5}) + if err != nil { + t.Fatalf("failed to create HTTP client: %s", err) + } + testCases(t, userHTTPApi) + + }) + }) +} From f47515e38b0bbf734bf977daedd836bf85465272 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 23 Dec 2022 12:52:47 +0100 Subject: [PATCH 3/6] Pushrule tweaks, make `pattern` non-optional on `EventMatchCondition` (#2918) This should fix https://github.com/matrix-org/dendrite/issues/2882 (Tested with FluffyChat 1.7.1) Also adds tests that the predefined push rules (as per the spec) is what we have in Dendrite. --- internal/pushrules/condition.go | 2 +- internal/pushrules/default_content.go | 9 +- internal/pushrules/default_override.go | 54 +++++---- internal/pushrules/default_pushrules_test.go | 111 +++++++++++++++++++ internal/pushrules/default_underride.go | 39 ++----- internal/pushrules/evaluate.go | 10 +- internal/pushrules/evaluate_test.go | 51 +++++---- internal/pushrules/pushrules.go | 10 +- internal/pushrules/util.go | 4 + internal/pushrules/validate.go | 5 +- internal/pushrules/validate_test.go | 19 ++-- userapi/consumers/roomserver_test.go | 6 - 12 files changed, 210 insertions(+), 110 deletions(-) create mode 100644 internal/pushrules/default_pushrules_test.go diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go index 2d9773c0f..c7b30da8e 100644 --- a/internal/pushrules/condition.go +++ b/internal/pushrules/condition.go @@ -14,7 +14,7 @@ type Condition struct { // Pattern indicates the value pattern that must match. Required // for EventMatchCondition. - Pattern string `json:"pattern,omitempty"` + Pattern *string `json:"pattern,omitempty"` // Is indicates the condition that must be fulfilled. Required for // RoomMemberCountCondition. diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go index 8982dd587..a055ba03c 100644 --- a/internal/pushrules/default_content.go +++ b/internal/pushrules/default_content.go @@ -15,13 +15,7 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule { RuleID: MRuleContainsUserName, Default: true, Enabled: true, - Pattern: localpart, - Conditions: []*Condition{ - { - Kind: EventMatchCondition, - Key: "content.body", - }, - }, + Pattern: &localpart, Actions: []*Action{ {Kind: NotifyAction}, { @@ -32,7 +26,6 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule { { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go index a9788df2f..f97427b71 100644 --- a/internal/pushrules/default_override.go +++ b/internal/pushrules/default_override.go @@ -22,15 +22,15 @@ const ( MRuleTombstone = ".m.rule.tombstone" MRuleRoomNotif = ".m.rule.roomnotif" MRuleReaction = ".m.rule.reaction" + MRuleRoomACLs = ".m.rule.room.server_acl" ) var ( mRuleMasterDefinition = Rule{ - RuleID: MRuleMaster, - Default: true, - Enabled: false, - Conditions: []*Condition{}, - Actions: []*Action{{Kind: DontNotifyAction}}, + RuleID: MRuleMaster, + Default: true, + Enabled: false, + Actions: []*Action{{Kind: DontNotifyAction}}, } mRuleSuppressNoticesDefinition = Rule{ RuleID: MRuleSuppressNotices, @@ -40,7 +40,7 @@ var ( { Kind: EventMatchCondition, Key: "content.msgtype", - Pattern: "m.notice", + Pattern: pointer("m.notice"), }, }, Actions: []*Action{{Kind: DontNotifyAction}}, @@ -53,7 +53,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.member", + Pattern: pointer("m.room.member"), }, }, Actions: []*Action{{Kind: DontNotifyAction}}, @@ -73,7 +73,6 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } @@ -85,12 +84,12 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.tombstone", + Pattern: pointer("m.room.tombstone"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: "", + Pattern: pointer(""), }, }, Actions: []*Action{ @@ -98,10 +97,27 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } + mRuleACLsDefinition = Rule{ + RuleID: MRuleRoomACLs, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: pointer("m.room.server_acl"), + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: pointer(""), + }, + }, + Actions: []*Action{}, + } mRuleRoomNotifDefinition = Rule{ RuleID: MRuleRoomNotif, Default: true, @@ -110,7 +126,7 @@ var ( { Kind: EventMatchCondition, Key: "content.body", - Pattern: "@room", + Pattern: pointer("@room"), }, { Kind: SenderNotificationPermissionCondition, @@ -122,7 +138,6 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } @@ -134,7 +149,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.reaction", + Pattern: pointer("m.reaction"), }, }, Actions: []*Action{ @@ -152,17 +167,17 @@ func mRuleInviteForMeDefinition(userID string) *Rule { { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.member", + Pattern: pointer("m.room.member"), }, { Kind: EventMatchCondition, Key: "content.membership", - Pattern: "invite", + Pattern: pointer("invite"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: userID, + Pattern: pointer(userID), }, }, Actions: []*Action{ @@ -172,11 +187,6 @@ func mRuleInviteForMeDefinition(userID string) *Rule { Tweak: SoundTweak, Value: "default", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } } diff --git a/internal/pushrules/default_pushrules_test.go b/internal/pushrules/default_pushrules_test.go new file mode 100644 index 000000000..dea829842 --- /dev/null +++ b/internal/pushrules/default_pushrules_test.go @@ -0,0 +1,111 @@ +package pushrules + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Tests that the pre-defined rules as of +// https://spec.matrix.org/v1.4/client-server-api/#predefined-rules +// are correct +func TestDefaultRules(t *testing.T) { + type testCase struct { + name string + inputBytes []byte + want Rule + } + + testCases := []testCase{ + // Default override rules + { + name: ".m.rule.master", + inputBytes: []byte(`{"rule_id":".m.rule.master","default":true,"enabled":false,"actions":["dont_notify"]}`), + want: mRuleMasterDefinition, + }, + { + name: ".m.rule.suppress_notices", + inputBytes: []byte(`{"rule_id":".m.rule.suppress_notices","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.msgtype","pattern":"m.notice"}],"actions":["dont_notify"]}`), + want: mRuleSuppressNoticesDefinition, + }, + { + name: ".m.rule.invite_for_me", + inputBytes: []byte(`{"rule_id":".m.rule.invite_for_me","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"},{"kind":"event_match","key":"content.membership","pattern":"invite"},{"kind":"event_match","key":"state_key","pattern":"@test:localhost"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: *mRuleInviteForMeDefinition("@test:localhost"), + }, + { + name: ".m.rule.member_event", + inputBytes: []byte(`{"rule_id":".m.rule.member_event","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"}],"actions":["dont_notify"]}`), + want: mRuleMemberEventDefinition, + }, + { + name: ".m.rule.contains_display_name", + inputBytes: []byte(`{"rule_id":".m.rule.contains_display_name","default":true,"enabled":true,"conditions":[{"kind":"contains_display_name"}],"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}]}`), + want: mRuleContainsDisplayNameDefinition, + }, + { + name: ".m.rule.tombstone", + inputBytes: []byte(`{"rule_id":".m.rule.tombstone","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.tombstone"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":["notify",{"set_tweak":"highlight"}]}`), + want: mRuleTombstoneDefinition, + }, + { + name: ".m.rule.room.server_acl", + inputBytes: []byte(`{"rule_id":".m.rule.room.server_acl","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.server_acl"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":[]}`), + want: mRuleACLsDefinition, + }, + { + name: ".m.rule.roomnotif", + inputBytes: []byte(`{"rule_id":".m.rule.roomnotif","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.body","pattern":"@room"},{"kind":"sender_notification_permission","key":"room"}],"actions":["notify",{"set_tweak":"highlight"}]}`), + want: mRuleRoomNotifDefinition, + }, + // Default content rules + { + name: ".m.rule.contains_user_name", + inputBytes: []byte(`{"rule_id":".m.rule.contains_user_name","default":true,"enabled":true,"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}],"pattern":"myLocalUser"}`), + want: *mRuleContainsUserNameDefinition("myLocalUser"), + }, + // default underride rules + { + name: ".m.rule.call", + inputBytes: []byte(`{"rule_id":".m.rule.call","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.call.invite"}],"actions":["notify",{"set_tweak":"sound","value":"ring"}]}`), + want: mRuleCallDefinition, + }, + { + name: ".m.rule.encrypted_room_one_to_one", + inputBytes: []byte(`{"rule_id":".m.rule.encrypted_room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: mRuleEncryptedRoomOneToOneDefinition, + }, + { + name: ".m.rule.room_one_to_one", + inputBytes: []byte(`{"rule_id":".m.rule.room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: mRuleRoomOneToOneDefinition, + }, + { + name: ".m.rule.message", + inputBytes: []byte(`{"rule_id":".m.rule.message","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify"]}`), + want: mRuleMessageDefinition, + }, + { + name: ".m.rule.encrypted", + inputBytes: []byte(`{"rule_id":".m.rule.encrypted","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify"]}`), + want: mRuleEncryptedDefinition, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := Rule{} + // unmarshal predefined push rules + err := json.Unmarshal(tc.inputBytes, &r) + assert.NoError(t, err) + assert.Equal(t, tc.want, r) + + // and reverse it to check we get the expected result + got, err := json.Marshal(r) + assert.NoError(t, err) + assert.Equal(t, string(got), string(tc.inputBytes)) + }) + + } +} diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go index 8da449a19..118bfae59 100644 --- a/internal/pushrules/default_underride.go +++ b/internal/pushrules/default_underride.go @@ -25,7 +25,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.call.invite", + Pattern: pointer("m.call.invite"), }, }, Actions: []*Action{ @@ -35,11 +35,6 @@ var ( Tweak: SoundTweak, Value: "ring", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleEncryptedRoomOneToOneDefinition = Rule{ @@ -54,7 +49,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.encrypted", + Pattern: pointer("m.room.encrypted"), }, }, Actions: []*Action{ @@ -64,11 +59,6 @@ var ( Tweak: SoundTweak, Value: "default", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleRoomOneToOneDefinition = Rule{ @@ -83,20 +73,15 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.message", + Pattern: pointer("m.room.message"), }, }, Actions: []*Action{ {Kind: NotifyAction}, { Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, + Tweak: SoundTweak, + Value: "default", }, }, } @@ -108,16 +93,11 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.message", + Pattern: pointer("m.room.message"), }, }, Actions: []*Action{ {Kind: NotifyAction}, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleEncryptedDefinition = Rule{ @@ -128,16 +108,11 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.encrypted", + Pattern: pointer("m.room.encrypted"), }, }, Actions: []*Action{ {Kind: NotifyAction}, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } ) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index 4ff9939a6..fc8e0f174 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -104,7 +104,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu case ContentKind: // TODO: "These configure behaviour for (unencrypted) messages // that match certain patterns." - Does that mean "content.body"? - return patternMatches("content.body", rule.Pattern, event) + if rule.Pattern == nil { + return false, nil + } + return patternMatches("content.body", *rule.Pattern, event) case RoomKind: return rule.RuleID == event.RoomID(), nil @@ -120,7 +123,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { switch cond.Kind { case EventMatchCondition: - return patternMatches(cond.Key, cond.Pattern, event) + if cond.Pattern == nil { + return false, fmt.Errorf("missing condition pattern") + } + return patternMatches(cond.Key, *cond.Pattern, event) case ContainsDisplayNameCondition: return patternMatches("content.body", ec.UserDisplayName(), event) diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index c5d5abd2a..ca8ae5519 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -79,8 +79,8 @@ func TestRuleMatches(t *testing.T) { {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, - {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true}, - {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false}, + {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, + {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, @@ -106,41 +106,44 @@ func TestConditionMatches(t *testing.T) { Name string Cond Condition EventJSON string - Want bool + WantMatch bool + WantErr bool }{ - {"empty", Condition{}, `{}`, false}, - {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, + {Name: "empty", Cond: Condition{}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "empty", Cond: Condition{Kind: "unknownstring"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, // Neither of these should match because `content` is not a full string match, // and `content.body` is not a string value. - {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, false}, - {"eventBodyMatch", Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3"}, `{"content":{"body": 3}}`, false}, + {Name: "eventMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content", Pattern: pointer("")}, EventJSON: `{"content":{}}`, WantMatch: false, WantErr: false}, + {Name: "eventBodyMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3", Pattern: pointer("")}, EventJSON: `{"content":{"body": "3"}}`, WantMatch: false, WantErr: false}, + {Name: "eventBodyMatch matches", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Pattern: pointer("world")}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: true, WantErr: false}, + {Name: "EventMatch missing pattern", Cond: Condition{Kind: EventMatchCondition, Key: "content.body"}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: false, WantErr: true}, - {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, - {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, + {Name: "displayNameNoMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"something without displayname"}}`, WantMatch: false, WantErr: false}, + {Name: "displayNameMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"hello Dear User, how are you?"}}`, WantMatch: true, WantErr: false}, - {"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false}, - {"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true}, - {"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false}, - {"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true}, - {"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false}, - {"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true}, - {"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false}, - {"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true}, - {"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false}, - {"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true}, + {Name: "roomMemberCountLessNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<2"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountLessMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<3"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountLessEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountLessEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==1"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountGreaterEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountGreaterEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountGreaterNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">2"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountGreaterMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">1"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, - {"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true}, - {"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false}, + {Name: "senderNotificationPermissionMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@poweruser:example.com"}`, WantMatch: true, WantErr: false}, + {Name: "senderNotificationPermissionNoMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@nobody:example.com"}`, WantMatch: false, WantErr: false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{2}) - if err != nil { + if err != nil && !tst.WantErr { t.Fatalf("conditionMatches failed: %v", err) } - if got != tst.Want { - t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.Want, tst.Name) + if got != tst.WantMatch { + t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.WantMatch, tst.Name) } }) } diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go index bbed1f95f..98deaf132 100644 --- a/internal/pushrules/pushrules.go +++ b/internal/pushrules/pushrules.go @@ -36,18 +36,18 @@ type Rule struct { // around. Required. Enabled bool `json:"enabled"` + // Conditions provide the rule's conditions for OverrideKind and + // UnderrideKind. Not allowed for other kinds. + Conditions []*Condition `json:"conditions,omitempty"` + // Actions describe the desired outcome, should the rule // match. Required. Actions []*Action `json:"actions"` - // Conditions provide the rule's conditions for OverrideKind and - // UnderrideKind. Not allowed for other kinds. - Conditions []*Condition `json:"conditions"` - // Pattern is the body pattern to match for ContentKind. Required // for that kind. The interpretation is the same as that of // Condition.Pattern. - Pattern string `json:"pattern"` + Pattern *string `json:"pattern,omitempty"` } // Scope only has one valid value. See also AccountRuleSets. diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go index fb9c05be2..de8fe5cd0 100644 --- a/internal/pushrules/util.go +++ b/internal/pushrules/util.go @@ -128,3 +128,7 @@ func parseRoomMemberCountCondition(s string) (func(int) bool, error) { b = int(v) return cmp, nil } + +func pointer[t any](s t) *t { + return &s +} diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go index 5d260f0b9..f50c51bd7 100644 --- a/internal/pushrules/validate.go +++ b/internal/pushrules/validate.go @@ -34,7 +34,10 @@ func ValidateRule(kind Kind, rule *Rule) []error { } case ContentKind: - if rule.Pattern == "" { + if rule.Pattern == nil { + errs = append(errs, fmt.Errorf("missing content rule pattern")) + } + if rule.Pattern != nil && *rule.Pattern == "" { errs = append(errs, fmt.Errorf("missing content rule pattern")) } diff --git a/internal/pushrules/validate_test.go b/internal/pushrules/validate_test.go index b276eb551..966e46259 100644 --- a/internal/pushrules/validate_test.go +++ b/internal/pushrules/validate_test.go @@ -12,15 +12,16 @@ func TestValidateRuleNegatives(t *testing.T) { Rule Rule WantErrString string }{ - {"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"}, - {"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"}, - {"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"}, - {"noActions", OverrideKind, Rule{}, "missing actions"}, - {"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"}, - {"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"}, - {"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"}, - {"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"}, - {"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"}, + {Name: "emptyRuleID", Kind: OverrideKind, Rule: Rule{}, WantErrString: "invalid rule ID"}, + {Name: "invalidKind", Kind: Kind("something else"), Rule: Rule{}, WantErrString: "invalid rule kind"}, + {Name: "ruleIDBackslash", Kind: OverrideKind, Rule: Rule{RuleID: "#foo\\:example.com"}, WantErrString: "invalid rule ID"}, + {Name: "noActions", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing actions"}, + {Name: "invalidAction", Kind: OverrideKind, Rule: Rule{Actions: []*Action{{}}}, WantErrString: "invalid rule action kind"}, + {Name: "invalidCondition", Kind: OverrideKind, Rule: Rule{Conditions: []*Condition{{}}}, WantErrString: "invalid rule condition kind"}, + {Name: "overrideNoCondition", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"}, + {Name: "underrideNoCondition", Kind: UnderrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"}, + {Name: "contentNoPattern", Kind: ContentKind, Rule: Rule{}, WantErrString: "missing content rule pattern"}, + {Name: "contentEmptyPattern", Kind: ContentKind, Rule: Rule{Pattern: pointer("")}, WantErrString: "missing content rule pattern"}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 265e3a3aa..39f4aab4a 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -81,11 +81,6 @@ func Test_evaluatePushRules(t *testing.T) { wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ {Kind: pushrules.NotifyAction}, - { - Kind: pushrules.SetTweakAction, - Tweak: pushrules.HighlightTweak, - Value: false, - }, }, }, { @@ -103,7 +98,6 @@ func Test_evaluatePushRules(t *testing.T) { { Kind: pushrules.SetTweakAction, Tweak: pushrules.HighlightTweak, - Value: true, }, }, }, From f762ce1050f2add409a83b1eeb6da5940177cfa7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 23 Dec 2022 14:11:11 +0100 Subject: [PATCH 4/6] Add clientapi tests (#2916) This PR - adds several tests for the clientapi, mostly around `/register` and auth fallback. - removes the now deprecated `homeserver` field from responses to `/register` and `/login` - slightly refactors auth fallback handling --- .github/workflows/dendrite.yml | 3 +- clientapi/routing/admin.go | 6 +- clientapi/routing/auth_fallback.go | 115 ++++----- clientapi/routing/auth_fallback_test.go | 149 ++++++++++++ clientapi/routing/login.go | 9 +- clientapi/routing/password.go | 4 +- clientapi/routing/register.go | 148 ++++-------- clientapi/routing/register_test.go | 306 ++++++++++++++++++++++++ clientapi/routing/routing.go | 4 +- cmd/create-account/main.go | 32 +-- internal/httputil/httpapi.go | 9 +- internal/validate.go | 84 ++++++- internal/validate_test.go | 170 +++++++++++++ setup/config/config.go | 12 +- setup/config/config_clientapi.go | 7 +- 15 files changed, 838 insertions(+), 220 deletions(-) create mode 100644 clientapi/routing/auth_fallback_test.go create mode 100644 internal/validate_test.go diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 2c04005d2..1de39850d 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -331,8 +331,7 @@ jobs: postgres: postgres api: full-http container: - # Temporary for debugging to see if this image is working better. - image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73 + image: matrixdotorg/sytest-dendrite volumes: - ${{ github.workspace }}:/src - /root/.cache/go-build:/github/home/.cache/go-build diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 8419622df..dbd913376 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -137,7 +137,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap request := struct { Password string `json:"password"` }{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + if err = json.NewDecoder(req.Body).Decode(&request); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()), @@ -150,8 +150,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } } - if resErr := internal.ValidatePassword(request.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(request.Password); err != nil { + return *internal.PasswordResponse(err) } updateReq := &userapi.PerformPasswordUpdateRequest{ diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index ad870993e..f8d3684fe 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -15,11 +15,11 @@ package routing import ( + "fmt" "html/template" "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/util" ) @@ -101,14 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s func AuthFallback( w http.ResponseWriter, req *http.Request, authType string, cfg *config.ClientAPI, -) *util.JSONResponse { - sessionID := req.URL.Query().Get("session") +) { + // We currently only support "m.login.recaptcha", so fail early if that's not requested + if authType == authtypes.LoginTypeRecaptcha { + if !cfg.RecaptchaEnabled { + writeHTTPMessage(w, req, + "Recaptcha login is disabled on this Homeserver", + http.StatusBadRequest, + ) + return + } + } else { + writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented) + return + } + sessionID := req.URL.Query().Get("session") if sessionID == "" { - return writeHTTPMessage(w, req, + writeHTTPMessage(w, req, "Session ID not provided", http.StatusBadRequest, ) + return } serveRecaptcha := func() { @@ -130,70 +144,44 @@ func AuthFallback( if req.Method == http.MethodGet { // Handle Recaptcha - if authType == authtypes.LoginTypeRecaptcha { - if err := checkRecaptchaEnabled(cfg, w, req); err != nil { - return err - } - - serveRecaptcha() - return nil - } - return &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown auth stage type"), - } + serveRecaptcha() + return } else if req.Method == http.MethodPost { // Handle Recaptcha - if authType == authtypes.LoginTypeRecaptcha { - if err := checkRecaptchaEnabled(cfg, w, req); err != nil { - return err - } - - clientIP := req.RemoteAddr - err := req.ParseForm() - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed") - res := jsonerror.InternalServerError() - return &res - } - - response := req.Form.Get(cfg.RecaptchaFormField) - if err := validateRecaptcha(cfg, response, clientIP); err != nil { - util.GetLogger(req.Context()).Error(err) - return err - } - - // Success. Add recaptcha as a completed login flow - sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) - - serveSuccess() - return nil + clientIP := req.RemoteAddr + err := req.ParseForm() + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed") + w.WriteHeader(http.StatusBadRequest) + serveRecaptcha() + return } - return &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown auth stage type"), + response := req.Form.Get(cfg.RecaptchaFormField) + err = validateRecaptcha(cfg, response, clientIP) + switch err { + case ErrMissingResponse: + w.WriteHeader(http.StatusBadRequest) + serveRecaptcha() // serve the initial page again, instead of nothing + return + case ErrInvalidCaptcha: + w.WriteHeader(http.StatusUnauthorized) + serveRecaptcha() + return + case nil: + default: // something else failed + util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") + serveRecaptcha() + return } - } - return &util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), - } -} -// checkRecaptchaEnabled creates an error response if recaptcha is not usable on homeserver. -func checkRecaptchaEnabled( - cfg *config.ClientAPI, - w http.ResponseWriter, - req *http.Request, -) *util.JSONResponse { - if !cfg.RecaptchaEnabled { - return writeHTTPMessage(w, req, - "Recaptcha login is disabled on this Homeserver", - http.StatusBadRequest, - ) + // Success. Add recaptcha as a completed login flow + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + + serveSuccess() + return } - return nil + writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed) } // writeHTTPMessage writes the given header and message to the HTTP response writer. @@ -201,13 +189,10 @@ func checkRecaptchaEnabled( func writeHTTPMessage( w http.ResponseWriter, req *http.Request, message string, header int, -) *util.JSONResponse { +) { w.WriteHeader(header) _, err := w.Write([]byte(message)) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("w.Write failed") - res := jsonerror.InternalServerError() - return &res } - return nil } diff --git a/clientapi/routing/auth_fallback_test.go b/clientapi/routing/auth_fallback_test.go new file mode 100644 index 000000000..0d77f9a01 --- /dev/null +++ b/clientapi/routing/auth_fallback_test.go @@ -0,0 +1,149 @@ +package routing + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test/testrig" +) + +func Test_AuthFallback(t *testing.T) { + base, _, _ := testrig.Base(nil) + defer base.Close() + + for _, useHCaptcha := range []bool{false, true} { + for _, recaptchaEnabled := range []bool{false, true} { + for _, wantErr := range []bool{false, true} { + t.Run(fmt.Sprintf("useHCaptcha(%v) - recaptchaEnabled(%v) - wantErr(%v)", useHCaptcha, recaptchaEnabled, wantErr), func(t *testing.T) { + // Set the defaults for each test + base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, Monolithic: true}) + base.Cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled + base.Cfg.ClientAPI.RecaptchaPublicKey = "pub" + base.Cfg.ClientAPI.RecaptchaPrivateKey = "priv" + if useHCaptcha { + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = "https://hcaptcha.com/siteverify" + base.Cfg.ClientAPI.RecaptchaApiJsUrl = "https://js.hcaptcha.com/1/api.js" + base.Cfg.ClientAPI.RecaptchaFormField = "h-captcha-response" + base.Cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha" + } + cfgErrs := &config.ConfigErrors{} + base.Cfg.ClientAPI.Verify(cfgErrs, true) + if len(*cfgErrs) > 0 { + t.Fatalf("(hCaptcha=%v) unexpected config errors: %s", useHCaptcha, cfgErrs.Error()) + } + + req := httptest.NewRequest(http.MethodGet, "/?session=1337", nil) + rec := httptest.NewRecorder() + + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if !recaptchaEnabled { + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest) + } + if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" { + t.Fatalf("unexpected response body: %s", rec.Body.String()) + } + } else { + if !strings.Contains(rec.Body.String(), base.Cfg.ClientAPI.RecaptchaSitekeyClass) { + t.Fatalf("body does not contain %s: %s", base.Cfg.ClientAPI.RecaptchaSitekeyClass, rec.Body.String()) + } + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if wantErr { + _, _ = w.Write([]byte(`{"success":false}`)) + return + } + _, _ = w.Write([]byte(`{"success":true}`)) + })) + defer srv.Close() // nolint: errcheck + + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL + + // check the result after sending the captcha + req = httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + req.Form = url.Values{} + req.Form.Add(base.Cfg.ClientAPI.RecaptchaFormField, "someRandomValue") + rec = httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if recaptchaEnabled { + if !wantErr { + if rec.Code != http.StatusOK { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusOK) + } + if rec.Body.String() != successTemplate { + t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), successTemplate) + } + } else { + if rec.Code != http.StatusUnauthorized { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusUnauthorized) + } + wantString := "Authentication" + if !strings.Contains(rec.Body.String(), wantString) { + t.Fatalf("expected response to contain '%s', but didn't: %s", wantString, rec.Body.String()) + } + } + } else { + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest) + } + if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" { + t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), "successTemplate") + } + } + }) + } + } + } + + t.Run("unknown fallbacks are handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, "DoesNotExist", &base.Cfg.ClientAPI) + if rec.Code != http.StatusNotImplemented { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusNotImplemented) + } + }) + + t.Run("unknown methods are handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodDelete, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusMethodNotAllowed) + } + }) + + t.Run("missing session parameter is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing session parameter is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing 'response' is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 0de324da1..778c8c0c3 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -23,15 +23,13 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) type loginResponse struct { - UserID string `json:"user_id"` - AccessToken string `json:"access_token"` - HomeServer gomatrixserverlib.ServerName `json:"home_server"` - DeviceID string `json:"device_id"` + UserID string `json:"user_id"` + AccessToken string `json:"access_token"` + DeviceID string `json:"device_id"` } type flows struct { @@ -116,7 +114,6 @@ func completeAuth( JSON: loginResponse{ UserID: performRes.Device.UserID, AccessToken: performRes.Device.AccessToken, - HomeServer: serverName, DeviceID: performRes.Device.ID, }, } diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index cd88b025a..f7f9da622 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -82,8 +82,8 @@ func Password( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. - if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil { - return *resErr + if err := internal.ValidatePassword(r.NewPassword); err != nil { + return *internal.PasswordResponse(err) } // Get the local part. diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 4abbcdf9e..6087bda0c 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -18,12 +18,12 @@ package routing import ( "context" "encoding/json" + "errors" "fmt" "io" "net" "net/http" "net/url" - "regexp" "sort" "strconv" "strings" @@ -60,10 +60,7 @@ var ( ) ) -const ( - maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain - sessionIDLength = 24 -) +const sessionIDLength = 24 // sessionsDict keeps track of completed auth stages for each session. // It shouldn't be passed by value because it contains a mutex. @@ -198,8 +195,7 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) { } var ( - sessions = newSessionsDict() - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) + sessions = newSessionsDict() ) // registerRequest represents the submitted registration request. @@ -262,10 +258,9 @@ func newUserInteractiveResponse( // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register type registerResponse struct { - UserID string `json:"user_id"` - AccessToken string `json:"access_token,omitempty"` - HomeServer gomatrixserverlib.ServerName `json:"home_server"` - DeviceID string `json:"device_id,omitempty"` + UserID string `json:"user_id"` + AccessToken string `json:"access_token,omitempty"` + DeviceID string `json:"device_id,omitempty"` } // recaptchaResponse represents the HTTP response from a Google Recaptcha server @@ -276,66 +271,28 @@ type recaptchaResponse struct { ErrorCodes []int `json:"error-codes"` } -// validateUsername returns an error response if the username is invalid -func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { - // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), - } - } else if !validUsernameRegex.MatchString(localpart) { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - } - } else if localpart[0] == '_' { // Regex checks its not a zero length string - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"), - } - } - return nil -} - -// validateApplicationServiceUsername returns an error response if the username is invalid for an application service -func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { - if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), - } - } else if !validUsernameRegex.MatchString(localpart) { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - } - } - return nil -} +var ( + ErrInvalidCaptcha = errors.New("invalid captcha response") + ErrMissingResponse = errors.New("captcha response is required") + ErrCaptchaDisabled = errors.New("captcha registration is disabled") +) // validateRecaptcha returns an error response if the captcha response is invalid func validateRecaptcha( cfg *config.ClientAPI, response string, clientip string, -) *util.JSONResponse { +) error { ip, _, _ := net.SplitHostPort(clientip) if !cfg.RecaptchaEnabled { - return &util.JSONResponse{ - Code: http.StatusConflict, - JSON: jsonerror.Unknown("Captcha registration is disabled"), - } + return ErrCaptchaDisabled } if response == "" { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Captcha response is required"), - } + return ErrMissingResponse } - // Make a POST request to Google's API to check the captcha response + // Make a POST request to the captcha provider API to check the captcha response resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI, url.Values{ "secret": {cfg.RecaptchaPrivateKey}, @@ -345,10 +302,7 @@ func validateRecaptcha( ) if err != nil { - return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"), - } + return err } // Close the request once we're finishing reading from it @@ -358,25 +312,16 @@ func validateRecaptcha( var r recaptchaResponse body, err := io.ReadAll(resp.Body) if err != nil { - return &util.JSONResponse{ - Code: http.StatusGatewayTimeout, - JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()), - } + return err } err = json.Unmarshal(body, &r) if err != nil { - return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()), - } + return err } // Check that we received a "success" if !r.Success { - return &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."), - } + return ErrInvalidCaptcha } return nil } @@ -508,8 +453,8 @@ func validateApplicationService( } // Check username application service is trying to register is valid - if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { - return "", err + if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { + return "", internal.UsernameResponse(err) } // No errors, registration valid @@ -564,15 +509,12 @@ func Register( if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } - if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { - r.Username, r.ServerName = l, d - } if req.URL.Query().Get("kind") == "guest" { return handleGuestRegistration(req, r, cfg, userAPI) } // Don't allow numeric usernames less than MAX_INT64. - if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil { + if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), @@ -584,7 +526,7 @@ func Register( ServerName: r.ServerName, } nres := &userapi.QueryNumericLocalpartResponse{} - if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { + if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") return jsonerror.InternalServerError() } @@ -601,8 +543,8 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil { + return *internal.UsernameResponse(err) } case accessTokenErr == nil: // Non-spec-compliant case (the access_token is specified but the login @@ -614,12 +556,12 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil { + return *internal.UsernameResponse(err) } } - if resErr := internal.ValidatePassword(r.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(r.Password); err != nil { + return *internal.PasswordResponse(err) } logger := util.GetLogger(req.Context()) @@ -697,7 +639,6 @@ func handleGuestRegistration( JSON: registerResponse{ UserID: devRes.Device.UserID, AccessToken: devRes.Device.AccessToken, - HomeServer: res.Account.ServerName, DeviceID: devRes.Device.ID, }, } @@ -761,9 +702,18 @@ func handleRegistrationFlow( switch r.Auth.Type { case authtypes.LoginTypeRecaptcha: // Check given captcha response - resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) - if resErr != nil { - return *resErr + err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) + switch err { + case ErrCaptchaDisabled: + return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())} + case ErrMissingResponse: + return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())} + case ErrInvalidCaptcha: + return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())} + case nil: + default: + util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") + return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()} } // Add Recaptcha to the list of completed registration stages @@ -924,8 +874,7 @@ func completeRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: userutil.MakeUserID(username, accRes.Account.ServerName), - HomeServer: accRes.Account.ServerName, + UserID: userutil.MakeUserID(username, accRes.Account.ServerName), }, } } @@ -958,7 +907,6 @@ func completeRegistration( result := registerResponse{ UserID: devRes.Device.UserID, AccessToken: devRes.Device.AccessToken, - HomeServer: accRes.Account.ServerName, DeviceID: devRes.Device.ID, } sessions.addCompletedRegistration(sessionID, result) @@ -1054,8 +1002,8 @@ func RegisterAvailable( } } - if err := validateUsername(username, domain); err != nil { - return *err + if err := internal.ValidateUsername(username, domain); err != nil { + return *internal.UsernameResponse(err) } // Check if this username is reserved by an application service @@ -1117,11 +1065,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien // downcase capitals ssrr.User = strings.ToLower(ssrr.User) - if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil { + return *internal.UsernameResponse(err) } - if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(ssrr.Password); err != nil { + return *internal.PasswordResponse(err) } deviceID := "shared_secret_registration" diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 85846c7d6..b8fd19e90 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -15,12 +15,27 @@ package routing import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "reflect" "regexp" + "strings" "testing" "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/util" ) var ( @@ -264,3 +279,294 @@ func TestSessionCleanUp(t *testing.T) { } }) } + +func Test_register(t *testing.T) { + testCases := []struct { + name string + kind string + password string + username string + loginType string + forceEmpty bool + registrationDisabled bool + guestsDisabled bool + enableRecaptcha bool + captchaBody string + wantResponse util.JSONResponse + }{ + { + name: "disallow guests", + kind: "guest", + guestsDisabled: true, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(`Guest registration is disabled on "test"`), + }, + }, + { + name: "allow guests", + kind: "guest", + }, + { + name: "unknown login type", + loginType: "im.not.known", + wantResponse: util.JSONResponse{ + Code: http.StatusNotImplemented, + JSON: jsonerror.Unknown("unknown/unimplemented auth type"), + }, + }, + { + name: "disabled registration", + registrationDisabled: true, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(`Registration is disabled on "test"`), + }, + }, + { + name: "successful registration, numeric ID", + username: "", + password: "someRandomPassword", + forceEmpty: true, + }, + { + name: "successful registration", + username: "success", + }, + { + name: "failing registration - user already exists", + username: "success", + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UserInUse("Desired user ID is already taken."), + }, + }, + { + name: "successful registration uppercase username", + username: "LOWERCASED", // this is going to be lower-cased + }, + { + name: "invalid username", + username: "#totalyNotValid", + wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid), + }, + { + name: "numeric username is forbidden", + username: "1337", + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), + }, + }, + { + name: "disabled recaptcha login", + loginType: authtypes.LoginTypeRecaptcha, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Unknown(ErrCaptchaDisabled.Error()), + }, + }, + { + name: "enabled recaptcha, no response defined", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrMissingResponse.Error()), + }, + }, + { + name: "invalid captcha response", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `notvalid`, + wantResponse: util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.BadJSON(ErrInvalidCaptcha.Error()), + }, + }, + { + name: "valid captcha response", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `success`, + }, + { + name: "captcha invalid from remote", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `i should fail for other reasons`, + wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.enableRecaptcha { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatal(err) + } + response := r.Form.Get("response") + + // Respond with valid JSON or no JSON at all to test happy/error cases + switch response { + case "success": + json.NewEncoder(w).Encode(recaptchaResponse{Success: true}) + case "notvalid": + json.NewEncoder(w).Encode(recaptchaResponse{Success: false}) + default: + + } + })) + defer srv.Close() + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL + } + + if err := base.Cfg.Derive(); err != nil { + t.Fatalf("failed to derive config: %s", err) + } + + base.Cfg.ClientAPI.RecaptchaEnabled = tc.enableRecaptcha + base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled + base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled + + if tc.kind == "" { + tc.kind = "user" + } + if tc.password == "" && !tc.forceEmpty { + tc.password = "someRandomPassword" + } + if tc.username == "" && !tc.forceEmpty { + tc.username = "valid" + } + if tc.loginType == "" { + tc.loginType = "m.login.dummy" + } + + reg := registerRequest{ + Password: tc.password, + Username: tc.username, + } + + body := &bytes.Buffer{} + err := json.NewEncoder(body).Encode(reg) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body) + + resp := Register(req, userAPI, &base.Cfg.ClientAPI) + t.Logf("Resp: %+v", resp) + + // The first request should return a userInteractiveResponse + switch r := resp.JSON.(type) { + case userInteractiveResponse: + // Check that the flows are the ones we configured + if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) { + t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows) + } + case *jsonerror.MatrixError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) + } + return + case registerResponse: + // this should only be possible on guest user registration, never for normal users + if tc.kind != "guest" { + t.Fatalf("got register response on first request: %+v", r) + } + // assert we've got a UserID, AccessToken and DeviceID + if r.UserID == "" { + t.Fatalf("missing userID in response") + } + if r.AccessToken == "" { + t.Fatalf("missing accessToken in response") + } + if r.DeviceID == "" { + t.Fatalf("missing deviceID in response") + } + return + default: + t.Logf("Got response: %T", resp.JSON) + } + + // If we reached this, we should have received a UIA response + uia, ok := resp.JSON.(userInteractiveResponse) + if !ok { + t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON) + } + t.Logf("%+v", uia) + + // Register the user + reg.Auth = authDict{ + Type: authtypes.LoginType(tc.loginType), + Session: uia.Session, + } + + if tc.captchaBody != "" { + reg.Auth.Response = tc.captchaBody + } + + dummy := "dummy" + reg.DeviceID = &dummy + reg.InitialDisplayName = &dummy + reg.Type = authtypes.LoginType(tc.loginType) + + err = json.NewEncoder(body).Encode(reg) + if err != nil { + t.Fatal(err) + } + + req = httptest.NewRequest(http.MethodPost, "/", body) + + resp = Register(req, userAPI, &base.Cfg.ClientAPI) + + switch resp.JSON.(type) { + case *jsonerror.MatrixError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + case util.JSONResponse: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + } + + rr, ok := resp.JSON.(registerResponse) + if !ok { + t.Fatalf("expected a registerresponse, got %T", resp.JSON) + } + + // validate the response + if tc.forceEmpty { + // when not supplying a username, one will be generated. Given this _SHOULD_ be + // the second user, set the username accordingly + reg.Username = "2" + } + wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test")) + if wantUserID != rr.UserID { + t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID) + } + if rr.DeviceID != *reg.DeviceID { + t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID) + } + if rr.AccessToken == "" { + t.Fatalf("missing accessToken in response") + } + }) + } + }) +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 69b46214c..09c2cd02f 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -639,9 +639,9 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - return AuthFallback(w, req, vars["authType"], cfg) + AuthFallback(w, req, vars["authType"], cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 15b043ed5..772778680 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -25,10 +25,10 @@ import ( "io" "net/http" "os" - "regexp" "strings" "time" + "github.com/matrix-org/dendrite/internal" "github.com/tidwall/gjson" "github.com/sirupsen/logrus" @@ -58,15 +58,14 @@ Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - isAdmin = flag.Bool("admin", false, "Create an admin account") - resetPassword = flag.Bool("reset-password", false, "Deprecated") - serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) - timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Deprecated") + serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") + timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") ) var cl = http.Client{ @@ -95,20 +94,21 @@ func main() { os.Exit(1) } - if !validUsernameRegex.MatchString(*username) { - logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='") + if err := internal.ValidateUsername(*username, cfg.Global.ServerName); err != nil { + logrus.WithError(err).Error("Specified username is invalid") os.Exit(1) } - if len(fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) > 255 { - logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) - } - pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin) if err != nil { logrus.Fatalln(err) } + if err = internal.ValidatePassword(pass); err != nil { + logrus.WithError(err).Error("Specified password is invalid") + os.Exit(1) + } + cl.Timeout = *timeout accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin) diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 383913c60..37d144f4e 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -198,17 +198,12 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // MakeHTMLAPI adds Span metrics to the HTML Handler function // This is used to serve HTML alongside JSON error messages -func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { +func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request)) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { span := opentracing.StartSpan(metricsName) defer span.Finish() req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) - if err := f(w, req); err != nil { - h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { - return *err - })) - h.ServeHTTP(w, req) - } + f(w, req) } if !enableMetrics { diff --git a/internal/validate.go b/internal/validate.go index fc685ad50..0461b897e 100644 --- a/internal/validate.go +++ b/internal/validate.go @@ -15,30 +15,96 @@ package internal import ( + "errors" "fmt" "net/http" + "regexp" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based +const ( + maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain -const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based + maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 +) -// ValidatePassword returns an error response if the password is invalid -func ValidatePassword(password string) *util.JSONResponse { +var ( + ErrPasswordTooLong = fmt.Errorf("password too long: max %d characters", maxPasswordLength) + ErrPasswordWeak = fmt.Errorf("password too weak: min %d characters", minPasswordLength) + ErrUsernameTooLong = fmt.Errorf("username exceeds the maximum length of %d characters", maxUsernameLength) + ErrUsernameInvalid = errors.New("username can only contain characters a-z, 0-9, or '_-./='") + ErrUsernameUnderscore = errors.New("username cannot start with a '_'") + validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) +) + +// ValidatePassword returns an error if the password is invalid +func ValidatePassword(password string) error { // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 if len(password) > maxPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)), - } + return ErrPasswordTooLong } else if len(password) > 0 && len(password) < minPasswordLength { + return ErrPasswordWeak + } + return nil +} + +// PasswordResponse returns a util.JSONResponse for a given error, if any. +func PasswordResponse(err error) *util.JSONResponse { + switch err { + case ErrPasswordWeak: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), + JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()), + } + case ErrPasswordTooLong: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()), } } return nil } + +// ValidateUsername returns an error if the username is invalid +func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error { + // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { + return ErrUsernameTooLong + } else if !validUsernameRegex.MatchString(localpart) { + return ErrUsernameInvalid + } else if localpart[0] == '_' { // Regex checks its not a zero length string + return ErrUsernameUnderscore + } + return nil +} + +// UsernameResponse returns a util.JSONResponse for the given error, if any. +func UsernameResponse(err error) *util.JSONResponse { + switch err { + case ErrUsernameTooLong: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(err.Error()), + } + case ErrUsernameInvalid, ErrUsernameUnderscore: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(err.Error()), + } + } + return nil +} + +// ValidateApplicationServiceUsername returns an error if the username is invalid for an application service +func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error { + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { + return ErrUsernameTooLong + } else if !validUsernameRegex.MatchString(localpart) { + return ErrUsernameInvalid + } + return nil +} diff --git a/internal/validate_test.go b/internal/validate_test.go new file mode 100644 index 000000000..d0ad04707 --- /dev/null +++ b/internal/validate_test.go @@ -0,0 +1,170 @@ +package internal + +import ( + "net/http" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func Test_validatePassword(t *testing.T) { + tests := []struct { + name string + password string + wantError error + wantJSON *util.JSONResponse + }{ + { + name: "password too short", + password: "shortpw", + wantError: ErrPasswordWeak, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())}, + }, + { + name: "password too long", + password: strings.Repeat("a", maxPasswordLength+1), + wantError: ErrPasswordTooLong, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())}, + }, + { + name: "password OK", + password: util.RandomString(10), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := ValidatePassword(tt.password) + if !reflect.DeepEqual(gotErr, tt.wantError) { + t.Errorf("validatePassword() = %v, wantJSON %v", gotErr, tt.wantError) + } + + if got := PasswordResponse(gotErr); !reflect.DeepEqual(got, tt.wantJSON) { + t.Errorf("validatePassword() = %v, wantJSON %v", got, tt.wantJSON) + } + }) + } +} + +func Test_validateUsername(t *testing.T) { + tooLongUsername := strings.Repeat("a", maxUsernameLength) + tests := []struct { + name string + localpart string + domain gomatrixserverlib.ServerName + wantErr error + wantJSON *util.JSONResponse + }{ + { + name: "empty username", + localpart: "", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "invalid username", + localpart: "INVALIDUSERNAME", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "username too long", + localpart: tooLongUsername, + domain: "localhost", + wantErr: ErrUsernameTooLong, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()), + }, + }, + { + name: "localpart starting with an underscore", + localpart: "_notvalid", + domain: "localhost", + wantErr: ErrUsernameUnderscore, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()), + }, + }, + { + name: "valid username", + localpart: "valid", + domain: "localhost", + }, + { + name: "complex username", + localpart: "f00_bar-baz.=40/", + domain: "localhost", + }, + { + name: "rejects emoji username 💥", + localpart: "💥", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "special characters are allowed", + localpart: "/dev/null", + domain: "localhost", + }, + { + name: "special characters are allowed 2", + localpart: "i_am_allowed=1", + domain: "localhost", + }, + { + name: "not all special characters are allowed", + localpart: "notallowed#", // contains # + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "username containing numbers", + localpart: "hello1337", + domain: "localhost", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := ValidateUsername(tt.localpart, tt.domain) + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr) + } + if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) { + t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON) + } + + // Application services are allowed usernames starting with an underscore + if tt.wantErr == ErrUsernameUnderscore { + return + } + gotErr = ValidateApplicationServiceUsername(tt.localpart, tt.domain) + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr) + } + if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) { + t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON) + } + }) + } +} diff --git a/setup/config/config.go b/setup/config/config.go index 7e7ed1aa1..6523a2452 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -29,7 +29,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "golang.org/x/crypto/ed25519" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" jaegerconfig "github.com/uber/jaeger-client-go/config" jaegermetrics "github.com/uber/jaeger-lib/metrics" @@ -314,11 +314,13 @@ func (config *Dendrite) Derive() error { if config.ClientAPI.RecaptchaEnabled { config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey} - config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, - authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}) + config.Derived.Registration.Flows = []authtypes.Flow{ + {Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}, + } } else { - config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, - authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) + config.Derived.Registration.Flows = []authtypes.Flow{ + {Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}, + } } // Load application service configuration files diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 0a871da18..11628b1b0 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -78,9 +78,6 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { c.TURN.Verify(configErrs) c.RateLimiting.Verify(configErrs) if c.RecaptchaEnabled { - checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) - checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) - checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) if c.RecaptchaSiteVerifyAPI == "" { c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify" } @@ -93,6 +90,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { if c.RecaptchaSitekeyClass == "" { c.RecaptchaSitekeyClass = "g-recaptcha-response" } + checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) + checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) + checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) + checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass) } // Ensure there is any spam counter measure when enabling registration if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled { From e449d174ccf7569b2536289f3c8145298e80bc90 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 23 Dec 2022 14:28:15 +0100 Subject: [PATCH 5/6] Add possibility to run complement with coverage enabled (#2901) This adds the possibility to run Complement with coverage enabled. In combination with https://github.com/matrix-org/complement/pull/566 we should then be able to extract the coverage logs, combine them with https://github.com/wadey/gocovmerge (or similar) and upload them to Codecov (with different flags, depending on SQLite, HTTP etc.) --- Dockerfile | 27 --------------------- build/scripts/Complement.Dockerfile | 7 ++++-- build/scripts/ComplementLocal.Dockerfile | 5 +++- build/scripts/ComplementPostgres.Dockerfile | 7 ++++-- build/scripts/complement-cmd.sh | 22 +++++++++++++++++ 5 files changed, 36 insertions(+), 32 deletions(-) create mode 100755 build/scripts/complement-cmd.sh diff --git a/Dockerfile b/Dockerfile index a9bbce925..ede33e635 100644 --- a/Dockerfile +++ b/Dockerfile @@ -63,30 +63,3 @@ WORKDIR /etc/dendrite ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] EXPOSE 8008 8448 -# -# Builds the Complement image, used for integration tests -# -FROM base AS complement -LABEL org.opencontainers.image.title="Dendrite (Complement)" -RUN apk add --no-cache sqlite openssl ca-certificates - -COPY --from=build /out/generate-config /usr/bin/generate-config -COPY --from=build /out/generate-keys /usr/bin/generate-keys -COPY --from=build /out/dendrite-monolith-server /usr/bin/dendrite-monolith-server - -WORKDIR /dendrite -RUN /usr/bin/generate-keys --private-key matrix_key.pem && \ - mkdir /ca && \ - openssl genrsa -out /ca/ca.key 2048 && \ - openssl req -new -x509 -key /ca/ca.key -days 3650 -subj "/C=GB/ST=London/O=matrix.org/CN=Complement CA" -out /ca/ca.crt - -ENV SERVER_NAME=localhost -ENV API=0 -EXPOSE 8008 8448 - -# At runtime, generate TLS cert based on the CA now mounted at /ca -# At runtime, replace the SERVER_NAME with what we are told -CMD /usr/bin/generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /ca/ca.crt --tls-authority-key /ca/ca.key && \ - /usr/bin/generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ - cp /ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - /usr/bin/dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 79422e645..3a00fbdf0 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -16,13 +16,16 @@ RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ - CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \ + CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \ + cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost ENV API=0 +ENV COVER=0 EXPOSE 8008 8448 # At runtime, generate TLS cert based on the CA now mounted at /ca @@ -30,4 +33,4 @@ EXPOSE 8008 8448 CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} + exec /complement-cmd.sh diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile index 3a019fc20..e3fbe1aa8 100644 --- a/build/scripts/ComplementLocal.Dockerfile +++ b/build/scripts/ComplementLocal.Dockerfile @@ -12,18 +12,20 @@ FROM golang:1.18-stretch RUN apt-get update && apt-get install -y sqlite3 ENV SERVER_NAME=localhost +ENV COVER=0 EXPOSE 8008 8448 WORKDIR /runtime # This script compiles Dendrite for us. RUN echo '\ #!/bin/bash -eux \n\ - if test -f "/runtime/dendrite-monolith-server"; then \n\ + if test -f "/runtime/dendrite-monolith-server" && test -f "/runtime/dendrite-monolith-server-cover"; then \n\ echo "Skipping compilation; binaries exist" \n\ exit 0 \n\ fi \n\ cd /dendrite \n\ go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\ + go test -c -cover -covermode=atomic -o /runtime/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." /dendrite/cmd/dendrite-monolith-server \n\ ' > compile.sh && chmod +x compile.sh # This script runs Dendrite for us. Must be run in the /runtime directory. @@ -33,6 +35,7 @@ RUN echo '\ ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ + [ ${COVER} -eq 1 ] && exec ./dendrite-monolith-server-cover --test.coverprofile=integrationcover.log --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ ' > run.sh && chmod +x run.sh diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index 3faf43cc7..444cb947d 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -34,13 +34,16 @@ RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ - CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \ + CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \ + cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost ENV API=0 +ENV COVER=0 EXPOSE 8008 8448 @@ -51,4 +54,4 @@ CMD /build/run_postgres.sh && ./generate-keys --keysize 1024 --server $SERVER_NA # Bump max_open_conns up here in the global database config sed -i 's/max_open_conns:.*$/max_open_conns: 1990/g' dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} \ No newline at end of file + exec /complement-cmd.sh \ No newline at end of file diff --git a/build/scripts/complement-cmd.sh b/build/scripts/complement-cmd.sh new file mode 100755 index 000000000..061bd18eb --- /dev/null +++ b/build/scripts/complement-cmd.sh @@ -0,0 +1,22 @@ +#!/bin/bash -e + +# This script is intended to be used inside a docker container for Complement + +if [[ "${COVER}" -eq 1 ]]; then + echo "Running with coverage" + exec /dendrite/dendrite-monolith-server-cover \ + --really-enable-open-registration \ + --tls-cert server.crt \ + --tls-key server.key \ + --config dendrite.yaml \ + -api=${API:-0} \ + --test.coverprofile=integrationcover.log +else + echo "Not running with coverage" + exec /dendrite/dendrite-monolith-server \ + --really-enable-open-registration \ + --tls-cert server.crt \ + --tls-key server.key \ + --config dendrite.yaml \ + -api=${API:-0} +fi From 2e1fe589375b650f9b2d9a09e1fcffb3ab6fe5b6 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 5 Jan 2023 09:24:00 +0100 Subject: [PATCH 6/6] Fix backfilling (#2926) This should fix https://github.com/matrix-org/dendrite/issues/2923 --- go.mod | 2 +- go.sum | 4 ++-- roomserver/internal/perform/perform_backfill.go | 9 +++++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index d3eb4890a..2d7174150 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index ad9372c84..b12f65eab 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 069f017a9..d9214fdc6 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -122,11 +122,14 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform ctx, req.VirtualHost, requester, r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, ) - if err != nil { + // Only return an error if we really couldn't get any events. + if err != nil && len(events) == 0 { logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed") return err } - logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) + // If we got an error but still got events, that's fine, because a server might have returned a 404 (or something) + // but other servers could provide the missing event. + logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) // persist these new events - auth checks have already been done roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) @@ -319,6 +322,7 @@ FederationHit: FedClient: b.fsAPI, RememberAuthEvents: false, Server: srv, + Origin: b.virtualHost, } res, err := c.StateIDsBeforeEvent(ctx, targetEvent) if err != nil { @@ -394,6 +398,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr FedClient: b.fsAPI, RememberAuthEvents: false, Server: srv, + Origin: b.virtualHost, } result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs) if err != nil {