Entry improvements (#11)

* Refactor ApplicationServiceWorkerState to be more robust

* Add launch.json to VS Code

* Implement login with JWT, registering with email, failed login rate limiting and reset password with m.login.email.identity auth type

* Log errors when JWT parsing failed

* Development build script

* Fix linter errors

* Use golangci-lint as a linter in VS Code

* Fix tests with RtFailedLogin

* Pass config load tests - parse JWT public key only if enabled

* Reduce CI steps

Do not support 386 arch and go 1.16, 1.17

* Fix linter errors

* Change RtFailedLogin logic - nil pointer can be provided

* Respect access token in query

* Fix typos

* Use only one mutex in RtFailedLogin

* Remove eventsRemaining across appservice component

* Push dendrite to production registry as well

* Rafactor TestRtFailedLogin
This commit is contained in:
PiotrKozimor 2022-06-30 14:56:45 +02:00 committed by GitHub
parent 83797573be
commit 374b77a3df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
37 changed files with 711 additions and 260 deletions

View file

@ -13,50 +13,6 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
wasm:
name: WASM build test
timeout-minutes: 5
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Install Go
uses: actions/setup-go@v2
with:
go-version: 1.16
- uses: actions/cache@v2
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-wasm-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-wasm
- name: Install Node
uses: actions/setup-node@v2
with:
node-version: 14
- uses: actions/cache@v2
with:
path: ~/.npm
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-
- name: Reconfigure Git to use HTTPS auth for repo packages
run: >
git config --global url."https://github.com/".insteadOf
ssh://git@github.com/
- name: Install test dependencies
working-directory: ./test/wasm
run: npm ci
- name: Test
run: ./test-dendritejs.sh
# Run golangci-lint # Run golangci-lint
lint: lint:
@ -68,7 +24,7 @@ jobs:
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
# run go test with different go versions # run go test with go 1.18
test: test:
timeout-minutes: 5 timeout-minutes: 5
name: Unit tests (Go ${{ matrix.go }}) name: Unit tests (Go ${{ matrix.go }})
@ -96,7 +52,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
go: ["1.16", "1.17", "1.18"] go: ["1.18"]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup go - name: Setup go
@ -118,7 +74,7 @@ jobs:
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite POSTGRES_DB: dendrite
# build Dendrite for linux with different architectures and go versions # build Dendrite for linux amd64 with go 1.18
build: build:
name: Build for Linux name: Build for Linux
timeout-minutes: 10 timeout-minutes: 10
@ -126,9 +82,9 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
go: ["1.16", "1.17", "1.18"] go: ["1.18"]
goos: ["linux"] goos: ["linux"]
goarch: ["amd64", "386"] goarch: ["amd64"]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup go - name: Setup go
@ -152,43 +108,10 @@ jobs:
CGO_ENABLED: 1 CGO_ENABLED: 1
run: go build -trimpath -v -o "bin/" ./cmd/... run: go build -trimpath -v -o "bin/" ./cmd/...
# build for Windows 64-bit
build_windows:
name: Build for Windows
timeout-minutes: 10
runs-on: ubuntu-latest
strategy:
matrix:
go: ["1.16", "1.17", "1.18"]
goos: ["windows"]
goarch: ["amd64"]
steps:
- uses: actions/checkout@v3
- name: Setup Go ${{ matrix.go }}
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go }}
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}
- env:
GOOS: ${{ matrix.goos }}
GOARCH: ${{ matrix.goarch }}
CGO_ENABLED: 1
CC: "/usr/bin/x86_64-w64-mingw32-gcc"
run: go build -trimpath -v -o "bin/" ./cmd/...
# Dummy step to gate other tests on without repeating the whole list # Dummy step to gate other tests on without repeating the whole list
initial-tests-done: initial-tests-done:
name: Initial tests passed name: Initial tests passed
needs: [lint, test, build, build_windows] needs: [lint, test, build]
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
steps: steps:

5
.gitignore vendored
View file

@ -2,6 +2,7 @@
# Hidden files # Hidden files
.* .*
!.vscode
# Allow GitHub config # Allow GitHub config
!.github !.github
@ -73,3 +74,7 @@ complement/
docs/_site docs/_site
media_store/ media_store/
__debug_bin
cmd/dendrite-monolith-server/dendrite-monolith-server

16
.vscode/launch.json vendored Normal file
View file

@ -0,0 +1,16 @@
{
"configurations": [
{
"name": "Launch Package",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/cmd/dendrite-monolith-server",
"args": [
"-really-enable-open-registration",
"-config",
"../../../adminas/.ci/config/dendrite-local/dendrite.yaml"
],
}
]
}

3
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"go.lintTool": "golangci-lint"
}

View file

@ -70,14 +70,14 @@ func NewInternalAPI(
// Wrap application services in a type that relates the application service and // Wrap application services in a type that relates the application service and
// a sync.Cond object that can be used to notify workers when there are new // a sync.Cond object that can be used to notify workers when there are new
// events to be sent out. // events to be sent out.
workerStates := make([]types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices)) workerStates := make([]*types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices))
for i, appservice := range base.Cfg.Derived.ApplicationServices { for i, appservice := range base.Cfg.Derived.ApplicationServices {
m := sync.Mutex{} m := sync.Mutex{}
ws := types.ApplicationServiceWorkerState{ ws := types.ApplicationServiceWorkerState{
AppService: appservice, AppService: appservice,
Cond: sync.NewCond(&m), Cond: sync.NewCond(&m),
} }
workerStates[i] = ws workerStates[i] = &ws
// Create bot account for this AS if it doesn't already exist // Create bot account for this AS if it doesn't already exist
if err = generateAppServiceAccount(userAPI, appservice); err != nil { if err = generateAppServiceAccount(userAPI, appservice); err != nil {

View file

@ -39,7 +39,7 @@ type OutputRoomEventConsumer struct {
asDB storage.Database asDB storage.Database
rsAPI api.AppserviceRoomserverAPI rsAPI api.AppserviceRoomserverAPI
serverName string serverName string
workerStates []types.ApplicationServiceWorkerState workerStates []*types.ApplicationServiceWorkerState
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call
@ -50,7 +50,7 @@ func NewOutputRoomEventConsumer(
js nats.JetStreamContext, js nats.JetStreamContext,
appserviceDB storage.Database, appserviceDB storage.Database,
rsAPI api.AppserviceRoomserverAPI, rsAPI api.AppserviceRoomserverAPI,
workerStates []types.ApplicationServiceWorkerState, workerStates []*types.ApplicationServiceWorkerState,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -140,13 +140,13 @@ func (s *OutputRoomEventConsumer) filterRoomserverEvents(
// Check if this event is interesting to this application service // Check if this event is interesting to this application service
if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) { if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) {
// Queue this event to be sent off to the application service // Queue this event to be sent off to the application service
if err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil { if id, err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil {
log.WithError(err).Warn("failed to insert incoming event into appservices database") log.WithError(err).Warnf("failed to insert incoming event into appservices database. id: %d", id)
return err return err
} else { } else {
// Tell our worker to send out new messages by updating remaining message // Tell our worker to send out new messages by updating remaining message
// count and waking them up with a broadcast // count and waking them up with a broadcast
ws.NotifyNewEvents() ws.NotifyNewEvents(id)
} }
} }
} }

View file

@ -21,9 +21,9 @@ import (
) )
type Database interface { type Database interface {
StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) (int, error)
GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, error)
CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error) GetLatestId(ctx context.Context, appServiceID string) (int, error)
UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error
RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error
GetLatestTxnID(ctx context.Context) (int, error) GetLatestTxnID(ctx context.Context) (int, error)

View file

@ -45,12 +45,13 @@ const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, headered_event_json, txn_id " + "SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" + const getLatestIdSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" "SELECT id FROM appservice_events WHERE as_id = $1 ORDER BY id DESC LIMIT 1"
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " + "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)" "VALUES ($1, $2, $3)" +
"RETURNING id"
const updateTxnIDForEventsSQL = "" + const updateTxnIDForEventsSQL = "" +
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3" "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
@ -66,7 +67,7 @@ const (
type eventsStatements struct { type eventsStatements struct {
selectEventsByApplicationServiceIDStmt *sql.Stmt selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt getLatestIdStmt *sql.Stmt
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
updateTxnIDForEventsStmt *sql.Stmt updateTxnIDForEventsStmt *sql.Stmt
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
@ -81,7 +82,7 @@ func (s *eventsStatements) prepare(db *sql.DB) (err error) {
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil { if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
return return
} }
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil { if s.getLatestIdStmt, err = db.Prepare(getLatestIdSQL); err != nil {
return return
} }
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
@ -108,7 +109,6 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
) ( ) (
txnID, maxID int, txnID, maxID int,
events []gomatrixserverlib.HeaderedEvent, events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool,
err error, err error,
) { ) {
defer func() { defer func() {
@ -124,7 +124,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
return return
} }
defer checkNamedErr(eventRows.Close, &err) defer checkNamedErr(eventRows.Close, &err)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) events, maxID, txnID, err = retrieveEvents(eventRows, limit)
if err != nil { if err != nil {
return return
} }
@ -139,7 +139,7 @@ func checkNamedErr(fn func() error, err *error) {
} }
} }
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, err error) {
// Get current time for use in calculating event age // Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
@ -157,18 +157,18 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
&txnID, &txnID,
) )
if err != nil { if err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
// Unmarshal eventJSON // Unmarshal eventJSON
if err = json.Unmarshal(eventJSON, &event); err != nil { if err = json.Unmarshal(eventJSON, &event); err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
// If txnID has changed on this event from the previous event, then we've // If txnID has changed on this event from the previous event, then we've
// reached the end of a transaction's events. Return only those events. // reached the end of a transaction's events. Return only those events.
if lastTxnID > invalidTxnID && lastTxnID != txnID { if lastTxnID > invalidTxnID && lastTxnID != txnID {
return events, maxID, lastTxnID, true, nil return events, maxID, lastTxnID, nil
} }
lastTxnID = txnID lastTxnID = txnID
@ -176,7 +176,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
if txnID == -1 { if txnID == -1 {
// Return if we've hit the limit // Return if we've hit the limit
if eventsProcessed++; eventsProcessed > limit { if eventsProcessed++; eventsProcessed > limit {
return events, maxID, lastTxnID, true, nil return events, maxID, lastTxnID, nil
} }
} }
@ -187,7 +187,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
// Portion of the event that is unsigned due to rapid change // Portion of the event that is unsigned due to rapid change
// TODO: Consider removing age as not many app services use it // TODO: Consider removing age as not many app services use it
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil { if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
events = append(events, event) events = append(events, event)
@ -196,14 +196,12 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
return return
} }
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service func (s *eventsStatements) getLatestId(
// IDs into the db.
func (s *eventsStatements) countEventsByApplicationServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
) (int, error) { ) (int, error) {
var count int var count int
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count) err := s.getLatestIdStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return 0, err return 0, err
} }
@ -217,19 +215,19 @@ func (s *eventsStatements) insertEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
) (err error) { ) (id int, err error) {
// Convert event to JSON before inserting // Convert event to JSON before inserting
eventJSON, err := json.Marshal(event) var eventJSON []byte
eventJSON, err = json.Marshal(event)
if err != nil { if err != nil {
return err return 0, err
} }
err = s.insertEventStmt.QueryRowContext(
_, err = s.insertEventStmt.ExecContext(
ctx, ctx,
appServiceID, appServiceID,
eventJSON, eventJSON,
-1, // No transaction ID yet -1, // No transaction ID yet
) ).Scan(&id)
return return
} }

View file

@ -62,7 +62,7 @@ func (d *Database) StoreEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
) error { ) (int, error) {
return d.events.insertEvent(ctx, appServiceID, event) return d.events.insertEvent(ctx, appServiceID, event)
} }
@ -72,17 +72,20 @@ func (d *Database) GetEventsWithAppServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
limit int, limit int,
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) { ) (int, int, []gomatrixserverlib.HeaderedEvent, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
} }
// CountEventsWithAppServiceID returns the number of events destined for an // GetLatestId returns the latest incremental id associated with appservice.
// application service given its ID. func (d *Database) GetLatestId(
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
) (int, error) { ) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID) id, err := d.events.getLatestId(ctx, appServiceID)
if err == sql.ErrNoRows {
return 0, nil
}
return id, err
} }
// UpdateTxnIDForEvents takes in an application service ID and a // UpdateTxnIDForEvents takes in an application service ID and a

View file

@ -46,12 +46,13 @@ const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, headered_event_json, txn_id " + "SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" + const getLatestIdSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" "SELECT id FROM appservice_events WHERE as_id = $1 ORDER BY id DESC LIMIT 1"
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " + "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)" "VALUES ($1, $2, $3)" +
"RETURNING id"
const updateTxnIDForEventsSQL = "" + const updateTxnIDForEventsSQL = "" +
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3" "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
@ -69,7 +70,7 @@ type eventsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer writer sqlutil.Writer
selectEventsByApplicationServiceIDStmt *sql.Stmt selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt getLatestIdStmt *sql.Stmt
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
updateTxnIDForEventsStmt *sql.Stmt updateTxnIDForEventsStmt *sql.Stmt
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
@ -86,7 +87,7 @@ func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil { if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
return return
} }
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil { if s.getLatestIdStmt, err = db.Prepare(getLatestIdSQL); err != nil {
return return
} }
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
@ -113,7 +114,6 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
) ( ) (
txnID, maxID int, txnID, maxID int,
events []gomatrixserverlib.HeaderedEvent, events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool,
err error, err error,
) { ) {
defer func() { defer func() {
@ -129,7 +129,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
return return
} }
defer checkNamedErr(eventRows.Close, &err) defer checkNamedErr(eventRows.Close, &err)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) events, maxID, txnID, err = retrieveEvents(eventRows, limit)
if err != nil { if err != nil {
return return
} }
@ -144,7 +144,7 @@ func checkNamedErr(fn func() error, err *error) {
} }
} }
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, err error) {
// Get current time for use in calculating event age // Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
@ -162,18 +162,18 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
&txnID, &txnID,
) )
if err != nil { if err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
// Unmarshal eventJSON // Unmarshal eventJSON
if err = json.Unmarshal(eventJSON, &event); err != nil { if err = json.Unmarshal(eventJSON, &event); err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
// If txnID has changed on this event from the previous event, then we've // If txnID has changed on this event from the previous event, then we've
// reached the end of a transaction's events. Return only those events. // reached the end of a transaction's events. Return only those events.
if lastTxnID > invalidTxnID && lastTxnID != txnID { if lastTxnID > invalidTxnID && lastTxnID != txnID {
return events, maxID, lastTxnID, true, nil return events, maxID, lastTxnID, nil
} }
lastTxnID = txnID lastTxnID = txnID
@ -181,7 +181,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
if txnID == -1 { if txnID == -1 {
// Return if we've hit the limit // Return if we've hit the limit
if eventsProcessed++; eventsProcessed > limit { if eventsProcessed++; eventsProcessed > limit {
return events, maxID, lastTxnID, true, nil return events, maxID, lastTxnID, nil
} }
} }
@ -192,7 +192,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
// Portion of the event that is unsigned due to rapid change // Portion of the event that is unsigned due to rapid change
// TODO: Consider removing age as not many app services use it // TODO: Consider removing age as not many app services use it
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil { if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
events = append(events, event) events = append(events, event)
@ -201,14 +201,12 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
return return
} }
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service func (s *eventsStatements) getLatestId(
// IDs into the db.
func (s *eventsStatements) countEventsByApplicationServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
) (int, error) { ) (int, error) {
var count int var count int
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count) err := s.getLatestIdStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return 0, err return 0, err
} }
@ -222,22 +220,22 @@ func (s *eventsStatements) insertEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
) (err error) { ) (id int, err error) {
// Convert event to JSON before inserting // Convert event to JSON before inserting
eventJSON, err := json.Marshal(event) eventJSON, err := json.Marshal(event)
if err != nil { if err != nil {
return err return 0, err
} }
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { err = s.insertEventStmt.QueryRowContext(
_, err := s.insertEventStmt.ExecContext(
ctx, ctx,
appServiceID, appServiceID,
eventJSON, eventJSON,
-1, // No transaction ID yet -1, // No transaction ID yet
) ).Scan(&id)
return err return err
}) })
return
} }
// updateTxnIDForEvents sets the transactionID for a collection of events. Done // updateTxnIDForEvents sets the transactionID for a collection of events. Done

View file

@ -61,7 +61,7 @@ func (d *Database) StoreEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
) error { ) (int, error) {
return d.events.insertEvent(ctx, appServiceID, event) return d.events.insertEvent(ctx, appServiceID, event)
} }
@ -71,17 +71,20 @@ func (d *Database) GetEventsWithAppServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
limit int, limit int,
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) { ) (int, int, []gomatrixserverlib.HeaderedEvent, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
} }
// CountEventsWithAppServiceID returns the number of events destined for an // GetLatestId returns the latest incremental id associated with appservice.
// application service given its ID. func (d *Database) GetLatestId(
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
) (int, error) { ) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID) id, err := d.events.getLatestId(ctx, appServiceID)
if err == sql.ErrNoRows {
return 0, nil
}
return id, err
} }
// UpdateTxnIDForEvents takes in an application service ID and a // UpdateTxnIDForEvents takes in an application service ID and a

View file

@ -30,34 +30,26 @@ const (
type ApplicationServiceWorkerState struct { type ApplicationServiceWorkerState struct {
AppService config.ApplicationService AppService config.ApplicationService
Cond *sync.Cond Cond *sync.Cond
// Events ready to be sent // Lastest incremental ID from appservice_events table that is ready to be sent to application service
EventsReady bool latestId int
// Backoff exponent (2^x secs). Max 6, aka 64s. // Backoff exponent (2^x secs). Max 6, aka 64s.
Backoff int Backoff int
} }
// NotifyNewEvents wakes up all waiting goroutines, notifying that events remain // NotifyNewEvents wakes up all waiting goroutines, notifying that events remain
// in the event queue for this application service worker. // in the event queue for this application service worker.
func (a *ApplicationServiceWorkerState) NotifyNewEvents() { func (a *ApplicationServiceWorkerState) NotifyNewEvents(id int) {
a.Cond.L.Lock() a.Cond.L.Lock()
a.EventsReady = true a.latestId = id
a.Cond.Broadcast() a.Cond.Broadcast()
a.Cond.L.Unlock() a.Cond.L.Unlock()
} }
// FinishEventProcessing marks all events of this worker as being sent to the
// application service.
func (a *ApplicationServiceWorkerState) FinishEventProcessing() {
a.Cond.L.Lock()
a.EventsReady = false
a.Cond.L.Unlock()
}
// WaitForNewEvents causes the calling goroutine to wait on the worker state's // WaitForNewEvents causes the calling goroutine to wait on the worker state's
// condition for a broadcast or similar wakeup, if there are no events ready. // condition for a broadcast or similar wakeup, if there are no events ready.
func (a *ApplicationServiceWorkerState) WaitForNewEvents() { func (a *ApplicationServiceWorkerState) WaitForNewEvents(id int) {
a.Cond.L.Lock() a.Cond.L.Lock()
if !a.EventsReady { if a.latestId <= id {
a.Cond.Wait() a.Cond.Wait()
} }
a.Cond.L.Unlock() a.Cond.L.Unlock()

View file

@ -44,7 +44,7 @@ var (
func SetupTransactionWorkers( func SetupTransactionWorkers(
client *http.Client, client *http.Client,
appserviceDB storage.Database, appserviceDB storage.Database,
workerStates []types.ApplicationServiceWorkerState, workerStates []*types.ApplicationServiceWorkerState,
) error { ) error {
// Create a worker that handles transmitting events to a single homeserver // Create a worker that handles transmitting events to a single homeserver
for _, workerState := range workerStates { for _, workerState := range workerStates {
@ -58,31 +58,29 @@ func SetupTransactionWorkers(
// worker is a goroutine that sends any queued events to the application service // worker is a goroutine that sends any queued events to the application service
// it is given. // it is given.
func worker(client *http.Client, db storage.Database, ws types.ApplicationServiceWorkerState) { func worker(client *http.Client, db storage.Database, ws *types.ApplicationServiceWorkerState) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": ws.AppService.ID, "appservice": ws.AppService.ID,
}).Info("Starting application service") }).Info("Starting application service")
ctx := context.Background() ctx := context.Background()
// Initial check for any leftover events to send from last time // Initial check for any leftover events to send from last time
eventCount, err := db.CountEventsWithAppServiceID(ctx, ws.AppService.ID) latestId, err := db.GetLatestId(ctx, ws.AppService.ID)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": ws.AppService.ID, "appservice": ws.AppService.ID,
}).WithError(err).Fatal("appservice worker unable to read queued events from DB") }).WithError(err).Fatal("appservice worker unable to read queued events from DB")
return return
} }
if eventCount > 0 { ws.NotifyNewEvents(latestId)
ws.NotifyNewEvents() id := 0
}
// Loop forever and keep waiting for more events to send // Loop forever and keep waiting for more events to send
for { for {
// Wait for more events if we've sent all the events in the database // Wait for more events if we've sent all the events in the database
ws.WaitForNewEvents() ws.WaitForNewEvents(id)
// Batch events up into a transaction // Batch events up into a transaction
transactionJSON, txnID, maxEventID, eventsRemaining, err := createTransaction(ctx, db, ws.AppService.ID) transactionJSON, txnID, maxEventID, err := createTransaction(ctx, db, ws.AppService.ID)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": ws.AppService.ID, "appservice": ws.AppService.ID,
@ -90,6 +88,10 @@ func worker(client *http.Client, db storage.Database, ws types.ApplicationServic
return return
} }
// Transactions have a maximum event size (or new events may arrive while
// transaction is processed by Application Service), so there may still be
// some events left over to send. We will keep sending if id < ws.latestID.
id = maxEventID
// Send the events off to the application service // Send the events off to the application service
// Backoff if the application service does not respond // Backoff if the application service does not respond
@ -99,19 +101,13 @@ func worker(client *http.Client, db storage.Database, ws types.ApplicationServic
"appservice": ws.AppService.ID, "appservice": ws.AppService.ID,
}).WithError(err).Error("unable to send event") }).WithError(err).Error("unable to send event")
// Backoff // Backoff
backoff(&ws, err) backoff(ws, err)
continue continue
} }
// We sent successfully, hooray! // We sent successfully, hooray!
ws.Backoff = 0 ws.Backoff = 0
// Transactions have a maximum event size, so there may still be some events
// left over to send. Keep sending until none are left
if !eventsRemaining {
ws.FinishEventProcessing()
}
// Remove sent events from the DB // Remove sent events from the DB
err = db.RemoveEventsBeforeAndIncludingID(ctx, ws.AppService.ID, maxEventID) err = db.RemoveEventsBeforeAndIncludingID(ctx, ws.AppService.ID, maxEventID)
if err != nil { if err != nil {
@ -152,11 +148,10 @@ func createTransaction(
) ( ) (
transactionJSON []byte, transactionJSON []byte,
txnID, maxID int, txnID, maxID int,
eventsRemaining bool,
err error, err error,
) { ) {
// Retrieve the latest events from the DB (will return old events if they weren't successfully sent) // Retrieve the latest events from the DB (will return old events if they weren't successfully sent)
txnID, maxID, events, eventsRemaining, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize) txnID, maxID, events, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": appserviceID, "appservice": appserviceID,
@ -170,12 +165,12 @@ func createTransaction(
// If not, grab next available ID from the DB // If not, grab next available ID from the DB
txnID, err = db.GetLatestTxnID(ctx) txnID, err = db.GetLatestTxnID(ctx)
if err != nil { if err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
// Mark new events with current transactionID // Mark new events with current transactionID
if err = db.UpdateTxnIDForEvents(ctx, appserviceID, maxID, txnID); err != nil { if err = db.UpdateTxnIDForEvents(ctx, appserviceID, maxID, txnID); err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
} }

View file

@ -11,4 +11,6 @@ const (
LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeRecaptcha = "m.login.recaptcha"
LoginTypeApplicationService = "m.login.application_service" LoginTypeApplicationService = "m.login.application_service"
LoginTypeToken = "m.login.token" LoginTypeToken = "m.login.token"
LoginTypeJwt = "org.matrix.login.jwt"
LoginTypeEmail = "m.login.email.identity"
) )

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api" uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -33,7 +34,7 @@ import (
// called after authorization has completed, with the result of the authorization. // called after authorization has completed, with the result of the authorization.
// If the final return value is non-nil, an error occurred and the cleanup function // If the final return value is non-nil, an error occurred and the cleanup function
// is nil. // is nil.
func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserLoginAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.ClientUserAPI, cfg *config.ClientAPI, rt *ratelimit.RtFailedLogin) (*Login, LoginCleanupFunc, *util.JSONResponse) {
reqBytes, err := ioutil.ReadAll(r) reqBytes, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
err := &util.JSONResponse{ err := &util.JSONResponse{
@ -58,14 +59,19 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.U
switch header.Type { switch header.Type {
case authtypes.LoginTypePassword: case authtypes.LoginTypePassword:
typ = &LoginTypePassword{ typ = &LoginTypePassword{
GetAccountByPassword: useraccountAPI.QueryAccountByPassword, UserApi: useraccountAPI,
Config: cfg, Config: cfg,
Rt: rt,
} }
case authtypes.LoginTypeToken: case authtypes.LoginTypeToken:
typ = &LoginTypeToken{ typ = &LoginTypeToken{
UserAPI: userAPI, UserAPI: useraccountAPI,
Config: cfg, Config: cfg,
} }
case authtypes.LoginTypeJwt:
typ = &LoginTypeTokenJwt{
Config: cfg,
}
default: default:
err := util.JSONResponse{ err := util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,

View file

@ -0,0 +1,74 @@
package auth
import (
"context"
"fmt"
"net/http"
"github.com/golang-jwt/jwt/v4"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/util"
)
// LoginTypeToken describes how to authenticate with a login token.
type LoginTypeTokenJwt struct {
// UserAPI uapi.LoginTokenInternalAPI
Config *config.ClientAPI
}
// Name implements Type.
func (t *LoginTypeTokenJwt) Name() string {
return authtypes.LoginTypeJwt
}
type Claims struct {
jwt.StandardClaims
}
const mIdUser = "m.id.user"
// LoginFromJSON implements Type. The cleanup function deletes the token from
// the database on success.
func (t *LoginTypeTokenJwt) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) {
var r loginTokenRequest
if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil {
return nil, nil, err
}
if r.Token == "" {
return nil, nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Token field for JWT is missing"),
}
}
c := &Claims{}
token, err := jwt.ParseWithClaims(r.Token, c, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Method.Alg())
}
return t.Config.JwtConfig.SecretKey, nil
})
if err != nil {
util.GetLogger(ctx).WithError(err).Error("jwt.ParseWithClaims failed")
return nil, nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Couldn't parse JWT"),
}
}
if !token.Valid {
return nil, nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Invalid JWT"),
}
}
r.Login.Identifier.User = c.Subject
r.Login.Identifier.Type = mIdUser
return &r.Login, func(context.Context, *util.JSONResponse) {}, nil
}

View file

@ -22,6 +22,7 @@ import (
"testing" "testing"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api" uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -68,8 +69,11 @@ func TestLoginFromJSONReader(t *testing.T) {
Matrix: &config.Global{ Matrix: &config.Global{
ServerName: serverName, ServerName: serverName,
}, },
RtFailedLogin: ratelimit.RtFailedLoginConfig{
Enabled: false,
},
} }
login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil)
if err != nil { if err != nil {
t.Fatalf("LoginFromJSONReader failed: %+v", err) t.Fatalf("LoginFromJSONReader failed: %+v", err)
} }
@ -147,7 +151,7 @@ func TestBadLoginFromJSONReader(t *testing.T) {
ServerName: serverName, ServerName: serverName,
}, },
} }
_, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil)
if errRes == nil { if errRes == nil {
cleanup(ctx, nil) cleanup(ctx, nil)
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode)
@ -159,6 +163,7 @@ func TestBadLoginFromJSONReader(t *testing.T) {
} }
type fakeUserInternalAPI struct { type fakeUserInternalAPI struct {
uapi.ClientUserAPI
UserInternalAPIForLogin UserInternalAPIForLogin
DeletedTokens []string DeletedTokens []string
} }

View file

@ -58,7 +58,7 @@ func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*L
} }
} }
r.Login.Identifier.Type = "m.id.user" r.Login.Identifier.Type = mIdUser
r.Login.Identifier.User = res.Data.UserID r.Login.Identifier.User = res.Data.UserID
cleanup := func(ctx context.Context, authRes *util.JSONResponse) { cleanup := func(ctx context.Context, authRes *util.JSONResponse) {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -33,12 +34,17 @@ type GetAccountByPassword func(ctx context.Context, req *api.QueryAccountByPassw
type PasswordRequest struct { type PasswordRequest struct {
Login Login
Password string `json:"password"` Password string `json:"password"`
Address string `json:"address"`
Medium string `json:"medium"`
} }
const email = "email"
// LoginTypePassword implements https://matrix.org/docs/spec/client_server/r0.6.1#password-based // LoginTypePassword implements https://matrix.org/docs/spec/client_server/r0.6.1#password-based
type LoginTypePassword struct { type LoginTypePassword struct {
GetAccountByPassword GetAccountByPassword UserApi api.ClientUserAPI
Config *config.ClientAPI Config *config.ClientAPI
Rt *ratelimit.RtFailedLogin
} }
func (t *LoginTypePassword) Name() string { func (t *LoginTypePassword) Name() string {
@ -61,7 +67,35 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
r := req.(*PasswordRequest) r := req.(*PasswordRequest)
username := strings.ToLower(r.Username()) if r.Identifier.Address != "" {
r.Address = r.Identifier.Address
}
if r.Identifier.Medium != "" {
r.Medium = r.Identifier.Medium
}
var username string
if r.Medium == email && r.Address != "" {
r.Address = strings.ToLower(r.Address)
res := api.QueryLocalpartForThreePIDResponse{}
err := t.UserApi.QueryLocalpartForThreePID(ctx, &api.QueryLocalpartForThreePIDRequest{
ThreePID: r.Address,
Medium: email,
}, &res)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("userApi.QueryLocalpartForThreePID failed")
resp := jsonerror.InternalServerError()
return nil, &resp
}
username = res.Localpart
if username == "" {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.Forbidden("Invalid username or password"),
}
}
} else {
username = strings.ToLower(r.Username())
}
if username == "" { if username == "" {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
@ -77,7 +111,17 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
} }
// Squash username to all lowercase letters // Squash username to all lowercase letters
res := &api.QueryAccountByPasswordResponse{} res := &api.QueryAccountByPasswordResponse{}
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res) localpart = strings.ToLower(localpart)
if t.Rt != nil {
ok, retryIn := t.Rt.CanAct(localpart)
if !ok {
return nil, &util.JSONResponse{
Code: http.StatusTooManyRequests,
JSON: jsonerror.LimitExceeded("Too Many Requests", retryIn.Milliseconds()),
}
}
}
err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: localpart, PlaintextPassword: r.Password}, res)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -86,7 +130,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
} }
if !res.Exists { if !res.Exists {
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: localpart, Localpart: localpart,
PlaintextPassword: r.Password, PlaintextPassword: r.Password,
}, res) }, res)
@ -99,11 +143,15 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
// but that would leak the existence of the user. // but that would leak the existence of the user.
if !res.Exists { if !res.Exists {
if t.Rt != nil {
t.Rt.Act(localpart)
}
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), JSON: jsonerror.Forbidden("Invalid username or password"),
} }
} }
} }
r.Login.User = username
return &r.Login, nil return &r.Login, nil
} }

View file

@ -74,7 +74,7 @@ type Login struct {
// Username returns the user localpart/user_id in this request, if it exists. // Username returns the user localpart/user_id in this request, if it exists.
func (r *Login) Username() string { func (r *Login) Username() string {
if r.Identifier.Type == "m.id.user" { if r.Identifier.Type == mIdUser {
return r.Identifier.User return r.Identifier.User
} }
// deprecated but without it Element iOS won't log in // deprecated but without it Element iOS won't log in
@ -87,8 +87,8 @@ func (r *Login) ThirdPartyID() (medium, address string) {
return r.Identifier.Medium, r.Identifier.Address return r.Identifier.Medium, r.Identifier.Address
} }
// deprecated // deprecated
if r.Medium == "email" { if r.Medium == email {
return "email", r.Address return email, r.Address
} }
return "", "" return "", ""
} }
@ -109,10 +109,10 @@ type UserInteractive struct {
Sessions map[string][]string Sessions map[string][]string
} }
func NewUserInteractive(userAccountAPI api.UserLoginAPI, cfg *config.ClientAPI) *UserInteractive { func NewUserInteractive(userAccountAPI api.ClientUserAPI, cfg *config.ClientAPI) *UserInteractive {
typePassword := &LoginTypePassword{ typePassword := &LoginTypePassword{
GetAccountByPassword: userAccountAPI.QueryAccountByPassword, UserApi: userAccountAPI,
Config: cfg, Config: cfg,
} }
return &UserInteractive{ return &UserInteractive{
Flows: []userInteractiveFlow{ Flows: []userInteractiveFlow{

View file

@ -24,7 +24,9 @@ var (
} }
) )
type fakeAccountDatabase struct{} type fakeAccountDatabase struct {
api.ClientUserAPI
}
func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
return nil return nil

View file

@ -0,0 +1,117 @@
package ratelimit
import (
"container/list"
"sync"
"time"
)
type rateLimit struct {
cfg *RtFailedLoginConfig
times *list.List
}
type RtFailedLogin struct {
cfg *RtFailedLoginConfig
mtx sync.RWMutex
rts map[string]*rateLimit
}
type RtFailedLoginConfig struct {
Enabled bool `yaml:"enabled"`
Limit int `yaml:"burst"`
Interval time.Duration `yaml:"interval"`
}
// New creates a new rate limiter for the limit and interval.
func NewRtFailedLogin(cfg *RtFailedLoginConfig) *RtFailedLogin {
if !cfg.Enabled {
return nil
}
rt := &RtFailedLogin{
cfg: cfg,
mtx: sync.RWMutex{},
rts: make(map[string]*rateLimit),
}
go rt.clean()
return rt
}
// CanAct is expected to be called before Act
func (r *RtFailedLogin) CanAct(key string) (ok bool, remaining time.Duration) {
r.mtx.RLock()
rt, ok := r.rts[key]
if !ok {
r.mtx.RUnlock()
return true, 0
}
ok, remaining = rt.canAct()
r.mtx.RUnlock()
return
}
// Act can be called after CanAct returns true.
func (r *RtFailedLogin) Act(key string) {
r.mtx.Lock()
rt, ok := r.rts[key]
if !ok {
rt = &rateLimit{
cfg: r.cfg,
times: list.New(),
}
r.rts[key] = rt
}
rt.act()
r.mtx.Unlock()
}
func (r *RtFailedLogin) clean() {
for {
r.mtx.Lock()
for k, v := range r.rts {
if v.empty() {
delete(r.rts, k)
}
}
r.mtx.Unlock()
time.Sleep(time.Hour)
}
}
func (r *rateLimit) empty() bool {
back := r.times.Back()
if back == nil {
return true
}
v := back.Value
b := v.(time.Time)
now := time.Now()
return now.Sub(b) > r.cfg.Interval
}
func (r *rateLimit) canAct() (ok bool, remaining time.Duration) {
now := time.Now()
l := r.times.Len()
if l < r.cfg.Limit {
return true, 0
}
frnt := r.times.Front()
t := frnt.Value.(time.Time)
diff := now.Sub(t)
if diff < r.cfg.Interval {
return false, r.cfg.Interval - diff
}
return true, 0
}
func (r *rateLimit) act() {
now := time.Now()
l := r.times.Len()
if l < r.cfg.Limit {
r.times.PushBack(now)
return
}
frnt := r.times.Front()
frnt.Value = now
r.times.MoveToBack(frnt)
}

View file

@ -0,0 +1,40 @@
package ratelimit
import (
"testing"
"time"
"github.com/matryer/is"
)
func TestRtFailedLogin(t *testing.T) {
is := is.New(t)
rtfl := NewRtFailedLogin(&RtFailedLoginConfig{
Enabled: true,
Limit: 3,
Interval: 10 * time.Millisecond,
})
var (
can bool
remaining time.Duration
remainingB time.Duration
)
for i := 0; i < 3; i++ {
can, remaining = rtfl.CanAct("foo")
is.True(can)
is.Equal(remaining, time.Duration(0))
rtfl.Act("foo")
}
can, remaining = rtfl.CanAct("foo")
is.True(!can)
is.True(remaining > time.Millisecond*9)
can, remainingB = rtfl.CanAct("bar")
is.True(can)
is.Equal(remainingB, time.Duration(0))
rtfl.Act("bar")
rtfl.Act("bar")
time.Sleep(remaining + time.Millisecond)
can, remaining = rtfl.CanAct("foo")
is.True(can)
is.Equal(remaining, time.Duration(0))
}

View file

@ -63,8 +63,8 @@ func UploadCrossSigningDeviceKeys(
} }
} }
typePassword := auth.LoginTypePassword{ typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountAPI.QueryAccountByPassword, UserApi: accountAPI,
Config: cfg, Config: cfg,
} }
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
return *authErr return *authErr

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -55,6 +56,7 @@ func passwordLogin() flows {
func Login( func Login(
req *http.Request, userAPI userapi.ClientUserAPI, req *http.Request, userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI, cfg *config.ClientAPI,
rt *ratelimit.RtFailedLogin,
) util.JSONResponse { ) util.JSONResponse {
if req.Method == http.MethodGet { if req.Method == http.MethodGet {
// TODO: support other forms of login other than password, depending on config options // TODO: support other forms of login other than password, depending on config options
@ -63,7 +65,7 @@ func Login(
JSON: passwordLogin(), JSON: passwordLogin(),
} }
} else if req.Method == http.MethodPost { } else if req.Method == http.MethodPost {
login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, userAPI, cfg) login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, cfg, rt)
if authErr != nil { if authErr != nil {
return *authErr return *authErr
} }

View file

@ -1,12 +1,14 @@
package routing package routing
import ( import (
"fmt"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -24,6 +26,7 @@ type newPasswordAuth struct {
Type string `json:"type"` Type string `json:"type"`
Session string `json:"session"` Session string `json:"session"`
auth.PasswordRequest auth.PasswordRequest
ThreePidCreds threepid.Credentials `json:"threepid_creds"`
} }
func Password( func Password(
@ -33,13 +36,17 @@ func Password(
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
// Check that the existing password is right. // Check that the existing password is right.
var fields logrus.Fields
if device != nil {
fields = logrus.Fields{
"sessionId": device.SessionID,
"userId": device.UserID,
}
}
var r newPasswordRequest var r newPasswordRequest
r.LogoutDevices = true r.LogoutDevices = true
logrus.WithFields(logrus.Fields{ logrus.WithFields(fields).Debug("Changing password")
"sessionId": device.SessionID,
"userId": device.UserID,
}).Debug("Changing password")
// Unmarshal the request. // Unmarshal the request.
resErr := httputil.UnmarshalJSONRequest(req, &r) resErr := httputil.UnmarshalJSONRequest(req, &r)
@ -53,45 +60,95 @@ func Password(
// Generate a new, random session ID // Generate a new, random session ID
sessionID = util.RandomString(sessionIDLength) sessionID = util.RandomString(sessionIDLength)
} }
var localpart string
// Require password auth to change the password. switch r.Auth.Type {
if r.Auth.Type != authtypes.LoginTypePassword { case authtypes.LoginTypePassword:
return util.JSONResponse{ // Check if the existing password is correct.
Code: http.StatusUnauthorized, typePassword := auth.LoginTypePassword{
JSON: newUserInteractiveResponse( UserApi: userAPI,
sessionID, Config: cfg,
[]authtypes.Flow{ }
{ if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, return *authErr
}, }
// Get the local part.
var err error
localpart, _, err = gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
case authtypes.LoginTypeEmail:
threePid := &authtypes.ThreePID{}
r.Auth.ThreePidCreds.IDServer = cfg.ThreePidDelegate
var (
bound bool
err error
)
bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed")
return jsonerror.InternalServerError()
}
if !bound {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_AUTH_FAILED",
Err: "Failed to auth 3pid",
}, },
nil, }
), }
var res api.QueryLocalpartForThreePIDResponse
err = userAPI.QueryLocalpartForThreePID(req.Context(), &api.QueryLocalpartForThreePIDRequest{
Medium: threePid.Medium,
ThreePID: threePid.Address,
}, &res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryLocalpartForThreePID failed")
return jsonerror.InternalServerError()
}
if res.Localpart == "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_NOT_FOUND",
Err: "3pid is not bound to any account",
},
}
}
localpart = res.Localpart
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail)
default:
flows := []authtypes.Flow{
{
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
},
}
if cfg.ThreePidDelegate != "" {
flows = append(flows, authtypes.Flow{
Stages: []authtypes.LoginType{authtypes.LoginTypeEmail},
})
}
// Require password auth to change the password.
if r.Auth.Type == authtypes.LoginTypePassword {
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(
sessionID,
flows,
nil,
),
}
} }
} }
// Check if the existing password is correct.
typePassword := auth.LoginTypePassword{
GetAccountByPassword: userAPI.QueryAccountByPassword,
Config: cfg,
}
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
return *authErr
}
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength. // Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil { if resErr = validatePassword(r.NewPassword); resErr != nil {
return *resErr return *resErr
} }
// Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
// Ask the user API to perform the password change. // Ask the user API to perform the password change.
passwordReq := &api.PerformPasswordUpdateRequest{ passwordReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
@ -109,11 +166,23 @@ func Password(
// If the request asks us to log out all other devices then // If the request asks us to log out all other devices then
// ask the user API to do that. // ask the user API to do that.
if r.LogoutDevices { if r.LogoutDevices {
logoutReq := &api.PerformDeviceDeletionRequest{ var logoutReq *api.PerformDeviceDeletionRequest
UserID: device.UserID, var sessionId int64
DeviceIDs: nil, if device == nil {
ExceptDeviceID: device.ID, logoutReq = &api.PerformDeviceDeletionRequest{
UserID: fmt.Sprintf("@%s:%s", localpart, cfg.Matrix.ServerName),
DeviceIDs: []string{},
}
sessionId = 0
} else {
logoutReq = &api.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: nil,
ExceptDeviceID: device.ID,
}
sessionId = device.SessionID
} }
logoutRes := &api.PerformDeviceDeletionResponse{} logoutRes := &api.PerformDeviceDeletionResponse{}
if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil { if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil {
@ -123,7 +192,7 @@ func Password(
pushersReq := &api.PerformPusherDeletionRequest{ pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart, Localpart: localpart,
SessionID: device.SessionID, SessionID: sessionId,
} }
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")

View file

@ -44,6 +44,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
@ -237,6 +238,7 @@ type authDict struct {
// Recaptcha // Recaptcha
Response string `json:"response"` Response string `json:"response"`
// TODO: Lots of custom keys depending on the type // TODO: Lots of custom keys depending on the type
ThreePidCreds threepid.Credentials `json:"threepid_creds"`
} }
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api
@ -745,6 +747,7 @@ func handleRegistrationFlow(
} }
} }
var threePid *authtypes.ThreePID
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypeRecaptcha: case authtypes.LoginTypeRecaptcha:
// Check given captcha response // Check given captcha response
@ -761,6 +764,29 @@ func handleRegistrationFlow(
// Add Dummy to the list of completed registration stages // Add Dummy to the list of completed registration stages
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
case authtypes.LoginTypeEmail:
threePid = &authtypes.ThreePID{}
r.Auth.ThreePidCreds.IDServer = cfg.ThreePidDelegate
var (
bound bool
err error
)
bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed")
return jsonerror.InternalServerError()
}
if !bound {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_AUTH_FAILED",
Err: "Failed to auth 3pid",
},
}
}
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail)
case "": case "":
// An empty auth type means that we want to fetch the available // An empty auth type means that we want to fetch the available
// flows. It can also mean that we want to register as an appservice // flows. It can also mean that we want to register as an appservice
@ -776,7 +802,7 @@ func handleRegistrationFlow(
// A response with current registration flow and remaining available methods // A response with current registration flow and remaining available methods
// will be returned if a flow has not been successfully completed yet // will be returned if a flow has not been successfully completed yet
return checkAndCompleteFlow(sessions.getCompletedStages(sessionID), return checkAndCompleteFlow(sessions.getCompletedStages(sessionID),
req, r, sessionID, cfg, userAPI) req, r, sessionID, cfg, userAPI, threePid)
} }
// handleApplicationServiceRegistration handles the registration of an // handleApplicationServiceRegistration handles the registration of an
@ -818,7 +844,7 @@ func handleApplicationServiceRegistration(
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, nil,
) )
} }
@ -832,12 +858,13 @@ func checkAndCompleteFlow(
sessionID string, sessionID string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
threePid *authtypes.ThreePID,
) util.JSONResponse { ) util.JSONResponse {
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, threePid,
) )
} }
sessions.addParams(sessionID, r) sessions.addParams(sessionID, r)
@ -863,6 +890,7 @@ func completeRegistration(
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
accType userapi.AccountType, accType userapi.AccountType,
threePid *authtypes.ThreePID,
) util.JSONResponse { ) util.JSONResponse {
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
@ -901,6 +929,21 @@ func completeRegistration(
// Increment prometheus counter for created users // Increment prometheus counter for created users
amtRegUsers.Inc() amtRegUsers.Inc()
// TODO-entry refuse register if threepid is already bound to account.
if threePid != nil {
err = userAPI.PerformSaveThreePIDAssociation(ctx, &userapi.PerformSaveThreePIDAssociationRequest{
Medium: threePid.Medium,
ThreePID: threePid.Address,
Localpart: accRes.Account.Localpart,
}, &struct{}{})
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("Failed to save 3PID association: " + err.Error()),
}
}
}
// Check whether inhibit_login option is set. If so, don't create an access // Check whether inhibit_login option is set. If so, don't create an access
// token or a device for this user // token or a device for this user
if inhibitLogin { if inhibitLogin {
@ -1092,5 +1135,5 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec
if ssrr.Admin { if ssrr.Admin {
accType = userapi.AccountTypeAdmin accType = userapi.AccountTypeAdmin
} }
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil)
} }

View file

@ -26,6 +26,7 @@ import (
clientutil "github.com/matrix-org/dendrite/clientapi/httputil" clientutil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/internal/transactions"
@ -65,6 +66,7 @@ func Setup(
prometheus.MustRegister(amtRegUsers, sendEventDuration) prometheus.MustRegister(amtRegUsers, sendEventDuration)
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
rateLimitsFailedLogin := ratelimit.NewRtFailedLogin(&cfg.RtFailedLogin)
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
unstableFeatures := map[string]bool{ unstableFeatures := map[string]bool{
@ -538,7 +540,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/account/password", v3mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeConditionalAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, device); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
@ -562,7 +564,7 @@ func Setup(
if r := rateLimits.Limit(req, nil); r != nil { if r := rateLimits.Limit(req, nil); r != nil {
return *r return *r
} }
return Login(req, userAPI, cfg) return Login(req, userAPI, cfg, rateLimitsFailedLogin)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)

View file

@ -103,11 +103,8 @@ func CreateSession(
func CheckAssociation( func CheckAssociation(
ctx context.Context, creds Credentials, cfg *config.ClientAPI, ctx context.Context, creds Credentials, cfg *config.ClientAPI,
) (bool, string, string, error) { ) (bool, string, string, error) {
if err := isTrusted(creds.IDServer, cfg); err != nil {
return false, "", "", err
}
requestURL := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret) requestURL := fmt.Sprintf("%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", cfg.ThreePidDelegate, creds.SID, creds.Secret)
req, err := http.NewRequest(http.MethodGet, requestURL, nil) req, err := http.NewRequest(http.MethodGet, requestURL, nil)
if err != nil { if err != nil {
return false, "", "", err return false, "", "", err

View file

@ -0,0 +1,8 @@
FROM alpine:latest
COPY dendrite-monolith-server /usr/bin/
VOLUME /etc/dendrite
WORKDIR /etc/dendrite
ENTRYPOINT ["/usr/bin/dendrite-monolith-server"]

View file

@ -0,0 +1,12 @@
set -xe
if [ -z "$(git status --porcelain)" ]; then
CGO_ENABLED=0 go build .
TAG=$(git rev-parse --short HEAD)
docker build -f Dockerfile.dev -t gcr.io/globekeeper-development/dendrite-monolith:$TAG -t gcr.io/globekeeper-development/dendrite-monolith -t gcr.io/globekeeper-production/dendrite-monolith:$TAG .
docker push gcr.io/globekeeper-development/dendrite-monolith:$TAG
docker push gcr.io/globekeeper-production/dendrite-monolith:$TAG
docker push gcr.io/globekeeper-development/dendrite-monolith
else
echo "Please commit changes"
exit 0
fi

2
go.mod
View file

@ -21,6 +21,7 @@ require (
github.com/frankban/quicktest v1.14.3 // indirect github.com/frankban/quicktest v1.14.3 // indirect
github.com/getsentry/sentry-go v0.13.0 github.com/getsentry/sentry-go v0.13.0
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.4.1
github.com/gologme/log v1.3.0 github.com/gologme/log v1.3.0
github.com/google/go-cmp v0.5.8 github.com/google/go-cmp v0.5.8
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
@ -37,6 +38,7 @@ require (
github.com/matrix-org/gomatrixserverlib v0.0.0-20220607143425-e55d796fd0b3 github.com/matrix-org/gomatrixserverlib v0.0.0-20220607143425-e55d796fd0b3
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/matryer/is v1.4.0
github.com/mattn/go-sqlite3 v1.14.13 github.com/mattn/go-sqlite3 v1.14.13
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
github.com/miekg/dns v1.1.49 // indirect github.com/miekg/dns v1.1.49 // indirect

5
go.sum
View file

@ -202,7 +202,10 @@ github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/E
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-jwt/jwt/v4 v4.4.1 h1:pC5DB52sCeK48Wlb9oPcdhnjkz1TKt1D/P7WKJ0kUcQ=
github.com/golang-jwt/jwt/v4 v4.4.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
@ -425,6 +428,8 @@ github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJz
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
github.com/mattn/go-colorable v0.0.6/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.0.6/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=

View file

@ -83,6 +83,57 @@ func MakeAuthAPI(
return MakeExternalAPI(metricsName, h) return MakeExternalAPI(metricsName, h)
} }
// MakeConditionalAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
// It passes nil device if header is not provided.
func MakeConditionalAuthAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
f func(*http.Request, *userapi.Device) util.JSONResponse,
) http.Handler {
h := func(req *http.Request) util.JSONResponse {
var (
jsonRes util.JSONResponse
dev *userapi.Device
)
if _, err := auth.ExtractAccessToken(req); err != nil {
dev = nil
} else {
logger := util.GetLogger(req.Context())
var err *util.JSONResponse
dev, err = auth.VerifyUserFromRequest(req, userAPI)
if err != nil {
logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code)
return *err
}
// add the user ID to the logger
logger = logger.WithField("user_id", dev.UserID)
req = req.WithContext(util.ContextWithLogger(req.Context(), logger))
}
// add the user to Sentry, if enabled
hub := sentry.GetHubFromContext(req.Context())
if hub != nil {
hub.Scope().SetTag("user_id", dev.UserID)
hub.Scope().SetTag("device_id", dev.ID)
}
defer func() {
if r := recover(); r != nil {
if hub != nil {
hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
}
// re-panic to return the 500
panic(r)
}
}()
jsonRes = f(req, dev)
// do not log 4xx as errors as they are client fails, not server fails
if hub != nil && jsonRes.Code >= 500 {
hub.Scope().SetExtra("response", jsonRes)
hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code))
}
return jsonRes
}
return MakeExternalAPI(metricsName, h)
}
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler. // MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
// This is used for APIs that are called from the internet. // This is used for APIs that are called from the internet.
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler { func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {

View file

@ -16,6 +16,7 @@ package config
import ( import (
"bytes" "bytes"
"crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
@ -252,6 +253,15 @@ func loadConfig(
c.Global.OldVerifyKeys[i].KeyID, c.Global.OldVerifyKeys[i].PrivateKey = keyID, privateKey c.Global.OldVerifyKeys[i].KeyID, c.Global.OldVerifyKeys[i].PrivateKey = keyID, privateKey
} }
if c.ClientAPI.JwtConfig.Enabled {
pubPki, _ := pem.Decode([]byte(c.ClientAPI.JwtConfig.Secret))
var pub interface{}
pub, err = x509.ParsePKIXPublicKey(pubPki.Bytes)
if err != nil {
return nil, err
}
c.ClientAPI.JwtConfig.SecretKey = pub.(ed25519.PublicKey)
}
c.MediaAPI.AbsBasePath = Path(absPath(basePath, c.MediaAPI.BasePath)) c.MediaAPI.AbsBasePath = Path(absPath(basePath, c.MediaAPI.BasePath))
@ -283,7 +293,10 @@ func (config *Dendrite) Derive() error {
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}})
} }
if config.ClientAPI.ThreePidDelegate != "" {
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeEmail}})
}
// Load application service configuration files // Load application service configuration files
if err := loadAppServices(&config.AppServiceAPI, &config.Derived); err != nil { if err := loadAppServices(&config.AppServiceAPI, &config.Derived); err != nil {
return err return err

View file

@ -3,6 +3,9 @@ package config
import ( import (
"fmt" "fmt"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
"golang.org/x/crypto/ed25519"
) )
type ClientAPI struct { type ClientAPI struct {
@ -46,9 +49,23 @@ type ClientAPI struct {
TURN TURN `yaml:"turn"` TURN TURN `yaml:"turn"`
// Rate-limiting options // Rate-limiting options
RateLimiting RateLimiting `yaml:"rate_limiting"` RateLimiting RateLimiting `yaml:"rate_limiting"`
RtFailedLogin ratelimit.RtFailedLoginConfig `yaml:"rate_limiting_failed_login"`
MSCs *MSCs `yaml:"mscs"` MSCs *MSCs `yaml:"mscs"`
ThreePidDelegate string `yaml:"three_pid_delegate"`
JwtConfig JwtConfig `yaml:"jwt_config"`
}
type JwtConfig struct {
Enabled bool `yaml:"enabled"`
Algorithm string `yaml:"algorithm"`
Issuer string `yaml:"issuer"`
Secret string `yaml:"secret"`
SecretKey ed25519.PublicKey
Audiences []string `yaml:"audiences"`
} }
func (c *ClientAPI) Defaults(generate bool) { func (c *ClientAPI) Defaults(generate bool) {

View file

@ -517,7 +517,7 @@ type PerformPusherSetRequest struct {
type PerformPusherDeletionRequest struct { type PerformPusherDeletionRequest struct {
Localpart string Localpart string
SessionID int64 SessionID int64 // Pusher corresponding to this SessionID will not be deleted
} }
// Pusher represents a push notification subscriber // Pusher represents a push notification subscriber