mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-06 14:33:10 -06:00
Entry improvements (#11)
* Refactor ApplicationServiceWorkerState to be more robust * Add launch.json to VS Code * Implement login with JWT, registering with email, failed login rate limiting and reset password with m.login.email.identity auth type * Log errors when JWT parsing failed * Development build script * Fix linter errors * Use golangci-lint as a linter in VS Code * Fix tests with RtFailedLogin * Pass config load tests - parse JWT public key only if enabled * Reduce CI steps Do not support 386 arch and go 1.16, 1.17 * Fix linter errors * Change RtFailedLogin logic - nil pointer can be provided * Respect access token in query * Fix typos * Use only one mutex in RtFailedLogin * Remove eventsRemaining across appservice component * Push dendrite to production registry as well * Rafactor TestRtFailedLogin
This commit is contained in:
parent
83797573be
commit
374b77a3df
89
.github/workflows/dendrite.yml
vendored
89
.github/workflows/dendrite.yml
vendored
|
|
@ -13,50 +13,6 @@ concurrency:
|
|||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
wasm:
|
||||
name: WASM build test
|
||||
timeout-minutes: 5
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.16
|
||||
|
||||
- uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-wasm-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-wasm
|
||||
|
||||
- name: Install Node
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 14
|
||||
|
||||
- uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.npm
|
||||
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-
|
||||
|
||||
- name: Reconfigure Git to use HTTPS auth for repo packages
|
||||
run: >
|
||||
git config --global url."https://github.com/".insteadOf
|
||||
ssh://git@github.com/
|
||||
|
||||
- name: Install test dependencies
|
||||
working-directory: ./test/wasm
|
||||
run: npm ci
|
||||
|
||||
- name: Test
|
||||
run: ./test-dendritejs.sh
|
||||
|
||||
# Run golangci-lint
|
||||
lint:
|
||||
|
|
@ -68,7 +24,7 @@ jobs:
|
|||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
|
||||
# run go test with different go versions
|
||||
# run go test with go 1.18
|
||||
test:
|
||||
timeout-minutes: 5
|
||||
name: Unit tests (Go ${{ matrix.go }})
|
||||
|
|
@ -96,7 +52,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go: ["1.16", "1.17", "1.18"]
|
||||
go: ["1.18"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Setup go
|
||||
|
|
@ -118,7 +74,7 @@ jobs:
|
|||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: dendrite
|
||||
|
||||
# build Dendrite for linux with different architectures and go versions
|
||||
# build Dendrite for linux amd64 with go 1.18
|
||||
build:
|
||||
name: Build for Linux
|
||||
timeout-minutes: 10
|
||||
|
|
@ -126,9 +82,9 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go: ["1.16", "1.17", "1.18"]
|
||||
go: ["1.18"]
|
||||
goos: ["linux"]
|
||||
goarch: ["amd64", "386"]
|
||||
goarch: ["amd64"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Setup go
|
||||
|
|
@ -152,43 +108,10 @@ jobs:
|
|||
CGO_ENABLED: 1
|
||||
run: go build -trimpath -v -o "bin/" ./cmd/...
|
||||
|
||||
# build for Windows 64-bit
|
||||
build_windows:
|
||||
name: Build for Windows
|
||||
timeout-minutes: 10
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go: ["1.16", "1.17", "1.18"]
|
||||
goos: ["windows"]
|
||||
goarch: ["amd64"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Setup Go ${{ matrix.go }}
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}
|
||||
- env:
|
||||
GOOS: ${{ matrix.goos }}
|
||||
GOARCH: ${{ matrix.goarch }}
|
||||
CGO_ENABLED: 1
|
||||
CC: "/usr/bin/x86_64-w64-mingw32-gcc"
|
||||
run: go build -trimpath -v -o "bin/" ./cmd/...
|
||||
|
||||
# Dummy step to gate other tests on without repeating the whole list
|
||||
initial-tests-done:
|
||||
name: Initial tests passed
|
||||
needs: [lint, test, build, build_windows]
|
||||
needs: [lint, test, build]
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
|
||||
steps:
|
||||
|
|
|
|||
5
.gitignore
vendored
5
.gitignore
vendored
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
# Hidden files
|
||||
.*
|
||||
!.vscode
|
||||
|
||||
# Allow GitHub config
|
||||
!.github
|
||||
|
|
@ -73,3 +74,7 @@ complement/
|
|||
docs/_site
|
||||
|
||||
media_store/
|
||||
|
||||
__debug_bin
|
||||
|
||||
cmd/dendrite-monolith-server/dendrite-monolith-server
|
||||
16
.vscode/launch.json
vendored
Normal file
16
.vscode/launch.json
vendored
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Launch Package",
|
||||
"type": "go",
|
||||
"request": "launch",
|
||||
"mode": "auto",
|
||||
"program": "${workspaceFolder}/cmd/dendrite-monolith-server",
|
||||
"args": [
|
||||
"-really-enable-open-registration",
|
||||
"-config",
|
||||
"../../../adminas/.ci/config/dendrite-local/dendrite.yaml"
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"go.lintTool": "golangci-lint"
|
||||
}
|
||||
|
|
@ -70,14 +70,14 @@ func NewInternalAPI(
|
|||
// Wrap application services in a type that relates the application service and
|
||||
// a sync.Cond object that can be used to notify workers when there are new
|
||||
// events to be sent out.
|
||||
workerStates := make([]types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices))
|
||||
workerStates := make([]*types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices))
|
||||
for i, appservice := range base.Cfg.Derived.ApplicationServices {
|
||||
m := sync.Mutex{}
|
||||
ws := types.ApplicationServiceWorkerState{
|
||||
AppService: appservice,
|
||||
Cond: sync.NewCond(&m),
|
||||
}
|
||||
workerStates[i] = ws
|
||||
workerStates[i] = &ws
|
||||
|
||||
// Create bot account for this AS if it doesn't already exist
|
||||
if err = generateAppServiceAccount(userAPI, appservice); err != nil {
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ type OutputRoomEventConsumer struct {
|
|||
asDB storage.Database
|
||||
rsAPI api.AppserviceRoomserverAPI
|
||||
serverName string
|
||||
workerStates []types.ApplicationServiceWorkerState
|
||||
workerStates []*types.ApplicationServiceWorkerState
|
||||
}
|
||||
|
||||
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call
|
||||
|
|
@ -50,7 +50,7 @@ func NewOutputRoomEventConsumer(
|
|||
js nats.JetStreamContext,
|
||||
appserviceDB storage.Database,
|
||||
rsAPI api.AppserviceRoomserverAPI,
|
||||
workerStates []types.ApplicationServiceWorkerState,
|
||||
workerStates []*types.ApplicationServiceWorkerState,
|
||||
) *OutputRoomEventConsumer {
|
||||
return &OutputRoomEventConsumer{
|
||||
ctx: process.Context(),
|
||||
|
|
@ -140,13 +140,13 @@ func (s *OutputRoomEventConsumer) filterRoomserverEvents(
|
|||
// Check if this event is interesting to this application service
|
||||
if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) {
|
||||
// Queue this event to be sent off to the application service
|
||||
if err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil {
|
||||
log.WithError(err).Warn("failed to insert incoming event into appservices database")
|
||||
if id, err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil {
|
||||
log.WithError(err).Warnf("failed to insert incoming event into appservices database. id: %d", id)
|
||||
return err
|
||||
} else {
|
||||
// Tell our worker to send out new messages by updating remaining message
|
||||
// count and waking them up with a broadcast
|
||||
ws.NotifyNewEvents()
|
||||
ws.NotifyNewEvents(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ import (
|
|||
)
|
||||
|
||||
type Database interface {
|
||||
StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error
|
||||
GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error)
|
||||
CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error)
|
||||
StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) (int, error)
|
||||
GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, error)
|
||||
GetLatestId(ctx context.Context, appServiceID string) (int, error)
|
||||
UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error
|
||||
RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error
|
||||
GetLatestTxnID(ctx context.Context) (int, error)
|
||||
|
|
|
|||
|
|
@ -45,12 +45,13 @@ const selectEventsByApplicationServiceIDSQL = "" +
|
|||
"SELECT id, headered_event_json, txn_id " +
|
||||
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
|
||||
|
||||
const countEventsByApplicationServiceIDSQL = "" +
|
||||
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
|
||||
const getLatestIdSQL = "" +
|
||||
"SELECT id FROM appservice_events WHERE as_id = $1 ORDER BY id DESC LIMIT 1"
|
||||
|
||||
const insertEventSQL = "" +
|
||||
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
|
||||
"VALUES ($1, $2, $3)"
|
||||
"VALUES ($1, $2, $3)" +
|
||||
"RETURNING id"
|
||||
|
||||
const updateTxnIDForEventsSQL = "" +
|
||||
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
|
||||
|
|
@ -66,7 +67,7 @@ const (
|
|||
|
||||
type eventsStatements struct {
|
||||
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
||||
countEventsByApplicationServiceIDStmt *sql.Stmt
|
||||
getLatestIdStmt *sql.Stmt
|
||||
insertEventStmt *sql.Stmt
|
||||
updateTxnIDForEventsStmt *sql.Stmt
|
||||
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
|
||||
|
|
@ -81,7 +82,7 @@ func (s *eventsStatements) prepare(db *sql.DB) (err error) {
|
|||
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil {
|
||||
if s.getLatestIdStmt, err = db.Prepare(getLatestIdSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
|
||||
|
|
@ -108,7 +109,6 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
|
|||
) (
|
||||
txnID, maxID int,
|
||||
events []gomatrixserverlib.HeaderedEvent,
|
||||
eventsRemaining bool,
|
||||
err error,
|
||||
) {
|
||||
defer func() {
|
||||
|
|
@ -124,7 +124,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
|
|||
return
|
||||
}
|
||||
defer checkNamedErr(eventRows.Close, &err)
|
||||
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
|
||||
events, maxID, txnID, err = retrieveEvents(eventRows, limit)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -139,7 +139,7 @@ func checkNamedErr(fn func() error, err *error) {
|
|||
}
|
||||
}
|
||||
|
||||
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
|
||||
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, err error) {
|
||||
// Get current time for use in calculating event age
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
|
||||
|
|
@ -157,18 +157,18 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
|
|||
&txnID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
// Unmarshal eventJSON
|
||||
if err = json.Unmarshal(eventJSON, &event); err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
// If txnID has changed on this event from the previous event, then we've
|
||||
// reached the end of a transaction's events. Return only those events.
|
||||
if lastTxnID > invalidTxnID && lastTxnID != txnID {
|
||||
return events, maxID, lastTxnID, true, nil
|
||||
return events, maxID, lastTxnID, nil
|
||||
}
|
||||
lastTxnID = txnID
|
||||
|
||||
|
|
@ -176,7 +176,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
|
|||
if txnID == -1 {
|
||||
// Return if we've hit the limit
|
||||
if eventsProcessed++; eventsProcessed > limit {
|
||||
return events, maxID, lastTxnID, true, nil
|
||||
return events, maxID, lastTxnID, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -187,7 +187,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
|
|||
// Portion of the event that is unsigned due to rapid change
|
||||
// TODO: Consider removing age as not many app services use it
|
||||
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
|
|
@ -196,14 +196,12 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
|
|||
return
|
||||
}
|
||||
|
||||
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service
|
||||
// IDs into the db.
|
||||
func (s *eventsStatements) countEventsByApplicationServiceID(
|
||||
func (s *eventsStatements) getLatestId(
|
||||
ctx context.Context,
|
||||
appServiceID string,
|
||||
) (int, error) {
|
||||
var count int
|
||||
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
|
||||
err := s.getLatestIdStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
|
|
@ -217,19 +215,19 @@ func (s *eventsStatements) insertEvent(
|
|||
ctx context.Context,
|
||||
appServiceID string,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
) (err error) {
|
||||
) (id int, err error) {
|
||||
// Convert event to JSON before inserting
|
||||
eventJSON, err := json.Marshal(event)
|
||||
var eventJSON []byte
|
||||
eventJSON, err = json.Marshal(event)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, err = s.insertEventStmt.ExecContext(
|
||||
err = s.insertEventStmt.QueryRowContext(
|
||||
ctx,
|
||||
appServiceID,
|
||||
eventJSON,
|
||||
-1, // No transaction ID yet
|
||||
)
|
||||
).Scan(&id)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ func (d *Database) StoreEvent(
|
|||
ctx context.Context,
|
||||
appServiceID string,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
) error {
|
||||
) (int, error) {
|
||||
return d.events.insertEvent(ctx, appServiceID, event)
|
||||
}
|
||||
|
||||
|
|
@ -72,17 +72,20 @@ func (d *Database) GetEventsWithAppServiceID(
|
|||
ctx context.Context,
|
||||
appServiceID string,
|
||||
limit int,
|
||||
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
|
||||
) (int, int, []gomatrixserverlib.HeaderedEvent, error) {
|
||||
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
|
||||
}
|
||||
|
||||
// CountEventsWithAppServiceID returns the number of events destined for an
|
||||
// application service given its ID.
|
||||
func (d *Database) CountEventsWithAppServiceID(
|
||||
// GetLatestId returns the latest incremental id associated with appservice.
|
||||
func (d *Database) GetLatestId(
|
||||
ctx context.Context,
|
||||
appServiceID string,
|
||||
) (int, error) {
|
||||
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
|
||||
id, err := d.events.getLatestId(ctx, appServiceID)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, nil
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// UpdateTxnIDForEvents takes in an application service ID and a
|
||||
|
|
|
|||
|
|
@ -46,12 +46,13 @@ const selectEventsByApplicationServiceIDSQL = "" +
|
|||
"SELECT id, headered_event_json, txn_id " +
|
||||
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
|
||||
|
||||
const countEventsByApplicationServiceIDSQL = "" +
|
||||
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
|
||||
const getLatestIdSQL = "" +
|
||||
"SELECT id FROM appservice_events WHERE as_id = $1 ORDER BY id DESC LIMIT 1"
|
||||
|
||||
const insertEventSQL = "" +
|
||||
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
|
||||
"VALUES ($1, $2, $3)"
|
||||
"VALUES ($1, $2, $3)" +
|
||||
"RETURNING id"
|
||||
|
||||
const updateTxnIDForEventsSQL = "" +
|
||||
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
|
||||
|
|
@ -69,7 +70,7 @@ type eventsStatements struct {
|
|||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
||||
countEventsByApplicationServiceIDStmt *sql.Stmt
|
||||
getLatestIdStmt *sql.Stmt
|
||||
insertEventStmt *sql.Stmt
|
||||
updateTxnIDForEventsStmt *sql.Stmt
|
||||
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
|
||||
|
|
@ -86,7 +87,7 @@ func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error
|
|||
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil {
|
||||
if s.getLatestIdStmt, err = db.Prepare(getLatestIdSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
|
||||
|
|
@ -113,7 +114,6 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
|
|||
) (
|
||||
txnID, maxID int,
|
||||
events []gomatrixserverlib.HeaderedEvent,
|
||||
eventsRemaining bool,
|
||||
err error,
|
||||
) {
|
||||
defer func() {
|
||||
|
|
@ -129,7 +129,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
|
|||
return
|
||||
}
|
||||
defer checkNamedErr(eventRows.Close, &err)
|
||||
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
|
||||
events, maxID, txnID, err = retrieveEvents(eventRows, limit)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -144,7 +144,7 @@ func checkNamedErr(fn func() error, err *error) {
|
|||
}
|
||||
}
|
||||
|
||||
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
|
||||
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, err error) {
|
||||
// Get current time for use in calculating event age
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
|
||||
|
|
@ -162,18 +162,18 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
|
|||
&txnID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
// Unmarshal eventJSON
|
||||
if err = json.Unmarshal(eventJSON, &event); err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
// If txnID has changed on this event from the previous event, then we've
|
||||
// reached the end of a transaction's events. Return only those events.
|
||||
if lastTxnID > invalidTxnID && lastTxnID != txnID {
|
||||
return events, maxID, lastTxnID, true, nil
|
||||
return events, maxID, lastTxnID, nil
|
||||
}
|
||||
lastTxnID = txnID
|
||||
|
||||
|
|
@ -181,7 +181,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
|
|||
if txnID == -1 {
|
||||
// Return if we've hit the limit
|
||||
if eventsProcessed++; eventsProcessed > limit {
|
||||
return events, maxID, lastTxnID, true, nil
|
||||
return events, maxID, lastTxnID, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -192,7 +192,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
|
|||
// Portion of the event that is unsigned due to rapid change
|
||||
// TODO: Consider removing age as not many app services use it
|
||||
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
|
|
@ -201,14 +201,12 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
|
|||
return
|
||||
}
|
||||
|
||||
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service
|
||||
// IDs into the db.
|
||||
func (s *eventsStatements) countEventsByApplicationServiceID(
|
||||
func (s *eventsStatements) getLatestId(
|
||||
ctx context.Context,
|
||||
appServiceID string,
|
||||
) (int, error) {
|
||||
var count int
|
||||
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
|
||||
err := s.getLatestIdStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
|
|
@ -222,22 +220,22 @@ func (s *eventsStatements) insertEvent(
|
|||
ctx context.Context,
|
||||
appServiceID string,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
) (err error) {
|
||||
) (id int, err error) {
|
||||
// Convert event to JSON before inserting
|
||||
eventJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
_, err := s.insertEventStmt.ExecContext(
|
||||
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
err = s.insertEventStmt.QueryRowContext(
|
||||
ctx,
|
||||
appServiceID,
|
||||
eventJSON,
|
||||
-1, // No transaction ID yet
|
||||
)
|
||||
).Scan(&id)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ func (d *Database) StoreEvent(
|
|||
ctx context.Context,
|
||||
appServiceID string,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
) error {
|
||||
) (int, error) {
|
||||
return d.events.insertEvent(ctx, appServiceID, event)
|
||||
}
|
||||
|
||||
|
|
@ -71,17 +71,20 @@ func (d *Database) GetEventsWithAppServiceID(
|
|||
ctx context.Context,
|
||||
appServiceID string,
|
||||
limit int,
|
||||
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
|
||||
) (int, int, []gomatrixserverlib.HeaderedEvent, error) {
|
||||
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
|
||||
}
|
||||
|
||||
// CountEventsWithAppServiceID returns the number of events destined for an
|
||||
// application service given its ID.
|
||||
func (d *Database) CountEventsWithAppServiceID(
|
||||
// GetLatestId returns the latest incremental id associated with appservice.
|
||||
func (d *Database) GetLatestId(
|
||||
ctx context.Context,
|
||||
appServiceID string,
|
||||
) (int, error) {
|
||||
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
|
||||
id, err := d.events.getLatestId(ctx, appServiceID)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, nil
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// UpdateTxnIDForEvents takes in an application service ID and a
|
||||
|
|
|
|||
|
|
@ -30,34 +30,26 @@ const (
|
|||
type ApplicationServiceWorkerState struct {
|
||||
AppService config.ApplicationService
|
||||
Cond *sync.Cond
|
||||
// Events ready to be sent
|
||||
EventsReady bool
|
||||
// Lastest incremental ID from appservice_events table that is ready to be sent to application service
|
||||
latestId int
|
||||
// Backoff exponent (2^x secs). Max 6, aka 64s.
|
||||
Backoff int
|
||||
}
|
||||
|
||||
// NotifyNewEvents wakes up all waiting goroutines, notifying that events remain
|
||||
// in the event queue for this application service worker.
|
||||
func (a *ApplicationServiceWorkerState) NotifyNewEvents() {
|
||||
func (a *ApplicationServiceWorkerState) NotifyNewEvents(id int) {
|
||||
a.Cond.L.Lock()
|
||||
a.EventsReady = true
|
||||
a.latestId = id
|
||||
a.Cond.Broadcast()
|
||||
a.Cond.L.Unlock()
|
||||
}
|
||||
|
||||
// FinishEventProcessing marks all events of this worker as being sent to the
|
||||
// application service.
|
||||
func (a *ApplicationServiceWorkerState) FinishEventProcessing() {
|
||||
a.Cond.L.Lock()
|
||||
a.EventsReady = false
|
||||
a.Cond.L.Unlock()
|
||||
}
|
||||
|
||||
// WaitForNewEvents causes the calling goroutine to wait on the worker state's
|
||||
// condition for a broadcast or similar wakeup, if there are no events ready.
|
||||
func (a *ApplicationServiceWorkerState) WaitForNewEvents() {
|
||||
func (a *ApplicationServiceWorkerState) WaitForNewEvents(id int) {
|
||||
a.Cond.L.Lock()
|
||||
if !a.EventsReady {
|
||||
if a.latestId <= id {
|
||||
a.Cond.Wait()
|
||||
}
|
||||
a.Cond.L.Unlock()
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ var (
|
|||
func SetupTransactionWorkers(
|
||||
client *http.Client,
|
||||
appserviceDB storage.Database,
|
||||
workerStates []types.ApplicationServiceWorkerState,
|
||||
workerStates []*types.ApplicationServiceWorkerState,
|
||||
) error {
|
||||
// Create a worker that handles transmitting events to a single homeserver
|
||||
for _, workerState := range workerStates {
|
||||
|
|
@ -58,31 +58,29 @@ func SetupTransactionWorkers(
|
|||
|
||||
// worker is a goroutine that sends any queued events to the application service
|
||||
// it is given.
|
||||
func worker(client *http.Client, db storage.Database, ws types.ApplicationServiceWorkerState) {
|
||||
func worker(client *http.Client, db storage.Database, ws *types.ApplicationServiceWorkerState) {
|
||||
log.WithFields(log.Fields{
|
||||
"appservice": ws.AppService.ID,
|
||||
}).Info("Starting application service")
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial check for any leftover events to send from last time
|
||||
eventCount, err := db.CountEventsWithAppServiceID(ctx, ws.AppService.ID)
|
||||
latestId, err := db.GetLatestId(ctx, ws.AppService.ID)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"appservice": ws.AppService.ID,
|
||||
}).WithError(err).Fatal("appservice worker unable to read queued events from DB")
|
||||
return
|
||||
}
|
||||
if eventCount > 0 {
|
||||
ws.NotifyNewEvents()
|
||||
}
|
||||
|
||||
ws.NotifyNewEvents(latestId)
|
||||
id := 0
|
||||
// Loop forever and keep waiting for more events to send
|
||||
for {
|
||||
// Wait for more events if we've sent all the events in the database
|
||||
ws.WaitForNewEvents()
|
||||
ws.WaitForNewEvents(id)
|
||||
|
||||
// Batch events up into a transaction
|
||||
transactionJSON, txnID, maxEventID, eventsRemaining, err := createTransaction(ctx, db, ws.AppService.ID)
|
||||
transactionJSON, txnID, maxEventID, err := createTransaction(ctx, db, ws.AppService.ID)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"appservice": ws.AppService.ID,
|
||||
|
|
@ -90,6 +88,10 @@ func worker(client *http.Client, db storage.Database, ws types.ApplicationServic
|
|||
|
||||
return
|
||||
}
|
||||
// Transactions have a maximum event size (or new events may arrive while
|
||||
// transaction is processed by Application Service), so there may still be
|
||||
// some events left over to send. We will keep sending if id < ws.latestID.
|
||||
id = maxEventID
|
||||
|
||||
// Send the events off to the application service
|
||||
// Backoff if the application service does not respond
|
||||
|
|
@ -99,19 +101,13 @@ func worker(client *http.Client, db storage.Database, ws types.ApplicationServic
|
|||
"appservice": ws.AppService.ID,
|
||||
}).WithError(err).Error("unable to send event")
|
||||
// Backoff
|
||||
backoff(&ws, err)
|
||||
backoff(ws, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// We sent successfully, hooray!
|
||||
ws.Backoff = 0
|
||||
|
||||
// Transactions have a maximum event size, so there may still be some events
|
||||
// left over to send. Keep sending until none are left
|
||||
if !eventsRemaining {
|
||||
ws.FinishEventProcessing()
|
||||
}
|
||||
|
||||
// Remove sent events from the DB
|
||||
err = db.RemoveEventsBeforeAndIncludingID(ctx, ws.AppService.ID, maxEventID)
|
||||
if err != nil {
|
||||
|
|
@ -152,11 +148,10 @@ func createTransaction(
|
|||
) (
|
||||
transactionJSON []byte,
|
||||
txnID, maxID int,
|
||||
eventsRemaining bool,
|
||||
err error,
|
||||
) {
|
||||
// Retrieve the latest events from the DB (will return old events if they weren't successfully sent)
|
||||
txnID, maxID, events, eventsRemaining, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize)
|
||||
txnID, maxID, events, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"appservice": appserviceID,
|
||||
|
|
@ -170,12 +165,12 @@ func createTransaction(
|
|||
// If not, grab next available ID from the DB
|
||||
txnID, err = db.GetLatestTxnID(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
// Mark new events with current transactionID
|
||||
if err = db.UpdateTxnIDForEvents(ctx, appserviceID, maxID, txnID); err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,4 +11,6 @@ const (
|
|||
LoginTypeRecaptcha = "m.login.recaptcha"
|
||||
LoginTypeApplicationService = "m.login.application_service"
|
||||
LoginTypeToken = "m.login.token"
|
||||
LoginTypeJwt = "org.matrix.login.jwt"
|
||||
LoginTypeEmail = "m.login.email.identity"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/ratelimit"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -33,7 +34,7 @@ import (
|
|||
// called after authorization has completed, with the result of the authorization.
|
||||
// If the final return value is non-nil, an error occurred and the cleanup function
|
||||
// is nil.
|
||||
func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserLoginAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) {
|
||||
func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.ClientUserAPI, cfg *config.ClientAPI, rt *ratelimit.RtFailedLogin) (*Login, LoginCleanupFunc, *util.JSONResponse) {
|
||||
reqBytes, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
err := &util.JSONResponse{
|
||||
|
|
@ -58,12 +59,17 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.U
|
|||
switch header.Type {
|
||||
case authtypes.LoginTypePassword:
|
||||
typ = &LoginTypePassword{
|
||||
GetAccountByPassword: useraccountAPI.QueryAccountByPassword,
|
||||
UserApi: useraccountAPI,
|
||||
Config: cfg,
|
||||
Rt: rt,
|
||||
}
|
||||
case authtypes.LoginTypeToken:
|
||||
typ = &LoginTypeToken{
|
||||
UserAPI: userAPI,
|
||||
UserAPI: useraccountAPI,
|
||||
Config: cfg,
|
||||
}
|
||||
case authtypes.LoginTypeJwt:
|
||||
typ = &LoginTypeTokenJwt{
|
||||
Config: cfg,
|
||||
}
|
||||
default:
|
||||
|
|
|
|||
74
clientapi/auth/login_jwt.go
Normal file
74
clientapi/auth/login_jwt.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// LoginTypeToken describes how to authenticate with a login token.
|
||||
type LoginTypeTokenJwt struct {
|
||||
// UserAPI uapi.LoginTokenInternalAPI
|
||||
Config *config.ClientAPI
|
||||
}
|
||||
|
||||
// Name implements Type.
|
||||
func (t *LoginTypeTokenJwt) Name() string {
|
||||
return authtypes.LoginTypeJwt
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
const mIdUser = "m.id.user"
|
||||
|
||||
// LoginFromJSON implements Type. The cleanup function deletes the token from
|
||||
// the database on success.
|
||||
func (t *LoginTypeTokenJwt) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) {
|
||||
var r loginTokenRequest
|
||||
if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if r.Token == "" {
|
||||
return nil, nil, &util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("Token field for JWT is missing"),
|
||||
}
|
||||
}
|
||||
c := &Claims{}
|
||||
token, err := jwt.ParseWithClaims(r.Token, c, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Method.Alg())
|
||||
}
|
||||
return t.Config.JwtConfig.SecretKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("jwt.ParseWithClaims failed")
|
||||
return nil, nil, &util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("Couldn't parse JWT"),
|
||||
}
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, nil, &util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("Invalid JWT"),
|
||||
}
|
||||
}
|
||||
|
||||
r.Login.Identifier.User = c.Subject
|
||||
r.Login.Identifier.Type = mIdUser
|
||||
|
||||
return &r.Login, func(context.Context, *util.JSONResponse) {}, nil
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/ratelimit"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -68,8 +69,11 @@ func TestLoginFromJSONReader(t *testing.T) {
|
|||
Matrix: &config.Global{
|
||||
ServerName: serverName,
|
||||
},
|
||||
RtFailedLogin: ratelimit.RtFailedLoginConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg)
|
||||
login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("LoginFromJSONReader failed: %+v", err)
|
||||
}
|
||||
|
|
@ -147,7 +151,7 @@ func TestBadLoginFromJSONReader(t *testing.T) {
|
|||
ServerName: serverName,
|
||||
},
|
||||
}
|
||||
_, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg)
|
||||
_, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil)
|
||||
if errRes == nil {
|
||||
cleanup(ctx, nil)
|
||||
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode)
|
||||
|
|
@ -159,6 +163,7 @@ func TestBadLoginFromJSONReader(t *testing.T) {
|
|||
}
|
||||
|
||||
type fakeUserInternalAPI struct {
|
||||
uapi.ClientUserAPI
|
||||
UserInternalAPIForLogin
|
||||
DeletedTokens []string
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*L
|
|||
}
|
||||
}
|
||||
|
||||
r.Login.Identifier.Type = "m.id.user"
|
||||
r.Login.Identifier.Type = mIdUser
|
||||
r.Login.Identifier.User = res.Data.UserID
|
||||
|
||||
cleanup := func(ctx context.Context, authRes *util.JSONResponse) {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/ratelimit"
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
|
@ -33,12 +34,17 @@ type GetAccountByPassword func(ctx context.Context, req *api.QueryAccountByPassw
|
|||
type PasswordRequest struct {
|
||||
Login
|
||||
Password string `json:"password"`
|
||||
Address string `json:"address"`
|
||||
Medium string `json:"medium"`
|
||||
}
|
||||
|
||||
const email = "email"
|
||||
|
||||
// LoginTypePassword implements https://matrix.org/docs/spec/client_server/r0.6.1#password-based
|
||||
type LoginTypePassword struct {
|
||||
GetAccountByPassword GetAccountByPassword
|
||||
UserApi api.ClientUserAPI
|
||||
Config *config.ClientAPI
|
||||
Rt *ratelimit.RtFailedLogin
|
||||
}
|
||||
|
||||
func (t *LoginTypePassword) Name() string {
|
||||
|
|
@ -61,7 +67,35 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
|
|||
|
||||
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
|
||||
r := req.(*PasswordRequest)
|
||||
username := strings.ToLower(r.Username())
|
||||
if r.Identifier.Address != "" {
|
||||
r.Address = r.Identifier.Address
|
||||
}
|
||||
if r.Identifier.Medium != "" {
|
||||
r.Medium = r.Identifier.Medium
|
||||
}
|
||||
var username string
|
||||
if r.Medium == email && r.Address != "" {
|
||||
r.Address = strings.ToLower(r.Address)
|
||||
res := api.QueryLocalpartForThreePIDResponse{}
|
||||
err := t.UserApi.QueryLocalpartForThreePID(ctx, &api.QueryLocalpartForThreePIDRequest{
|
||||
ThreePID: r.Address,
|
||||
Medium: email,
|
||||
}, &res)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userApi.QueryLocalpartForThreePID failed")
|
||||
resp := jsonerror.InternalServerError()
|
||||
return nil, &resp
|
||||
}
|
||||
username = res.Localpart
|
||||
if username == "" {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: jsonerror.Forbidden("Invalid username or password"),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
username = strings.ToLower(r.Username())
|
||||
}
|
||||
if username == "" {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
|
|
@ -77,7 +111,17 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
|
|||
}
|
||||
// Squash username to all lowercase letters
|
||||
res := &api.QueryAccountByPasswordResponse{}
|
||||
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res)
|
||||
localpart = strings.ToLower(localpart)
|
||||
if t.Rt != nil {
|
||||
ok, retryIn := t.Rt.CanAct(localpart)
|
||||
if !ok {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusTooManyRequests,
|
||||
JSON: jsonerror.LimitExceeded("Too Many Requests", retryIn.Milliseconds()),
|
||||
}
|
||||
}
|
||||
}
|
||||
err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: localpart, PlaintextPassword: r.Password}, res)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
|
|
@ -86,7 +130,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
|
|||
}
|
||||
|
||||
if !res.Exists {
|
||||
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
|
||||
err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
|
||||
Localpart: localpart,
|
||||
PlaintextPassword: r.Password,
|
||||
}, res)
|
||||
|
|
@ -99,11 +143,15 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
|
|||
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
||||
// but that would leak the existence of the user.
|
||||
if !res.Exists {
|
||||
if t.Rt != nil {
|
||||
t.Rt.Act(localpart)
|
||||
}
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."),
|
||||
JSON: jsonerror.Forbidden("Invalid username or password"),
|
||||
}
|
||||
}
|
||||
}
|
||||
r.Login.User = username
|
||||
return &r.Login, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ type Login struct {
|
|||
|
||||
// Username returns the user localpart/user_id in this request, if it exists.
|
||||
func (r *Login) Username() string {
|
||||
if r.Identifier.Type == "m.id.user" {
|
||||
if r.Identifier.Type == mIdUser {
|
||||
return r.Identifier.User
|
||||
}
|
||||
// deprecated but without it Element iOS won't log in
|
||||
|
|
@ -87,8 +87,8 @@ func (r *Login) ThirdPartyID() (medium, address string) {
|
|||
return r.Identifier.Medium, r.Identifier.Address
|
||||
}
|
||||
// deprecated
|
||||
if r.Medium == "email" {
|
||||
return "email", r.Address
|
||||
if r.Medium == email {
|
||||
return email, r.Address
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
|
@ -109,9 +109,9 @@ type UserInteractive struct {
|
|||
Sessions map[string][]string
|
||||
}
|
||||
|
||||
func NewUserInteractive(userAccountAPI api.UserLoginAPI, cfg *config.ClientAPI) *UserInteractive {
|
||||
func NewUserInteractive(userAccountAPI api.ClientUserAPI, cfg *config.ClientAPI) *UserInteractive {
|
||||
typePassword := &LoginTypePassword{
|
||||
GetAccountByPassword: userAccountAPI.QueryAccountByPassword,
|
||||
UserApi: userAccountAPI,
|
||||
Config: cfg,
|
||||
}
|
||||
return &UserInteractive{
|
||||
|
|
|
|||
|
|
@ -24,7 +24,9 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
type fakeAccountDatabase struct{}
|
||||
type fakeAccountDatabase struct {
|
||||
api.ClientUserAPI
|
||||
}
|
||||
|
||||
func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
||||
return nil
|
||||
|
|
|
|||
117
clientapi/ratelimit/rt_failed_login.go
Normal file
117
clientapi/ratelimit/rt_failed_login.go
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type rateLimit struct {
|
||||
cfg *RtFailedLoginConfig
|
||||
times *list.List
|
||||
}
|
||||
|
||||
type RtFailedLogin struct {
|
||||
cfg *RtFailedLoginConfig
|
||||
mtx sync.RWMutex
|
||||
rts map[string]*rateLimit
|
||||
}
|
||||
|
||||
type RtFailedLoginConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Limit int `yaml:"burst"`
|
||||
Interval time.Duration `yaml:"interval"`
|
||||
}
|
||||
|
||||
// New creates a new rate limiter for the limit and interval.
|
||||
func NewRtFailedLogin(cfg *RtFailedLoginConfig) *RtFailedLogin {
|
||||
if !cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
rt := &RtFailedLogin{
|
||||
cfg: cfg,
|
||||
mtx: sync.RWMutex{},
|
||||
rts: make(map[string]*rateLimit),
|
||||
}
|
||||
go rt.clean()
|
||||
return rt
|
||||
}
|
||||
|
||||
// CanAct is expected to be called before Act
|
||||
func (r *RtFailedLogin) CanAct(key string) (ok bool, remaining time.Duration) {
|
||||
r.mtx.RLock()
|
||||
rt, ok := r.rts[key]
|
||||
if !ok {
|
||||
r.mtx.RUnlock()
|
||||
return true, 0
|
||||
}
|
||||
ok, remaining = rt.canAct()
|
||||
r.mtx.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Act can be called after CanAct returns true.
|
||||
func (r *RtFailedLogin) Act(key string) {
|
||||
r.mtx.Lock()
|
||||
rt, ok := r.rts[key]
|
||||
if !ok {
|
||||
rt = &rateLimit{
|
||||
cfg: r.cfg,
|
||||
times: list.New(),
|
||||
}
|
||||
r.rts[key] = rt
|
||||
}
|
||||
rt.act()
|
||||
r.mtx.Unlock()
|
||||
}
|
||||
|
||||
func (r *RtFailedLogin) clean() {
|
||||
for {
|
||||
r.mtx.Lock()
|
||||
for k, v := range r.rts {
|
||||
if v.empty() {
|
||||
delete(r.rts, k)
|
||||
}
|
||||
}
|
||||
r.mtx.Unlock()
|
||||
time.Sleep(time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rateLimit) empty() bool {
|
||||
back := r.times.Back()
|
||||
if back == nil {
|
||||
return true
|
||||
}
|
||||
v := back.Value
|
||||
b := v.(time.Time)
|
||||
now := time.Now()
|
||||
return now.Sub(b) > r.cfg.Interval
|
||||
}
|
||||
|
||||
func (r *rateLimit) canAct() (ok bool, remaining time.Duration) {
|
||||
now := time.Now()
|
||||
l := r.times.Len()
|
||||
if l < r.cfg.Limit {
|
||||
return true, 0
|
||||
}
|
||||
frnt := r.times.Front()
|
||||
t := frnt.Value.(time.Time)
|
||||
diff := now.Sub(t)
|
||||
if diff < r.cfg.Interval {
|
||||
return false, r.cfg.Interval - diff
|
||||
}
|
||||
return true, 0
|
||||
}
|
||||
|
||||
func (r *rateLimit) act() {
|
||||
now := time.Now()
|
||||
l := r.times.Len()
|
||||
if l < r.cfg.Limit {
|
||||
r.times.PushBack(now)
|
||||
return
|
||||
}
|
||||
frnt := r.times.Front()
|
||||
frnt.Value = now
|
||||
r.times.MoveToBack(frnt)
|
||||
}
|
||||
40
clientapi/ratelimit/rt_failed_login_test.go
Normal file
40
clientapi/ratelimit/rt_failed_login_test.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matryer/is"
|
||||
)
|
||||
|
||||
func TestRtFailedLogin(t *testing.T) {
|
||||
is := is.New(t)
|
||||
rtfl := NewRtFailedLogin(&RtFailedLoginConfig{
|
||||
Enabled: true,
|
||||
Limit: 3,
|
||||
Interval: 10 * time.Millisecond,
|
||||
})
|
||||
var (
|
||||
can bool
|
||||
remaining time.Duration
|
||||
remainingB time.Duration
|
||||
)
|
||||
for i := 0; i < 3; i++ {
|
||||
can, remaining = rtfl.CanAct("foo")
|
||||
is.True(can)
|
||||
is.Equal(remaining, time.Duration(0))
|
||||
rtfl.Act("foo")
|
||||
}
|
||||
can, remaining = rtfl.CanAct("foo")
|
||||
is.True(!can)
|
||||
is.True(remaining > time.Millisecond*9)
|
||||
can, remainingB = rtfl.CanAct("bar")
|
||||
is.True(can)
|
||||
is.Equal(remainingB, time.Duration(0))
|
||||
rtfl.Act("bar")
|
||||
rtfl.Act("bar")
|
||||
time.Sleep(remaining + time.Millisecond)
|
||||
can, remaining = rtfl.CanAct("foo")
|
||||
is.True(can)
|
||||
is.Equal(remaining, time.Duration(0))
|
||||
}
|
||||
|
|
@ -63,7 +63,7 @@ func UploadCrossSigningDeviceKeys(
|
|||
}
|
||||
}
|
||||
typePassword := auth.LoginTypePassword{
|
||||
GetAccountByPassword: accountAPI.QueryAccountByPassword,
|
||||
UserApi: accountAPI,
|
||||
Config: cfg,
|
||||
}
|
||||
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/ratelimit"
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
|
|
@ -55,6 +56,7 @@ func passwordLogin() flows {
|
|||
func Login(
|
||||
req *http.Request, userAPI userapi.ClientUserAPI,
|
||||
cfg *config.ClientAPI,
|
||||
rt *ratelimit.RtFailedLogin,
|
||||
) util.JSONResponse {
|
||||
if req.Method == http.MethodGet {
|
||||
// TODO: support other forms of login other than password, depending on config options
|
||||
|
|
@ -63,7 +65,7 @@ func Login(
|
|||
JSON: passwordLogin(),
|
||||
}
|
||||
} else if req.Method == http.MethodPost {
|
||||
login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, userAPI, cfg)
|
||||
login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, cfg, rt)
|
||||
if authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/threepid"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -24,6 +26,7 @@ type newPasswordAuth struct {
|
|||
Type string `json:"type"`
|
||||
Session string `json:"session"`
|
||||
auth.PasswordRequest
|
||||
ThreePidCreds threepid.Credentials `json:"threepid_creds"`
|
||||
}
|
||||
|
||||
func Password(
|
||||
|
|
@ -33,13 +36,17 @@ func Password(
|
|||
cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
// Check that the existing password is right.
|
||||
var fields logrus.Fields
|
||||
if device != nil {
|
||||
fields = logrus.Fields{
|
||||
"sessionId": device.SessionID,
|
||||
"userId": device.UserID,
|
||||
}
|
||||
}
|
||||
var r newPasswordRequest
|
||||
r.LogoutDevices = true
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"sessionId": device.SessionID,
|
||||
"userId": device.UserID,
|
||||
}).Debug("Changing password")
|
||||
logrus.WithFields(fields).Debug("Changing password")
|
||||
|
||||
// Unmarshal the request.
|
||||
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||
|
|
@ -53,45 +60,95 @@ func Password(
|
|||
// Generate a new, random session ID
|
||||
sessionID = util.RandomString(sessionIDLength)
|
||||
}
|
||||
|
||||
// Require password auth to change the password.
|
||||
if r.Auth.Type != authtypes.LoginTypePassword {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: newUserInteractiveResponse(
|
||||
sessionID,
|
||||
[]authtypes.Flow{
|
||||
{
|
||||
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
var localpart string
|
||||
switch r.Auth.Type {
|
||||
case authtypes.LoginTypePassword:
|
||||
// Check if the existing password is correct.
|
||||
typePassword := auth.LoginTypePassword{
|
||||
GetAccountByPassword: userAPI.QueryAccountByPassword,
|
||||
UserApi: userAPI,
|
||||
Config: cfg,
|
||||
}
|
||||
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
// Get the local part.
|
||||
var err error
|
||||
localpart, _, err = gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
case authtypes.LoginTypeEmail:
|
||||
threePid := &authtypes.ThreePID{}
|
||||
r.Auth.ThreePidCreds.IDServer = cfg.ThreePidDelegate
|
||||
var (
|
||||
bound bool
|
||||
err error
|
||||
)
|
||||
bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
if !bound {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.MatrixError{
|
||||
ErrCode: "M_THREEPID_AUTH_FAILED",
|
||||
Err: "Failed to auth 3pid",
|
||||
},
|
||||
}
|
||||
}
|
||||
var res api.QueryLocalpartForThreePIDResponse
|
||||
err = userAPI.QueryLocalpartForThreePID(req.Context(), &api.QueryLocalpartForThreePIDRequest{
|
||||
Medium: threePid.Medium,
|
||||
ThreePID: threePid.Address,
|
||||
}, &res)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryLocalpartForThreePID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
if res.Localpart == "" {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.MatrixError{
|
||||
ErrCode: "M_THREEPID_NOT_FOUND",
|
||||
Err: "3pid is not bound to any account",
|
||||
},
|
||||
}
|
||||
}
|
||||
localpart = res.Localpart
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail)
|
||||
default:
|
||||
flows := []authtypes.Flow{
|
||||
{
|
||||
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
|
||||
},
|
||||
}
|
||||
if cfg.ThreePidDelegate != "" {
|
||||
flows = append(flows, authtypes.Flow{
|
||||
Stages: []authtypes.LoginType{authtypes.LoginTypeEmail},
|
||||
})
|
||||
}
|
||||
// Require password auth to change the password.
|
||||
if r.Auth.Type == authtypes.LoginTypePassword {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: newUserInteractiveResponse(
|
||||
sessionID,
|
||||
flows,
|
||||
nil,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check the new password strength.
|
||||
if resErr = validatePassword(r.NewPassword); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
// Get the local part.
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
// Ask the user API to perform the password change.
|
||||
passwordReq := &api.PerformPasswordUpdateRequest{
|
||||
Localpart: localpart,
|
||||
|
|
@ -109,12 +166,24 @@ func Password(
|
|||
|
||||
// If the request asks us to log out all other devices then
|
||||
// ask the user API to do that.
|
||||
|
||||
if r.LogoutDevices {
|
||||
logoutReq := &api.PerformDeviceDeletionRequest{
|
||||
var logoutReq *api.PerformDeviceDeletionRequest
|
||||
var sessionId int64
|
||||
if device == nil {
|
||||
logoutReq = &api.PerformDeviceDeletionRequest{
|
||||
UserID: fmt.Sprintf("@%s:%s", localpart, cfg.Matrix.ServerName),
|
||||
DeviceIDs: []string{},
|
||||
}
|
||||
sessionId = 0
|
||||
} else {
|
||||
logoutReq = &api.PerformDeviceDeletionRequest{
|
||||
UserID: device.UserID,
|
||||
DeviceIDs: nil,
|
||||
ExceptDeviceID: device.ID,
|
||||
}
|
||||
sessionId = device.SessionID
|
||||
}
|
||||
logoutRes := &api.PerformDeviceDeletionResponse{}
|
||||
if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
|
||||
|
|
@ -123,7 +192,7 @@ func Password(
|
|||
|
||||
pushersReq := &api.PerformPusherDeletionRequest{
|
||||
Localpart: localpart,
|
||||
SessionID: device.SessionID,
|
||||
SessionID: sessionId,
|
||||
}
|
||||
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/threepid"
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
|
@ -237,6 +238,7 @@ type authDict struct {
|
|||
// Recaptcha
|
||||
Response string `json:"response"`
|
||||
// TODO: Lots of custom keys depending on the type
|
||||
ThreePidCreds threepid.Credentials `json:"threepid_creds"`
|
||||
}
|
||||
|
||||
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api
|
||||
|
|
@ -745,6 +747,7 @@ func handleRegistrationFlow(
|
|||
}
|
||||
}
|
||||
|
||||
var threePid *authtypes.ThreePID
|
||||
switch r.Auth.Type {
|
||||
case authtypes.LoginTypeRecaptcha:
|
||||
// Check given captcha response
|
||||
|
|
@ -761,6 +764,29 @@ func handleRegistrationFlow(
|
|||
// Add Dummy to the list of completed registration stages
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
|
||||
|
||||
case authtypes.LoginTypeEmail:
|
||||
threePid = &authtypes.ThreePID{}
|
||||
r.Auth.ThreePidCreds.IDServer = cfg.ThreePidDelegate
|
||||
var (
|
||||
bound bool
|
||||
err error
|
||||
)
|
||||
bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
if !bound {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.MatrixError{
|
||||
ErrCode: "M_THREEPID_AUTH_FAILED",
|
||||
Err: "Failed to auth 3pid",
|
||||
},
|
||||
}
|
||||
}
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail)
|
||||
|
||||
case "":
|
||||
// An empty auth type means that we want to fetch the available
|
||||
// flows. It can also mean that we want to register as an appservice
|
||||
|
|
@ -776,7 +802,7 @@ func handleRegistrationFlow(
|
|||
// A response with current registration flow and remaining available methods
|
||||
// will be returned if a flow has not been successfully completed yet
|
||||
return checkAndCompleteFlow(sessions.getCompletedStages(sessionID),
|
||||
req, r, sessionID, cfg, userAPI)
|
||||
req, r, sessionID, cfg, userAPI, threePid)
|
||||
}
|
||||
|
||||
// handleApplicationServiceRegistration handles the registration of an
|
||||
|
|
@ -818,7 +844,7 @@ func handleApplicationServiceRegistration(
|
|||
// application service registration is entirely separate.
|
||||
return completeRegistration(
|
||||
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -832,12 +858,13 @@ func checkAndCompleteFlow(
|
|||
sessionID string,
|
||||
cfg *config.ClientAPI,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
threePid *authtypes.ThreePID,
|
||||
) util.JSONResponse {
|
||||
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
||||
// This flow was completed, registration can continue
|
||||
return completeRegistration(
|
||||
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, threePid,
|
||||
)
|
||||
}
|
||||
sessions.addParams(sessionID, r)
|
||||
|
|
@ -863,6 +890,7 @@ func completeRegistration(
|
|||
inhibitLogin eventutil.WeakBoolean,
|
||||
displayName, deviceID *string,
|
||||
accType userapi.AccountType,
|
||||
threePid *authtypes.ThreePID,
|
||||
) util.JSONResponse {
|
||||
if username == "" {
|
||||
return util.JSONResponse{
|
||||
|
|
@ -901,6 +929,21 @@ func completeRegistration(
|
|||
// Increment prometheus counter for created users
|
||||
amtRegUsers.Inc()
|
||||
|
||||
// TODO-entry refuse register if threepid is already bound to account.
|
||||
if threePid != nil {
|
||||
err = userAPI.PerformSaveThreePIDAssociation(ctx, &userapi.PerformSaveThreePIDAssociationRequest{
|
||||
Medium: threePid.Medium,
|
||||
ThreePID: threePid.Address,
|
||||
Localpart: accRes.Account.Localpart,
|
||||
}, &struct{}{})
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.Unknown("Failed to save 3PID association: " + err.Error()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check whether inhibit_login option is set. If so, don't create an access
|
||||
// token or a device for this user
|
||||
if inhibitLogin {
|
||||
|
|
@ -1092,5 +1135,5 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec
|
|||
if ssrr.Admin {
|
||||
accType = userapi.AccountTypeAdmin
|
||||
}
|
||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
|
||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import (
|
|||
clientutil "github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||
"github.com/matrix-org/dendrite/clientapi/ratelimit"
|
||||
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/internal/transactions"
|
||||
|
|
@ -65,6 +66,7 @@ func Setup(
|
|||
prometheus.MustRegister(amtRegUsers, sendEventDuration)
|
||||
|
||||
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
|
||||
rateLimitsFailedLogin := ratelimit.NewRtFailedLogin(&cfg.RtFailedLogin)
|
||||
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
|
||||
|
||||
unstableFeatures := map[string]bool{
|
||||
|
|
@ -538,7 +540,7 @@ func Setup(
|
|||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/account/password",
|
||||
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeConditionalAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
return *r
|
||||
}
|
||||
|
|
@ -562,7 +564,7 @@ func Setup(
|
|||
if r := rateLimits.Limit(req, nil); r != nil {
|
||||
return *r
|
||||
}
|
||||
return Login(req, userAPI, cfg)
|
||||
return Login(req, userAPI, cfg, rateLimitsFailedLogin)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||
|
||||
|
|
|
|||
|
|
@ -103,11 +103,8 @@ func CreateSession(
|
|||
func CheckAssociation(
|
||||
ctx context.Context, creds Credentials, cfg *config.ClientAPI,
|
||||
) (bool, string, string, error) {
|
||||
if err := isTrusted(creds.IDServer, cfg); err != nil {
|
||||
return false, "", "", err
|
||||
}
|
||||
|
||||
requestURL := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret)
|
||||
requestURL := fmt.Sprintf("%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", cfg.ThreePidDelegate, creds.SID, creds.Secret)
|
||||
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return false, "", "", err
|
||||
|
|
|
|||
8
cmd/dendrite-monolith-server/Dockerfile.dev
Normal file
8
cmd/dendrite-monolith-server/Dockerfile.dev
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
FROM alpine:latest
|
||||
|
||||
COPY dendrite-monolith-server /usr/bin/
|
||||
|
||||
VOLUME /etc/dendrite
|
||||
WORKDIR /etc/dendrite
|
||||
|
||||
ENTRYPOINT ["/usr/bin/dendrite-monolith-server"]
|
||||
12
cmd/dendrite-monolith-server/build_dev.sh
Executable file
12
cmd/dendrite-monolith-server/build_dev.sh
Executable file
|
|
@ -0,0 +1,12 @@
|
|||
set -xe
|
||||
if [ -z "$(git status --porcelain)" ]; then
|
||||
CGO_ENABLED=0 go build .
|
||||
TAG=$(git rev-parse --short HEAD)
|
||||
docker build -f Dockerfile.dev -t gcr.io/globekeeper-development/dendrite-monolith:$TAG -t gcr.io/globekeeper-development/dendrite-monolith -t gcr.io/globekeeper-production/dendrite-monolith:$TAG .
|
||||
docker push gcr.io/globekeeper-development/dendrite-monolith:$TAG
|
||||
docker push gcr.io/globekeeper-production/dendrite-monolith:$TAG
|
||||
docker push gcr.io/globekeeper-development/dendrite-monolith
|
||||
else
|
||||
echo "Please commit changes"
|
||||
exit 0
|
||||
fi
|
||||
2
go.mod
2
go.mod
|
|
@ -21,6 +21,7 @@ require (
|
|||
github.com/frankban/quicktest v1.14.3 // indirect
|
||||
github.com/getsentry/sentry-go v0.13.0
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.4.1
|
||||
github.com/gologme/log v1.3.0
|
||||
github.com/google/go-cmp v0.5.8
|
||||
github.com/google/uuid v1.3.0
|
||||
|
|
@ -37,6 +38,7 @@ require (
|
|||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220607143425-e55d796fd0b3
|
||||
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/matryer/is v1.4.0
|
||||
github.com/mattn/go-sqlite3 v1.14.13
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
|
||||
github.com/miekg/dns v1.1.49 // indirect
|
||||
|
|
|
|||
5
go.sum
5
go.sum
|
|
@ -202,7 +202,10 @@ github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/E
|
|||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang-jwt/jwt/v4 v4.4.1 h1:pC5DB52sCeK48Wlb9oPcdhnjkz1TKt1D/P7WKJ0kUcQ=
|
||||
github.com/golang-jwt/jwt/v4 v4.4.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
|
|
@ -425,6 +428,8 @@ github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJz
|
|||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
|
||||
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
|
||||
github.com/mattn/go-colorable v0.0.6/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
||||
|
|
|
|||
|
|
@ -83,6 +83,57 @@ func MakeAuthAPI(
|
|||
return MakeExternalAPI(metricsName, h)
|
||||
}
|
||||
|
||||
// MakeConditionalAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
|
||||
// It passes nil device if header is not provided.
|
||||
func MakeConditionalAuthAPI(
|
||||
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
|
||||
f func(*http.Request, *userapi.Device) util.JSONResponse,
|
||||
) http.Handler {
|
||||
h := func(req *http.Request) util.JSONResponse {
|
||||
var (
|
||||
jsonRes util.JSONResponse
|
||||
dev *userapi.Device
|
||||
)
|
||||
if _, err := auth.ExtractAccessToken(req); err != nil {
|
||||
dev = nil
|
||||
} else {
|
||||
logger := util.GetLogger(req.Context())
|
||||
var err *util.JSONResponse
|
||||
dev, err = auth.VerifyUserFromRequest(req, userAPI)
|
||||
if err != nil {
|
||||
logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code)
|
||||
return *err
|
||||
}
|
||||
// add the user ID to the logger
|
||||
logger = logger.WithField("user_id", dev.UserID)
|
||||
req = req.WithContext(util.ContextWithLogger(req.Context(), logger))
|
||||
}
|
||||
// add the user to Sentry, if enabled
|
||||
hub := sentry.GetHubFromContext(req.Context())
|
||||
if hub != nil {
|
||||
hub.Scope().SetTag("user_id", dev.UserID)
|
||||
hub.Scope().SetTag("device_id", dev.ID)
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if hub != nil {
|
||||
hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
|
||||
}
|
||||
// re-panic to return the 500
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
jsonRes = f(req, dev)
|
||||
// do not log 4xx as errors as they are client fails, not server fails
|
||||
if hub != nil && jsonRes.Code >= 500 {
|
||||
hub.Scope().SetExtra("response", jsonRes)
|
||||
hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code))
|
||||
}
|
||||
return jsonRes
|
||||
}
|
||||
return MakeExternalAPI(metricsName, h)
|
||||
}
|
||||
|
||||
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
|
||||
// This is used for APIs that are called from the internet.
|
||||
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ package config
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -252,6 +253,15 @@ func loadConfig(
|
|||
|
||||
c.Global.OldVerifyKeys[i].KeyID, c.Global.OldVerifyKeys[i].PrivateKey = keyID, privateKey
|
||||
}
|
||||
if c.ClientAPI.JwtConfig.Enabled {
|
||||
pubPki, _ := pem.Decode([]byte(c.ClientAPI.JwtConfig.Secret))
|
||||
var pub interface{}
|
||||
pub, err = x509.ParsePKIXPublicKey(pubPki.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.ClientAPI.JwtConfig.SecretKey = pub.(ed25519.PublicKey)
|
||||
}
|
||||
|
||||
c.MediaAPI.AbsBasePath = Path(absPath(basePath, c.MediaAPI.BasePath))
|
||||
|
||||
|
|
@ -283,7 +293,10 @@ func (config *Dendrite) Derive() error {
|
|||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
|
||||
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}})
|
||||
}
|
||||
|
||||
if config.ClientAPI.ThreePidDelegate != "" {
|
||||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
|
||||
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeEmail}})
|
||||
}
|
||||
// Load application service configuration files
|
||||
if err := loadAppServices(&config.AppServiceAPI, &config.Derived); err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@ package config
|
|||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/ratelimit"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
type ClientAPI struct {
|
||||
|
|
@ -47,8 +50,22 @@ type ClientAPI struct {
|
|||
|
||||
// Rate-limiting options
|
||||
RateLimiting RateLimiting `yaml:"rate_limiting"`
|
||||
RtFailedLogin ratelimit.RtFailedLoginConfig `yaml:"rate_limiting_failed_login"`
|
||||
|
||||
MSCs *MSCs `yaml:"mscs"`
|
||||
|
||||
ThreePidDelegate string `yaml:"three_pid_delegate"`
|
||||
|
||||
JwtConfig JwtConfig `yaml:"jwt_config"`
|
||||
}
|
||||
|
||||
type JwtConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Algorithm string `yaml:"algorithm"`
|
||||
Issuer string `yaml:"issuer"`
|
||||
Secret string `yaml:"secret"`
|
||||
SecretKey ed25519.PublicKey
|
||||
Audiences []string `yaml:"audiences"`
|
||||
}
|
||||
|
||||
func (c *ClientAPI) Defaults(generate bool) {
|
||||
|
|
|
|||
|
|
@ -517,7 +517,7 @@ type PerformPusherSetRequest struct {
|
|||
|
||||
type PerformPusherDeletionRequest struct {
|
||||
Localpart string
|
||||
SessionID int64
|
||||
SessionID int64 // Pusher corresponding to this SessionID will not be deleted
|
||||
}
|
||||
|
||||
// Pusher represents a push notification subscriber
|
||||
|
|
|
|||
Loading…
Reference in a new issue