From 374b77a3dfd431d967cab7aee964c9cd709fbe9c Mon Sep 17 00:00:00 2001 From: PiotrKozimor <37144818+PiotrKozimor@users.noreply.github.com> Date: Thu, 30 Jun 2022 14:56:45 +0200 Subject: [PATCH] Entry improvements (#11) * Refactor ApplicationServiceWorkerState to be more robust * Add launch.json to VS Code * Implement login with JWT, registering with email, failed login rate limiting and reset password with m.login.email.identity auth type * Log errors when JWT parsing failed * Development build script * Fix linter errors * Use golangci-lint as a linter in VS Code * Fix tests with RtFailedLogin * Pass config load tests - parse JWT public key only if enabled * Reduce CI steps Do not support 386 arch and go 1.16, 1.17 * Fix linter errors * Change RtFailedLogin logic - nil pointer can be provided * Respect access token in query * Fix typos * Use only one mutex in RtFailedLogin * Remove eventsRemaining across appservice component * Push dendrite to production registry as well * Rafactor TestRtFailedLogin --- .github/workflows/dendrite.yml | 89 +---------- .gitignore | 5 + .vscode/launch.json | 16 ++ .vscode/settings.json | 3 + appservice/appservice.go | 4 +- appservice/consumers/roomserver.go | 10 +- appservice/storage/interface.go | 6 +- .../postgres/appservice_events_table.go | 44 +++--- appservice/storage/postgres/storage.go | 15 +- .../sqlite3/appservice_events_table.go | 44 +++--- appservice/storage/sqlite3/storage.go | 15 +- appservice/types/types.go | 20 +-- appservice/workers/transaction_scheduler.go | 35 ++--- clientapi/auth/authtypes/logintypes.go | 2 + clientapi/auth/login.go | 14 +- clientapi/auth/login_jwt.go | 74 +++++++++ clientapi/auth/login_test.go | 9 +- clientapi/auth/login_token.go | 2 +- clientapi/auth/password.go | 60 ++++++- clientapi/auth/user_interactive.go | 12 +- clientapi/auth/user_interactive_test.go | 4 +- clientapi/ratelimit/rt_failed_login.go | 117 ++++++++++++++ clientapi/ratelimit/rt_failed_login_test.go | 40 +++++ clientapi/routing/key_crosssigning.go | 4 +- clientapi/routing/login.go | 4 +- clientapi/routing/password.go | 147 +++++++++++++----- clientapi/routing/register.go | 51 +++++- clientapi/routing/routing.go | 6 +- clientapi/threepid/threepid.go | 5 +- cmd/dendrite-monolith-server/Dockerfile.dev | 8 + cmd/dendrite-monolith-server/build_dev.sh | 12 ++ go.mod | 2 + go.sum | 5 + internal/httputil/httpapi.go | 51 ++++++ setup/config/config.go | 15 +- setup/config/config_clientapi.go | 19 ++- userapi/api/api.go | 2 +- 37 files changed, 711 insertions(+), 260 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json create mode 100644 clientapi/auth/login_jwt.go create mode 100644 clientapi/ratelimit/rt_failed_login.go create mode 100644 clientapi/ratelimit/rt_failed_login_test.go create mode 100644 cmd/dendrite-monolith-server/Dockerfile.dev create mode 100755 cmd/dendrite-monolith-server/build_dev.sh diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 5d60301c7..9a4cec591 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -13,50 +13,6 @@ concurrency: cancel-in-progress: true jobs: - wasm: - name: WASM build test - timeout-minutes: 5 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - - name: Install Go - uses: actions/setup-go@v2 - with: - go-version: 1.16 - - - uses: actions/cache@v2 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-wasm-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go-wasm - - - name: Install Node - uses: actions/setup-node@v2 - with: - node-version: 14 - - - uses: actions/cache@v2 - with: - path: ~/.npm - key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} - restore-keys: | - ${{ runner.os }}-node- - - - name: Reconfigure Git to use HTTPS auth for repo packages - run: > - git config --global url."https://github.com/".insteadOf - ssh://git@github.com/ - - - name: Install test dependencies - working-directory: ./test/wasm - run: npm ci - - - name: Test - run: ./test-dendritejs.sh # Run golangci-lint lint: @@ -68,7 +24,7 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v2 - # run go test with different go versions + # run go test with go 1.18 test: timeout-minutes: 5 name: Unit tests (Go ${{ matrix.go }}) @@ -96,7 +52,7 @@ jobs: strategy: fail-fast: false matrix: - go: ["1.16", "1.17", "1.18"] + go: ["1.18"] steps: - uses: actions/checkout@v3 - name: Setup go @@ -118,7 +74,7 @@ jobs: POSTGRES_PASSWORD: postgres POSTGRES_DB: dendrite - # build Dendrite for linux with different architectures and go versions + # build Dendrite for linux amd64 with go 1.18 build: name: Build for Linux timeout-minutes: 10 @@ -126,9 +82,9 @@ jobs: strategy: fail-fast: false matrix: - go: ["1.16", "1.17", "1.18"] + go: ["1.18"] goos: ["linux"] - goarch: ["amd64", "386"] + goarch: ["amd64"] steps: - uses: actions/checkout@v3 - name: Setup go @@ -152,43 +108,10 @@ jobs: CGO_ENABLED: 1 run: go build -trimpath -v -o "bin/" ./cmd/... - # build for Windows 64-bit - build_windows: - name: Build for Windows - timeout-minutes: 10 - runs-on: ubuntu-latest - strategy: - matrix: - go: ["1.16", "1.17", "1.18"] - goos: ["windows"] - goarch: ["amd64"] - steps: - - uses: actions/checkout@v3 - - name: Setup Go ${{ matrix.go }} - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go }} - - name: Install dependencies - run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }} - - env: - GOOS: ${{ matrix.goos }} - GOARCH: ${{ matrix.goarch }} - CGO_ENABLED: 1 - CC: "/usr/bin/x86_64-w64-mingw32-gcc" - run: go build -trimpath -v -o "bin/" ./cmd/... - # Dummy step to gate other tests on without repeating the whole list initial-tests-done: name: Initial tests passed - needs: [lint, test, build, build_windows] + needs: [lint, test, build] runs-on: ubuntu-latest if: ${{ !cancelled() }} # Run this even if prior jobs were skipped steps: diff --git a/.gitignore b/.gitignore index e4f0112c4..63116857a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ # Hidden files .* +!.vscode # Allow GitHub config !.github @@ -73,3 +74,7 @@ complement/ docs/_site media_store/ + +__debug_bin + +cmd/dendrite-monolith-server/dendrite-monolith-server \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..6142a8df0 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + "configurations": [ + { + "name": "Launch Package", + "type": "go", + "request": "launch", + "mode": "auto", + "program": "${workspaceFolder}/cmd/dendrite-monolith-server", + "args": [ + "-really-enable-open-registration", + "-config", + "../../../adminas/.ci/config/dendrite-local/dendrite.yaml" + ], + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..9012f14ad --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "go.lintTool": "golangci-lint" +} \ No newline at end of file diff --git a/appservice/appservice.go b/appservice/appservice.go index 8fe1b2fc4..00b2f2759 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -70,14 +70,14 @@ func NewInternalAPI( // Wrap application services in a type that relates the application service and // a sync.Cond object that can be used to notify workers when there are new // events to be sent out. - workerStates := make([]types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices)) + workerStates := make([]*types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices)) for i, appservice := range base.Cfg.Derived.ApplicationServices { m := sync.Mutex{} ws := types.ApplicationServiceWorkerState{ AppService: appservice, Cond: sync.NewCond(&m), } - workerStates[i] = ws + workerStates[i] = &ws // Create bot account for this AS if it doesn't already exist if err = generateAppServiceAccount(userAPI, appservice); err != nil { diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 37d4ef9c2..a2c56192e 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -39,7 +39,7 @@ type OutputRoomEventConsumer struct { asDB storage.Database rsAPI api.AppserviceRoomserverAPI serverName string - workerStates []types.ApplicationServiceWorkerState + workerStates []*types.ApplicationServiceWorkerState } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call @@ -50,7 +50,7 @@ func NewOutputRoomEventConsumer( js nats.JetStreamContext, appserviceDB storage.Database, rsAPI api.AppserviceRoomserverAPI, - workerStates []types.ApplicationServiceWorkerState, + workerStates []*types.ApplicationServiceWorkerState, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), @@ -140,13 +140,13 @@ func (s *OutputRoomEventConsumer) filterRoomserverEvents( // Check if this event is interesting to this application service if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) { // Queue this event to be sent off to the application service - if err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil { - log.WithError(err).Warn("failed to insert incoming event into appservices database") + if id, err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil { + log.WithError(err).Warnf("failed to insert incoming event into appservices database. id: %d", id) return err } else { // Tell our worker to send out new messages by updating remaining message // count and waking them up with a broadcast - ws.NotifyNewEvents() + ws.NotifyNewEvents(id) } } } diff --git a/appservice/storage/interface.go b/appservice/storage/interface.go index 25d35af6c..1da6f6ffa 100644 --- a/appservice/storage/interface.go +++ b/appservice/storage/interface.go @@ -21,9 +21,9 @@ import ( ) type Database interface { - StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error - GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) - CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error) + StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) (int, error) + GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, error) + GetLatestId(ctx context.Context, appServiceID string) (int, error) UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error GetLatestTxnID(ctx context.Context) (int, error) diff --git a/appservice/storage/postgres/appservice_events_table.go b/appservice/storage/postgres/appservice_events_table.go index a95be6b8a..a4a2513bb 100644 --- a/appservice/storage/postgres/appservice_events_table.go +++ b/appservice/storage/postgres/appservice_events_table.go @@ -45,12 +45,13 @@ const selectEventsByApplicationServiceIDSQL = "" + "SELECT id, headered_event_json, txn_id " + "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" -const countEventsByApplicationServiceIDSQL = "" + - "SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" +const getLatestIdSQL = "" + + "SELECT id FROM appservice_events WHERE as_id = $1 ORDER BY id DESC LIMIT 1" const insertEventSQL = "" + "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " + - "VALUES ($1, $2, $3)" + "VALUES ($1, $2, $3)" + + "RETURNING id" const updateTxnIDForEventsSQL = "" + "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3" @@ -66,7 +67,7 @@ const ( type eventsStatements struct { selectEventsByApplicationServiceIDStmt *sql.Stmt - countEventsByApplicationServiceIDStmt *sql.Stmt + getLatestIdStmt *sql.Stmt insertEventStmt *sql.Stmt updateTxnIDForEventsStmt *sql.Stmt deleteEventsBeforeAndIncludingIDStmt *sql.Stmt @@ -81,7 +82,7 @@ func (s *eventsStatements) prepare(db *sql.DB) (err error) { if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil { return } - if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil { + if s.getLatestIdStmt, err = db.Prepare(getLatestIdSQL); err != nil { return } if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { @@ -108,7 +109,6 @@ func (s *eventsStatements) selectEventsByApplicationServiceID( ) ( txnID, maxID int, events []gomatrixserverlib.HeaderedEvent, - eventsRemaining bool, err error, ) { defer func() { @@ -124,7 +124,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID( return } defer checkNamedErr(eventRows.Close, &err) - events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) + events, maxID, txnID, err = retrieveEvents(eventRows, limit) if err != nil { return } @@ -139,7 +139,7 @@ func checkNamedErr(fn func() error, err *error) { } } -func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { +func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, err error) { // Get current time for use in calculating event age nowMilli := time.Now().UnixNano() / int64(time.Millisecond) @@ -157,18 +157,18 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib. &txnID, ) if err != nil { - return nil, 0, 0, false, err + return nil, 0, 0, err } // Unmarshal eventJSON if err = json.Unmarshal(eventJSON, &event); err != nil { - return nil, 0, 0, false, err + return nil, 0, 0, err } // If txnID has changed on this event from the previous event, then we've // reached the end of a transaction's events. Return only those events. if lastTxnID > invalidTxnID && lastTxnID != txnID { - return events, maxID, lastTxnID, true, nil + return events, maxID, lastTxnID, nil } lastTxnID = txnID @@ -176,7 +176,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib. if txnID == -1 { // Return if we've hit the limit if eventsProcessed++; eventsProcessed > limit { - return events, maxID, lastTxnID, true, nil + return events, maxID, lastTxnID, nil } } @@ -187,7 +187,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib. // Portion of the event that is unsigned due to rapid change // TODO: Consider removing age as not many app services use it if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil { - return nil, 0, 0, false, err + return nil, 0, 0, err } events = append(events, event) @@ -196,14 +196,12 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib. return } -// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service -// IDs into the db. -func (s *eventsStatements) countEventsByApplicationServiceID( +func (s *eventsStatements) getLatestId( ctx context.Context, appServiceID string, ) (int, error) { var count int - err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count) + err := s.getLatestIdStmt.QueryRowContext(ctx, appServiceID).Scan(&count) if err != nil && err != sql.ErrNoRows { return 0, err } @@ -217,19 +215,19 @@ func (s *eventsStatements) insertEvent( ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent, -) (err error) { +) (id int, err error) { // Convert event to JSON before inserting - eventJSON, err := json.Marshal(event) + var eventJSON []byte + eventJSON, err = json.Marshal(event) if err != nil { - return err + return 0, err } - - _, err = s.insertEventStmt.ExecContext( + err = s.insertEventStmt.QueryRowContext( ctx, appServiceID, eventJSON, -1, // No transaction ID yet - ) + ).Scan(&id) return } diff --git a/appservice/storage/postgres/storage.go b/appservice/storage/postgres/storage.go index a4c04b2cc..12010522a 100644 --- a/appservice/storage/postgres/storage.go +++ b/appservice/storage/postgres/storage.go @@ -62,7 +62,7 @@ func (d *Database) StoreEvent( ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent, -) error { +) (int, error) { return d.events.insertEvent(ctx, appServiceID, event) } @@ -72,17 +72,20 @@ func (d *Database) GetEventsWithAppServiceID( ctx context.Context, appServiceID string, limit int, -) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) { +) (int, int, []gomatrixserverlib.HeaderedEvent, error) { return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) } -// CountEventsWithAppServiceID returns the number of events destined for an -// application service given its ID. -func (d *Database) CountEventsWithAppServiceID( +// GetLatestId returns the latest incremental id associated with appservice. +func (d *Database) GetLatestId( ctx context.Context, appServiceID string, ) (int, error) { - return d.events.countEventsByApplicationServiceID(ctx, appServiceID) + id, err := d.events.getLatestId(ctx, appServiceID) + if err == sql.ErrNoRows { + return 0, nil + } + return id, err } // UpdateTxnIDForEvents takes in an application service ID and a diff --git a/appservice/storage/sqlite3/appservice_events_table.go b/appservice/storage/sqlite3/appservice_events_table.go index 34b4859ea..c2d265df5 100644 --- a/appservice/storage/sqlite3/appservice_events_table.go +++ b/appservice/storage/sqlite3/appservice_events_table.go @@ -46,12 +46,13 @@ const selectEventsByApplicationServiceIDSQL = "" + "SELECT id, headered_event_json, txn_id " + "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" -const countEventsByApplicationServiceIDSQL = "" + - "SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" +const getLatestIdSQL = "" + + "SELECT id FROM appservice_events WHERE as_id = $1 ORDER BY id DESC LIMIT 1" const insertEventSQL = "" + "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " + - "VALUES ($1, $2, $3)" + "VALUES ($1, $2, $3)" + + "RETURNING id" const updateTxnIDForEventsSQL = "" + "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3" @@ -69,7 +70,7 @@ type eventsStatements struct { db *sql.DB writer sqlutil.Writer selectEventsByApplicationServiceIDStmt *sql.Stmt - countEventsByApplicationServiceIDStmt *sql.Stmt + getLatestIdStmt *sql.Stmt insertEventStmt *sql.Stmt updateTxnIDForEventsStmt *sql.Stmt deleteEventsBeforeAndIncludingIDStmt *sql.Stmt @@ -86,7 +87,7 @@ func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil { return } - if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil { + if s.getLatestIdStmt, err = db.Prepare(getLatestIdSQL); err != nil { return } if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { @@ -113,7 +114,6 @@ func (s *eventsStatements) selectEventsByApplicationServiceID( ) ( txnID, maxID int, events []gomatrixserverlib.HeaderedEvent, - eventsRemaining bool, err error, ) { defer func() { @@ -129,7 +129,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID( return } defer checkNamedErr(eventRows.Close, &err) - events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) + events, maxID, txnID, err = retrieveEvents(eventRows, limit) if err != nil { return } @@ -144,7 +144,7 @@ func checkNamedErr(fn func() error, err *error) { } } -func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { +func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, err error) { // Get current time for use in calculating event age nowMilli := time.Now().UnixNano() / int64(time.Millisecond) @@ -162,18 +162,18 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib. &txnID, ) if err != nil { - return nil, 0, 0, false, err + return nil, 0, 0, err } // Unmarshal eventJSON if err = json.Unmarshal(eventJSON, &event); err != nil { - return nil, 0, 0, false, err + return nil, 0, 0, err } // If txnID has changed on this event from the previous event, then we've // reached the end of a transaction's events. Return only those events. if lastTxnID > invalidTxnID && lastTxnID != txnID { - return events, maxID, lastTxnID, true, nil + return events, maxID, lastTxnID, nil } lastTxnID = txnID @@ -181,7 +181,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib. if txnID == -1 { // Return if we've hit the limit if eventsProcessed++; eventsProcessed > limit { - return events, maxID, lastTxnID, true, nil + return events, maxID, lastTxnID, nil } } @@ -192,7 +192,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib. // Portion of the event that is unsigned due to rapid change // TODO: Consider removing age as not many app services use it if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil { - return nil, 0, 0, false, err + return nil, 0, 0, err } events = append(events, event) @@ -201,14 +201,12 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib. return } -// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service -// IDs into the db. -func (s *eventsStatements) countEventsByApplicationServiceID( +func (s *eventsStatements) getLatestId( ctx context.Context, appServiceID string, ) (int, error) { var count int - err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count) + err := s.getLatestIdStmt.QueryRowContext(ctx, appServiceID).Scan(&count) if err != nil && err != sql.ErrNoRows { return 0, err } @@ -222,22 +220,22 @@ func (s *eventsStatements) insertEvent( ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent, -) (err error) { +) (id int, err error) { // Convert event to JSON before inserting eventJSON, err := json.Marshal(event) if err != nil { - return err + return 0, err } - - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err := s.insertEventStmt.ExecContext( + err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + err = s.insertEventStmt.QueryRowContext( ctx, appServiceID, eventJSON, -1, // No transaction ID yet - ) + ).Scan(&id) return err }) + return } // updateTxnIDForEvents sets the transactionID for a collection of events. Done diff --git a/appservice/storage/sqlite3/storage.go b/appservice/storage/sqlite3/storage.go index ad62b3628..ad1b7e192 100644 --- a/appservice/storage/sqlite3/storage.go +++ b/appservice/storage/sqlite3/storage.go @@ -61,7 +61,7 @@ func (d *Database) StoreEvent( ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent, -) error { +) (int, error) { return d.events.insertEvent(ctx, appServiceID, event) } @@ -71,17 +71,20 @@ func (d *Database) GetEventsWithAppServiceID( ctx context.Context, appServiceID string, limit int, -) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) { +) (int, int, []gomatrixserverlib.HeaderedEvent, error) { return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) } -// CountEventsWithAppServiceID returns the number of events destined for an -// application service given its ID. -func (d *Database) CountEventsWithAppServiceID( +// GetLatestId returns the latest incremental id associated with appservice. +func (d *Database) GetLatestId( ctx context.Context, appServiceID string, ) (int, error) { - return d.events.countEventsByApplicationServiceID(ctx, appServiceID) + id, err := d.events.getLatestId(ctx, appServiceID) + if err == sql.ErrNoRows { + return 0, nil + } + return id, err } // UpdateTxnIDForEvents takes in an application service ID and a diff --git a/appservice/types/types.go b/appservice/types/types.go index 098face62..50b18d998 100644 --- a/appservice/types/types.go +++ b/appservice/types/types.go @@ -30,34 +30,26 @@ const ( type ApplicationServiceWorkerState struct { AppService config.ApplicationService Cond *sync.Cond - // Events ready to be sent - EventsReady bool + // Lastest incremental ID from appservice_events table that is ready to be sent to application service + latestId int // Backoff exponent (2^x secs). Max 6, aka 64s. Backoff int } // NotifyNewEvents wakes up all waiting goroutines, notifying that events remain // in the event queue for this application service worker. -func (a *ApplicationServiceWorkerState) NotifyNewEvents() { +func (a *ApplicationServiceWorkerState) NotifyNewEvents(id int) { a.Cond.L.Lock() - a.EventsReady = true + a.latestId = id a.Cond.Broadcast() a.Cond.L.Unlock() } -// FinishEventProcessing marks all events of this worker as being sent to the -// application service. -func (a *ApplicationServiceWorkerState) FinishEventProcessing() { - a.Cond.L.Lock() - a.EventsReady = false - a.Cond.L.Unlock() -} - // WaitForNewEvents causes the calling goroutine to wait on the worker state's // condition for a broadcast or similar wakeup, if there are no events ready. -func (a *ApplicationServiceWorkerState) WaitForNewEvents() { +func (a *ApplicationServiceWorkerState) WaitForNewEvents(id int) { a.Cond.L.Lock() - if !a.EventsReady { + if a.latestId <= id { a.Cond.Wait() } a.Cond.L.Unlock() diff --git a/appservice/workers/transaction_scheduler.go b/appservice/workers/transaction_scheduler.go index 4dab00bd7..132b6f012 100644 --- a/appservice/workers/transaction_scheduler.go +++ b/appservice/workers/transaction_scheduler.go @@ -44,7 +44,7 @@ var ( func SetupTransactionWorkers( client *http.Client, appserviceDB storage.Database, - workerStates []types.ApplicationServiceWorkerState, + workerStates []*types.ApplicationServiceWorkerState, ) error { // Create a worker that handles transmitting events to a single homeserver for _, workerState := range workerStates { @@ -58,31 +58,29 @@ func SetupTransactionWorkers( // worker is a goroutine that sends any queued events to the application service // it is given. -func worker(client *http.Client, db storage.Database, ws types.ApplicationServiceWorkerState) { +func worker(client *http.Client, db storage.Database, ws *types.ApplicationServiceWorkerState) { log.WithFields(log.Fields{ "appservice": ws.AppService.ID, }).Info("Starting application service") ctx := context.Background() // Initial check for any leftover events to send from last time - eventCount, err := db.CountEventsWithAppServiceID(ctx, ws.AppService.ID) + latestId, err := db.GetLatestId(ctx, ws.AppService.ID) if err != nil { log.WithFields(log.Fields{ "appservice": ws.AppService.ID, }).WithError(err).Fatal("appservice worker unable to read queued events from DB") return } - if eventCount > 0 { - ws.NotifyNewEvents() - } - + ws.NotifyNewEvents(latestId) + id := 0 // Loop forever and keep waiting for more events to send for { // Wait for more events if we've sent all the events in the database - ws.WaitForNewEvents() + ws.WaitForNewEvents(id) // Batch events up into a transaction - transactionJSON, txnID, maxEventID, eventsRemaining, err := createTransaction(ctx, db, ws.AppService.ID) + transactionJSON, txnID, maxEventID, err := createTransaction(ctx, db, ws.AppService.ID) if err != nil { log.WithFields(log.Fields{ "appservice": ws.AppService.ID, @@ -90,6 +88,10 @@ func worker(client *http.Client, db storage.Database, ws types.ApplicationServic return } + // Transactions have a maximum event size (or new events may arrive while + // transaction is processed by Application Service), so there may still be + // some events left over to send. We will keep sending if id < ws.latestID. + id = maxEventID // Send the events off to the application service // Backoff if the application service does not respond @@ -99,19 +101,13 @@ func worker(client *http.Client, db storage.Database, ws types.ApplicationServic "appservice": ws.AppService.ID, }).WithError(err).Error("unable to send event") // Backoff - backoff(&ws, err) + backoff(ws, err) continue } // We sent successfully, hooray! ws.Backoff = 0 - // Transactions have a maximum event size, so there may still be some events - // left over to send. Keep sending until none are left - if !eventsRemaining { - ws.FinishEventProcessing() - } - // Remove sent events from the DB err = db.RemoveEventsBeforeAndIncludingID(ctx, ws.AppService.ID, maxEventID) if err != nil { @@ -152,11 +148,10 @@ func createTransaction( ) ( transactionJSON []byte, txnID, maxID int, - eventsRemaining bool, err error, ) { // Retrieve the latest events from the DB (will return old events if they weren't successfully sent) - txnID, maxID, events, eventsRemaining, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize) + txnID, maxID, events, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize) if err != nil { log.WithFields(log.Fields{ "appservice": appserviceID, @@ -170,12 +165,12 @@ func createTransaction( // If not, grab next available ID from the DB txnID, err = db.GetLatestTxnID(ctx) if err != nil { - return nil, 0, 0, false, err + return nil, 0, 0, err } // Mark new events with current transactionID if err = db.UpdateTxnIDForEvents(ctx, appserviceID, maxID, txnID); err != nil { - return nil, 0, 0, false, err + return nil, 0, 0, err } } diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index f01e48f80..00253fede 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -11,4 +11,6 @@ const ( LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" LoginTypeToken = "m.login.token" + LoginTypeJwt = "org.matrix.login.jwt" + LoginTypeEmail = "m.login.email.identity" ) diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go index 5f51c662a..de5473ad3 100644 --- a/clientapi/auth/login.go +++ b/clientapi/auth/login.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" @@ -33,7 +34,7 @@ import ( // called after authorization has completed, with the result of the authorization. // If the final return value is non-nil, an error occurred and the cleanup function // is nil. -func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserLoginAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { +func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.ClientUserAPI, cfg *config.ClientAPI, rt *ratelimit.RtFailedLogin) (*Login, LoginCleanupFunc, *util.JSONResponse) { reqBytes, err := ioutil.ReadAll(r) if err != nil { err := &util.JSONResponse{ @@ -58,14 +59,19 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.U switch header.Type { case authtypes.LoginTypePassword: typ = &LoginTypePassword{ - GetAccountByPassword: useraccountAPI.QueryAccountByPassword, - Config: cfg, + UserApi: useraccountAPI, + Config: cfg, + Rt: rt, } case authtypes.LoginTypeToken: typ = &LoginTypeToken{ - UserAPI: userAPI, + UserAPI: useraccountAPI, Config: cfg, } + case authtypes.LoginTypeJwt: + typ = &LoginTypeTokenJwt{ + Config: cfg, + } default: err := util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/auth/login_jwt.go b/clientapi/auth/login_jwt.go new file mode 100644 index 000000000..35c7d1948 --- /dev/null +++ b/clientapi/auth/login_jwt.go @@ -0,0 +1,74 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + + "github.com/golang-jwt/jwt/v4" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/util" +) + +// LoginTypeToken describes how to authenticate with a login token. +type LoginTypeTokenJwt struct { + // UserAPI uapi.LoginTokenInternalAPI + Config *config.ClientAPI +} + +// Name implements Type. +func (t *LoginTypeTokenJwt) Name() string { + return authtypes.LoginTypeJwt +} + +type Claims struct { + jwt.StandardClaims +} + +const mIdUser = "m.id.user" + +// LoginFromJSON implements Type. The cleanup function deletes the token from +// the database on success. +func (t *LoginTypeTokenJwt) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r loginTokenRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + if r.Token == "" { + return nil, nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Token field for JWT is missing"), + } + } + c := &Claims{} + token, err := jwt.ParseWithClaims(r.Token, c, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Method.Alg()) + } + return t.Config.JwtConfig.SecretKey, nil + }) + + if err != nil { + util.GetLogger(ctx).WithError(err).Error("jwt.ParseWithClaims failed") + return nil, nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Couldn't parse JWT"), + } + } + + if !token.Valid { + return nil, nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Invalid JWT"), + } + } + + r.Login.Identifier.User = c.Subject + r.Login.Identifier.Type = mIdUser + + return &r.Login, func(context.Context, *util.JSONResponse) {}, nil +} diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index 5085f0170..4017c26d5 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" @@ -68,8 +69,11 @@ func TestLoginFromJSONReader(t *testing.T) { Matrix: &config.Global{ ServerName: serverName, }, + RtFailedLogin: ratelimit.RtFailedLoginConfig{ + Enabled: false, + }, } - login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) + login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil) if err != nil { t.Fatalf("LoginFromJSONReader failed: %+v", err) } @@ -147,7 +151,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { ServerName: serverName, }, } - _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) + _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil) if errRes == nil { cleanup(ctx, nil) t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) @@ -159,6 +163,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { } type fakeUserInternalAPI struct { + uapi.ClientUserAPI UserInternalAPIForLogin DeletedTokens []string } diff --git a/clientapi/auth/login_token.go b/clientapi/auth/login_token.go index 845eb5de9..293b9a460 100644 --- a/clientapi/auth/login_token.go +++ b/clientapi/auth/login_token.go @@ -58,7 +58,7 @@ func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*L } } - r.Login.Identifier.Type = "m.id.user" + r.Login.Identifier.Type = mIdUser r.Login.Identifier.User = res.Data.UserID cleanup := func(ctx context.Context, authRes *util.JSONResponse) { diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index bcb4ca97b..54019a8a8 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" @@ -33,12 +34,17 @@ type GetAccountByPassword func(ctx context.Context, req *api.QueryAccountByPassw type PasswordRequest struct { Login Password string `json:"password"` + Address string `json:"address"` + Medium string `json:"medium"` } +const email = "email" + // LoginTypePassword implements https://matrix.org/docs/spec/client_server/r0.6.1#password-based type LoginTypePassword struct { - GetAccountByPassword GetAccountByPassword - Config *config.ClientAPI + UserApi api.ClientUserAPI + Config *config.ClientAPI + Rt *ratelimit.RtFailedLogin } func (t *LoginTypePassword) Name() string { @@ -61,7 +67,35 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { r := req.(*PasswordRequest) - username := strings.ToLower(r.Username()) + if r.Identifier.Address != "" { + r.Address = r.Identifier.Address + } + if r.Identifier.Medium != "" { + r.Medium = r.Identifier.Medium + } + var username string + if r.Medium == email && r.Address != "" { + r.Address = strings.ToLower(r.Address) + res := api.QueryLocalpartForThreePIDResponse{} + err := t.UserApi.QueryLocalpartForThreePID(ctx, &api.QueryLocalpartForThreePIDRequest{ + ThreePID: r.Address, + Medium: email, + }, &res) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("userApi.QueryLocalpartForThreePID failed") + resp := jsonerror.InternalServerError() + return nil, &resp + } + username = res.Localpart + if username == "" { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.Forbidden("Invalid username or password"), + } + } + } else { + username = strings.ToLower(r.Username()) + } if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, @@ -77,7 +111,17 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } // Squash username to all lowercase letters res := &api.QueryAccountByPasswordResponse{} - err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res) + localpart = strings.ToLower(localpart) + if t.Rt != nil { + ok, retryIn := t.Rt.CanAct(localpart) + if !ok { + return nil, &util.JSONResponse{ + Code: http.StatusTooManyRequests, + JSON: jsonerror.LimitExceeded("Too Many Requests", retryIn.Milliseconds()), + } + } + } + err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: localpart, PlaintextPassword: r.Password}, res) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, @@ -86,7 +130,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } if !res.Exists { - err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ + err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ Localpart: localpart, PlaintextPassword: r.Password, }, res) @@ -99,11 +143,15 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // but that would leak the existence of the user. if !res.Exists { + if t.Rt != nil { + t.Rt.Act(localpart) + } return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), + JSON: jsonerror.Forbidden("Invalid username or password"), } } } + r.Login.User = username return &r.Login, nil } diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 82ecf674c..318f64aa0 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -74,7 +74,7 @@ type Login struct { // Username returns the user localpart/user_id in this request, if it exists. func (r *Login) Username() string { - if r.Identifier.Type == "m.id.user" { + if r.Identifier.Type == mIdUser { return r.Identifier.User } // deprecated but without it Element iOS won't log in @@ -87,8 +87,8 @@ func (r *Login) ThirdPartyID() (medium, address string) { return r.Identifier.Medium, r.Identifier.Address } // deprecated - if r.Medium == "email" { - return "email", r.Address + if r.Medium == email { + return email, r.Address } return "", "" } @@ -109,10 +109,10 @@ type UserInteractive struct { Sessions map[string][]string } -func NewUserInteractive(userAccountAPI api.UserLoginAPI, cfg *config.ClientAPI) *UserInteractive { +func NewUserInteractive(userAccountAPI api.ClientUserAPI, cfg *config.ClientAPI) *UserInteractive { typePassword := &LoginTypePassword{ - GetAccountByPassword: userAccountAPI.QueryAccountByPassword, - Config: cfg, + UserApi: userAccountAPI, + Config: cfg, } return &UserInteractive{ Flows: []userInteractiveFlow{ diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 001b1a6d4..8267d2222 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -24,7 +24,9 @@ var ( } ) -type fakeAccountDatabase struct{} +type fakeAccountDatabase struct { + api.ClientUserAPI +} func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { return nil diff --git a/clientapi/ratelimit/rt_failed_login.go b/clientapi/ratelimit/rt_failed_login.go new file mode 100644 index 000000000..291af581d --- /dev/null +++ b/clientapi/ratelimit/rt_failed_login.go @@ -0,0 +1,117 @@ +package ratelimit + +import ( + "container/list" + "sync" + "time" +) + +type rateLimit struct { + cfg *RtFailedLoginConfig + times *list.List +} + +type RtFailedLogin struct { + cfg *RtFailedLoginConfig + mtx sync.RWMutex + rts map[string]*rateLimit +} + +type RtFailedLoginConfig struct { + Enabled bool `yaml:"enabled"` + Limit int `yaml:"burst"` + Interval time.Duration `yaml:"interval"` +} + +// New creates a new rate limiter for the limit and interval. +func NewRtFailedLogin(cfg *RtFailedLoginConfig) *RtFailedLogin { + if !cfg.Enabled { + return nil + } + rt := &RtFailedLogin{ + cfg: cfg, + mtx: sync.RWMutex{}, + rts: make(map[string]*rateLimit), + } + go rt.clean() + return rt +} + +// CanAct is expected to be called before Act +func (r *RtFailedLogin) CanAct(key string) (ok bool, remaining time.Duration) { + r.mtx.RLock() + rt, ok := r.rts[key] + if !ok { + r.mtx.RUnlock() + return true, 0 + } + ok, remaining = rt.canAct() + r.mtx.RUnlock() + return +} + +// Act can be called after CanAct returns true. +func (r *RtFailedLogin) Act(key string) { + r.mtx.Lock() + rt, ok := r.rts[key] + if !ok { + rt = &rateLimit{ + cfg: r.cfg, + times: list.New(), + } + r.rts[key] = rt + } + rt.act() + r.mtx.Unlock() +} + +func (r *RtFailedLogin) clean() { + for { + r.mtx.Lock() + for k, v := range r.rts { + if v.empty() { + delete(r.rts, k) + } + } + r.mtx.Unlock() + time.Sleep(time.Hour) + } +} + +func (r *rateLimit) empty() bool { + back := r.times.Back() + if back == nil { + return true + } + v := back.Value + b := v.(time.Time) + now := time.Now() + return now.Sub(b) > r.cfg.Interval +} + +func (r *rateLimit) canAct() (ok bool, remaining time.Duration) { + now := time.Now() + l := r.times.Len() + if l < r.cfg.Limit { + return true, 0 + } + frnt := r.times.Front() + t := frnt.Value.(time.Time) + diff := now.Sub(t) + if diff < r.cfg.Interval { + return false, r.cfg.Interval - diff + } + return true, 0 +} + +func (r *rateLimit) act() { + now := time.Now() + l := r.times.Len() + if l < r.cfg.Limit { + r.times.PushBack(now) + return + } + frnt := r.times.Front() + frnt.Value = now + r.times.MoveToBack(frnt) +} diff --git a/clientapi/ratelimit/rt_failed_login_test.go b/clientapi/ratelimit/rt_failed_login_test.go new file mode 100644 index 000000000..5281bc765 --- /dev/null +++ b/clientapi/ratelimit/rt_failed_login_test.go @@ -0,0 +1,40 @@ +package ratelimit + +import ( + "testing" + "time" + + "github.com/matryer/is" +) + +func TestRtFailedLogin(t *testing.T) { + is := is.New(t) + rtfl := NewRtFailedLogin(&RtFailedLoginConfig{ + Enabled: true, + Limit: 3, + Interval: 10 * time.Millisecond, + }) + var ( + can bool + remaining time.Duration + remainingB time.Duration + ) + for i := 0; i < 3; i++ { + can, remaining = rtfl.CanAct("foo") + is.True(can) + is.Equal(remaining, time.Duration(0)) + rtfl.Act("foo") + } + can, remaining = rtfl.CanAct("foo") + is.True(!can) + is.True(remaining > time.Millisecond*9) + can, remainingB = rtfl.CanAct("bar") + is.True(can) + is.Equal(remainingB, time.Duration(0)) + rtfl.Act("bar") + rtfl.Act("bar") + time.Sleep(remaining + time.Millisecond) + can, remaining = rtfl.CanAct("foo") + is.True(can) + is.Equal(remaining, time.Duration(0)) +} diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 8fbb86f7a..1def3c4d3 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -63,8 +63,8 @@ func UploadCrossSigningDeviceKeys( } } typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountAPI.QueryAccountByPassword, - Config: cfg, + UserApi: accountAPI, + Config: cfg, } if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { return *authErr diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 6017b5840..cae24ce96 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -55,6 +56,7 @@ func passwordLogin() flows { func Login( req *http.Request, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + rt *ratelimit.RtFailedLogin, ) util.JSONResponse { if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options @@ -63,7 +65,7 @@ func Login( JSON: passwordLogin(), } } else if req.Method == http.MethodPost { - login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, userAPI, cfg) + login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, cfg, rt) if authErr != nil { return *authErr } diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 6dc9af508..44ca153f2 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -1,12 +1,14 @@ package routing import ( + "fmt" "net/http" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -24,6 +26,7 @@ type newPasswordAuth struct { Type string `json:"type"` Session string `json:"session"` auth.PasswordRequest + ThreePidCreds threepid.Credentials `json:"threepid_creds"` } func Password( @@ -33,13 +36,17 @@ func Password( cfg *config.ClientAPI, ) util.JSONResponse { // Check that the existing password is right. + var fields logrus.Fields + if device != nil { + fields = logrus.Fields{ + "sessionId": device.SessionID, + "userId": device.UserID, + } + } var r newPasswordRequest r.LogoutDevices = true - logrus.WithFields(logrus.Fields{ - "sessionId": device.SessionID, - "userId": device.UserID, - }).Debug("Changing password") + logrus.WithFields(fields).Debug("Changing password") // Unmarshal the request. resErr := httputil.UnmarshalJSONRequest(req, &r) @@ -53,45 +60,95 @@ func Password( // Generate a new, random session ID sessionID = util.RandomString(sessionIDLength) } - - // Require password auth to change the password. - if r.Auth.Type != authtypes.LoginTypePassword { - return util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: newUserInteractiveResponse( - sessionID, - []authtypes.Flow{ - { - Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, - }, + var localpart string + switch r.Auth.Type { + case authtypes.LoginTypePassword: + // Check if the existing password is correct. + typePassword := auth.LoginTypePassword{ + UserApi: userAPI, + Config: cfg, + } + if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { + return *authErr + } + // Get the local part. + var err error + localpart, _, err = gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError() + } + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + case authtypes.LoginTypeEmail: + threePid := &authtypes.ThreePID{} + r.Auth.ThreePidCreds.IDServer = cfg.ThreePidDelegate + var ( + bound bool + err error + ) + bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") + return jsonerror.InternalServerError() + } + if !bound { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MatrixError{ + ErrCode: "M_THREEPID_AUTH_FAILED", + Err: "Failed to auth 3pid", }, - nil, - ), + } + } + var res api.QueryLocalpartForThreePIDResponse + err = userAPI.QueryLocalpartForThreePID(req.Context(), &api.QueryLocalpartForThreePIDRequest{ + Medium: threePid.Medium, + ThreePID: threePid.Address, + }, &res) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryLocalpartForThreePID failed") + return jsonerror.InternalServerError() + } + if res.Localpart == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MatrixError{ + ErrCode: "M_THREEPID_NOT_FOUND", + Err: "3pid is not bound to any account", + }, + } + } + localpart = res.Localpart + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail) + default: + flows := []authtypes.Flow{ + { + Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, + }, + } + if cfg.ThreePidDelegate != "" { + flows = append(flows, authtypes.Flow{ + Stages: []authtypes.LoginType{authtypes.LoginTypeEmail}, + }) + } + // Require password auth to change the password. + if r.Auth.Type == authtypes.LoginTypePassword { + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + sessionID, + flows, + nil, + ), + } } } - // Check if the existing password is correct. - typePassword := auth.LoginTypePassword{ - GetAccountByPassword: userAPI.QueryAccountByPassword, - Config: cfg, - } - if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { - return *authErr - } - sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) - // Check the new password strength. if resErr = validatePassword(r.NewPassword); resErr != nil { return *resErr } - // Get the local part. - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - // Ask the user API to perform the password change. passwordReq := &api.PerformPasswordUpdateRequest{ Localpart: localpart, @@ -109,11 +166,23 @@ func Password( // If the request asks us to log out all other devices then // ask the user API to do that. + if r.LogoutDevices { - logoutReq := &api.PerformDeviceDeletionRequest{ - UserID: device.UserID, - DeviceIDs: nil, - ExceptDeviceID: device.ID, + var logoutReq *api.PerformDeviceDeletionRequest + var sessionId int64 + if device == nil { + logoutReq = &api.PerformDeviceDeletionRequest{ + UserID: fmt.Sprintf("@%s:%s", localpart, cfg.Matrix.ServerName), + DeviceIDs: []string{}, + } + sessionId = 0 + } else { + logoutReq = &api.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: nil, + ExceptDeviceID: device.ID, + } + sessionId = device.SessionID } logoutRes := &api.PerformDeviceDeletionResponse{} if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil { @@ -123,7 +192,7 @@ func Password( pushersReq := &api.PerformPusherDeletionRequest{ Localpart: localpart, - SessionID: device.SessionID, + SessionID: sessionId, } if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index c4ac0f2e7..ad1fa1e8a 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -44,6 +44,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/clientapi/userutil" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -237,6 +238,7 @@ type authDict struct { // Recaptcha Response string `json:"response"` // TODO: Lots of custom keys depending on the type + ThreePidCreds threepid.Credentials `json:"threepid_creds"` } // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api @@ -745,6 +747,7 @@ func handleRegistrationFlow( } } + var threePid *authtypes.ThreePID switch r.Auth.Type { case authtypes.LoginTypeRecaptcha: // Check given captcha response @@ -761,6 +764,29 @@ func handleRegistrationFlow( // Add Dummy to the list of completed registration stages sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) + case authtypes.LoginTypeEmail: + threePid = &authtypes.ThreePID{} + r.Auth.ThreePidCreds.IDServer = cfg.ThreePidDelegate + var ( + bound bool + err error + ) + bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") + return jsonerror.InternalServerError() + } + if !bound { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MatrixError{ + ErrCode: "M_THREEPID_AUTH_FAILED", + Err: "Failed to auth 3pid", + }, + } + } + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail) + case "": // An empty auth type means that we want to fetch the available // flows. It can also mean that we want to register as an appservice @@ -776,7 +802,7 @@ func handleRegistrationFlow( // A response with current registration flow and remaining available methods // will be returned if a flow has not been successfully completed yet return checkAndCompleteFlow(sessions.getCompletedStages(sessionID), - req, r, sessionID, cfg, userAPI) + req, r, sessionID, cfg, userAPI, threePid) } // handleApplicationServiceRegistration handles the registration of an @@ -818,7 +844,7 @@ func handleApplicationServiceRegistration( // application service registration is entirely separate. return completeRegistration( req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, nil, ) } @@ -832,12 +858,13 @@ func checkAndCompleteFlow( sessionID string, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI, + threePid *authtypes.ThreePID, ) util.JSONResponse { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, threePid, ) } sessions.addParams(sessionID, r) @@ -863,6 +890,7 @@ func completeRegistration( inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, accType userapi.AccountType, + threePid *authtypes.ThreePID, ) util.JSONResponse { if username == "" { return util.JSONResponse{ @@ -901,6 +929,21 @@ func completeRegistration( // Increment prometheus counter for created users amtRegUsers.Inc() + // TODO-entry refuse register if threepid is already bound to account. + if threePid != nil { + err = userAPI.PerformSaveThreePIDAssociation(ctx, &userapi.PerformSaveThreePIDAssociationRequest{ + Medium: threePid.Medium, + ThreePID: threePid.Address, + Localpart: accRes.Account.Localpart, + }, &struct{}{}) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("Failed to save 3PID association: " + err.Error()), + } + } + } + // Check whether inhibit_login option is set. If so, don't create an access // token or a device for this user if inhibitLogin { @@ -1092,5 +1135,5 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index aa4b5a235..72e58b4d0 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -26,6 +26,7 @@ import ( clientutil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/dendrite/clientapi/ratelimit" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/transactions" @@ -65,6 +66,7 @@ func Setup( prometheus.MustRegister(amtRegUsers, sendEventDuration) rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) + rateLimitsFailedLogin := ratelimit.NewRtFailedLogin(&cfg.RtFailedLogin) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg) unstableFeatures := map[string]bool{ @@ -538,7 +540,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/password", - httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeConditionalAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -562,7 +564,7 @@ func Setup( if r := rateLimits.Limit(req, nil); r != nil { return *r } - return Login(req, userAPI, cfg) + return Login(req, userAPI, cfg, rateLimitsFailedLogin) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) diff --git a/clientapi/threepid/threepid.go b/clientapi/threepid/threepid.go index 1e64e3034..a6a469670 100644 --- a/clientapi/threepid/threepid.go +++ b/clientapi/threepid/threepid.go @@ -103,11 +103,8 @@ func CreateSession( func CheckAssociation( ctx context.Context, creds Credentials, cfg *config.ClientAPI, ) (bool, string, string, error) { - if err := isTrusted(creds.IDServer, cfg); err != nil { - return false, "", "", err - } - requestURL := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret) + requestURL := fmt.Sprintf("%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", cfg.ThreePidDelegate, creds.SID, creds.Secret) req, err := http.NewRequest(http.MethodGet, requestURL, nil) if err != nil { return false, "", "", err diff --git a/cmd/dendrite-monolith-server/Dockerfile.dev b/cmd/dendrite-monolith-server/Dockerfile.dev new file mode 100644 index 000000000..7fbf6c667 --- /dev/null +++ b/cmd/dendrite-monolith-server/Dockerfile.dev @@ -0,0 +1,8 @@ +FROM alpine:latest + +COPY dendrite-monolith-server /usr/bin/ + +VOLUME /etc/dendrite +WORKDIR /etc/dendrite + +ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] diff --git a/cmd/dendrite-monolith-server/build_dev.sh b/cmd/dendrite-monolith-server/build_dev.sh new file mode 100755 index 000000000..5d121890a --- /dev/null +++ b/cmd/dendrite-monolith-server/build_dev.sh @@ -0,0 +1,12 @@ +set -xe +if [ -z "$(git status --porcelain)" ]; then + CGO_ENABLED=0 go build . + TAG=$(git rev-parse --short HEAD) + docker build -f Dockerfile.dev -t gcr.io/globekeeper-development/dendrite-monolith:$TAG -t gcr.io/globekeeper-development/dendrite-monolith -t gcr.io/globekeeper-production/dendrite-monolith:$TAG . + docker push gcr.io/globekeeper-development/dendrite-monolith:$TAG + docker push gcr.io/globekeeper-production/dendrite-monolith:$TAG + docker push gcr.io/globekeeper-development/dendrite-monolith +else + echo "Please commit changes" + exit 0 +fi \ No newline at end of file diff --git a/go.mod b/go.mod index ea6e8caeb..90aeff601 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/frankban/quicktest v1.14.3 // indirect github.com/getsentry/sentry-go v0.13.0 github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v4 v4.4.1 github.com/gologme/log v1.3.0 github.com/google/go-cmp v0.5.8 github.com/google/uuid v1.3.0 @@ -37,6 +38,7 @@ require ( github.com/matrix-org/gomatrixserverlib v0.0.0-20220607143425-e55d796fd0b3 github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 + github.com/matryer/is v1.4.0 github.com/mattn/go-sqlite3 v1.14.13 github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/miekg/dns v1.1.49 // indirect diff --git a/go.sum b/go.sum index e21794f41..1b127f248 100644 --- a/go.sum +++ b/go.sum @@ -202,7 +202,10 @@ github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/E github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v4 v4.4.1 h1:pC5DB52sCeK48Wlb9oPcdhnjkz1TKt1D/P7WKJ0kUcQ= +github.com/golang-jwt/jwt/v4 v4.4.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -425,6 +428,8 @@ github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJz github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= +github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= +github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= github.com/mattn/go-colorable v0.0.6/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index aba50ae4d..162c11bdc 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -83,6 +83,57 @@ func MakeAuthAPI( return MakeExternalAPI(metricsName, h) } +// MakeConditionalAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. +// It passes nil device if header is not provided. +func MakeConditionalAuthAPI( + metricsName string, userAPI userapi.QueryAcccessTokenAPI, + f func(*http.Request, *userapi.Device) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + var ( + jsonRes util.JSONResponse + dev *userapi.Device + ) + if _, err := auth.ExtractAccessToken(req); err != nil { + dev = nil + } else { + logger := util.GetLogger(req.Context()) + var err *util.JSONResponse + dev, err = auth.VerifyUserFromRequest(req, userAPI) + if err != nil { + logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code) + return *err + } + // add the user ID to the logger + logger = logger.WithField("user_id", dev.UserID) + req = req.WithContext(util.ContextWithLogger(req.Context(), logger)) + } + // add the user to Sentry, if enabled + hub := sentry.GetHubFromContext(req.Context()) + if hub != nil { + hub.Scope().SetTag("user_id", dev.UserID) + hub.Scope().SetTag("device_id", dev.ID) + } + defer func() { + if r := recover(); r != nil { + if hub != nil { + hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path)) + } + // re-panic to return the 500 + panic(r) + } + }() + jsonRes = f(req, dev) + // do not log 4xx as errors as they are client fails, not server fails + if hub != nil && jsonRes.Code >= 500 { + hub.Scope().SetExtra("response", jsonRes) + hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code)) + } + return jsonRes + } + return MakeExternalAPI(metricsName, h) +} + // MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler. // This is used for APIs that are called from the internet. func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler { diff --git a/setup/config/config.go b/setup/config/config.go index 9b9000a62..68f4b70f6 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -16,6 +16,7 @@ package config import ( "bytes" + "crypto/x509" "encoding/pem" "fmt" "io" @@ -252,6 +253,15 @@ func loadConfig( c.Global.OldVerifyKeys[i].KeyID, c.Global.OldVerifyKeys[i].PrivateKey = keyID, privateKey } + if c.ClientAPI.JwtConfig.Enabled { + pubPki, _ := pem.Decode([]byte(c.ClientAPI.JwtConfig.Secret)) + var pub interface{} + pub, err = x509.ParsePKIXPublicKey(pubPki.Bytes) + if err != nil { + return nil, err + } + c.ClientAPI.JwtConfig.SecretKey = pub.(ed25519.PublicKey) + } c.MediaAPI.AbsBasePath = Path(absPath(basePath, c.MediaAPI.BasePath)) @@ -283,7 +293,10 @@ func (config *Dendrite) Derive() error { config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) } - + if config.ClientAPI.ThreePidDelegate != "" { + config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, + authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeEmail}}) + } // Load application service configuration files if err := loadAppServices(&config.AppServiceAPI, &config.Derived); err != nil { return err diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index ecf8f6bd5..b0a9c12ed 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -3,6 +3,9 @@ package config import ( "fmt" "time" + + "github.com/matrix-org/dendrite/clientapi/ratelimit" + "golang.org/x/crypto/ed25519" ) type ClientAPI struct { @@ -46,9 +49,23 @@ type ClientAPI struct { TURN TURN `yaml:"turn"` // Rate-limiting options - RateLimiting RateLimiting `yaml:"rate_limiting"` + RateLimiting RateLimiting `yaml:"rate_limiting"` + RtFailedLogin ratelimit.RtFailedLoginConfig `yaml:"rate_limiting_failed_login"` MSCs *MSCs `yaml:"mscs"` + + ThreePidDelegate string `yaml:"three_pid_delegate"` + + JwtConfig JwtConfig `yaml:"jwt_config"` +} + +type JwtConfig struct { + Enabled bool `yaml:"enabled"` + Algorithm string `yaml:"algorithm"` + Issuer string `yaml:"issuer"` + Secret string `yaml:"secret"` + SecretKey ed25519.PublicKey + Audiences []string `yaml:"audiences"` } func (c *ClientAPI) Defaults(generate bool) { diff --git a/userapi/api/api.go b/userapi/api/api.go index df9408acb..c4bc976c8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -517,7 +517,7 @@ type PerformPusherSetRequest struct { type PerformPusherDeletionRequest struct { Localpart string - SessionID int64 + SessionID int64 // Pusher corresponding to this SessionID will not be deleted } // Pusher represents a push notification subscriber