diff --git a/.cloudbuild/dev.yaml b/.cloudbuild/dev.yaml new file mode 100644 index 000000000..de6148b38 --- /dev/null +++ b/.cloudbuild/dev.yaml @@ -0,0 +1,12 @@ +steps: + - name: gcr.io/cloud-builders/docker + args: ['build', '-t', 'gcr.io/$PROJECT_ID/dendrite-monolith:$COMMIT_SHA', '-f', 'build/docker/Dockerfile.monolith', '.'] + - name: gcr.io/cloud-builders/kubectl + args: ['-n', 'dendrite', 'set', 'image', 'deployment/dendrite', 'dendrite=gcr.io/$PROJECT_ID/dendrite-monolith:$COMMIT_SHA'] + env: + - CLOUDSDK_CORE_PROJECT=globekeeper-development + - CLOUDSDK_COMPUTE_ZONE=europe-west2-a + - CLOUDSDK_CONTAINER_CLUSTER=synapse +images: + - gcr.io/$PROJECT_ID/dendrite-monolith:$COMMIT_SHA +timeout: 360s diff --git a/.cloudbuild/prod.yaml b/.cloudbuild/prod.yaml new file mode 100644 index 000000000..cbde9043a --- /dev/null +++ b/.cloudbuild/prod.yaml @@ -0,0 +1,12 @@ +steps: + - name: gcr.io/cloud-builders/docker + args: ['build', '-t', 'gcr.io/$PROJECT_ID/dendrite-monolith:$TAG_NAME', '-f', 'build/docker/Dockerfile.monolith', '.'] + - name: gcr.io/cloud-builders/kubectl + args: ['set', 'image', 'deployment/dendrite', 'dendrite=gcr.io/$PROJECT_ID/dendrite-monolith:$TAG_NAME'] + env: + - CLOUDSDK_CORE_PROJECT=globekeeper-production + - CLOUDSDK_COMPUTE_ZONE=europe-west2-a + - CLOUDSDK_CONTAINER_CLUSTER=synapse-production +images: + - gcr.io/$PROJECT_ID/dendrite-monolith:$TAG_NAME +timeout: 360s diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index be3c7c173..2b978dc7a 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -14,51 +14,6 @@ concurrency: cancel-in-progress: true jobs: - wasm: - name: WASM build test - timeout-minutes: 5 - runs-on: ubuntu-latest - if: ${{ false }} # disable for now - steps: - - uses: actions/checkout@v3 - - - name: Install Go - uses: actions/setup-go@v3 - with: - go-version: 1.18 - - - uses: actions/cache@v2 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-wasm-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go-wasm - - - name: Install Node - uses: actions/setup-node@v2 - with: - node-version: 14 - - - uses: actions/cache@v2 - with: - path: ~/.npm - key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} - restore-keys: | - ${{ runner.os }}-node- - - - name: Reconfigure Git to use HTTPS auth for repo packages - run: > - git config --global url."https://github.com/".insteadOf - ssh://git@github.com/ - - - name: Install test dependencies - working-directory: ./test/wasm - run: npm ci - - - name: Test - run: ./test-dendritejs.sh # Run golangci-lint lint: @@ -74,7 +29,7 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v3 - # run go test with different go versions + # run go test with go 1.18 test: timeout-minutes: 5 name: Unit tests (Go ${{ matrix.go }}) @@ -124,7 +79,7 @@ jobs: POSTGRES_PASSWORD: postgres POSTGRES_DB: dendrite - # build Dendrite for linux with different architectures and go versions + # build Dendrite for linux amd64 with go 1.18 build: name: Build for Linux timeout-minutes: 10 @@ -134,7 +89,7 @@ jobs: matrix: go: ["1.18", "1.19"] goos: ["linux"] - goarch: ["amd64", "386"] + goarch: ["amd64"] steps: - uses: actions/checkout@v3 - name: Setup go @@ -159,43 +114,10 @@ jobs: CGO_CFLAGS: -fno-stack-protector run: go build -trimpath -v -o "bin/" ./cmd/... - # build for Windows 64-bit - build_windows: - name: Build for Windows - timeout-minutes: 10 - runs-on: ubuntu-latest - strategy: - matrix: - go: ["1.18", "1.19"] - goos: ["windows"] - goarch: ["amd64"] - steps: - - uses: actions/checkout@v3 - - name: Setup Go ${{ matrix.go }} - uses: actions/setup-go@v3 - with: - go-version: ${{ matrix.go }} - - name: Install dependencies - run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }} - - env: - GOOS: ${{ matrix.goos }} - GOARCH: ${{ matrix.goarch }} - CGO_ENABLED: 1 - CC: "/usr/bin/x86_64-w64-mingw32-gcc" - run: go build -trimpath -v -o "bin/" ./cmd/... - # Dummy step to gate other tests on without repeating the whole list initial-tests-done: name: Initial tests passed - needs: [lint, test, build, build_windows] + needs: [lint, test, build] runs-on: ubuntu-latest if: ${{ !cancelled() }} # Run this even if prior jobs were skipped steps: @@ -300,13 +222,12 @@ jobs: run: /src/are-we-synapse-yet.py /logs/results.tap -v continue-on-error: true # not fatal - name: Upload Sytest logs - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 if: ${{ always() }} with: name: Sytest Logs - ${{ job.status }} - (Dendrite, ${{ join(matrix.*, ', ') }}) path: | - /logs/results.tap - /logs/**/*.log* + /logs # run Complement complement: @@ -370,7 +291,7 @@ jobs: continue fi - (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break + (wget -O - "https://github.com/globekeeper/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break done # Build initial Dendrite image diff --git a/.gitignore b/.gitignore index e4f0112c4..662d3ae97 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ # Hidden files .* +!.vscode +!.cloudbuild # Allow GitHub config !.github @@ -73,3 +75,7 @@ complement/ docs/_site media_store/ + +__debug_bin + +cmd/dendrite-monolith-server/dendrite-monolith-server \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..6142a8df0 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + "configurations": [ + { + "name": "Launch Package", + "type": "go", + "request": "launch", + "mode": "auto", + "program": "${workspaceFolder}/cmd/dendrite-monolith-server", + "args": [ + "-really-enable-open-registration", + "-config", + "../../../adminas/.ci/config/dendrite-local/dendrite.yaml" + ], + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..f9731b7f8 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,9 @@ +{ + "go.lintTool": "golangci-lint", + "go.testEnvVars": { + "POSTGRES_HOST": "localhost", + "POSTGRES_USER": "postgres", + "POSTGRES_PASSWORD": "foobar", + "POSTGRES_DB": "postgres" + } +} \ No newline at end of file diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 44eab4cd7..bc7c2b0f5 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -22,7 +22,6 @@ import ( "encoding/hex" "fmt" "io" - "io/ioutil" "net" "net/http" "os" @@ -212,11 +211,11 @@ func (m *DendriteMonolith) Start() { if pk, sk, err = ed25519.GenerateKey(nil); err != nil { panic(err) } - if err = ioutil.WriteFile(keyfile, sk, 0644); err != nil { + if err = os.WriteFile(keyfile, sk, 0644); err != nil { panic(err) } } else if err == nil { - if sk, err = ioutil.ReadFile(keyfile); err != nil { + if sk, err = os.ReadFile(keyfile); err != nil { panic(err) } if len(sk) != ed25519.PrivateKeySize { diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index f01e48f80..00253fede 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -11,4 +11,6 @@ const ( LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" LoginTypeToken = "m.login.token" + LoginTypeJwt = "org.matrix.login.jwt" + LoginTypeEmail = "m.login.email.identity" ) diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go index 5467e814d..fbae6f792 100644 --- a/clientapi/auth/login.go +++ b/clientapi/auth/login.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" @@ -32,7 +33,7 @@ import ( // called after authorization has completed, with the result of the authorization. // If the final return value is non-nil, an error occurred and the cleanup function // is nil. -func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserLoginAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { +func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.ClientUserAPI, cfg *config.ClientAPI, rt *ratelimit.RtFailedLogin) (*Login, LoginCleanupFunc, *util.JSONResponse) { reqBytes, err := io.ReadAll(r) if err != nil { err := &util.JSONResponse{ @@ -57,14 +58,19 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.U switch header.Type { case authtypes.LoginTypePassword: typ = &LoginTypePassword{ - GetAccountByPassword: useraccountAPI.QueryAccountByPassword, - Config: cfg, + UserApi: useraccountAPI, + Config: cfg, + Rt: rt, } case authtypes.LoginTypeToken: typ = &LoginTypeToken{ - UserAPI: userAPI, + UserAPI: useraccountAPI, Config: cfg, } + case authtypes.LoginTypeJwt: + typ = &LoginTypeTokenJwt{ + Config: cfg, + } default: err := util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/auth/login_jwt.go b/clientapi/auth/login_jwt.go new file mode 100644 index 000000000..35c7d1948 --- /dev/null +++ b/clientapi/auth/login_jwt.go @@ -0,0 +1,74 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + + "github.com/golang-jwt/jwt/v4" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/util" +) + +// LoginTypeToken describes how to authenticate with a login token. +type LoginTypeTokenJwt struct { + // UserAPI uapi.LoginTokenInternalAPI + Config *config.ClientAPI +} + +// Name implements Type. +func (t *LoginTypeTokenJwt) Name() string { + return authtypes.LoginTypeJwt +} + +type Claims struct { + jwt.StandardClaims +} + +const mIdUser = "m.id.user" + +// LoginFromJSON implements Type. The cleanup function deletes the token from +// the database on success. +func (t *LoginTypeTokenJwt) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r loginTokenRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + if r.Token == "" { + return nil, nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Token field for JWT is missing"), + } + } + c := &Claims{} + token, err := jwt.ParseWithClaims(r.Token, c, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Method.Alg()) + } + return t.Config.JwtConfig.SecretKey, nil + }) + + if err != nil { + util.GetLogger(ctx).WithError(err).Error("jwt.ParseWithClaims failed") + return nil, nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Couldn't parse JWT"), + } + } + + if !token.Valid { + return nil, nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Invalid JWT"), + } + } + + r.Login.Identifier.User = c.Subject + r.Login.Identifier.Type = mIdUser + + return &r.Login, func(context.Context, *util.JSONResponse) {}, nil +} diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index 5085f0170..4017c26d5 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" @@ -68,8 +69,11 @@ func TestLoginFromJSONReader(t *testing.T) { Matrix: &config.Global{ ServerName: serverName, }, + RtFailedLogin: ratelimit.RtFailedLoginConfig{ + Enabled: false, + }, } - login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) + login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil) if err != nil { t.Fatalf("LoginFromJSONReader failed: %+v", err) } @@ -147,7 +151,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { ServerName: serverName, }, } - _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) + _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil) if errRes == nil { cleanup(ctx, nil) t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) @@ -159,6 +163,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { } type fakeUserInternalAPI struct { + uapi.ClientUserAPI UserInternalAPIForLogin DeletedTokens []string } diff --git a/clientapi/auth/login_token.go b/clientapi/auth/login_token.go index 845eb5de9..293b9a460 100644 --- a/clientapi/auth/login_token.go +++ b/clientapi/auth/login_token.go @@ -58,7 +58,7 @@ func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*L } } - r.Login.Identifier.Type = "m.id.user" + r.Login.Identifier.Type = mIdUser r.Login.Identifier.User = res.Data.UserID cleanup := func(ctx context.Context, authRes *util.JSONResponse) { diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index bcb4ca97b..54019a8a8 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" @@ -33,12 +34,17 @@ type GetAccountByPassword func(ctx context.Context, req *api.QueryAccountByPassw type PasswordRequest struct { Login Password string `json:"password"` + Address string `json:"address"` + Medium string `json:"medium"` } +const email = "email" + // LoginTypePassword implements https://matrix.org/docs/spec/client_server/r0.6.1#password-based type LoginTypePassword struct { - GetAccountByPassword GetAccountByPassword - Config *config.ClientAPI + UserApi api.ClientUserAPI + Config *config.ClientAPI + Rt *ratelimit.RtFailedLogin } func (t *LoginTypePassword) Name() string { @@ -61,7 +67,35 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { r := req.(*PasswordRequest) - username := strings.ToLower(r.Username()) + if r.Identifier.Address != "" { + r.Address = r.Identifier.Address + } + if r.Identifier.Medium != "" { + r.Medium = r.Identifier.Medium + } + var username string + if r.Medium == email && r.Address != "" { + r.Address = strings.ToLower(r.Address) + res := api.QueryLocalpartForThreePIDResponse{} + err := t.UserApi.QueryLocalpartForThreePID(ctx, &api.QueryLocalpartForThreePIDRequest{ + ThreePID: r.Address, + Medium: email, + }, &res) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("userApi.QueryLocalpartForThreePID failed") + resp := jsonerror.InternalServerError() + return nil, &resp + } + username = res.Localpart + if username == "" { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.Forbidden("Invalid username or password"), + } + } + } else { + username = strings.ToLower(r.Username()) + } if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, @@ -77,7 +111,17 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } // Squash username to all lowercase letters res := &api.QueryAccountByPasswordResponse{} - err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res) + localpart = strings.ToLower(localpart) + if t.Rt != nil { + ok, retryIn := t.Rt.CanAct(localpart) + if !ok { + return nil, &util.JSONResponse{ + Code: http.StatusTooManyRequests, + JSON: jsonerror.LimitExceeded("Too Many Requests", retryIn.Milliseconds()), + } + } + } + err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: localpart, PlaintextPassword: r.Password}, res) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, @@ -86,7 +130,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } if !res.Exists { - err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ + err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ Localpart: localpart, PlaintextPassword: r.Password, }, res) @@ -99,11 +143,15 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // but that would leak the existence of the user. if !res.Exists { + if t.Rt != nil { + t.Rt.Act(localpart) + } return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), + JSON: jsonerror.Forbidden("Invalid username or password"), } } } + r.Login.User = username return &r.Login, nil } diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 9971bf8a4..3b2473ea2 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -75,7 +75,7 @@ type Login struct { // Username returns the user localpart/user_id in this request, if it exists. func (r *Login) Username() string { - if r.Identifier.Type == "m.id.user" { + if r.Identifier.Type == mIdUser { return r.Identifier.User } // deprecated but without it Element iOS won't log in @@ -88,8 +88,8 @@ func (r *Login) ThirdPartyID() (medium, address string) { return r.Identifier.Medium, r.Identifier.Address } // deprecated - if r.Medium == "email" { - return "email", r.Address + if r.Medium == email { + return email, r.Address } return "", "" } @@ -111,10 +111,10 @@ type UserInteractive struct { Sessions map[string][]string } -func NewUserInteractive(userAccountAPI api.UserLoginAPI, cfg *config.ClientAPI) *UserInteractive { +func NewUserInteractive(userAccountAPI api.ClientUserAPI, cfg *config.ClientAPI) *UserInteractive { typePassword := &LoginTypePassword{ - GetAccountByPassword: userAccountAPI.QueryAccountByPassword, - Config: cfg, + UserApi: userAccountAPI, + Config: cfg, } return &UserInteractive{ Flows: []userInteractiveFlow{ diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 001b1a6d4..8267d2222 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -24,7 +24,9 @@ var ( } ) -type fakeAccountDatabase struct{} +type fakeAccountDatabase struct { + api.ClientUserAPI +} func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { return nil diff --git a/clientapi/ratelimit/rt_failed_login.go b/clientapi/ratelimit/rt_failed_login.go new file mode 100644 index 000000000..291af581d --- /dev/null +++ b/clientapi/ratelimit/rt_failed_login.go @@ -0,0 +1,117 @@ +package ratelimit + +import ( + "container/list" + "sync" + "time" +) + +type rateLimit struct { + cfg *RtFailedLoginConfig + times *list.List +} + +type RtFailedLogin struct { + cfg *RtFailedLoginConfig + mtx sync.RWMutex + rts map[string]*rateLimit +} + +type RtFailedLoginConfig struct { + Enabled bool `yaml:"enabled"` + Limit int `yaml:"burst"` + Interval time.Duration `yaml:"interval"` +} + +// New creates a new rate limiter for the limit and interval. +func NewRtFailedLogin(cfg *RtFailedLoginConfig) *RtFailedLogin { + if !cfg.Enabled { + return nil + } + rt := &RtFailedLogin{ + cfg: cfg, + mtx: sync.RWMutex{}, + rts: make(map[string]*rateLimit), + } + go rt.clean() + return rt +} + +// CanAct is expected to be called before Act +func (r *RtFailedLogin) CanAct(key string) (ok bool, remaining time.Duration) { + r.mtx.RLock() + rt, ok := r.rts[key] + if !ok { + r.mtx.RUnlock() + return true, 0 + } + ok, remaining = rt.canAct() + r.mtx.RUnlock() + return +} + +// Act can be called after CanAct returns true. +func (r *RtFailedLogin) Act(key string) { + r.mtx.Lock() + rt, ok := r.rts[key] + if !ok { + rt = &rateLimit{ + cfg: r.cfg, + times: list.New(), + } + r.rts[key] = rt + } + rt.act() + r.mtx.Unlock() +} + +func (r *RtFailedLogin) clean() { + for { + r.mtx.Lock() + for k, v := range r.rts { + if v.empty() { + delete(r.rts, k) + } + } + r.mtx.Unlock() + time.Sleep(time.Hour) + } +} + +func (r *rateLimit) empty() bool { + back := r.times.Back() + if back == nil { + return true + } + v := back.Value + b := v.(time.Time) + now := time.Now() + return now.Sub(b) > r.cfg.Interval +} + +func (r *rateLimit) canAct() (ok bool, remaining time.Duration) { + now := time.Now() + l := r.times.Len() + if l < r.cfg.Limit { + return true, 0 + } + frnt := r.times.Front() + t := frnt.Value.(time.Time) + diff := now.Sub(t) + if diff < r.cfg.Interval { + return false, r.cfg.Interval - diff + } + return true, 0 +} + +func (r *rateLimit) act() { + now := time.Now() + l := r.times.Len() + if l < r.cfg.Limit { + r.times.PushBack(now) + return + } + frnt := r.times.Front() + frnt.Value = now + r.times.MoveToBack(frnt) +} diff --git a/clientapi/ratelimit/rt_failed_login_test.go b/clientapi/ratelimit/rt_failed_login_test.go new file mode 100644 index 000000000..5281bc765 --- /dev/null +++ b/clientapi/ratelimit/rt_failed_login_test.go @@ -0,0 +1,40 @@ +package ratelimit + +import ( + "testing" + "time" + + "github.com/matryer/is" +) + +func TestRtFailedLogin(t *testing.T) { + is := is.New(t) + rtfl := NewRtFailedLogin(&RtFailedLoginConfig{ + Enabled: true, + Limit: 3, + Interval: 10 * time.Millisecond, + }) + var ( + can bool + remaining time.Duration + remainingB time.Duration + ) + for i := 0; i < 3; i++ { + can, remaining = rtfl.CanAct("foo") + is.True(can) + is.Equal(remaining, time.Duration(0)) + rtfl.Act("foo") + } + can, remaining = rtfl.CanAct("foo") + is.True(!can) + is.True(remaining > time.Millisecond*9) + can, remainingB = rtfl.CanAct("bar") + is.True(can) + is.Equal(remainingB, time.Duration(0)) + rtfl.Act("bar") + rtfl.Act("bar") + time.Sleep(remaining + time.Millisecond) + can, remaining = rtfl.CanAct("foo") + is.True(can) + is.Equal(remaining, time.Duration(0)) +} diff --git a/clientapi/routing/deactivate.go b/clientapi/routing/deactivate.go index f213db7f3..9640b7f59 100644 --- a/clientapi/routing/deactivate.go +++ b/clientapi/routing/deactivate.go @@ -27,13 +27,17 @@ func Deactivate( JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), } } - - login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, deviceAPI) - if errRes != nil { - return *errRes + var userId string + if deviceAPI.AccountType != api.AccountTypeAppService { + login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, deviceAPI) + if errRes != nil { + return *errRes + } + userId = login.Username() + } else { + userId = deviceAPI.UserID } - - localpart, _, err := gomatrixserverlib.SplitID('@', login.Username()) + localpart, _, err := gomatrixserverlib.SplitID('@', userId) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() diff --git a/clientapi/routing/getevent.go b/clientapi/routing/getevent.go index 7f5842800..836935ca8 100644 --- a/clientapi/routing/getevent.go +++ b/clientapi/routing/getevent.go @@ -82,7 +82,7 @@ func GetEvent( }}, } var stateResp api.QueryStateAfterEventsResponse - if err := rsAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil { + if err = rsAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil { util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryStateAfterEvents failed") return jsonerror.InternalServerError() } @@ -118,12 +118,13 @@ func GetEvent( } else if !stateEvent.StateKeyEquals(device.UserID) { continue } - membership, err := stateEvent.Membership() + var membership string + membership, err = stateEvent.Membership() if err != nil { util.GetLogger(req.Context()).WithError(err).Error("stateEvent.Membership failed") return jsonerror.InternalServerError() } - if membership == gomatrixserverlib.Join { + if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { return util.JSONResponse{ Code: http.StatusOK, JSON: gomatrixserverlib.ToClientEvent(r.requestedEvent, gomatrixserverlib.FormatAll), @@ -131,8 +132,28 @@ func GetEvent( } } - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"), + // we might fail to retrieve correct state above, let's check user membership and allow to fetch event if they are invited or joined, since we always use m.room.history_visibility shared. + var membershipRes api.QueryMembershipForUserResponse + ctx := req.Context() + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ + RoomID: roomID, + UserID: device.UserID, + }, &membershipRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") + return jsonerror.InternalServerError() + } + // If the user has never been in the room then stop at this point. + // We won't tell the user about a room they have never joined. + if !membershipRes.HasBeenInRoom && membershipRes.Membership != gomatrixserverlib.Invite || membershipRes.Membership == gomatrixserverlib.Ban { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"), + } + } else { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: gomatrixserverlib.ToClientEvent(r.requestedEvent, gomatrixserverlib.FormatAll), + } } } diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 2570db09c..ca6ecefd2 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -63,8 +63,8 @@ func UploadCrossSigningDeviceKeys( } } typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountAPI.QueryAccountByPassword, - Config: cfg, + UserApi: accountAPI, + Config: cfg, } if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { return *authErr diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 6017b5840..cae24ce96 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -55,6 +56,7 @@ func passwordLogin() flows { func Login( req *http.Request, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + rt *ratelimit.RtFailedLogin, ) util.JSONResponse { if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options @@ -63,7 +65,7 @@ func Login( JSON: passwordLogin(), } } else if req.Method == http.MethodPost { - login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, userAPI, cfg) + login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, cfg, rt) if authErr != nil { return *authErr } diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 6dc9af508..44ca153f2 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -1,12 +1,14 @@ package routing import ( + "fmt" "net/http" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -24,6 +26,7 @@ type newPasswordAuth struct { Type string `json:"type"` Session string `json:"session"` auth.PasswordRequest + ThreePidCreds threepid.Credentials `json:"threepid_creds"` } func Password( @@ -33,13 +36,17 @@ func Password( cfg *config.ClientAPI, ) util.JSONResponse { // Check that the existing password is right. + var fields logrus.Fields + if device != nil { + fields = logrus.Fields{ + "sessionId": device.SessionID, + "userId": device.UserID, + } + } var r newPasswordRequest r.LogoutDevices = true - logrus.WithFields(logrus.Fields{ - "sessionId": device.SessionID, - "userId": device.UserID, - }).Debug("Changing password") + logrus.WithFields(fields).Debug("Changing password") // Unmarshal the request. resErr := httputil.UnmarshalJSONRequest(req, &r) @@ -53,45 +60,95 @@ func Password( // Generate a new, random session ID sessionID = util.RandomString(sessionIDLength) } - - // Require password auth to change the password. - if r.Auth.Type != authtypes.LoginTypePassword { - return util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: newUserInteractiveResponse( - sessionID, - []authtypes.Flow{ - { - Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, - }, + var localpart string + switch r.Auth.Type { + case authtypes.LoginTypePassword: + // Check if the existing password is correct. + typePassword := auth.LoginTypePassword{ + UserApi: userAPI, + Config: cfg, + } + if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { + return *authErr + } + // Get the local part. + var err error + localpart, _, err = gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError() + } + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + case authtypes.LoginTypeEmail: + threePid := &authtypes.ThreePID{} + r.Auth.ThreePidCreds.IDServer = cfg.ThreePidDelegate + var ( + bound bool + err error + ) + bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") + return jsonerror.InternalServerError() + } + if !bound { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MatrixError{ + ErrCode: "M_THREEPID_AUTH_FAILED", + Err: "Failed to auth 3pid", }, - nil, - ), + } + } + var res api.QueryLocalpartForThreePIDResponse + err = userAPI.QueryLocalpartForThreePID(req.Context(), &api.QueryLocalpartForThreePIDRequest{ + Medium: threePid.Medium, + ThreePID: threePid.Address, + }, &res) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryLocalpartForThreePID failed") + return jsonerror.InternalServerError() + } + if res.Localpart == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MatrixError{ + ErrCode: "M_THREEPID_NOT_FOUND", + Err: "3pid is not bound to any account", + }, + } + } + localpart = res.Localpart + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail) + default: + flows := []authtypes.Flow{ + { + Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, + }, + } + if cfg.ThreePidDelegate != "" { + flows = append(flows, authtypes.Flow{ + Stages: []authtypes.LoginType{authtypes.LoginTypeEmail}, + }) + } + // Require password auth to change the password. + if r.Auth.Type == authtypes.LoginTypePassword { + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + sessionID, + flows, + nil, + ), + } } } - // Check if the existing password is correct. - typePassword := auth.LoginTypePassword{ - GetAccountByPassword: userAPI.QueryAccountByPassword, - Config: cfg, - } - if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { - return *authErr - } - sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) - // Check the new password strength. if resErr = validatePassword(r.NewPassword); resErr != nil { return *resErr } - // Get the local part. - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - // Ask the user API to perform the password change. passwordReq := &api.PerformPasswordUpdateRequest{ Localpart: localpart, @@ -109,11 +166,23 @@ func Password( // If the request asks us to log out all other devices then // ask the user API to do that. + if r.LogoutDevices { - logoutReq := &api.PerformDeviceDeletionRequest{ - UserID: device.UserID, - DeviceIDs: nil, - ExceptDeviceID: device.ID, + var logoutReq *api.PerformDeviceDeletionRequest + var sessionId int64 + if device == nil { + logoutReq = &api.PerformDeviceDeletionRequest{ + UserID: fmt.Sprintf("@%s:%s", localpart, cfg.Matrix.ServerName), + DeviceIDs: []string{}, + } + sessionId = 0 + } else { + logoutReq = &api.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: nil, + ExceptDeviceID: device.ID, + } + sessionId = device.SessionID } logoutRes := &api.PerformDeviceDeletionResponse{} if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil { @@ -123,7 +192,7 @@ func Password( pushersReq := &api.PerformPusherDeletionRequest{ Localpart: localpart, - SessionID: device.SessionID, + SessionID: sessionId, } if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 0685c7352..af4702fc6 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -105,12 +105,6 @@ func SetAvatarURL( if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr } - if r.AvatarURL == "" { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("'avatar_url' must be supplied."), - } - } localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go index d6a6eb936..48d319ebd 100644 --- a/clientapi/routing/pusher.go +++ b/clientapi/routing/pusher.go @@ -86,8 +86,8 @@ func SetPusher( if err != nil { return invalidParam("malformed url passed") } - if pushUrl.Scheme != "https" { - return invalidParam("only https scheme is allowed") + if pushUrl.Scheme != "https" && pushUrl.Scheme != "http" { + return invalidParam("only https and http schemes are allowed") } } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 0bda1e488..23151883d 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -44,6 +44,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/clientapi/userutil" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -237,6 +238,7 @@ type authDict struct { // Recaptcha Response string `json:"response"` // TODO: Lots of custom keys depending on the type + ThreePidCreds threepid.Credentials `json:"threepid_creds"` } // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api @@ -745,6 +747,7 @@ func handleRegistrationFlow( } } + var threePid *authtypes.ThreePID switch r.Auth.Type { case authtypes.LoginTypeRecaptcha: // Check given captcha response @@ -761,6 +764,29 @@ func handleRegistrationFlow( // Add Dummy to the list of completed registration stages sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) + case authtypes.LoginTypeEmail: + threePid = &authtypes.ThreePID{} + r.Auth.ThreePidCreds.IDServer = cfg.ThreePidDelegate + var ( + bound bool + err error + ) + bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") + return jsonerror.InternalServerError() + } + if !bound { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MatrixError{ + ErrCode: "M_THREEPID_AUTH_FAILED", + Err: "Failed to auth 3pid", + }, + } + } + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail) + case "": // An empty auth type means that we want to fetch the available // flows. It can also mean that we want to register as an appservice @@ -776,7 +802,7 @@ func handleRegistrationFlow( // A response with current registration flow and remaining available methods // will be returned if a flow has not been successfully completed yet return checkAndCompleteFlow(sessions.getCompletedStages(sessionID), - req, r, sessionID, cfg, userAPI) + req, r, sessionID, cfg, userAPI, threePid) } // handleApplicationServiceRegistration handles the registration of an @@ -818,7 +844,7 @@ func handleApplicationServiceRegistration( // application service registration is entirely separate. return completeRegistration( req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, nil, ) } @@ -832,12 +858,13 @@ func checkAndCompleteFlow( sessionID string, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI, + threePid *authtypes.ThreePID, ) util.JSONResponse { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, threePid, ) } sessions.addParams(sessionID, r) @@ -863,6 +890,7 @@ func completeRegistration( inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, accType userapi.AccountType, + threePid *authtypes.ThreePID, ) util.JSONResponse { if username == "" { return util.JSONResponse{ @@ -901,6 +929,21 @@ func completeRegistration( // Increment prometheus counter for created users amtRegUsers.Inc() + // TODO-entry refuse register if threepid is already bound to account. + if threePid != nil { + err = userAPI.PerformSaveThreePIDAssociation(ctx, &userapi.PerformSaveThreePIDAssociationRequest{ + Medium: threePid.Medium, + ThreePID: threePid.Address, + Localpart: accRes.Account.Localpart, + }, &struct{}{}) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("Failed to save 3PID association: " + err.Error()), + } + } + } + // Check whether inhibit_login option is set. If so, don't create an access // token or a device for this user if inhibitLogin { @@ -1092,5 +1135,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d7a48d228..52178997e 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -26,6 +26,7 @@ import ( clientutil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/dendrite/clientapi/ratelimit" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/transactions" @@ -65,6 +66,7 @@ func Setup( prometheus.MustRegister(amtRegUsers, sendEventDuration) rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) + rateLimitsFailedLogin := ratelimit.NewRtFailedLogin(&cfg.RtFailedLogin) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg) unstableFeatures := map[string]bool{ @@ -570,7 +572,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/password", - httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeConditionalAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -594,7 +596,7 @@ func Setup( if r := rateLimits.Limit(req, nil); r != nil { return *r } - return Login(req, userAPI, cfg) + return Login(req, userAPI, cfg, rateLimitsFailedLogin) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 12984c39a..a687211b6 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -101,7 +101,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a } // If the user has never been in the room then stop at this point. // We won't tell the user about a room they have never joined. - if !membershipRes.HasBeenInRoom { + if !membershipRes.HasBeenInRoom && membershipRes.Membership != gomatrixserverlib.Invite { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), @@ -241,7 +241,7 @@ func OnIncomingStateTypeRequest( } // If the user has never been in the room then stop at this point. // We won't tell the user about a room they have never joined. - if !membershipRes.HasBeenInRoom || membershipRes.Membership == gomatrixserverlib.Ban { + if !membershipRes.HasBeenInRoom && membershipRes.Membership != gomatrixserverlib.Invite || membershipRes.Membership == gomatrixserverlib.Ban { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), diff --git a/clientapi/threepid/threepid.go b/clientapi/threepid/threepid.go index 1e64e3034..a6a469670 100644 --- a/clientapi/threepid/threepid.go +++ b/clientapi/threepid/threepid.go @@ -103,11 +103,8 @@ func CreateSession( func CheckAssociation( ctx context.Context, creds Credentials, cfg *config.ClientAPI, ) (bool, string, string, error) { - if err := isTrusted(creds.IDServer, cfg); err != nil { - return false, "", "", err - } - requestURL := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret) + requestURL := fmt.Sprintf("%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", cfg.ThreePidDelegate, creds.SID, creds.Secret) req, err := http.NewRequest(http.MethodGet, requestURL, nil) if err != nil { return false, "", "", err diff --git a/cmd/dendrite-monolith-server/Dockerfile.dev b/cmd/dendrite-monolith-server/Dockerfile.dev new file mode 100644 index 000000000..7fbf6c667 --- /dev/null +++ b/cmd/dendrite-monolith-server/Dockerfile.dev @@ -0,0 +1,8 @@ +FROM alpine:latest + +COPY dendrite-monolith-server /usr/bin/ + +VOLUME /etc/dendrite +WORKDIR /etc/dendrite + +ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] diff --git a/cmd/dendrite-monolith-server/build_dev.sh b/cmd/dendrite-monolith-server/build_dev.sh new file mode 100755 index 000000000..5d121890a --- /dev/null +++ b/cmd/dendrite-monolith-server/build_dev.sh @@ -0,0 +1,12 @@ +set -xe +if [ -z "$(git status --porcelain)" ]; then + CGO_ENABLED=0 go build . + TAG=$(git rev-parse --short HEAD) + docker build -f Dockerfile.dev -t gcr.io/globekeeper-development/dendrite-monolith:$TAG -t gcr.io/globekeeper-development/dendrite-monolith -t gcr.io/globekeeper-production/dendrite-monolith:$TAG . + docker push gcr.io/globekeeper-development/dendrite-monolith:$TAG + docker push gcr.io/globekeeper-production/dendrite-monolith:$TAG + docker push gcr.io/globekeeper-development/dendrite-monolith +else + echo "Please commit changes" + exit 0 +fi \ No newline at end of file diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 845b9e465..a864b1185 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -16,6 +16,7 @@ package main import ( "flag" + "log" "os" "github.com/matrix-org/dendrite/appservice" @@ -47,6 +48,16 @@ var ( func main() { cfg := setup.ParseFlags(true) httpAddr := config.HTTPAddress("http://" + *httpBindAddr) + for _, logging := range cfg.Logging { + if logging.Type == "std" { + level, err := logrus.ParseLevel(logging.Level) + if err != nil { + log.Fatal(err) + } + logrus.SetLevel(level) + logrus.SetFormatter(&logrus.JSONFormatter{}) + } + } httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr) httpAPIAddr := httpAddr options := []basepkg.BaseDendriteOptions{} diff --git a/cmd/dendritejs-pinecone/jsServer.go b/cmd/dendritejs-pinecone/jsServer.go index 4298c2ae9..a2fc39d42 100644 --- a/cmd/dendritejs-pinecone/jsServer.go +++ b/cmd/dendritejs-pinecone/jsServer.go @@ -34,13 +34,16 @@ type JSServer struct { // OnRequestFromJS is the function that JS will invoke when there is a new request. // The JS function signature is: -// function(reqString: string): Promise<{result: string, error: string}> +// +// function(reqString: string): Promise<{result: string, error: string}> +// // Usage is like: -// const res = await global._go_js_server.fetch(reqString); -// if (res.error) { -// // handle error: this is a 'network' error, not a non-2xx error. -// } -// const rawHttpResponse = res.result; +// +// const res = await global._go_js_server.fetch(reqString); +// if (res.error) { +// // handle error: this is a 'network' error, not a non-2xx error. +// } +// const rawHttpResponse = res.result; func (h *JSServer) OnRequestFromJS(this js.Value, args []js.Value) interface{} { // we HAVE to spawn a new goroutine and return immediately or else Go will deadlock // if this request blocks at all e.g for /sync calls diff --git a/docs/installation/10_optimisation.md b/docs/installation/10_optimisation.md new file mode 100644 index 000000000..c19b7a75e --- /dev/null +++ b/docs/installation/10_optimisation.md @@ -0,0 +1,71 @@ +--- +title: Optimise your installation +parent: Installation +has_toc: true +nav_order: 10 +permalink: /installation/start/optimisation +--- + +# Optimise your installation + +Now that you have Dendrite running, the following tweaks will improve the reliability +and performance of your installation. + +## File descriptor limit + +Most platforms have a limit on how many file descriptors a single process can open. All +connections made by Dendrite consume file descriptors — this includes database connections +and network requests to remote homeservers. When participating in large federated rooms +where Dendrite must talk to many remote servers, it is often very easy to exhaust default +limits which are quite low. + +We currently recommend setting the file descriptor limit to 65535 to avoid such +issues. Dendrite will log immediately after startup if the file descriptor limit is too low: + +``` +level=warning msg="IMPORTANT: Process file descriptor limit is currently 1024, it is recommended to raise the limit for Dendrite to at least 65535 to avoid issues" +``` + +UNIX systems have two limits: a hard limit and a soft limit. You can view the soft limit +by running `ulimit -Sn` and the hard limit with `ulimit -Hn`: + +```bash +$ ulimit -Hn +1048576 + +$ ulimit -Sn +1024 +``` + +Increase the soft limit before starting Dendrite: + +```bash +ulimit -Sn 65535 +``` + +The log line at startup should no longer appear if the limit is sufficient. + +If you are running under a systemd service, you can instead add `LimitNOFILE=65535` option +to the `[Service]` section of your service unit file. + +## DNS caching + +Dendrite has a built-in DNS cache which significantly reduces the load that Dendrite will +place on your DNS resolver. This may also speed up outbound federation. + +Consider enabling the DNS cache by modifying the `global` section of your configuration file: + +```yaml + dns_cache: + enabled: true + cache_size: 4096 + cache_lifetime: 600s +``` + +## Time synchronisation + +Matrix relies heavily on TLS which requires the system time to be correct. If the clock +drifts then you may find that federation no works reliably (or at all) and clients may +struggle to connect to your Dendrite server. + +Ensure that the time is synchronised on your system by enabling NTP sync. diff --git a/go.mod b/go.mod index 8bf8f454d..cfee84d47 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,14 @@ require ( github.com/Arceliar/ironwood v0.0.0-20220306165321-319147a02d98 github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/MFAshby/stdemuxerhook v1.0.0 github.com/Masterminds/semver/v3 v3.1.1 github.com/codeclysm/extract v2.2.0+incompatible github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d github.com/docker/docker v20.10.16+incompatible github.com/docker/go-connections v0.4.0 github.com/getsentry/sentry-go v0.13.0 + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v4 v4.4.1 github.com/gologme/log v1.3.0 github.com/google/go-cmp v0.5.8 github.com/google/uuid v1.3.0 @@ -24,6 +25,7 @@ require ( github.com/matrix-org/gomatrixserverlib v0.0.0-20220830164018-c71e518537a2 github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 + github.com/matryer/is v1.4.0 github.com/mattn/go-sqlite3 v1.14.13 github.com/nats-io/nats-server/v2 v2.8.5-0.20220811224153-d8d25d9b0b1c github.com/nats-io/nats.go v1.16.1-0.20220810192301-fb5ca2cbc995 @@ -66,7 +68,6 @@ require ( github.com/frankban/quicktest v1.14.3 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect - github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/h2non/filetype v1.1.3 // indirect diff --git a/go.sum b/go.sum index 8b8baabcf..959caaf26 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,6 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20O github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM= github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= -github.com/MFAshby/stdemuxerhook v1.0.0 h1:1XFGzakrsHMv76AeanPDL26NOgwjPl/OUxbGhJthwMc= -github.com/MFAshby/stdemuxerhook v1.0.0/go.mod h1:nLMI9FUf9Hz98n+yAXsTMUR4RZQy28uCTLG1Fzvj/uY= github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Microsoft/go-winio v0.5.1 h1:aPJp2QD7OOrhO5tQXqQoGSJc+DjDtWTGLOmNyAm6FgY= @@ -184,6 +182,8 @@ github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/E github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v4 v4.4.1 h1:pC5DB52sCeK48Wlb9oPcdhnjkz1TKt1D/P7WKJ0kUcQ= +github.com/golang-jwt/jwt/v4 v4.4.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -350,6 +350,8 @@ github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4Mq github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= +github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= +github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.13/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index e0436c60a..f63662ba4 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -84,6 +84,57 @@ func MakeAuthAPI( return MakeExternalAPI(metricsName, h) } +// MakeConditionalAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. +// It passes nil device if header is not provided. +func MakeConditionalAuthAPI( + metricsName string, userAPI userapi.QueryAcccessTokenAPI, + f func(*http.Request, *userapi.Device) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + var ( + jsonRes util.JSONResponse + dev *userapi.Device + ) + if _, err := auth.ExtractAccessToken(req); err != nil { + dev = nil + } else { + logger := util.GetLogger(req.Context()) + var err *util.JSONResponse + dev, err = auth.VerifyUserFromRequest(req, userAPI) + if err != nil { + logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code) + return *err + } + // add the user ID to the logger + logger = logger.WithField("user_id", dev.UserID) + req = req.WithContext(util.ContextWithLogger(req.Context(), logger)) + } + // add the user to Sentry, if enabled + hub := sentry.GetHubFromContext(req.Context()) + if hub != nil { + hub.Scope().SetTag("user_id", dev.UserID) + hub.Scope().SetTag("device_id", dev.ID) + } + defer func() { + if r := recover(); r != nil { + if hub != nil { + hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path)) + } + // re-panic to return the 500 + panic(r) + } + }() + jsonRes = f(req, dev) + // do not log 4xx as errors as they are client fails, not server fails + if hub != nil && jsonRes.Code >= 500 { + hub.Scope().SetExtra("response", jsonRes) + hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code)) + } + return jsonRes + } + return MakeExternalAPI(metricsName, h) +} + // MakeAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be // completed by a user that is a server administrator. func MakeAdminAPI( diff --git a/internal/log_unix.go b/internal/log_unix.go index 75332af73..5e8dcaad6 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -18,10 +18,8 @@ package internal import ( - "io" "log/syslog" - "github.com/MFAshby/stdemuxerhook" "github.com/matrix-org/dendrite/setup/config" "github.com/sirupsen/logrus" lSyslog "github.com/sirupsen/logrus/hooks/syslog" @@ -31,7 +29,6 @@ import ( // If something fails here it means that the logging was improperly configured, // so we just exit with the error func SetupHookLogging(hooks []config.LogrusHook, componentName string) { - stdLogAdded := false for _, hook := range hooks { // Check we received a proper logging level level, err := logrus.ParseLevel(hook.Level) @@ -39,12 +36,6 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) { logrus.Fatalf("Unrecognised logging level %s: %q", hook.Level, err) } - // Perform a first filter on the logs according to the lowest level of all - // (Eg: If we have hook for info and above, prevent logrus from processing debug logs) - if logrus.GetLevel() < level { - logrus.SetLevel(level) - } - switch hook.Type { case "file": checkFileHookParams(hook.Params) @@ -53,17 +44,10 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) { checkSyslogHookParams(hook.Params) setupSyslogHook(hook, level, componentName) case "std": - setupStdLogHook(level) - stdLogAdded = true default: logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type) } } - if !stdLogAdded { - setupStdLogHook(logrus.InfoLevel) - } - // Hooks are now configured for stdout/err, so throw away the default logger output - logrus.SetOutput(io.Discard) } func checkSyslogHookParams(params map[string]interface{}) { @@ -87,10 +71,6 @@ func checkSyslogHookParams(params map[string]interface{}) { } -func setupStdLogHook(level logrus.Level) { - logrus.AddHook(&logLevelHook{level, stdemuxerhook.New(logrus.StandardLogger())}) -} - func setupSyslogHook(hook config.LogrusHook, level logrus.Level, componentName string) { syslogHook, err := lSyslog.NewSyslogHook(hook.Params["protocol"].(string), hook.Params["address"].(string), syslog.LOG_INFO, componentName) if err == nil { diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 483e78c3f..03145ffd0 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -136,7 +136,7 @@ func (r *Inviter) PerformInvite( var isAlreadyJoined bool if info != nil { - _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) + _, _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) if err != nil { return nil, fmt.Errorf("r.DB.GetMembership: %w", err) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 6dce2bc3e..2503793d2 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -32,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/version" ) @@ -180,11 +181,16 @@ func (r *Queryer) QueryMembershipForUser( } response.RoomExists = true - membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) + membershipEventNID, membershipState, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) if err != nil { return err } + if membershipState == tables.MembershipStateInvite { + response.Membership = gomatrixserverlib.Invite + response.IsInRoom = true + } + response.IsRoomForgotten = isRoomforgotten if membershipEventNID == 0 { @@ -294,7 +300,7 @@ func (r *Queryer) QueryMembershipsForRoom( return nil } - membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) + membershipEventNID, _, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) if err != nil { return err } @@ -907,7 +913,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query } // At this point we're happy that we are in the room, so now let's // see if the target user is in the room. - _, isIn, _, err = r.DB.GetMembership(ctx, targetRoomInfo.RoomNID, req.UserID) + _, _, isIn, _, err = r.DB.GetMembership(ctx, targetRoomInfo.RoomNID, req.UserID) if err != nil { continue } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 43e8da7bb..144b8dd6f 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -127,7 +127,7 @@ type Database interface { // in this room, along a boolean set to true if the user is still in this room, // false if not. // Returns an error if there was a problem talking to the database. - GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) + GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, membershipNID tables.MembershipState, stillInRoom, isRoomForgotten bool, err error) // Lookup the membership event numeric IDs for all user that are or have // been members of a given room. Only lookup events of "join" membership if // joinOnly is set to true. diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index f35592a76..e602aa3b1 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -399,14 +399,14 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { }) } -func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { +func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, membershipState tables.MembershipState, stillInRoom, isRoomforgotten bool, err error) { var requestSenderUserNID types.EventStateKeyNID err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { requestSenderUserNID, err = d.assignStateKeyNID(ctx, requestSenderUserID) return err }) if err != nil { - return 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err) + return 0, 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err) } senderMembershipEventNID, senderMembership, isRoomforgotten, err := @@ -415,12 +415,12 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req ) if err == sql.ErrNoRows { // The user has never been a member of that room - return 0, false, false, nil + return 0, 0, false, false, nil } else if err != nil { return } - return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil + return senderMembershipEventNID, senderMembership, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil } func (d *Database) GetMembershipEventNIDsForRoom( diff --git a/setup/base/base.go b/setup/base/base.go index 87f415764..2ccb2ba7f 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -135,7 +135,6 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base logrus.Fatalf("Failed to start due to configuration errors") } - internal.SetupStdLogging() internal.SetupHookLogging(cfg.Logging, componentName) internal.SetupPprof() diff --git a/setup/config/config.go b/setup/config/config.go index 5a618d671..f6a82681a 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -16,6 +16,7 @@ package config import ( "bytes" + "crypto/x509" "encoding/pem" "fmt" "io" @@ -250,6 +251,15 @@ func loadConfig( c.Global.OldVerifyKeys[i].KeyID, c.Global.OldVerifyKeys[i].PrivateKey = keyID, privateKey } + if c.ClientAPI.JwtConfig.Enabled { + pubPki, _ := pem.Decode([]byte(c.ClientAPI.JwtConfig.Secret)) + var pub interface{} + pub, err = x509.ParsePKIXPublicKey(pubPki.Bytes) + if err != nil { + return nil, err + } + c.ClientAPI.JwtConfig.SecretKey = pub.(ed25519.PublicKey) + } c.MediaAPI.AbsBasePath = Path(absPath(basePath, c.MediaAPI.BasePath)) @@ -289,7 +299,10 @@ func (config *Dendrite) Derive() error { config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) } - + if config.ClientAPI.ThreePidDelegate != "" { + config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, + authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeEmail}}) + } // Load application service configuration files if err := loadAppServices(&config.AppServiceAPI, &config.Derived); err != nil { return err diff --git a/setup/config/config_appservice.go b/setup/config/config_appservice.go index bd21826fe..706d2dfd2 100644 --- a/setup/config/config_appservice.go +++ b/setup/config/config_appservice.go @@ -21,7 +21,6 @@ import ( "regexp" "strings" - log "github.com/sirupsen/logrus" yaml "gopkg.in/yaml.v2" ) @@ -346,11 +345,11 @@ func checkErrors(config *AppServiceAPI, derived *Derived) (err error) { // TODO: Remove once rate_limited is implemented if appservice.RateLimited { - log.Warn("WARNING: Application service option rate_limited is currently unimplemented") + // log.Warn("WARNING: Application service option rate_limited is currently unimplemented") } // TODO: Remove once protocols is implemented if len(appservice.Protocols) > 0 { - log.Warn("WARNING: Application service option protocols is currently unimplemented") + // log.Warn("WARNING: Application service option protocols is currently unimplemented") } } @@ -376,7 +375,7 @@ func validateNamespace( // Check if GroupID for the users namespace is in the correct format if key == "users" && namespace.GroupID != "" { // TODO: Remove once group_id is implemented - log.Warn("WARNING: Application service option group_id is currently unimplemented") + // log.Warn("WARNING: Application service option group_id is currently unimplemented") correctFormat := groupIDRegexp.MatchString(namespace.GroupID) if !correctFormat { diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 56f4b3f92..9fa6e85c2 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -3,6 +3,9 @@ package config import ( "fmt" "time" + + "github.com/matrix-org/dendrite/clientapi/ratelimit" + "golang.org/x/crypto/ed25519" ) type ClientAPI struct { @@ -46,9 +49,23 @@ type ClientAPI struct { TURN TURN `yaml:"turn"` // Rate-limiting options - RateLimiting RateLimiting `yaml:"rate_limiting"` + RateLimiting RateLimiting `yaml:"rate_limiting"` + RtFailedLogin ratelimit.RtFailedLoginConfig `yaml:"rate_limiting_failed_login"` MSCs *MSCs `yaml:"-"` + + ThreePidDelegate string `yaml:"three_pid_delegate"` + + JwtConfig JwtConfig `yaml:"jwt_config"` +} + +type JwtConfig struct { + Enabled bool `yaml:"enabled"` + Algorithm string `yaml:"algorithm"` + Issuer string `yaml:"issuer"` + Secret string `yaml:"secret"` + SecretKey ed25519.PublicKey + Audiences []string `yaml:"audiences"` } func (c *ClientAPI) Defaults(opts DefaultOpts) { diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 0c8ba4e3d..028f123a9 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -173,6 +173,8 @@ type Presence interface { GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) + ExpirePresence(ctx context.Context) ([]types.PresenceNotify, error) + UpdateLastActive(ctx context.Context, userId string, lastActiveTs uint64) error } type SharedUsers interface { diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index 7194afea6..6f0aa8991 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -62,6 +62,10 @@ const upsertPresenceFromSyncSQL = "" + " presence = $2, last_active_ts = $3" + " RETURNING id" +const updateLastActiveSQL = `UPDATE syncapi_presence +SET last_active_ts = $1 +WHERE user_id = $2` + const selectPresenceForUserSQL = "" + "SELECT presence, status_msg, last_active_ts" + " FROM syncapi_presence" + @@ -76,12 +80,24 @@ const selectPresenceAfter = "" + " WHERE id > $1 AND last_active_ts >= $2" + " ORDER BY id ASC LIMIT $3" +const expirePresenceSQL = `UPDATE syncapi_presence SET + id = nextval('syncapi_presence_id'), + presence = 3 +WHERE + to_timestamp(last_active_ts / 1000) < NOW() - INTERVAL` + types.PresenceExpire + ` +AND + presence != 3 +RETURNING id, user_id +` + type presenceStatements struct { upsertPresenceStmt *sql.Stmt upsertPresenceFromSyncStmt *sql.Stmt selectPresenceForUsersStmt *sql.Stmt selectMaxPresenceStmt *sql.Stmt selectPresenceAfterStmt *sql.Stmt + expirePresenceStmt *sql.Stmt + updateLastActiveStmt *sql.Stmt } func NewPostgresPresenceTable(db *sql.DB) (*presenceStatements, error) { @@ -96,6 +112,8 @@ func NewPostgresPresenceTable(db *sql.DB) (*presenceStatements, error) { {&s.selectPresenceForUsersStmt, selectPresenceForUserSQL}, {&s.selectMaxPresenceStmt, selectMaxPresenceSQL}, {&s.selectPresenceAfterStmt, selectPresenceAfter}, + {&s.expirePresenceStmt, expirePresenceSQL}, + {&s.updateLastActiveStmt, updateLastActiveSQL}, }.Prepare(db) } @@ -166,3 +184,28 @@ func (p *presenceStatements) GetPresenceAfter( } return presences, rows.Err() } + +func (p *presenceStatements) ExpirePresence( + ctx context.Context, +) ([]types.PresenceNotify, error) { + rows, err := p.expirePresenceStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + presences := make([]types.PresenceNotify, 0) + i := 0 + for rows.Next() { + presences = append(presences, types.PresenceNotify{}) + err = rows.Scan(&presences[i].StreamPos, &presences[i].UserID) + if err != nil { + return nil, err + } + i++ + } + return presences, err +} + +func (p *presenceStatements) UpdateLastActive(ctx context.Context, userId string, lastActiveTs uint64) error { + _, err := p.updateLastActiveStmt.Exec(&lastActiveTs, &userId) + return err +} diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index b06d2c6a9..80d9b7391 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1078,3 +1078,11 @@ func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.Stre func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) } + +func (s *Database) ExpirePresence(ctx context.Context) ([]types.PresenceNotify, error) { + return s.Presence.ExpirePresence(ctx) +} + +func (s *Database) UpdateLastActive(ctx context.Context, userId string, lastActiveTs uint64) error { + return s.Presence.UpdateLastActive(ctx, userId, lastActiveTs) +} diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index b61a825df..fe6b3ce84 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -180,3 +180,15 @@ func (p *presenceStatements) GetPresenceAfter( } return presences, rows.Err() } + +func (p *presenceStatements) ExpirePresence( + ctx context.Context, +) ([]types.PresenceNotify, error) { + // TODO implement + return nil, nil +} + +func (p *presenceStatements) UpdateLastActive(ctx context.Context, userId string, lastActiveTs uint64) error { + // TODO implement + return nil +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 468d26aca..3ee4b61dc 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -204,4 +204,6 @@ type Presence interface { GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) + ExpirePresence(ctx context.Context) ([]types.PresenceNotify, error) + UpdateLastActive(ctx context.Context, userId string, lastActiveTs uint64) error } diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 15db4d30e..637a65042 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -17,7 +17,6 @@ package streams import ( "context" "encoding/json" - "sync" "github.com/matrix-org/gomatrixserverlib" "github.com/tidwall/gjson" @@ -28,8 +27,6 @@ import ( type PresenceStreamProvider struct { StreamProvider - // cache contains previously sent presence updates to avoid unneeded updates - cache sync.Map notifier *notifier.Notifier } @@ -105,18 +102,6 @@ func (p *PresenceStreamProvider) IncrementalSync( if req.Device.UserID != presence.UserID && !p.notifier.IsSharedUser(req.Device.UserID, presence.UserID) { continue } - cacheKey := req.Device.UserID + req.Device.ID + presence.UserID - pres, ok := p.cache.Load(cacheKey) - if ok { - // skip already sent presence - prevPresence := pres.(*types.PresenceInternal) - currentlyActive := prevPresence.CurrentlyActive() - skip := prevPresence.Equals(presence) && currentlyActive && req.Device.UserID != presence.UserID - if skip { - req.Log.Tracef("Skipping presence, no change (%s)", presence.UserID) - continue - } - } if _, known := types.PresenceFromString(presence.ClientFields.Presence); known { presence.ClientFields.LastActiveAgo = presence.LastActiveAgo() @@ -144,7 +129,6 @@ func (p *PresenceStreamProvider) IncrementalSync( if len(req.Response.Presence.Events) == req.Filter.Presence.Limit { break } - p.cache.Store(cacheKey, presence) } if len(req.Response.Presence.Events) == 0 { diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index c2c9616e8..d18060a37 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -50,7 +50,7 @@ type RequestPool struct { keyAPI keyapi.SyncKeyAPI rsAPI roomserverAPI.SyncRoomserverAPI lastseen *sync.Map - presence *sync.Map + Presence *sync.Map streams *streams.Streams Notifier *notifier.Notifier producer PresencePublisher @@ -85,14 +85,14 @@ func NewRequestPool( keyAPI: keyAPI, rsAPI: rsAPI, lastseen: &sync.Map{}, - presence: &sync.Map{}, + Presence: &sync.Map{}, streams: streams, Notifier: notifier, producer: producer, consumer: consumer, } go rp.cleanLastSeen() - go rp.cleanPresence(db, time.Minute*5) + // go rp.cleanPresence(db, time.Minute*5) return rp } @@ -111,11 +111,11 @@ func (rp *RequestPool) cleanPresence(db storage.Presence, cleanupTime time.Durat return } for { - rp.presence.Range(func(key interface{}, v interface{}) bool { + rp.Presence.Range(func(key interface{}, v interface{}) bool { p := v.(types.PresenceInternal) if time.Since(p.LastActiveTS.Time()) > cleanupTime { rp.updatePresence(db, types.PresenceUnavailable.String(), p.UserID) - rp.presence.Delete(key) + rp.Presence.Delete(key) } return true }) @@ -153,13 +153,22 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user } newPresence.ClientFields.Presence = presenceID.String() - defer rp.presence.Store(userID, newPresence) + defer rp.Presence.Store(userID, newPresence) // avoid spamming presence updates when syncing - existingPresence, ok := rp.presence.LoadOrStore(userID, newPresence) + existingPresence, ok := rp.Presence.LoadOrStore(userID, newPresence) if ok { p := existingPresence.(types.PresenceInternal) - if p.ClientFields.Presence == newPresence.ClientFields.Presence { - return + if dbPresence != nil { + if p.Presence == newPresence.Presence && newPresence.LastActiveTS-dbPresence.LastActiveTS < types.PresenceNoOpMs { + return + } + if dbPresence.Presence == types.PresenceOnline && presenceID == types.PresenceOnline && newPresence.LastActiveTS-dbPresence.LastActiveTS >= types.PresenceNoOpMs { + err := db.UpdateLastActive(context.Background(), userID, uint64(newPresence.LastActiveTS)) + if err != nil { + logrus.WithError(err).Error("failed to update last active") + } + return + } } } @@ -247,7 +256,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. defer activeSyncRequests.Dec() rp.updateLastSeen(req, device) - rp.updatePresence(rp.db, req.FormValue("set_presence"), device.UserID) + rp.updatePresence(rp.db, "", device.UserID) waitingSyncRequests.Inc() defer waitingSyncRequests.Dec() diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index 3e5769d8c..cdc658331 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -23,7 +24,9 @@ func (d *dummyPublisher) SendPresence(userID string, presence types.Presence, st return nil } -type dummyDB struct{} +type dummyDB struct { + storage.Database +} func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { return 0, nil @@ -109,7 +112,7 @@ func TestRequestPool_updatePresence(t *testing.T) { }, } rp := &RequestPool{ - presence: &syncMap, + Presence: &syncMap, producer: publisher, consumer: consumer, cfg: &config.SyncAPI{ diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 92db18d56..2fff62d0a 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -16,6 +16,7 @@ package syncapi import ( "context" + "time" "github.com/matrix-org/dendrite/internal/caching" "github.com/sirupsen/logrus" @@ -33,6 +34,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/types" ) // AddPublicRoutes sets up and registers HTTP handlers for the SyncAPI @@ -144,4 +146,24 @@ func AddPublicRoutes( base.PublicClientAPIMux, requestPool, syncDB, userAPI, rsAPI, cfg, base.Caches, ) + + go func() { + ctx := context.Background() + for { + notify, err := syncDB.ExpirePresence(ctx) + if err != nil { + logrus.WithError(err).Error("failed to expire presence") + } + for i := range notify { + requestPool.Presence.Store(notify[i].UserID, types.PresenceInternal{ + Presence: types.PresenceOffline, + }) + notifier.OnNewPresence(types.StreamingToken{ + PresencePosition: notify[i].StreamPos, + }, notify[i].UserID) + + } + time.Sleep(types.PresenceExpireInterval) + } + }() } diff --git a/syncapi/types/presence.go b/syncapi/types/presence.go index 30e025b9f..760225de8 100644 --- a/syncapi/types/presence.go +++ b/syncapi/types/presence.go @@ -21,6 +21,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +const ( + PresenceNoOpMs = 60_000 + PresenceExpire = "'4 minutes'" + PresenceExpireInterval = time.Second * 30 +) + type Presence uint8 const ( @@ -66,6 +72,11 @@ type PresenceInternal struct { Presence Presence `json:"-"` } +type PresenceNotify struct { + StreamPos StreamPosition + UserID string +} + // Equals compares p1 with p2. func (p1 *PresenceInternal) Equals(p2 *PresenceInternal) bool { return p1.ClientFields.Presence == p2.ClientFields.Presence && diff --git a/syncapi/types/types.go b/syncapi/types/types.go index d75d53ca9..6ccbc8172 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -22,6 +22,7 @@ import ( "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/roomserver/api" @@ -283,9 +284,11 @@ func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) { func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { if len(tok) < 1 { err = ErrMalformedSyncToken + logrus.WithField("token", tok).Info("invalid stream token: bad length") return } if tok[0] != SyncTokenTypeStream[0] { + logrus.WithField("token", tok).Info("invalid stream token: not starting from s") err = ErrMalformedSyncToken return } @@ -301,6 +304,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { var pos int pos, err = strconv.Atoi(p) if err != nil { + logrus.WithField("token", tok).Info("invalid stream token: strconv") err = ErrMalformedSyncToken return } diff --git a/sytest-blacklist b/sytest-blacklist index bcc345f6e..4b76654b2 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -49,3 +49,18 @@ Notifications can be viewed with GET /notifications If remote user leaves room we no longer receive device updates Guest users can join guest_access rooms + +# You'll be shocked to discover this is flakey too + +Inbound /v1/send_join rejects joins from other servers + +# For notifications extension on iOS + +/event/ does not allow access to events before the user joined + +# Failing after recent updates with presence + +Newly joined room includes presence in incremental sync +User sees their own presence in a sync +User is offline if they set_presence=offline in their sync +User sees updates to presence from other users in the incremental sync. \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 5c8896b99..194f33799 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -204,7 +204,6 @@ Deleted tags appear in an incremental v2 /sync /event/ on non world readable room does not work Outbound federation can query profile data /event/ on joined room works -/event/ does not allow access to events before the user joined Federation key API allows unsigned requests for keys GET /publicRooms lists rooms GET /publicRooms includes avatar URLs @@ -741,4 +740,5 @@ Newly joined room includes presence in incremental sync User in private room doesn't appear in user directory User joining then leaving public room appears and dissappears from directory User in remote room doesn't appear in user directory after server left room -User in shared private room does appear in user directory until leave \ No newline at end of file +User in shared private room does appear in user directory until leave +Existing members see new member's presence \ No newline at end of file diff --git a/userapi/api/api.go b/userapi/api/api.go index 66ee9c7c8..f1e30aeda 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -518,7 +518,7 @@ type PerformPusherSetRequest struct { type PerformPusherDeletionRequest struct { Localpart string - SessionID int64 + SessionID int64 // Pusher corresponding to this SessionID will not be deleted } // Pusher represents a push notification subscriber diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 33fb6dd42..afd1ad410 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -98,6 +98,11 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam Up: deltas.UpAddAccountType, Down: deltas.DownAddAccountType, }, + { + Version: "userapi: no guests", + Up: deltas.UpNoGuests, + Down: deltas.DownNoGuests, + }, }...) err = m.Up(context.Background()) if err != nil { diff --git a/userapi/storage/postgres/deltas/2022080800000000_no_guests.go b/userapi/storage/postgres/deltas/2022080800000000_no_guests.go new file mode 100644 index 000000000..cc6126aad --- /dev/null +++ b/userapi/storage/postgres/deltas/2022080800000000_no_guests.go @@ -0,0 +1,20 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpNoGuests(ctx context.Context, tx *sql.Tx) error { + // AddAccountType introduced a bug where each user that had was registered as a regular user, but without user_id, became a guest. + _, err := tx.ExecContext(ctx, "UPDATE account_accounts SET account_type = 1 WHERE account_type = 2;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownNoGuests(ctx context.Context, tx *sql.Tx) error { + return nil +} diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index a27c1125e..d59297c1d 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -71,7 +71,7 @@ const selectNotificationSQL = "" + ") AND NOT read ORDER BY localpart, id LIMIT $4" const selectNotificationCountSQL = "" + - "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "SELECT COUNT(DISTINCT(room_id)) FROM userapi_notifications WHERE localpart = $1 AND (" + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + ") AND NOT read" diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index df8260251..52c9fb042 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -71,7 +71,7 @@ const selectNotificationSQL = "" + ") AND NOT read ORDER BY localpart, id LIMIT $4" const selectNotificationCountSQL = "" + - "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "SELECT COUNT(DISTINCT(room_id)) FROM userapi_notifications WHERE localpart = $1 AND (" + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + ") AND NOT read" diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index a26097338..930392428 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -520,7 +520,7 @@ func Test_Notification(t *testing.T) { // get notifications count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications) assert.NoError(t, err, "unable to get notification count") - assert.Equal(t, int64(10), count) + assert.Equal(t, int64(2), count) notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications) assert.NoError(t, err, "unable to get notifications") assert.Equal(t, int64(10), count)