Merge branch 'main' into release/upstream

This commit is contained in:
Piotr Kozimor 2022-08-22 14:45:25 +02:00
commit 4aaa80a56e
73 changed files with 1054 additions and 360 deletions

12
.cloudbuild/dev.yaml Normal file
View file

@ -0,0 +1,12 @@
steps:
- name: gcr.io/cloud-builders/docker
args: ['build', '-t', 'gcr.io/$PROJECT_ID/dendrite-monolith:$COMMIT_SHA', '-f', 'build/docker/Dockerfile.monolith', '.']
- name: gcr.io/cloud-builders/kubectl
args: ['-n', 'dendrite', 'set', 'image', 'deployment/dendrite', 'dendrite=gcr.io/$PROJECT_ID/dendrite-monolith:$COMMIT_SHA']
env:
- CLOUDSDK_CORE_PROJECT=globekeeper-development
- CLOUDSDK_COMPUTE_ZONE=europe-west2-a
- CLOUDSDK_CONTAINER_CLUSTER=synapse
images:
- gcr.io/$PROJECT_ID/dendrite-monolith:$COMMIT_SHA
timeout: 360s

12
.cloudbuild/prod.yaml Normal file
View file

@ -0,0 +1,12 @@
steps:
- name: gcr.io/cloud-builders/docker
args: ['build', '-t', 'gcr.io/$PROJECT_ID/dendrite-monolith:$TAG_NAME', '-f', 'build/docker/Dockerfile.monolith', '.']
- name: gcr.io/cloud-builders/kubectl
args: ['set', 'image', 'deployment/dendrite', 'dendrite=gcr.io/$PROJECT_ID/dendrite-monolith:$TAG_NAME']
env:
- CLOUDSDK_CORE_PROJECT=globekeeper-production
- CLOUDSDK_COMPUTE_ZONE=europe-west2-a
- CLOUDSDK_CONTAINER_CLUSTER=synapse-production
images:
- gcr.io/$PROJECT_ID/dendrite-monolith:$TAG_NAME
timeout: 360s

View file

@ -13,51 +13,6 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
wasm:
name: WASM build test
timeout-minutes: 5
runs-on: ubuntu-latest
if: ${{ false }} # disable for now
steps:
- uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v3
with:
go-version: 1.18
- uses: actions/cache@v2
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-wasm-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-wasm
- name: Install Node
uses: actions/setup-node@v2
with:
node-version: 14
- uses: actions/cache@v2
with:
path: ~/.npm
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-
- name: Reconfigure Git to use HTTPS auth for repo packages
run: >
git config --global url."https://github.com/".insteadOf
ssh://git@github.com/
- name: Install test dependencies
working-directory: ./test/wasm
run: npm ci
- name: Test
run: ./test-dendritejs.sh
# Run golangci-lint # Run golangci-lint
lint: lint:
@ -73,7 +28,7 @@ jobs:
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v3 uses: golangci/golangci-lint-action@v3
# run go test with different go versions # run go test with go 1.18
test: test:
timeout-minutes: 5 timeout-minutes: 5
name: Unit tests (Go ${{ matrix.go }}) name: Unit tests (Go ${{ matrix.go }})
@ -123,7 +78,7 @@ jobs:
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite POSTGRES_DB: dendrite
# build Dendrite for linux with different architectures and go versions # build Dendrite for linux amd64 with go 1.18
build: build:
name: Build for Linux name: Build for Linux
timeout-minutes: 10 timeout-minutes: 10
@ -133,7 +88,7 @@ jobs:
matrix: matrix:
go: ["1.18", "1.19"] go: ["1.18", "1.19"]
goos: ["linux"] goos: ["linux"]
goarch: ["amd64", "386"] goarch: ["amd64"]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup go - name: Setup go
@ -158,43 +113,10 @@ jobs:
CGO_CFLAGS: -fno-stack-protector CGO_CFLAGS: -fno-stack-protector
run: go build -trimpath -v -o "bin/" ./cmd/... run: go build -trimpath -v -o "bin/" ./cmd/...
# build for Windows 64-bit
build_windows:
name: Build for Windows
timeout-minutes: 10
runs-on: ubuntu-latest
strategy:
matrix:
go: ["1.18", "1.19"]
goos: ["windows"]
goarch: ["amd64"]
steps:
- uses: actions/checkout@v3
- name: Setup Go ${{ matrix.go }}
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}
- env:
GOOS: ${{ matrix.goos }}
GOARCH: ${{ matrix.goarch }}
CGO_ENABLED: 1
CC: "/usr/bin/x86_64-w64-mingw32-gcc"
run: go build -trimpath -v -o "bin/" ./cmd/...
# Dummy step to gate other tests on without repeating the whole list # Dummy step to gate other tests on without repeating the whole list
initial-tests-done: initial-tests-done:
name: Initial tests passed name: Initial tests passed
needs: [lint, test, build, build_windows] needs: [lint, test, build]
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
steps: steps:
@ -299,13 +221,12 @@ jobs:
run: /src/are-we-synapse-yet.py /logs/results.tap -v run: /src/are-we-synapse-yet.py /logs/results.tap -v
continue-on-error: true # not fatal continue-on-error: true # not fatal
- name: Upload Sytest logs - name: Upload Sytest logs
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v3
if: ${{ always() }} if: ${{ always() }}
with: with:
name: Sytest Logs - ${{ job.status }} - (Dendrite, ${{ join(matrix.*, ', ') }}) name: Sytest Logs - ${{ job.status }} - (Dendrite, ${{ join(matrix.*, ', ') }})
path: | path: |
/logs/results.tap /logs
/logs/**/*.log*
# run Complement # run Complement
complement: complement:
@ -369,7 +290,7 @@ jobs:
continue continue
fi fi
(wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break (wget -O - "https://github.com/globekeeper/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break
done done
# Build initial Dendrite image # Build initial Dendrite image

6
.gitignore vendored
View file

@ -2,6 +2,8 @@
# Hidden files # Hidden files
.* .*
!.vscode
!.cloudbuild
# Allow GitHub config # Allow GitHub config
!.github !.github
@ -73,3 +75,7 @@ complement/
docs/_site docs/_site
media_store/ media_store/
__debug_bin
cmd/dendrite-monolith-server/dendrite-monolith-server

16
.vscode/launch.json vendored Normal file
View file

@ -0,0 +1,16 @@
{
"configurations": [
{
"name": "Launch Package",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/cmd/dendrite-monolith-server",
"args": [
"-really-enable-open-registration",
"-config",
"../../../adminas/.ci/config/dendrite-local/dendrite.yaml"
],
}
]
}

9
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,9 @@
{
"go.lintTool": "golangci-lint",
"go.testEnvVars": {
"POSTGRES_HOST": "localhost",
"POSTGRES_USER": "postgres",
"POSTGRES_PASSWORD": "foobar",
"POSTGRES_DB": "postgres"
}
}

View file

@ -70,14 +70,14 @@ func NewInternalAPI(
// Wrap application services in a type that relates the application service and // Wrap application services in a type that relates the application service and
// a sync.Cond object that can be used to notify workers when there are new // a sync.Cond object that can be used to notify workers when there are new
// events to be sent out. // events to be sent out.
workerStates := make([]types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices)) workerStates := make([]*types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices))
for i, appservice := range base.Cfg.Derived.ApplicationServices { for i, appservice := range base.Cfg.Derived.ApplicationServices {
m := sync.Mutex{} m := sync.Mutex{}
ws := types.ApplicationServiceWorkerState{ ws := types.ApplicationServiceWorkerState{
AppService: appservice, AppService: appservice,
Cond: sync.NewCond(&m), Cond: sync.NewCond(&m),
} }
workerStates[i] = ws workerStates[i] = &ws
// Create bot account for this AS if it doesn't already exist // Create bot account for this AS if it doesn't already exist
if err = generateAppServiceAccount(userAPI, appservice); err != nil { if err = generateAppServiceAccount(userAPI, appservice); err != nil {

View file

@ -39,7 +39,7 @@ type OutputRoomEventConsumer struct {
asDB storage.Database asDB storage.Database
rsAPI api.AppserviceRoomserverAPI rsAPI api.AppserviceRoomserverAPI
serverName string serverName string
workerStates []types.ApplicationServiceWorkerState workerStates []*types.ApplicationServiceWorkerState
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call
@ -50,7 +50,7 @@ func NewOutputRoomEventConsumer(
js nats.JetStreamContext, js nats.JetStreamContext,
appserviceDB storage.Database, appserviceDB storage.Database,
rsAPI api.AppserviceRoomserverAPI, rsAPI api.AppserviceRoomserverAPI,
workerStates []types.ApplicationServiceWorkerState, workerStates []*types.ApplicationServiceWorkerState,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -140,13 +140,13 @@ func (s *OutputRoomEventConsumer) filterRoomserverEvents(
// Check if this event is interesting to this application service // Check if this event is interesting to this application service
if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) { if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) {
// Queue this event to be sent off to the application service // Queue this event to be sent off to the application service
if err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil { if id, err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil {
log.WithError(err).Warn("failed to insert incoming event into appservices database") log.WithError(err).Warnf("failed to insert incoming event into appservices database. id: %d", id)
return err return err
} else { } else {
// Tell our worker to send out new messages by updating remaining message // Tell our worker to send out new messages by updating remaining message
// count and waking them up with a broadcast // count and waking them up with a broadcast
ws.NotifyNewEvents() ws.NotifyNewEvents(id)
} }
} }
} }

View file

@ -21,9 +21,9 @@ import (
) )
type Database interface { type Database interface {
StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) (int, error)
GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, error)
CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error) GetLatestId(ctx context.Context, appServiceID string) (int, error)
UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error
RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error
GetLatestTxnID(ctx context.Context) (int, error) GetLatestTxnID(ctx context.Context) (int, error)

View file

@ -45,12 +45,13 @@ const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, headered_event_json, txn_id " + "SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" + const getLatestIdSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" "SELECT id FROM appservice_events WHERE as_id = $1 ORDER BY id DESC LIMIT 1"
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " + "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)" "VALUES ($1, $2, $3)" +
"RETURNING id"
const updateTxnIDForEventsSQL = "" + const updateTxnIDForEventsSQL = "" +
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3" "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
@ -66,7 +67,7 @@ const (
type eventsStatements struct { type eventsStatements struct {
selectEventsByApplicationServiceIDStmt *sql.Stmt selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt getLatestIdStmt *sql.Stmt
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
updateTxnIDForEventsStmt *sql.Stmt updateTxnIDForEventsStmt *sql.Stmt
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
@ -81,7 +82,7 @@ func (s *eventsStatements) prepare(db *sql.DB) (err error) {
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil { if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
return return
} }
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil { if s.getLatestIdStmt, err = db.Prepare(getLatestIdSQL); err != nil {
return return
} }
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
@ -108,7 +109,6 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
) ( ) (
txnID, maxID int, txnID, maxID int,
events []gomatrixserverlib.HeaderedEvent, events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool,
err error, err error,
) { ) {
defer func() { defer func() {
@ -124,7 +124,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
return return
} }
defer checkNamedErr(eventRows.Close, &err) defer checkNamedErr(eventRows.Close, &err)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) events, maxID, txnID, err = retrieveEvents(eventRows, limit)
if err != nil { if err != nil {
return return
} }
@ -139,7 +139,7 @@ func checkNamedErr(fn func() error, err *error) {
} }
} }
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, err error) {
// Get current time for use in calculating event age // Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
@ -157,18 +157,18 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
&txnID, &txnID,
) )
if err != nil { if err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
// Unmarshal eventJSON // Unmarshal eventJSON
if err = json.Unmarshal(eventJSON, &event); err != nil { if err = json.Unmarshal(eventJSON, &event); err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
// If txnID has changed on this event from the previous event, then we've // If txnID has changed on this event from the previous event, then we've
// reached the end of a transaction's events. Return only those events. // reached the end of a transaction's events. Return only those events.
if lastTxnID > invalidTxnID && lastTxnID != txnID { if lastTxnID > invalidTxnID && lastTxnID != txnID {
return events, maxID, lastTxnID, true, nil return events, maxID, lastTxnID, nil
} }
lastTxnID = txnID lastTxnID = txnID
@ -176,7 +176,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
if txnID == -1 { if txnID == -1 {
// Return if we've hit the limit // Return if we've hit the limit
if eventsProcessed++; eventsProcessed > limit { if eventsProcessed++; eventsProcessed > limit {
return events, maxID, lastTxnID, true, nil return events, maxID, lastTxnID, nil
} }
} }
@ -187,7 +187,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
// Portion of the event that is unsigned due to rapid change // Portion of the event that is unsigned due to rapid change
// TODO: Consider removing age as not many app services use it // TODO: Consider removing age as not many app services use it
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil { if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
events = append(events, event) events = append(events, event)
@ -196,14 +196,12 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
return return
} }
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service func (s *eventsStatements) getLatestId(
// IDs into the db.
func (s *eventsStatements) countEventsByApplicationServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
) (int, error) { ) (int, error) {
var count int var count int
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count) err := s.getLatestIdStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return 0, err return 0, err
} }
@ -217,19 +215,19 @@ func (s *eventsStatements) insertEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
) (err error) { ) (id int, err error) {
// Convert event to JSON before inserting // Convert event to JSON before inserting
eventJSON, err := json.Marshal(event) var eventJSON []byte
eventJSON, err = json.Marshal(event)
if err != nil { if err != nil {
return err return 0, err
} }
err = s.insertEventStmt.QueryRowContext(
_, err = s.insertEventStmt.ExecContext(
ctx, ctx,
appServiceID, appServiceID,
eventJSON, eventJSON,
-1, // No transaction ID yet -1, // No transaction ID yet
) ).Scan(&id)
return return
} }

View file

@ -62,7 +62,7 @@ func (d *Database) StoreEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
) error { ) (int, error) {
return d.events.insertEvent(ctx, appServiceID, event) return d.events.insertEvent(ctx, appServiceID, event)
} }
@ -72,17 +72,20 @@ func (d *Database) GetEventsWithAppServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
limit int, limit int,
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) { ) (int, int, []gomatrixserverlib.HeaderedEvent, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
} }
// CountEventsWithAppServiceID returns the number of events destined for an // GetLatestId returns the latest incremental id associated with appservice.
// application service given its ID. func (d *Database) GetLatestId(
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
) (int, error) { ) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID) id, err := d.events.getLatestId(ctx, appServiceID)
if err == sql.ErrNoRows {
return 0, nil
}
return id, err
} }
// UpdateTxnIDForEvents takes in an application service ID and a // UpdateTxnIDForEvents takes in an application service ID and a

View file

@ -46,12 +46,13 @@ const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, headered_event_json, txn_id " + "SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" + const getLatestIdSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" "SELECT id FROM appservice_events WHERE as_id = $1 ORDER BY id DESC LIMIT 1"
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " + "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)" "VALUES ($1, $2, $3)" +
"RETURNING id"
const updateTxnIDForEventsSQL = "" + const updateTxnIDForEventsSQL = "" +
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3" "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
@ -69,7 +70,7 @@ type eventsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer writer sqlutil.Writer
selectEventsByApplicationServiceIDStmt *sql.Stmt selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt getLatestIdStmt *sql.Stmt
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
updateTxnIDForEventsStmt *sql.Stmt updateTxnIDForEventsStmt *sql.Stmt
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
@ -86,7 +87,7 @@ func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil { if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
return return
} }
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil { if s.getLatestIdStmt, err = db.Prepare(getLatestIdSQL); err != nil {
return return
} }
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
@ -113,7 +114,6 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
) ( ) (
txnID, maxID int, txnID, maxID int,
events []gomatrixserverlib.HeaderedEvent, events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool,
err error, err error,
) { ) {
defer func() { defer func() {
@ -129,7 +129,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
return return
} }
defer checkNamedErr(eventRows.Close, &err) defer checkNamedErr(eventRows.Close, &err)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) events, maxID, txnID, err = retrieveEvents(eventRows, limit)
if err != nil { if err != nil {
return return
} }
@ -144,7 +144,7 @@ func checkNamedErr(fn func() error, err *error) {
} }
} }
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, err error) {
// Get current time for use in calculating event age // Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
@ -162,18 +162,18 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
&txnID, &txnID,
) )
if err != nil { if err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
// Unmarshal eventJSON // Unmarshal eventJSON
if err = json.Unmarshal(eventJSON, &event); err != nil { if err = json.Unmarshal(eventJSON, &event); err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
// If txnID has changed on this event from the previous event, then we've // If txnID has changed on this event from the previous event, then we've
// reached the end of a transaction's events. Return only those events. // reached the end of a transaction's events. Return only those events.
if lastTxnID > invalidTxnID && lastTxnID != txnID { if lastTxnID > invalidTxnID && lastTxnID != txnID {
return events, maxID, lastTxnID, true, nil return events, maxID, lastTxnID, nil
} }
lastTxnID = txnID lastTxnID = txnID
@ -181,7 +181,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
if txnID == -1 { if txnID == -1 {
// Return if we've hit the limit // Return if we've hit the limit
if eventsProcessed++; eventsProcessed > limit { if eventsProcessed++; eventsProcessed > limit {
return events, maxID, lastTxnID, true, nil return events, maxID, lastTxnID, nil
} }
} }
@ -192,7 +192,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
// Portion of the event that is unsigned due to rapid change // Portion of the event that is unsigned due to rapid change
// TODO: Consider removing age as not many app services use it // TODO: Consider removing age as not many app services use it
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil { if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
events = append(events, event) events = append(events, event)
@ -201,14 +201,12 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
return return
} }
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service func (s *eventsStatements) getLatestId(
// IDs into the db.
func (s *eventsStatements) countEventsByApplicationServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
) (int, error) { ) (int, error) {
var count int var count int
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count) err := s.getLatestIdStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return 0, err return 0, err
} }
@ -222,22 +220,22 @@ func (s *eventsStatements) insertEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
) (err error) { ) (id int, err error) {
// Convert event to JSON before inserting // Convert event to JSON before inserting
eventJSON, err := json.Marshal(event) eventJSON, err := json.Marshal(event)
if err != nil { if err != nil {
return err return 0, err
} }
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { err = s.insertEventStmt.QueryRowContext(
_, err := s.insertEventStmt.ExecContext(
ctx, ctx,
appServiceID, appServiceID,
eventJSON, eventJSON,
-1, // No transaction ID yet -1, // No transaction ID yet
) ).Scan(&id)
return err return err
}) })
return
} }
// updateTxnIDForEvents sets the transactionID for a collection of events. Done // updateTxnIDForEvents sets the transactionID for a collection of events. Done

View file

@ -61,7 +61,7 @@ func (d *Database) StoreEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
) error { ) (int, error) {
return d.events.insertEvent(ctx, appServiceID, event) return d.events.insertEvent(ctx, appServiceID, event)
} }
@ -71,17 +71,20 @@ func (d *Database) GetEventsWithAppServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
limit int, limit int,
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) { ) (int, int, []gomatrixserverlib.HeaderedEvent, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
} }
// CountEventsWithAppServiceID returns the number of events destined for an // GetLatestId returns the latest incremental id associated with appservice.
// application service given its ID. func (d *Database) GetLatestId(
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
) (int, error) { ) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID) id, err := d.events.getLatestId(ctx, appServiceID)
if err == sql.ErrNoRows {
return 0, nil
}
return id, err
} }
// UpdateTxnIDForEvents takes in an application service ID and a // UpdateTxnIDForEvents takes in an application service ID and a

View file

@ -30,34 +30,26 @@ const (
type ApplicationServiceWorkerState struct { type ApplicationServiceWorkerState struct {
AppService config.ApplicationService AppService config.ApplicationService
Cond *sync.Cond Cond *sync.Cond
// Events ready to be sent // Lastest incremental ID from appservice_events table that is ready to be sent to application service
EventsReady bool latestId int
// Backoff exponent (2^x secs). Max 6, aka 64s. // Backoff exponent (2^x secs). Max 6, aka 64s.
Backoff int Backoff int
} }
// NotifyNewEvents wakes up all waiting goroutines, notifying that events remain // NotifyNewEvents wakes up all waiting goroutines, notifying that events remain
// in the event queue for this application service worker. // in the event queue for this application service worker.
func (a *ApplicationServiceWorkerState) NotifyNewEvents() { func (a *ApplicationServiceWorkerState) NotifyNewEvents(id int) {
a.Cond.L.Lock() a.Cond.L.Lock()
a.EventsReady = true a.latestId = id
a.Cond.Broadcast() a.Cond.Broadcast()
a.Cond.L.Unlock() a.Cond.L.Unlock()
} }
// FinishEventProcessing marks all events of this worker as being sent to the
// application service.
func (a *ApplicationServiceWorkerState) FinishEventProcessing() {
a.Cond.L.Lock()
a.EventsReady = false
a.Cond.L.Unlock()
}
// WaitForNewEvents causes the calling goroutine to wait on the worker state's // WaitForNewEvents causes the calling goroutine to wait on the worker state's
// condition for a broadcast or similar wakeup, if there are no events ready. // condition for a broadcast or similar wakeup, if there are no events ready.
func (a *ApplicationServiceWorkerState) WaitForNewEvents() { func (a *ApplicationServiceWorkerState) WaitForNewEvents(id int) {
a.Cond.L.Lock() a.Cond.L.Lock()
if !a.EventsReady { if a.latestId <= id {
a.Cond.Wait() a.Cond.Wait()
} }
a.Cond.L.Unlock() a.Cond.L.Unlock()

View file

@ -44,7 +44,7 @@ var (
func SetupTransactionWorkers( func SetupTransactionWorkers(
client *http.Client, client *http.Client,
appserviceDB storage.Database, appserviceDB storage.Database,
workerStates []types.ApplicationServiceWorkerState, workerStates []*types.ApplicationServiceWorkerState,
) error { ) error {
// Create a worker that handles transmitting events to a single homeserver // Create a worker that handles transmitting events to a single homeserver
for _, workerState := range workerStates { for _, workerState := range workerStates {
@ -58,31 +58,29 @@ func SetupTransactionWorkers(
// worker is a goroutine that sends any queued events to the application service // worker is a goroutine that sends any queued events to the application service
// it is given. // it is given.
func worker(client *http.Client, db storage.Database, ws types.ApplicationServiceWorkerState) { func worker(client *http.Client, db storage.Database, ws *types.ApplicationServiceWorkerState) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": ws.AppService.ID, "appservice": ws.AppService.ID,
}).Info("Starting application service") }).Info("Starting application service")
ctx := context.Background() ctx := context.Background()
// Initial check for any leftover events to send from last time // Initial check for any leftover events to send from last time
eventCount, err := db.CountEventsWithAppServiceID(ctx, ws.AppService.ID) latestId, err := db.GetLatestId(ctx, ws.AppService.ID)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": ws.AppService.ID, "appservice": ws.AppService.ID,
}).WithError(err).Fatal("appservice worker unable to read queued events from DB") }).WithError(err).Fatal("appservice worker unable to read queued events from DB")
return return
} }
if eventCount > 0 { ws.NotifyNewEvents(latestId)
ws.NotifyNewEvents() id := 0
}
// Loop forever and keep waiting for more events to send // Loop forever and keep waiting for more events to send
for { for {
// Wait for more events if we've sent all the events in the database // Wait for more events if we've sent all the events in the database
ws.WaitForNewEvents() ws.WaitForNewEvents(id)
// Batch events up into a transaction // Batch events up into a transaction
transactionJSON, txnID, maxEventID, eventsRemaining, err := createTransaction(ctx, db, ws.AppService.ID) transactionJSON, txnID, maxEventID, err := createTransaction(ctx, db, ws.AppService.ID)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": ws.AppService.ID, "appservice": ws.AppService.ID,
@ -90,6 +88,10 @@ func worker(client *http.Client, db storage.Database, ws types.ApplicationServic
return return
} }
// Transactions have a maximum event size (or new events may arrive while
// transaction is processed by Application Service), so there may still be
// some events left over to send. We will keep sending if id < ws.latestID.
id = maxEventID
// Send the events off to the application service // Send the events off to the application service
// Backoff if the application service does not respond // Backoff if the application service does not respond
@ -99,19 +101,13 @@ func worker(client *http.Client, db storage.Database, ws types.ApplicationServic
"appservice": ws.AppService.ID, "appservice": ws.AppService.ID,
}).WithError(err).Error("unable to send event") }).WithError(err).Error("unable to send event")
// Backoff // Backoff
backoff(&ws, err) backoff(ws, err)
continue continue
} }
// We sent successfully, hooray! // We sent successfully, hooray!
ws.Backoff = 0 ws.Backoff = 0
// Transactions have a maximum event size, so there may still be some events
// left over to send. Keep sending until none are left
if !eventsRemaining {
ws.FinishEventProcessing()
}
// Remove sent events from the DB // Remove sent events from the DB
err = db.RemoveEventsBeforeAndIncludingID(ctx, ws.AppService.ID, maxEventID) err = db.RemoveEventsBeforeAndIncludingID(ctx, ws.AppService.ID, maxEventID)
if err != nil { if err != nil {
@ -152,11 +148,10 @@ func createTransaction(
) ( ) (
transactionJSON []byte, transactionJSON []byte,
txnID, maxID int, txnID, maxID int,
eventsRemaining bool,
err error, err error,
) { ) {
// Retrieve the latest events from the DB (will return old events if they weren't successfully sent) // Retrieve the latest events from the DB (will return old events if they weren't successfully sent)
txnID, maxID, events, eventsRemaining, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize) txnID, maxID, events, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": appserviceID, "appservice": appserviceID,
@ -170,12 +165,12 @@ func createTransaction(
// If not, grab next available ID from the DB // If not, grab next available ID from the DB
txnID, err = db.GetLatestTxnID(ctx) txnID, err = db.GetLatestTxnID(ctx)
if err != nil { if err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
// Mark new events with current transactionID // Mark new events with current transactionID
if err = db.UpdateTxnIDForEvents(ctx, appserviceID, maxID, txnID); err != nil { if err = db.UpdateTxnIDForEvents(ctx, appserviceID, maxID, txnID); err != nil {
return nil, 0, 0, false, err return nil, 0, 0, err
} }
} }

View file

@ -22,7 +22,6 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -212,11 +211,11 @@ func (m *DendriteMonolith) Start() {
if pk, sk, err = ed25519.GenerateKey(nil); err != nil { if pk, sk, err = ed25519.GenerateKey(nil); err != nil {
panic(err) panic(err)
} }
if err = ioutil.WriteFile(keyfile, sk, 0644); err != nil { if err = os.WriteFile(keyfile, sk, 0644); err != nil {
panic(err) panic(err)
} }
} else if err == nil { } else if err == nil {
if sk, err = ioutil.ReadFile(keyfile); err != nil { if sk, err = os.ReadFile(keyfile); err != nil {
panic(err) panic(err)
} }
if len(sk) != ed25519.PrivateKeySize { if len(sk) != ed25519.PrivateKeySize {

View file

@ -11,4 +11,6 @@ const (
LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeRecaptcha = "m.login.recaptcha"
LoginTypeApplicationService = "m.login.application_service" LoginTypeApplicationService = "m.login.application_service"
LoginTypeToken = "m.login.token" LoginTypeToken = "m.login.token"
LoginTypeJwt = "org.matrix.login.jwt"
LoginTypeEmail = "m.login.email.identity"
) )

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api" uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -32,7 +33,7 @@ import (
// called after authorization has completed, with the result of the authorization. // called after authorization has completed, with the result of the authorization.
// If the final return value is non-nil, an error occurred and the cleanup function // If the final return value is non-nil, an error occurred and the cleanup function
// is nil. // is nil.
func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserLoginAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.ClientUserAPI, cfg *config.ClientAPI, rt *ratelimit.RtFailedLogin) (*Login, LoginCleanupFunc, *util.JSONResponse) {
reqBytes, err := io.ReadAll(r) reqBytes, err := io.ReadAll(r)
if err != nil { if err != nil {
err := &util.JSONResponse{ err := &util.JSONResponse{
@ -57,12 +58,17 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.U
switch header.Type { switch header.Type {
case authtypes.LoginTypePassword: case authtypes.LoginTypePassword:
typ = &LoginTypePassword{ typ = &LoginTypePassword{
GetAccountByPassword: useraccountAPI.QueryAccountByPassword, UserApi: useraccountAPI,
Config: cfg, Config: cfg,
Rt: rt,
} }
case authtypes.LoginTypeToken: case authtypes.LoginTypeToken:
typ = &LoginTypeToken{ typ = &LoginTypeToken{
UserAPI: userAPI, UserAPI: useraccountAPI,
Config: cfg,
}
case authtypes.LoginTypeJwt:
typ = &LoginTypeTokenJwt{
Config: cfg, Config: cfg,
} }
default: default:

View file

@ -0,0 +1,74 @@
package auth
import (
"context"
"fmt"
"net/http"
"github.com/golang-jwt/jwt/v4"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/util"
)
// LoginTypeToken describes how to authenticate with a login token.
type LoginTypeTokenJwt struct {
// UserAPI uapi.LoginTokenInternalAPI
Config *config.ClientAPI
}
// Name implements Type.
func (t *LoginTypeTokenJwt) Name() string {
return authtypes.LoginTypeJwt
}
type Claims struct {
jwt.StandardClaims
}
const mIdUser = "m.id.user"
// LoginFromJSON implements Type. The cleanup function deletes the token from
// the database on success.
func (t *LoginTypeTokenJwt) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) {
var r loginTokenRequest
if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil {
return nil, nil, err
}
if r.Token == "" {
return nil, nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Token field for JWT is missing"),
}
}
c := &Claims{}
token, err := jwt.ParseWithClaims(r.Token, c, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Method.Alg())
}
return t.Config.JwtConfig.SecretKey, nil
})
if err != nil {
util.GetLogger(ctx).WithError(err).Error("jwt.ParseWithClaims failed")
return nil, nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Couldn't parse JWT"),
}
}
if !token.Valid {
return nil, nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Invalid JWT"),
}
}
r.Login.Identifier.User = c.Subject
r.Login.Identifier.Type = mIdUser
return &r.Login, func(context.Context, *util.JSONResponse) {}, nil
}

View file

@ -22,6 +22,7 @@ import (
"testing" "testing"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api" uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -68,8 +69,11 @@ func TestLoginFromJSONReader(t *testing.T) {
Matrix: &config.Global{ Matrix: &config.Global{
ServerName: serverName, ServerName: serverName,
}, },
RtFailedLogin: ratelimit.RtFailedLoginConfig{
Enabled: false,
},
} }
login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil)
if err != nil { if err != nil {
t.Fatalf("LoginFromJSONReader failed: %+v", err) t.Fatalf("LoginFromJSONReader failed: %+v", err)
} }
@ -147,7 +151,7 @@ func TestBadLoginFromJSONReader(t *testing.T) {
ServerName: serverName, ServerName: serverName,
}, },
} }
_, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil)
if errRes == nil { if errRes == nil {
cleanup(ctx, nil) cleanup(ctx, nil)
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode)
@ -159,6 +163,7 @@ func TestBadLoginFromJSONReader(t *testing.T) {
} }
type fakeUserInternalAPI struct { type fakeUserInternalAPI struct {
uapi.ClientUserAPI
UserInternalAPIForLogin UserInternalAPIForLogin
DeletedTokens []string DeletedTokens []string
} }

View file

@ -58,7 +58,7 @@ func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*L
} }
} }
r.Login.Identifier.Type = "m.id.user" r.Login.Identifier.Type = mIdUser
r.Login.Identifier.User = res.Data.UserID r.Login.Identifier.User = res.Data.UserID
cleanup := func(ctx context.Context, authRes *util.JSONResponse) { cleanup := func(ctx context.Context, authRes *util.JSONResponse) {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -33,12 +34,17 @@ type GetAccountByPassword func(ctx context.Context, req *api.QueryAccountByPassw
type PasswordRequest struct { type PasswordRequest struct {
Login Login
Password string `json:"password"` Password string `json:"password"`
Address string `json:"address"`
Medium string `json:"medium"`
} }
const email = "email"
// LoginTypePassword implements https://matrix.org/docs/spec/client_server/r0.6.1#password-based // LoginTypePassword implements https://matrix.org/docs/spec/client_server/r0.6.1#password-based
type LoginTypePassword struct { type LoginTypePassword struct {
GetAccountByPassword GetAccountByPassword UserApi api.ClientUserAPI
Config *config.ClientAPI Config *config.ClientAPI
Rt *ratelimit.RtFailedLogin
} }
func (t *LoginTypePassword) Name() string { func (t *LoginTypePassword) Name() string {
@ -61,7 +67,35 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
r := req.(*PasswordRequest) r := req.(*PasswordRequest)
username := strings.ToLower(r.Username()) if r.Identifier.Address != "" {
r.Address = r.Identifier.Address
}
if r.Identifier.Medium != "" {
r.Medium = r.Identifier.Medium
}
var username string
if r.Medium == email && r.Address != "" {
r.Address = strings.ToLower(r.Address)
res := api.QueryLocalpartForThreePIDResponse{}
err := t.UserApi.QueryLocalpartForThreePID(ctx, &api.QueryLocalpartForThreePIDRequest{
ThreePID: r.Address,
Medium: email,
}, &res)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("userApi.QueryLocalpartForThreePID failed")
resp := jsonerror.InternalServerError()
return nil, &resp
}
username = res.Localpart
if username == "" {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.Forbidden("Invalid username or password"),
}
}
} else {
username = strings.ToLower(r.Username())
}
if username == "" { if username == "" {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
@ -77,7 +111,17 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
} }
// Squash username to all lowercase letters // Squash username to all lowercase letters
res := &api.QueryAccountByPasswordResponse{} res := &api.QueryAccountByPasswordResponse{}
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res) localpart = strings.ToLower(localpart)
if t.Rt != nil {
ok, retryIn := t.Rt.CanAct(localpart)
if !ok {
return nil, &util.JSONResponse{
Code: http.StatusTooManyRequests,
JSON: jsonerror.LimitExceeded("Too Many Requests", retryIn.Milliseconds()),
}
}
}
err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: localpart, PlaintextPassword: r.Password}, res)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -86,7 +130,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
} }
if !res.Exists { if !res.Exists {
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: localpart, Localpart: localpart,
PlaintextPassword: r.Password, PlaintextPassword: r.Password,
}, res) }, res)
@ -99,11 +143,15 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
// but that would leak the existence of the user. // but that would leak the existence of the user.
if !res.Exists { if !res.Exists {
if t.Rt != nil {
t.Rt.Act(localpart)
}
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), JSON: jsonerror.Forbidden("Invalid username or password"),
} }
} }
} }
r.Login.User = username
return &r.Login, nil return &r.Login, nil
} }

View file

@ -75,7 +75,7 @@ type Login struct {
// Username returns the user localpart/user_id in this request, if it exists. // Username returns the user localpart/user_id in this request, if it exists.
func (r *Login) Username() string { func (r *Login) Username() string {
if r.Identifier.Type == "m.id.user" { if r.Identifier.Type == mIdUser {
return r.Identifier.User return r.Identifier.User
} }
// deprecated but without it Element iOS won't log in // deprecated but without it Element iOS won't log in
@ -88,8 +88,8 @@ func (r *Login) ThirdPartyID() (medium, address string) {
return r.Identifier.Medium, r.Identifier.Address return r.Identifier.Medium, r.Identifier.Address
} }
// deprecated // deprecated
if r.Medium == "email" { if r.Medium == email {
return "email", r.Address return email, r.Address
} }
return "", "" return "", ""
} }
@ -111,9 +111,9 @@ type UserInteractive struct {
Sessions map[string][]string Sessions map[string][]string
} }
func NewUserInteractive(userAccountAPI api.UserLoginAPI, cfg *config.ClientAPI) *UserInteractive { func NewUserInteractive(userAccountAPI api.ClientUserAPI, cfg *config.ClientAPI) *UserInteractive {
typePassword := &LoginTypePassword{ typePassword := &LoginTypePassword{
GetAccountByPassword: userAccountAPI.QueryAccountByPassword, UserApi: userAccountAPI,
Config: cfg, Config: cfg,
} }
return &UserInteractive{ return &UserInteractive{

View file

@ -24,7 +24,9 @@ var (
} }
) )
type fakeAccountDatabase struct{} type fakeAccountDatabase struct {
api.ClientUserAPI
}
func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
return nil return nil

View file

@ -0,0 +1,117 @@
package ratelimit
import (
"container/list"
"sync"
"time"
)
type rateLimit struct {
cfg *RtFailedLoginConfig
times *list.List
}
type RtFailedLogin struct {
cfg *RtFailedLoginConfig
mtx sync.RWMutex
rts map[string]*rateLimit
}
type RtFailedLoginConfig struct {
Enabled bool `yaml:"enabled"`
Limit int `yaml:"burst"`
Interval time.Duration `yaml:"interval"`
}
// New creates a new rate limiter for the limit and interval.
func NewRtFailedLogin(cfg *RtFailedLoginConfig) *RtFailedLogin {
if !cfg.Enabled {
return nil
}
rt := &RtFailedLogin{
cfg: cfg,
mtx: sync.RWMutex{},
rts: make(map[string]*rateLimit),
}
go rt.clean()
return rt
}
// CanAct is expected to be called before Act
func (r *RtFailedLogin) CanAct(key string) (ok bool, remaining time.Duration) {
r.mtx.RLock()
rt, ok := r.rts[key]
if !ok {
r.mtx.RUnlock()
return true, 0
}
ok, remaining = rt.canAct()
r.mtx.RUnlock()
return
}
// Act can be called after CanAct returns true.
func (r *RtFailedLogin) Act(key string) {
r.mtx.Lock()
rt, ok := r.rts[key]
if !ok {
rt = &rateLimit{
cfg: r.cfg,
times: list.New(),
}
r.rts[key] = rt
}
rt.act()
r.mtx.Unlock()
}
func (r *RtFailedLogin) clean() {
for {
r.mtx.Lock()
for k, v := range r.rts {
if v.empty() {
delete(r.rts, k)
}
}
r.mtx.Unlock()
time.Sleep(time.Hour)
}
}
func (r *rateLimit) empty() bool {
back := r.times.Back()
if back == nil {
return true
}
v := back.Value
b := v.(time.Time)
now := time.Now()
return now.Sub(b) > r.cfg.Interval
}
func (r *rateLimit) canAct() (ok bool, remaining time.Duration) {
now := time.Now()
l := r.times.Len()
if l < r.cfg.Limit {
return true, 0
}
frnt := r.times.Front()
t := frnt.Value.(time.Time)
diff := now.Sub(t)
if diff < r.cfg.Interval {
return false, r.cfg.Interval - diff
}
return true, 0
}
func (r *rateLimit) act() {
now := time.Now()
l := r.times.Len()
if l < r.cfg.Limit {
r.times.PushBack(now)
return
}
frnt := r.times.Front()
frnt.Value = now
r.times.MoveToBack(frnt)
}

View file

@ -0,0 +1,40 @@
package ratelimit
import (
"testing"
"time"
"github.com/matryer/is"
)
func TestRtFailedLogin(t *testing.T) {
is := is.New(t)
rtfl := NewRtFailedLogin(&RtFailedLoginConfig{
Enabled: true,
Limit: 3,
Interval: 10 * time.Millisecond,
})
var (
can bool
remaining time.Duration
remainingB time.Duration
)
for i := 0; i < 3; i++ {
can, remaining = rtfl.CanAct("foo")
is.True(can)
is.Equal(remaining, time.Duration(0))
rtfl.Act("foo")
}
can, remaining = rtfl.CanAct("foo")
is.True(!can)
is.True(remaining > time.Millisecond*9)
can, remainingB = rtfl.CanAct("bar")
is.True(can)
is.Equal(remainingB, time.Duration(0))
rtfl.Act("bar")
rtfl.Act("bar")
time.Sleep(remaining + time.Millisecond)
can, remaining = rtfl.CanAct("foo")
is.True(can)
is.Equal(remaining, time.Duration(0))
}

View file

@ -82,7 +82,7 @@ func GetEvent(
}}, }},
} }
var stateResp api.QueryStateAfterEventsResponse var stateResp api.QueryStateAfterEventsResponse
if err := rsAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil { if err = rsAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryStateAfterEvents failed") util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryStateAfterEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -118,12 +118,13 @@ func GetEvent(
} else if !stateEvent.StateKeyEquals(device.UserID) { } else if !stateEvent.StateKeyEquals(device.UserID) {
continue continue
} }
membership, err := stateEvent.Membership() var membership string
membership, err = stateEvent.Membership()
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("stateEvent.Membership failed") util.GetLogger(req.Context()).WithError(err).Error("stateEvent.Membership failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if membership == gomatrixserverlib.Join { if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: gomatrixserverlib.ToClientEvent(r.requestedEvent, gomatrixserverlib.FormatAll), JSON: gomatrixserverlib.ToClientEvent(r.requestedEvent, gomatrixserverlib.FormatAll),
@ -131,8 +132,28 @@ func GetEvent(
} }
} }
// we might fail to retrieve correct state above, let's check user membership and allow to fetch event if they are invited or joined, since we always use m.room.history_visibility shared.
var membershipRes api.QueryMembershipForUserResponse
ctx := req.Context()
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
RoomID: roomID,
UserID: device.UserID,
}, &membershipRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")
return jsonerror.InternalServerError()
}
// If the user has never been in the room then stop at this point.
// We won't tell the user about a room they have never joined.
if !membershipRes.HasBeenInRoom && membershipRes.Membership != gomatrixserverlib.Invite || membershipRes.Membership == gomatrixserverlib.Ban {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"), JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"),
} }
} else {
return util.JSONResponse{
Code: http.StatusOK,
JSON: gomatrixserverlib.ToClientEvent(r.requestedEvent, gomatrixserverlib.FormatAll),
}
}
} }

View file

@ -63,7 +63,7 @@ func UploadCrossSigningDeviceKeys(
} }
} }
typePassword := auth.LoginTypePassword{ typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountAPI.QueryAccountByPassword, UserApi: accountAPI,
Config: cfg, Config: cfg,
} }
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -55,6 +56,7 @@ func passwordLogin() flows {
func Login( func Login(
req *http.Request, userAPI userapi.ClientUserAPI, req *http.Request, userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI, cfg *config.ClientAPI,
rt *ratelimit.RtFailedLogin,
) util.JSONResponse { ) util.JSONResponse {
if req.Method == http.MethodGet { if req.Method == http.MethodGet {
// TODO: support other forms of login other than password, depending on config options // TODO: support other forms of login other than password, depending on config options
@ -63,7 +65,7 @@ func Login(
JSON: passwordLogin(), JSON: passwordLogin(),
} }
} else if req.Method == http.MethodPost { } else if req.Method == http.MethodPost {
login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, userAPI, cfg) login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, cfg, rt)
if authErr != nil { if authErr != nil {
return *authErr return *authErr
} }

View file

@ -1,12 +1,14 @@
package routing package routing
import ( import (
"fmt"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -24,6 +26,7 @@ type newPasswordAuth struct {
Type string `json:"type"` Type string `json:"type"`
Session string `json:"session"` Session string `json:"session"`
auth.PasswordRequest auth.PasswordRequest
ThreePidCreds threepid.Credentials `json:"threepid_creds"`
} }
func Password( func Password(
@ -33,13 +36,17 @@ func Password(
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
// Check that the existing password is right. // Check that the existing password is right.
var fields logrus.Fields
if device != nil {
fields = logrus.Fields{
"sessionId": device.SessionID,
"userId": device.UserID,
}
}
var r newPasswordRequest var r newPasswordRequest
r.LogoutDevices = true r.LogoutDevices = true
logrus.WithFields(logrus.Fields{ logrus.WithFields(fields).Debug("Changing password")
"sessionId": device.SessionID,
"userId": device.UserID,
}).Debug("Changing password")
// Unmarshal the request. // Unmarshal the request.
resErr := httputil.UnmarshalJSONRequest(req, &r) resErr := httputil.UnmarshalJSONRequest(req, &r)
@ -53,45 +60,95 @@ func Password(
// Generate a new, random session ID // Generate a new, random session ID
sessionID = util.RandomString(sessionIDLength) sessionID = util.RandomString(sessionIDLength)
} }
var localpart string
// Require password auth to change the password. switch r.Auth.Type {
if r.Auth.Type != authtypes.LoginTypePassword { case authtypes.LoginTypePassword:
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(
sessionID,
[]authtypes.Flow{
{
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
},
},
nil,
),
}
}
// Check if the existing password is correct. // Check if the existing password is correct.
typePassword := auth.LoginTypePassword{ typePassword := auth.LoginTypePassword{
GetAccountByPassword: userAPI.QueryAccountByPassword, UserApi: userAPI,
Config: cfg, Config: cfg,
} }
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
return *authErr return *authErr
} }
// Get the local part.
var err error
localpart, _, err = gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
case authtypes.LoginTypeEmail:
threePid := &authtypes.ThreePID{}
r.Auth.ThreePidCreds.IDServer = cfg.ThreePidDelegate
var (
bound bool
err error
)
bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed")
return jsonerror.InternalServerError()
}
if !bound {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_AUTH_FAILED",
Err: "Failed to auth 3pid",
},
}
}
var res api.QueryLocalpartForThreePIDResponse
err = userAPI.QueryLocalpartForThreePID(req.Context(), &api.QueryLocalpartForThreePIDRequest{
Medium: threePid.Medium,
ThreePID: threePid.Address,
}, &res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryLocalpartForThreePID failed")
return jsonerror.InternalServerError()
}
if res.Localpart == "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_NOT_FOUND",
Err: "3pid is not bound to any account",
},
}
}
localpart = res.Localpart
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail)
default:
flows := []authtypes.Flow{
{
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
},
}
if cfg.ThreePidDelegate != "" {
flows = append(flows, authtypes.Flow{
Stages: []authtypes.LoginType{authtypes.LoginTypeEmail},
})
}
// Require password auth to change the password.
if r.Auth.Type == authtypes.LoginTypePassword {
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(
sessionID,
flows,
nil,
),
}
}
}
// Check the new password strength. // Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil { if resErr = validatePassword(r.NewPassword); resErr != nil {
return *resErr return *resErr
} }
// Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
// Ask the user API to perform the password change. // Ask the user API to perform the password change.
passwordReq := &api.PerformPasswordUpdateRequest{ passwordReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
@ -109,12 +166,24 @@ func Password(
// If the request asks us to log out all other devices then // If the request asks us to log out all other devices then
// ask the user API to do that. // ask the user API to do that.
if r.LogoutDevices { if r.LogoutDevices {
logoutReq := &api.PerformDeviceDeletionRequest{ var logoutReq *api.PerformDeviceDeletionRequest
var sessionId int64
if device == nil {
logoutReq = &api.PerformDeviceDeletionRequest{
UserID: fmt.Sprintf("@%s:%s", localpart, cfg.Matrix.ServerName),
DeviceIDs: []string{},
}
sessionId = 0
} else {
logoutReq = &api.PerformDeviceDeletionRequest{
UserID: device.UserID, UserID: device.UserID,
DeviceIDs: nil, DeviceIDs: nil,
ExceptDeviceID: device.ID, ExceptDeviceID: device.ID,
} }
sessionId = device.SessionID
}
logoutRes := &api.PerformDeviceDeletionResponse{} logoutRes := &api.PerformDeviceDeletionResponse{}
if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil { if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
@ -123,7 +192,7 @@ func Password(
pushersReq := &api.PerformPusherDeletionRequest{ pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart, Localpart: localpart,
SessionID: device.SessionID, SessionID: sessionId,
} }
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")

View file

@ -105,12 +105,6 @@ func SetAvatarURL(
if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr return *resErr
} }
if r.AvatarURL == "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("'avatar_url' must be supplied."),
}
}
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {

View file

@ -86,8 +86,8 @@ func SetPusher(
if err != nil { if err != nil {
return invalidParam("malformed url passed") return invalidParam("malformed url passed")
} }
if pushUrl.Scheme != "https" { if pushUrl.Scheme != "https" && pushUrl.Scheme != "http" {
return invalidParam("only https scheme is allowed") return invalidParam("only https and http schemes are allowed")
} }
} }

View file

@ -44,6 +44,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
@ -237,6 +238,7 @@ type authDict struct {
// Recaptcha // Recaptcha
Response string `json:"response"` Response string `json:"response"`
// TODO: Lots of custom keys depending on the type // TODO: Lots of custom keys depending on the type
ThreePidCreds threepid.Credentials `json:"threepid_creds"`
} }
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api
@ -745,6 +747,7 @@ func handleRegistrationFlow(
} }
} }
var threePid *authtypes.ThreePID
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypeRecaptcha: case authtypes.LoginTypeRecaptcha:
// Check given captcha response // Check given captcha response
@ -761,6 +764,29 @@ func handleRegistrationFlow(
// Add Dummy to the list of completed registration stages // Add Dummy to the list of completed registration stages
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
case authtypes.LoginTypeEmail:
threePid = &authtypes.ThreePID{}
r.Auth.ThreePidCreds.IDServer = cfg.ThreePidDelegate
var (
bound bool
err error
)
bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed")
return jsonerror.InternalServerError()
}
if !bound {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_AUTH_FAILED",
Err: "Failed to auth 3pid",
},
}
}
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail)
case "": case "":
// An empty auth type means that we want to fetch the available // An empty auth type means that we want to fetch the available
// flows. It can also mean that we want to register as an appservice // flows. It can also mean that we want to register as an appservice
@ -776,7 +802,7 @@ func handleRegistrationFlow(
// A response with current registration flow and remaining available methods // A response with current registration flow and remaining available methods
// will be returned if a flow has not been successfully completed yet // will be returned if a flow has not been successfully completed yet
return checkAndCompleteFlow(sessions.getCompletedStages(sessionID), return checkAndCompleteFlow(sessions.getCompletedStages(sessionID),
req, r, sessionID, cfg, userAPI) req, r, sessionID, cfg, userAPI, threePid)
} }
// handleApplicationServiceRegistration handles the registration of an // handleApplicationServiceRegistration handles the registration of an
@ -818,7 +844,7 @@ func handleApplicationServiceRegistration(
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, nil,
) )
} }
@ -832,12 +858,13 @@ func checkAndCompleteFlow(
sessionID string, sessionID string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
threePid *authtypes.ThreePID,
) util.JSONResponse { ) util.JSONResponse {
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, threePid,
) )
} }
sessions.addParams(sessionID, r) sessions.addParams(sessionID, r)
@ -863,6 +890,7 @@ func completeRegistration(
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
accType userapi.AccountType, accType userapi.AccountType,
threePid *authtypes.ThreePID,
) util.JSONResponse { ) util.JSONResponse {
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
@ -901,6 +929,21 @@ func completeRegistration(
// Increment prometheus counter for created users // Increment prometheus counter for created users
amtRegUsers.Inc() amtRegUsers.Inc()
// TODO-entry refuse register if threepid is already bound to account.
if threePid != nil {
err = userAPI.PerformSaveThreePIDAssociation(ctx, &userapi.PerformSaveThreePIDAssociationRequest{
Medium: threePid.Medium,
ThreePID: threePid.Address,
Localpart: accRes.Account.Localpart,
}, &struct{}{})
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("Failed to save 3PID association: " + err.Error()),
}
}
}
// Check whether inhibit_login option is set. If so, don't create an access // Check whether inhibit_login option is set. If so, don't create an access
// token or a device for this user // token or a device for this user
if inhibitLogin { if inhibitLogin {
@ -1092,5 +1135,5 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec
if ssrr.Admin { if ssrr.Admin {
accType = userapi.AccountTypeAdmin accType = userapi.AccountTypeAdmin
} }
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil)
} }

View file

@ -26,6 +26,7 @@ import (
clientutil "github.com/matrix-org/dendrite/clientapi/httputil" clientutil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/internal/transactions"
@ -65,6 +66,7 @@ func Setup(
prometheus.MustRegister(amtRegUsers, sendEventDuration) prometheus.MustRegister(amtRegUsers, sendEventDuration)
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
rateLimitsFailedLogin := ratelimit.NewRtFailedLogin(&cfg.RtFailedLogin)
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
unstableFeatures := map[string]bool{ unstableFeatures := map[string]bool{
@ -570,7 +572,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/account/password", v3mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeConditionalAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, device); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
@ -594,7 +596,7 @@ func Setup(
if r := rateLimits.Limit(req, nil); r != nil { if r := rateLimits.Limit(req, nil); r != nil {
return *r return *r
} }
return Login(req, userAPI, cfg) return Login(req, userAPI, cfg, rateLimitsFailedLogin)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)

View file

@ -101,7 +101,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
} }
// If the user has never been in the room then stop at this point. // If the user has never been in the room then stop at this point.
// We won't tell the user about a room they have never joined. // We won't tell the user about a room they have never joined.
if !membershipRes.HasBeenInRoom { if !membershipRes.HasBeenInRoom && membershipRes.Membership != gomatrixserverlib.Invite {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)),
@ -241,7 +241,7 @@ func OnIncomingStateTypeRequest(
} }
// If the user has never been in the room then stop at this point. // If the user has never been in the room then stop at this point.
// We won't tell the user about a room they have never joined. // We won't tell the user about a room they have never joined.
if !membershipRes.HasBeenInRoom || membershipRes.Membership == gomatrixserverlib.Ban { if !membershipRes.HasBeenInRoom && membershipRes.Membership != gomatrixserverlib.Invite || membershipRes.Membership == gomatrixserverlib.Ban {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)),

View file

@ -103,11 +103,8 @@ func CreateSession(
func CheckAssociation( func CheckAssociation(
ctx context.Context, creds Credentials, cfg *config.ClientAPI, ctx context.Context, creds Credentials, cfg *config.ClientAPI,
) (bool, string, string, error) { ) (bool, string, string, error) {
if err := isTrusted(creds.IDServer, cfg); err != nil {
return false, "", "", err
}
requestURL := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret) requestURL := fmt.Sprintf("%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", cfg.ThreePidDelegate, creds.SID, creds.Secret)
req, err := http.NewRequest(http.MethodGet, requestURL, nil) req, err := http.NewRequest(http.MethodGet, requestURL, nil)
if err != nil { if err != nil {
return false, "", "", err return false, "", "", err

View file

@ -0,0 +1,8 @@
FROM alpine:latest
COPY dendrite-monolith-server /usr/bin/
VOLUME /etc/dendrite
WORKDIR /etc/dendrite
ENTRYPOINT ["/usr/bin/dendrite-monolith-server"]

View file

@ -0,0 +1,12 @@
set -xe
if [ -z "$(git status --porcelain)" ]; then
CGO_ENABLED=0 go build .
TAG=$(git rev-parse --short HEAD)
docker build -f Dockerfile.dev -t gcr.io/globekeeper-development/dendrite-monolith:$TAG -t gcr.io/globekeeper-development/dendrite-monolith -t gcr.io/globekeeper-production/dendrite-monolith:$TAG .
docker push gcr.io/globekeeper-development/dendrite-monolith:$TAG
docker push gcr.io/globekeeper-production/dendrite-monolith:$TAG
docker push gcr.io/globekeeper-development/dendrite-monolith
else
echo "Please commit changes"
exit 0
fi

View file

@ -16,6 +16,7 @@ package main
import ( import (
"flag" "flag"
"log"
"os" "os"
"github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/appservice"
@ -47,6 +48,16 @@ var (
func main() { func main() {
cfg := setup.ParseFlags(true) cfg := setup.ParseFlags(true)
httpAddr := config.HTTPAddress("http://" + *httpBindAddr) httpAddr := config.HTTPAddress("http://" + *httpBindAddr)
for _, logging := range cfg.Logging {
if logging.Type == "std" {
level, err := logrus.ParseLevel(logging.Level)
if err != nil {
log.Fatal(err)
}
logrus.SetLevel(level)
logrus.SetFormatter(&logrus.JSONFormatter{})
}
}
httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr) httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr)
httpAPIAddr := httpAddr httpAPIAddr := httpAddr
options := []basepkg.BaseDendriteOptions{} options := []basepkg.BaseDendriteOptions{}

View file

@ -34,8 +34,11 @@ type JSServer struct {
// OnRequestFromJS is the function that JS will invoke when there is a new request. // OnRequestFromJS is the function that JS will invoke when there is a new request.
// The JS function signature is: // The JS function signature is:
//
// function(reqString: string): Promise<{result: string, error: string}> // function(reqString: string): Promise<{result: string, error: string}>
//
// Usage is like: // Usage is like:
//
// const res = await global._go_js_server.fetch(reqString); // const res = await global._go_js_server.fetch(reqString);
// if (res.error) { // if (res.error) {
// // handle error: this is a 'network' error, not a non-2xx error. // // handle error: this is a 'network' error, not a non-2xx error.

View file

@ -0,0 +1,71 @@
---
title: Optimise your installation
parent: Installation
has_toc: true
nav_order: 10
permalink: /installation/start/optimisation
---
# Optimise your installation
Now that you have Dendrite running, the following tweaks will improve the reliability
and performance of your installation.
## File descriptor limit
Most platforms have a limit on how many file descriptors a single process can open. All
connections made by Dendrite consume file descriptors — this includes database connections
and network requests to remote homeservers. When participating in large federated rooms
where Dendrite must talk to many remote servers, it is often very easy to exhaust default
limits which are quite low.
We currently recommend setting the file descriptor limit to 65535 to avoid such
issues. Dendrite will log immediately after startup if the file descriptor limit is too low:
```
level=warning msg="IMPORTANT: Process file descriptor limit is currently 1024, it is recommended to raise the limit for Dendrite to at least 65535 to avoid issues"
```
UNIX systems have two limits: a hard limit and a soft limit. You can view the soft limit
by running `ulimit -Sn` and the hard limit with `ulimit -Hn`:
```bash
$ ulimit -Hn
1048576
$ ulimit -Sn
1024
```
Increase the soft limit before starting Dendrite:
```bash
ulimit -Sn 65535
```
The log line at startup should no longer appear if the limit is sufficient.
If you are running under a systemd service, you can instead add `LimitNOFILE=65535` option
to the `[Service]` section of your service unit file.
## DNS caching
Dendrite has a built-in DNS cache which significantly reduces the load that Dendrite will
place on your DNS resolver. This may also speed up outbound federation.
Consider enabling the DNS cache by modifying the `global` section of your configuration file:
```yaml
dns_cache:
enabled: true
cache_size: 4096
cache_lifetime: 600s
```
## Time synchronisation
Matrix relies heavily on TLS which requires the system time to be correct. If the clock
drifts then you may find that federation no works reliably (or at all) and clients may
struggle to connect to your Dendrite server.
Ensure that the time is synchronised on your system by enabling NTP sync.

5
go.mod
View file

@ -4,13 +4,14 @@ require (
github.com/Arceliar/ironwood v0.0.0-20220306165321-319147a02d98 github.com/Arceliar/ironwood v0.0.0-20220306165321-319147a02d98
github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979
github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/MFAshby/stdemuxerhook v1.0.0
github.com/Masterminds/semver/v3 v3.1.1 github.com/Masterminds/semver/v3 v3.1.1
github.com/codeclysm/extract v2.2.0+incompatible github.com/codeclysm/extract v2.2.0+incompatible
github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d
github.com/docker/docker v20.10.16+incompatible github.com/docker/docker v20.10.16+incompatible
github.com/docker/go-connections v0.4.0 github.com/docker/go-connections v0.4.0
github.com/getsentry/sentry-go v0.13.0 github.com/getsentry/sentry-go v0.13.0
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.4.1
github.com/gologme/log v1.3.0 github.com/gologme/log v1.3.0
github.com/google/go-cmp v0.5.8 github.com/google/go-cmp v0.5.8
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
@ -24,6 +25,7 @@ require (
github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/matryer/is v1.4.0
github.com/mattn/go-sqlite3 v1.14.13 github.com/mattn/go-sqlite3 v1.14.13
github.com/nats-io/nats-server/v2 v2.8.5-0.20220811224153-d8d25d9b0b1c github.com/nats-io/nats-server/v2 v2.8.5-0.20220811224153-d8d25d9b0b1c
github.com/nats-io/nats.go v1.16.1-0.20220810192301-fb5ca2cbc995 github.com/nats-io/nats.go v1.16.1-0.20220810192301-fb5ca2cbc995
@ -66,7 +68,6 @@ require (
github.com/frankban/quicktest v1.14.3 // indirect github.com/frankban/quicktest v1.14.3 // indirect
github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.2 // indirect
github.com/h2non/filetype v1.1.3 // indirect github.com/h2non/filetype v1.1.3 // indirect

6
go.sum
View file

@ -52,8 +52,6 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20O
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM= github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM=
github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo=
github.com/MFAshby/stdemuxerhook v1.0.0 h1:1XFGzakrsHMv76AeanPDL26NOgwjPl/OUxbGhJthwMc=
github.com/MFAshby/stdemuxerhook v1.0.0/go.mod h1:nLMI9FUf9Hz98n+yAXsTMUR4RZQy28uCTLG1Fzvj/uY=
github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc=
github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
github.com/Microsoft/go-winio v0.5.1 h1:aPJp2QD7OOrhO5tQXqQoGSJc+DjDtWTGLOmNyAm6FgY= github.com/Microsoft/go-winio v0.5.1 h1:aPJp2QD7OOrhO5tQXqQoGSJc+DjDtWTGLOmNyAm6FgY=
@ -184,6 +182,8 @@ github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/E
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.4.1 h1:pC5DB52sCeK48Wlb9oPcdhnjkz1TKt1D/P7WKJ0kUcQ=
github.com/golang-jwt/jwt/v4 v4.4.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
@ -350,6 +350,8 @@ github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4Mq
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.13/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.13/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=

View file

@ -84,6 +84,57 @@ func MakeAuthAPI(
return MakeExternalAPI(metricsName, h) return MakeExternalAPI(metricsName, h)
} }
// MakeConditionalAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
// It passes nil device if header is not provided.
func MakeConditionalAuthAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
f func(*http.Request, *userapi.Device) util.JSONResponse,
) http.Handler {
h := func(req *http.Request) util.JSONResponse {
var (
jsonRes util.JSONResponse
dev *userapi.Device
)
if _, err := auth.ExtractAccessToken(req); err != nil {
dev = nil
} else {
logger := util.GetLogger(req.Context())
var err *util.JSONResponse
dev, err = auth.VerifyUserFromRequest(req, userAPI)
if err != nil {
logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code)
return *err
}
// add the user ID to the logger
logger = logger.WithField("user_id", dev.UserID)
req = req.WithContext(util.ContextWithLogger(req.Context(), logger))
}
// add the user to Sentry, if enabled
hub := sentry.GetHubFromContext(req.Context())
if hub != nil {
hub.Scope().SetTag("user_id", dev.UserID)
hub.Scope().SetTag("device_id", dev.ID)
}
defer func() {
if r := recover(); r != nil {
if hub != nil {
hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
}
// re-panic to return the 500
panic(r)
}
}()
jsonRes = f(req, dev)
// do not log 4xx as errors as they are client fails, not server fails
if hub != nil && jsonRes.Code >= 500 {
hub.Scope().SetExtra("response", jsonRes)
hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code))
}
return jsonRes
}
return MakeExternalAPI(metricsName, h)
}
// MakeAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be // MakeAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be
// completed by a user that is a server administrator. // completed by a user that is a server administrator.
func MakeAdminAPI( func MakeAdminAPI(

View file

@ -18,10 +18,8 @@
package internal package internal
import ( import (
"io"
"log/syslog" "log/syslog"
"github.com/MFAshby/stdemuxerhook"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
lSyslog "github.com/sirupsen/logrus/hooks/syslog" lSyslog "github.com/sirupsen/logrus/hooks/syslog"
@ -31,7 +29,6 @@ import (
// If something fails here it means that the logging was improperly configured, // If something fails here it means that the logging was improperly configured,
// so we just exit with the error // so we just exit with the error
func SetupHookLogging(hooks []config.LogrusHook, componentName string) { func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
stdLogAdded := false
for _, hook := range hooks { for _, hook := range hooks {
// Check we received a proper logging level // Check we received a proper logging level
level, err := logrus.ParseLevel(hook.Level) level, err := logrus.ParseLevel(hook.Level)
@ -39,12 +36,6 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
logrus.Fatalf("Unrecognised logging level %s: %q", hook.Level, err) logrus.Fatalf("Unrecognised logging level %s: %q", hook.Level, err)
} }
// Perform a first filter on the logs according to the lowest level of all
// (Eg: If we have hook for info and above, prevent logrus from processing debug logs)
if logrus.GetLevel() < level {
logrus.SetLevel(level)
}
switch hook.Type { switch hook.Type {
case "file": case "file":
checkFileHookParams(hook.Params) checkFileHookParams(hook.Params)
@ -53,17 +44,10 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
checkSyslogHookParams(hook.Params) checkSyslogHookParams(hook.Params)
setupSyslogHook(hook, level, componentName) setupSyslogHook(hook, level, componentName)
case "std": case "std":
setupStdLogHook(level)
stdLogAdded = true
default: default:
logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type) logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type)
} }
} }
if !stdLogAdded {
setupStdLogHook(logrus.InfoLevel)
}
// Hooks are now configured for stdout/err, so throw away the default logger output
logrus.SetOutput(io.Discard)
} }
func checkSyslogHookParams(params map[string]interface{}) { func checkSyslogHookParams(params map[string]interface{}) {
@ -87,10 +71,6 @@ func checkSyslogHookParams(params map[string]interface{}) {
} }
func setupStdLogHook(level logrus.Level) {
logrus.AddHook(&logLevelHook{level, stdemuxerhook.New(logrus.StandardLogger())})
}
func setupSyslogHook(hook config.LogrusHook, level logrus.Level, componentName string) { func setupSyslogHook(hook config.LogrusHook, level logrus.Level, componentName string) {
syslogHook, err := lSyslog.NewSyslogHook(hook.Params["protocol"].(string), hook.Params["address"].(string), syslog.LOG_INFO, componentName) syslogHook, err := lSyslog.NewSyslogHook(hook.Params["protocol"].(string), hook.Params["address"].(string), syslog.LOG_INFO, componentName)
if err == nil { if err == nil {

View file

@ -136,7 +136,7 @@ func (r *Inviter) PerformInvite(
var isAlreadyJoined bool var isAlreadyJoined bool
if info != nil { if info != nil {
_, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) _, _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
if err != nil { if err != nil {
return nil, fmt.Errorf("r.DB.GetMembership: %w", err) return nil, fmt.Errorf("r.DB.GetMembership: %w", err)
} }

View file

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/roomserver/version"
) )
@ -177,11 +178,16 @@ func (r *Queryer) QueryMembershipForUser(
} }
response.RoomExists = true response.RoomExists = true
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) membershipEventNID, membershipState, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID)
if err != nil { if err != nil {
return err return err
} }
if membershipState == tables.MembershipStateInvite {
response.Membership = gomatrixserverlib.Invite
response.IsInRoom = true
}
response.IsRoomForgotten = isRoomforgotten response.IsRoomForgotten = isRoomforgotten
if membershipEventNID == 0 { if membershipEventNID == 0 {
@ -291,7 +297,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return nil return nil
} }
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) membershipEventNID, _, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender)
if err != nil { if err != nil {
return err return err
} }
@ -901,7 +907,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
} }
// At this point we're happy that we are in the room, so now let's // At this point we're happy that we are in the room, so now let's
// see if the target user is in the room. // see if the target user is in the room.
_, isIn, _, err = r.DB.GetMembership(ctx, targetRoomInfo.RoomNID, req.UserID) _, _, isIn, _, err = r.DB.GetMembership(ctx, targetRoomInfo.RoomNID, req.UserID)
if err != nil { if err != nil {
continue continue
} }

View file

@ -125,7 +125,7 @@ type Database interface {
// in this room, along a boolean set to true if the user is still in this room, // in this room, along a boolean set to true if the user is still in this room,
// false if not. // false if not.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, membershipNID tables.MembershipState, stillInRoom, isRoomForgotten bool, err error)
// Lookup the membership event numeric IDs for all user that are or have // Lookup the membership event numeric IDs for all user that are or have
// been members of a given room. Only lookup events of "join" membership if // been members of a given room. Only lookup events of "join" membership if
// joinOnly is set to true. // joinOnly is set to true.

View file

@ -399,14 +399,14 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
}) })
} }
func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, membershipState tables.MembershipState, stillInRoom, isRoomforgotten bool, err error) {
var requestSenderUserNID types.EventStateKeyNID var requestSenderUserNID types.EventStateKeyNID
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
requestSenderUserNID, err = d.assignStateKeyNID(ctx, requestSenderUserID) requestSenderUserNID, err = d.assignStateKeyNID(ctx, requestSenderUserID)
return err return err
}) })
if err != nil { if err != nil {
return 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err) return 0, 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err)
} }
senderMembershipEventNID, senderMembership, isRoomforgotten, err := senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
@ -415,12 +415,12 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// The user has never been a member of that room // The user has never been a member of that room
return 0, false, false, nil return 0, 0, false, false, nil
} else if err != nil { } else if err != nil {
return return
} }
return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil return senderMembershipEventNID, senderMembership, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil
} }
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(

View file

@ -131,7 +131,6 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base
logrus.Fatalf("Failed to start due to configuration errors") logrus.Fatalf("Failed to start due to configuration errors")
} }
internal.SetupStdLogging()
internal.SetupHookLogging(cfg.Logging, componentName) internal.SetupHookLogging(cfg.Logging, componentName)
internal.SetupPprof() internal.SetupPprof()

View file

@ -16,6 +16,7 @@ package config
import ( import (
"bytes" "bytes"
"crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
@ -252,6 +253,15 @@ func loadConfig(
c.Global.OldVerifyKeys[i].KeyID, c.Global.OldVerifyKeys[i].PrivateKey = keyID, privateKey c.Global.OldVerifyKeys[i].KeyID, c.Global.OldVerifyKeys[i].PrivateKey = keyID, privateKey
} }
if c.ClientAPI.JwtConfig.Enabled {
pubPki, _ := pem.Decode([]byte(c.ClientAPI.JwtConfig.Secret))
var pub interface{}
pub, err = x509.ParsePKIXPublicKey(pubPki.Bytes)
if err != nil {
return nil, err
}
c.ClientAPI.JwtConfig.SecretKey = pub.(ed25519.PublicKey)
}
c.MediaAPI.AbsBasePath = Path(absPath(basePath, c.MediaAPI.BasePath)) c.MediaAPI.AbsBasePath = Path(absPath(basePath, c.MediaAPI.BasePath))
@ -283,7 +293,10 @@ func (config *Dendrite) Derive() error {
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}})
} }
if config.ClientAPI.ThreePidDelegate != "" {
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeEmail}})
}
// Load application service configuration files // Load application service configuration files
if err := loadAppServices(&config.AppServiceAPI, &config.Derived); err != nil { if err := loadAppServices(&config.AppServiceAPI, &config.Derived); err != nil {
return err return err

View file

@ -21,7 +21,6 @@ import (
"regexp" "regexp"
"strings" "strings"
log "github.com/sirupsen/logrus"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
@ -353,11 +352,11 @@ func checkErrors(config *AppServiceAPI, derived *Derived) (err error) {
// TODO: Remove once rate_limited is implemented // TODO: Remove once rate_limited is implemented
if appservice.RateLimited { if appservice.RateLimited {
log.Warn("WARNING: Application service option rate_limited is currently unimplemented") // log.Warn("WARNING: Application service option rate_limited is currently unimplemented")
} }
// TODO: Remove once protocols is implemented // TODO: Remove once protocols is implemented
if len(appservice.Protocols) > 0 { if len(appservice.Protocols) > 0 {
log.Warn("WARNING: Application service option protocols is currently unimplemented") // log.Warn("WARNING: Application service option protocols is currently unimplemented")
} }
} }
@ -383,7 +382,7 @@ func validateNamespace(
// Check if GroupID for the users namespace is in the correct format // Check if GroupID for the users namespace is in the correct format
if key == "users" && namespace.GroupID != "" { if key == "users" && namespace.GroupID != "" {
// TODO: Remove once group_id is implemented // TODO: Remove once group_id is implemented
log.Warn("WARNING: Application service option group_id is currently unimplemented") // log.Warn("WARNING: Application service option group_id is currently unimplemented")
correctFormat := groupIDRegexp.MatchString(namespace.GroupID) correctFormat := groupIDRegexp.MatchString(namespace.GroupID)
if !correctFormat { if !correctFormat {

View file

@ -3,6 +3,9 @@ package config
import ( import (
"fmt" "fmt"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/ratelimit"
"golang.org/x/crypto/ed25519"
) )
type ClientAPI struct { type ClientAPI struct {
@ -47,8 +50,22 @@ type ClientAPI struct {
// Rate-limiting options // Rate-limiting options
RateLimiting RateLimiting `yaml:"rate_limiting"` RateLimiting RateLimiting `yaml:"rate_limiting"`
RtFailedLogin ratelimit.RtFailedLoginConfig `yaml:"rate_limiting_failed_login"`
MSCs *MSCs `yaml:"mscs"` MSCs *MSCs `yaml:"mscs"`
ThreePidDelegate string `yaml:"three_pid_delegate"`
JwtConfig JwtConfig `yaml:"jwt_config"`
}
type JwtConfig struct {
Enabled bool `yaml:"enabled"`
Algorithm string `yaml:"algorithm"`
Issuer string `yaml:"issuer"`
Secret string `yaml:"secret"`
SecretKey ed25519.PublicKey
Audiences []string `yaml:"audiences"`
} }
func (c *ClientAPI) Defaults(generate bool) { func (c *ClientAPI) Defaults(generate bool) {

View file

@ -172,6 +172,8 @@ type Presence interface {
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
ExpirePresence(ctx context.Context) ([]types.PresenceNotify, error)
UpdateLastActive(ctx context.Context, userId string, lastActiveTs uint64) error
} }
type SharedUsers interface { type SharedUsers interface {

View file

@ -62,6 +62,10 @@ const upsertPresenceFromSyncSQL = "" +
" presence = $2, last_active_ts = $3" + " presence = $2, last_active_ts = $3" +
" RETURNING id" " RETURNING id"
const updateLastActiveSQL = `UPDATE syncapi_presence
SET last_active_ts = $1
WHERE user_id = $2`
const selectPresenceForUserSQL = "" + const selectPresenceForUserSQL = "" +
"SELECT presence, status_msg, last_active_ts" + "SELECT presence, status_msg, last_active_ts" +
" FROM syncapi_presence" + " FROM syncapi_presence" +
@ -76,12 +80,24 @@ const selectPresenceAfter = "" +
" WHERE id > $1 AND last_active_ts >= $2" + " WHERE id > $1 AND last_active_ts >= $2" +
" ORDER BY id ASC LIMIT $3" " ORDER BY id ASC LIMIT $3"
const expirePresenceSQL = `UPDATE syncapi_presence SET
id = nextval('syncapi_presence_id'),
presence = 3
WHERE
to_timestamp(last_active_ts / 1000) < NOW() - INTERVAL` + types.PresenceExpire + `
AND
presence != 3
RETURNING id, user_id
`
type presenceStatements struct { type presenceStatements struct {
upsertPresenceStmt *sql.Stmt upsertPresenceStmt *sql.Stmt
upsertPresenceFromSyncStmt *sql.Stmt upsertPresenceFromSyncStmt *sql.Stmt
selectPresenceForUsersStmt *sql.Stmt selectPresenceForUsersStmt *sql.Stmt
selectMaxPresenceStmt *sql.Stmt selectMaxPresenceStmt *sql.Stmt
selectPresenceAfterStmt *sql.Stmt selectPresenceAfterStmt *sql.Stmt
expirePresenceStmt *sql.Stmt
updateLastActiveStmt *sql.Stmt
} }
func NewPostgresPresenceTable(db *sql.DB) (*presenceStatements, error) { func NewPostgresPresenceTable(db *sql.DB) (*presenceStatements, error) {
@ -96,6 +112,8 @@ func NewPostgresPresenceTable(db *sql.DB) (*presenceStatements, error) {
{&s.selectPresenceForUsersStmt, selectPresenceForUserSQL}, {&s.selectPresenceForUsersStmt, selectPresenceForUserSQL},
{&s.selectMaxPresenceStmt, selectMaxPresenceSQL}, {&s.selectMaxPresenceStmt, selectMaxPresenceSQL},
{&s.selectPresenceAfterStmt, selectPresenceAfter}, {&s.selectPresenceAfterStmt, selectPresenceAfter},
{&s.expirePresenceStmt, expirePresenceSQL},
{&s.updateLastActiveStmt, updateLastActiveSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -166,3 +184,28 @@ func (p *presenceStatements) GetPresenceAfter(
} }
return presences, rows.Err() return presences, rows.Err()
} }
func (p *presenceStatements) ExpirePresence(
ctx context.Context,
) ([]types.PresenceNotify, error) {
rows, err := p.expirePresenceStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
presences := make([]types.PresenceNotify, 0)
i := 0
for rows.Next() {
presences = append(presences, types.PresenceNotify{})
err = rows.Scan(&presences[i].StreamPos, &presences[i].UserID)
if err != nil {
return nil, err
}
i++
}
return presences, err
}
func (p *presenceStatements) UpdateLastActive(ctx context.Context, userId string, lastActiveTs uint64) error {
_, err := p.updateLastActiveStmt.Exec(&lastActiveTs, &userId)
return err
}

View file

@ -1078,3 +1078,11 @@ func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.Stre
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
} }
func (s *Database) ExpirePresence(ctx context.Context) ([]types.PresenceNotify, error) {
return s.Presence.ExpirePresence(ctx)
}
func (s *Database) UpdateLastActive(ctx context.Context, userId string, lastActiveTs uint64) error {
return s.Presence.UpdateLastActive(ctx, userId, lastActiveTs)
}

View file

@ -180,3 +180,15 @@ func (p *presenceStatements) GetPresenceAfter(
} }
return presences, rows.Err() return presences, rows.Err()
} }
func (p *presenceStatements) ExpirePresence(
ctx context.Context,
) ([]types.PresenceNotify, error) {
// TODO implement
return nil, nil
}
func (p *presenceStatements) UpdateLastActive(ctx context.Context, userId string, lastActiveTs uint64) error {
// TODO implement
return nil
}

View file

@ -204,4 +204,6 @@ type Presence interface {
GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error)
GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error)
GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error)
ExpirePresence(ctx context.Context) ([]types.PresenceNotify, error)
UpdateLastActive(ctx context.Context, userId string, lastActiveTs uint64) error
} }

View file

@ -17,7 +17,6 @@ package streams
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"sync"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -26,8 +25,6 @@ import (
type PresenceStreamProvider struct { type PresenceStreamProvider struct {
StreamProvider StreamProvider
// cache contains previously sent presence updates to avoid unneeded updates
cache sync.Map
notifier *notifier.Notifier notifier *notifier.Notifier
} }
@ -103,18 +100,6 @@ func (p *PresenceStreamProvider) IncrementalSync(
if req.Device.UserID != presence.UserID && !p.notifier.IsSharedUser(req.Device.UserID, presence.UserID) { if req.Device.UserID != presence.UserID && !p.notifier.IsSharedUser(req.Device.UserID, presence.UserID) {
continue continue
} }
cacheKey := req.Device.UserID + req.Device.ID + presence.UserID
pres, ok := p.cache.Load(cacheKey)
if ok {
// skip already sent presence
prevPresence := pres.(*types.PresenceInternal)
currentlyActive := prevPresence.CurrentlyActive()
skip := prevPresence.Equals(presence) && currentlyActive && req.Device.UserID != presence.UserID
if skip {
req.Log.Tracef("Skipping presence, no change (%s)", presence.UserID)
continue
}
}
if _, known := types.PresenceFromString(presence.ClientFields.Presence); known { if _, known := types.PresenceFromString(presence.ClientFields.Presence); known {
presence.ClientFields.LastActiveAgo = presence.LastActiveAgo() presence.ClientFields.LastActiveAgo = presence.LastActiveAgo()
@ -142,7 +127,6 @@ func (p *PresenceStreamProvider) IncrementalSync(
if len(req.Response.Presence.Events) == req.Filter.Presence.Limit { if len(req.Response.Presence.Events) == req.Filter.Presence.Limit {
break break
} }
p.cache.Store(cacheKey, presence)
} }
if len(req.Response.Presence.Events) == 0 { if len(req.Response.Presence.Events) == 0 {

View file

@ -50,7 +50,7 @@ type RequestPool struct {
keyAPI keyapi.SyncKeyAPI keyAPI keyapi.SyncKeyAPI
rsAPI roomserverAPI.SyncRoomserverAPI rsAPI roomserverAPI.SyncRoomserverAPI
lastseen *sync.Map lastseen *sync.Map
presence *sync.Map Presence *sync.Map
streams *streams.Streams streams *streams.Streams
Notifier *notifier.Notifier Notifier *notifier.Notifier
producer PresencePublisher producer PresencePublisher
@ -85,14 +85,14 @@ func NewRequestPool(
keyAPI: keyAPI, keyAPI: keyAPI,
rsAPI: rsAPI, rsAPI: rsAPI,
lastseen: &sync.Map{}, lastseen: &sync.Map{},
presence: &sync.Map{}, Presence: &sync.Map{},
streams: streams, streams: streams,
Notifier: notifier, Notifier: notifier,
producer: producer, producer: producer,
consumer: consumer, consumer: consumer,
} }
go rp.cleanLastSeen() go rp.cleanLastSeen()
go rp.cleanPresence(db, time.Minute*5) // go rp.cleanPresence(db, time.Minute*5)
return rp return rp
} }
@ -111,11 +111,11 @@ func (rp *RequestPool) cleanPresence(db storage.Presence, cleanupTime time.Durat
return return
} }
for { for {
rp.presence.Range(func(key interface{}, v interface{}) bool { rp.Presence.Range(func(key interface{}, v interface{}) bool {
p := v.(types.PresenceInternal) p := v.(types.PresenceInternal)
if time.Since(p.LastActiveTS.Time()) > cleanupTime { if time.Since(p.LastActiveTS.Time()) > cleanupTime {
rp.updatePresence(db, types.PresenceUnavailable.String(), p.UserID) rp.updatePresence(db, types.PresenceUnavailable.String(), p.UserID)
rp.presence.Delete(key) rp.Presence.Delete(key)
} }
return true return true
}) })
@ -153,14 +153,23 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
} }
newPresence.ClientFields.Presence = presenceID.String() newPresence.ClientFields.Presence = presenceID.String()
defer rp.presence.Store(userID, newPresence) defer rp.Presence.Store(userID, newPresence)
// avoid spamming presence updates when syncing // avoid spamming presence updates when syncing
existingPresence, ok := rp.presence.LoadOrStore(userID, newPresence) existingPresence, ok := rp.Presence.LoadOrStore(userID, newPresence)
if ok { if ok {
p := existingPresence.(types.PresenceInternal) p := existingPresence.(types.PresenceInternal)
if p.ClientFields.Presence == newPresence.ClientFields.Presence { if dbPresence != nil {
if p.Presence == newPresence.Presence && newPresence.LastActiveTS-dbPresence.LastActiveTS < types.PresenceNoOpMs {
return return
} }
if dbPresence.Presence == types.PresenceOnline && presenceID == types.PresenceOnline && newPresence.LastActiveTS-dbPresence.LastActiveTS >= types.PresenceNoOpMs {
err := db.UpdateLastActive(context.Background(), userID, uint64(newPresence.LastActiveTS))
if err != nil {
logrus.WithError(err).Error("failed to update last active")
}
return
}
}
} }
if err := rp.producer.SendPresence(userID, presenceID, newPresence.ClientFields.StatusMsg); err != nil { if err := rp.producer.SendPresence(userID, presenceID, newPresence.ClientFields.StatusMsg); err != nil {
@ -247,7 +256,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
defer activeSyncRequests.Dec() defer activeSyncRequests.Dec()
rp.updateLastSeen(req, device) rp.updateLastSeen(req, device)
rp.updatePresence(rp.db, req.FormValue("set_presence"), device.UserID) rp.updatePresence(rp.db, "", device.UserID)
waitingSyncRequests.Inc() waitingSyncRequests.Inc()
defer waitingSyncRequests.Dec() defer waitingSyncRequests.Dec()

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -23,7 +24,9 @@ func (d *dummyPublisher) SendPresence(userID string, presence types.Presence, st
return nil return nil
} }
type dummyDB struct{} type dummyDB struct {
storage.Database
}
func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
return 0, nil return 0, nil
@ -109,7 +112,7 @@ func TestRequestPool_updatePresence(t *testing.T) {
}, },
} }
rp := &RequestPool{ rp := &RequestPool{
presence: &syncMap, Presence: &syncMap,
producer: publisher, producer: publisher,
consumer: consumer, consumer: consumer,
cfg: &config.SyncAPI{ cfg: &config.SyncAPI{

View file

@ -16,6 +16,7 @@ package syncapi
import ( import (
"context" "context"
"time"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -33,6 +34,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
) )
// AddPublicRoutes sets up and registers HTTP handlers for the SyncAPI // AddPublicRoutes sets up and registers HTTP handlers for the SyncAPI
@ -144,4 +146,24 @@ func AddPublicRoutes(
base.PublicClientAPIMux, requestPool, syncDB, userAPI, base.PublicClientAPIMux, requestPool, syncDB, userAPI,
rsAPI, cfg, base.Caches, rsAPI, cfg, base.Caches,
) )
go func() {
ctx := context.Background()
for {
notify, err := syncDB.ExpirePresence(ctx)
if err != nil {
logrus.WithError(err).Error("failed to expire presence")
}
for i := range notify {
requestPool.Presence.Store(notify[i].UserID, types.PresenceInternal{
Presence: types.PresenceOffline,
})
notifier.OnNewPresence(types.StreamingToken{
PresencePosition: notify[i].StreamPos,
}, notify[i].UserID)
}
time.Sleep(types.PresenceExpireInterval)
}
}()
} }

View file

@ -21,6 +21,12 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const (
PresenceNoOpMs = 60_000
PresenceExpire = "'4 minutes'"
PresenceExpireInterval = time.Second * 30
)
type Presence uint8 type Presence uint8
const ( const (
@ -66,6 +72,11 @@ type PresenceInternal struct {
Presence Presence `json:"-"` Presence Presence `json:"-"`
} }
type PresenceNotify struct {
StreamPos StreamPosition
UserID string
}
// Equals compares p1 with p2. // Equals compares p1 with p2.
func (p1 *PresenceInternal) Equals(p2 *PresenceInternal) bool { func (p1 *PresenceInternal) Equals(p2 *PresenceInternal) bool {
return p1.ClientFields.Presence == p2.ClientFields.Presence && return p1.ClientFields.Presence == p2.ClientFields.Presence &&

View file

@ -22,6 +22,7 @@ import (
"strings" "strings"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -282,9 +283,11 @@ func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
if len(tok) < 1 { if len(tok) < 1 {
err = ErrMalformedSyncToken err = ErrMalformedSyncToken
logrus.WithField("token", tok).Info("invalid stream token: bad length")
return return
} }
if tok[0] != SyncTokenTypeStream[0] { if tok[0] != SyncTokenTypeStream[0] {
logrus.WithField("token", tok).Info("invalid stream token: not starting from s")
err = ErrMalformedSyncToken err = ErrMalformedSyncToken
return return
} }
@ -300,6 +303,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
var pos int var pos int
pos, err = strconv.Atoi(p) pos, err = strconv.Atoi(p)
if err != nil { if err != nil {
logrus.WithField("token", tok).Info("invalid stream token: strconv")
err = ErrMalformedSyncToken err = ErrMalformedSyncToken
return return
} }

View file

@ -49,3 +49,11 @@ Notifications can be viewed with GET /notifications
If remote user leaves room we no longer receive device updates If remote user leaves room we no longer receive device updates
Guest users can join guest_access rooms Guest users can join guest_access rooms
# You'll be shocked to discover this is flakey too
Inbound /v1/send_join rejects joins from other servers
# For notifications extension on iOS
/event/ does not allow access to events before the user joined

View file

@ -205,7 +205,6 @@ Deleted tags appear in an incremental v2 /sync
/event/ on non world readable room does not work /event/ on non world readable room does not work
Outbound federation can query profile data Outbound federation can query profile data
/event/ on joined room works /event/ on joined room works
/event/ does not allow access to events before the user joined
Federation key API allows unsigned requests for keys Federation key API allows unsigned requests for keys
GET /publicRooms lists rooms GET /publicRooms lists rooms
GET /publicRooms includes avatar URLs GET /publicRooms includes avatar URLs
@ -744,3 +743,4 @@ User in private room doesn't appear in user directory
User joining then leaving public room appears and dissappears from directory User joining then leaving public room appears and dissappears from directory
User in remote room doesn't appear in user directory after server left room User in remote room doesn't appear in user directory after server left room
User in shared private room does appear in user directory until leave User in shared private room does appear in user directory until leave
Existing members see new member's presence

View file

@ -518,7 +518,7 @@ type PerformPusherSetRequest struct {
type PerformPusherDeletionRequest struct { type PerformPusherDeletionRequest struct {
Localpart string Localpart string
SessionID int64 SessionID int64 // Pusher corresponding to this SessionID will not be deleted
} }
// Pusher represents a push notification subscriber // Pusher represents a push notification subscriber

View file

@ -529,7 +529,9 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
case "event_id_only": case "event_id_only":
req = pushgateway.NotifyRequest{ req = pushgateway.NotifyRequest{
Notification: pushgateway.Notification{ Notification: pushgateway.Notification{
Counts: &pushgateway.Counts{}, Counts: &pushgateway.Counts{
Unread: userNumUnreadNotifs,
},
Devices: devices, Devices: devices,
EventID: event.EventID(), EventID: event.EventID(),
RoomID: event.RoomID(), RoomID: event.RoomID(),

View file

@ -98,6 +98,11 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
Up: deltas.UpAddAccountType, Up: deltas.UpAddAccountType,
Down: deltas.DownAddAccountType, Down: deltas.DownAddAccountType,
}, },
{
Version: "userapi: no guests",
Up: deltas.UpNoGuests,
Down: deltas.DownNoGuests,
},
}...) }...)
err = m.Up(context.Background()) err = m.Up(context.Background())
if err != nil { if err != nil {

View file

@ -0,0 +1,20 @@
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpNoGuests(ctx context.Context, tx *sql.Tx) error {
// AddAccountType introduced a bug where each user that had was registered as a regular user, but without user_id, became a guest.
_, err := tx.ExecContext(ctx, "UPDATE account_accounts SET account_type = 1 WHERE account_type = 2;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownNoGuests(ctx context.Context, tx *sql.Tx) error {
return nil
}

View file

@ -71,7 +71,7 @@ const selectNotificationSQL = "" +
") AND NOT read ORDER BY localpart, id LIMIT $4" ") AND NOT read ORDER BY localpart, id LIMIT $4"
const selectNotificationCountSQL = "" + const selectNotificationCountSQL = "" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + "SELECT COUNT(DISTINCT(room_id)) FROM userapi_notifications WHERE localpart = $1 AND (" +
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
") AND NOT read" ") AND NOT read"

View file

@ -71,7 +71,7 @@ const selectNotificationSQL = "" +
") AND NOT read ORDER BY localpart, id LIMIT $4" ") AND NOT read ORDER BY localpart, id LIMIT $4"
const selectNotificationCountSQL = "" + const selectNotificationCountSQL = "" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + "SELECT COUNT(DISTINCT(room_id)) FROM userapi_notifications WHERE localpart = $1 AND (" +
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
") AND NOT read" ") AND NOT read"

View file

@ -520,7 +520,7 @@ func Test_Notification(t *testing.T) {
// get notifications // get notifications
count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications) count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
assert.NoError(t, err, "unable to get notification count") assert.NoError(t, err, "unable to get notification count")
assert.Equal(t, int64(10), count) assert.Equal(t, int64(2), count)
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications) notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
assert.NoError(t, err, "unable to get notifications") assert.NoError(t, err, "unable to get notifications")
assert.Equal(t, int64(10), count) assert.Equal(t, int64(10), count)