diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 033c5864b..17f02b09e 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -6,10 +6,12 @@ on: - main paths: - '**.go' # only execute on changes to go files + - 'go.sum' # or dependency updates - '.github/workflows/**' # or workflow changes pull_request: paths: - '**.go' + - 'go.sum' # or dependency updates - '.github/workflows/**' release: types: [published] @@ -391,7 +393,7 @@ jobs: # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64 run: | sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev - go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest + go install github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest - name: Run actions/checkout@v3 for dendrite uses: actions/checkout@v3 with: diff --git a/.github/workflows/helm.yml b/.github/workflows/helm.yml index 7cdc369ba..a9c1718a0 100644 --- a/.github/workflows/helm.yml +++ b/.github/workflows/helm.yml @@ -6,6 +6,7 @@ on: - main paths: - 'helm/**' # only execute if we have helm chart changes + workflow_dispatch: jobs: release: diff --git a/.github/workflows/k8s.yml b/.github/workflows/k8s.yml index fc5e8c906..af2750356 100644 --- a/.github/workflows/k8s.yml +++ b/.github/workflows/k8s.yml @@ -84,6 +84,7 @@ jobs: kubectl get pods -A kubectl get services kubectl get ingress + kubectl logs -l app.kubernetes.io/name=dendrite - name: Run create account run: | podName=$(kubectl get pods -l app.kubernetes.io/name=dendrite -o name) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index dff9b34c9..e76cc82f3 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -65,10 +65,11 @@ jobs: uses: actions/upload-artifact@v2 if: ${{ always() }} with: - name: Sytest Logs - ${{ job.status }} - (Dendrite, ${{ join(matrix.*, ', ') }}) + name: Sytest Logs - ${{ job.status }} - (Dendrite ${{ join(matrix.*, ' ') }}) path: | /logs/results.tap /logs/**/*.log* + /logs/**/covdatafiles/** sytest-coverage: timeout-minutes: 5 @@ -85,16 +86,15 @@ jobs: cache: true - name: Download all artifacts uses: actions/download-artifact@v3 - - name: Install gocovmerge - run: go install github.com/wadey/gocovmerge@latest - - name: Run gocovmerge + - name: Collect coverage run: | - find -name 'integrationcover.log' -printf '"%p"\n' | xargs gocovmerge | grep -Ev 'relayapi|setup/mscs|api_trace' > sytest.cov - go tool cover -func=sytest.cov + go tool covdata textfmt -i="$(find Sytest* -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" -o sytest.cov + grep -Ev 'relayapi|setup/mscs|api_trace' sytest.cov > final.cov + go tool covdata func -i="$(find Sytest* -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: - files: ./sytest.cov + files: ./final.cov flags: sytest fail_ci_if_error: true @@ -167,7 +167,7 @@ jobs: cat < /tmp/posttest.sh #!/bin/bash mkdir -p /tmp/Complement/logs/\$2/\$1/ - docker cp \$1:/dendrite/complementcover.log /tmp/Complement/logs/\$2/\$1/ + docker cp \$1:/tmp/covdatafiles/. /tmp/Complement/logs/\$2/\$1/ EOF chmod +x /tmp/posttest.sh @@ -188,9 +188,9 @@ jobs: uses: actions/upload-artifact@v2 if: ${{ always() }} with: - name: Complement Logs - (Dendrite, ${{ join(matrix.*, ', ') }}) + name: Complement Logs - (Dendrite ${{ join(matrix.*, ' ') }}) path: | - /tmp/Complement/**/complementcover.log + /tmp/Complement/logs/** complement-coverage: timeout-minutes: 5 @@ -207,20 +207,19 @@ jobs: cache: true - name: Download all artifacts uses: actions/download-artifact@v3 - - name: Install gocovmerge - run: go install github.com/wadey/gocovmerge@latest - - name: Run gocovmerge + - name: Collect coverage run: | - find -name 'complementcover.log' -printf '"%p"\n' | xargs gocovmerge | grep -Ev 'relayapi|setup/mscs|api_trace' > complement.cov - go tool cover -func=complement.cov + go tool covdata textfmt -i="$(find Complement* -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" -o complement.cov + grep -Ev 'relayapi|setup/mscs|api_trace' complement.cov > final.cov + go tool covdata func -i="$(find Complement* -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: - files: ./complement.cov + files: ./final.cov flags: complement fail_ci_if_error: true - element_web: + element-web: timeout-minutes: 120 runs-on: ubuntu-latest steps: @@ -258,3 +257,42 @@ jobs: env: PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true TMPDIR: ${{ runner.temp }} + + element-web-pinecone: + timeout-minutes: 120 + runs-on: ubuntu-latest + steps: + - uses: tecolicom/actions-use-apt-tools@v1 + with: + # Our test suite includes some screenshot tests with unusual diacritics, which are + # supposed to be covered by STIXGeneral. + tools: fonts-stix + - uses: actions/checkout@v2 + with: + repository: matrix-org/matrix-react-sdk + - uses: actions/setup-node@v3 + with: + cache: 'yarn' + - name: Fetch layered build + run: scripts/ci/layered.sh + - name: Copy config + run: cp element.io/develop/config.json config.json + working-directory: ./element-web + - name: Build + env: + CI_PACKAGE: true + NODE_OPTIONS: "--openssl-legacy-provider" + run: yarn build + working-directory: ./element-web + - name: Edit Test Config + run: | + sed -i '/HOMESERVER/c\ HOMESERVER: "dendritePinecone",' cypress.config.ts + - name: "Run cypress tests" + uses: cypress-io/github-action@v4.1.1 + with: + browser: chrome + start: npx serve -p 8080 ./element-web/webapp + wait-on: 'http://localhost:8080' + env: + PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true + TMPDIR: ${{ runner.temp }} diff --git a/.gitignore b/.gitignore index fe5e82797..dcfbf8007 100644 --- a/.gitignore +++ b/.gitignore @@ -74,3 +74,4 @@ complement/ docs/_site media_store/ +build \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml index a327370e1..bb8d38a8b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -179,7 +179,6 @@ linters-settings: linters: enable: - - deadcode - errcheck - goconst - gocyclo @@ -191,10 +190,8 @@ linters: - misspell # Check code comments, whereas misspell in CI checks *.md files - nakedret - staticcheck - - structcheck - unparam - unused - - varcheck enable-all: false disable: - bodyclose diff --git a/CHANGES.md b/CHANGES.md index f0cd2b4b9..8052efd8a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,30 @@ # Changelog +## Dendrite 0.12.0 (2023-03-13) + +### Features + +- The userapi and keyserver have been merged (no actions needed regarding the database) +- The internal NATS JetStream server is now using logrus for logging (contributed by [dvob](https://github.com/dvob)) +- The roomserver database has been refactored to have separate interfaces when working with rooms and events. Also includes increased usage of the cache to avoid database round trips. (database is unchanged) +- The pinecone demo now shuts down more cleanly +- The Helm chart now has the ability to deploy a Grafana chart as well (contributed by [genofire](https://github.com/genofire)) +- Support for listening on unix sockets has been added (contributed by [cyberb](https://github.com/cyberb)) +- The internal NATS server was updated to v2.9.15 +- Initial support for `runtime/trace` has been added, to further track down long-running tasks + +### Fixes + +- The `session_id` is now correctly set when using SQLite +- An issue where device keys could be removed if a device ID is reused has been fixed +- A possible DoS issue related to relations has been fixed (reported by [sleroq](https://github.com/sleroq)) +- When backfilling events, errors are now ignored if we still could fetch events + +### Other + +- **⚠️ DEPRECATION: Polylith/HTTP API mode has been removed** +- The default endpoint to report usages stats to has been updated + ## Dendrite 0.11.1 (2023-02-10) **⚠️ DEPRECATION WARNING: This is the last release to have polylith and HTTP API mode. Future releases are monolith only.** diff --git a/appservice/api/query.go b/appservice/api/query.go index eb567b2ee..472266d9e 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -22,8 +22,6 @@ import ( "encoding/json" "errors" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -150,6 +148,10 @@ type ASLocationResponse struct { Fields json.RawMessage `json:"fields"` } +// ErrProfileNotExists is returned when trying to lookup a user's profile that +// doesn't exist locally. +var ErrProfileNotExists = errors.New("no known profile for given user ID") + // RetrieveUserProfile is a wrapper that queries both the local database and // application services for a given user's profile // TODO: Remove this, it's called from federationapi and clientapi but is a pure function @@ -157,25 +159,11 @@ func RetrieveUserProfile( ctx context.Context, userID string, asAPI AppServiceInternalAPI, - profileAPI userapi.ClientUserAPI, + profileAPI userapi.ProfileAPI, ) (*authtypes.Profile, error) { - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return nil, err - } - // Try to query the user from the local database - res := &userapi.QueryProfileResponse{} - err = profileAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: userID}, res) - if err != nil { - return nil, err - } - profile := &authtypes.Profile{ - Localpart: localpart, - DisplayName: res.DisplayName, - AvatarURL: res.AvatarURL, - } - if res.UserExists { + profile, err := profileAPI.QueryProfile(ctx, userID) + if err == nil { return profile, nil } @@ -188,19 +176,15 @@ func RetrieveUserProfile( // If no user exists, return if !userResp.UserIDExists { - return nil, errors.New("no known profile for given user ID") + return nil, ErrProfileNotExists } // Try to query the user from the local database again - err = profileAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: userID}, res) + profile, err = profileAPI.QueryProfile(ctx, userID) if err != nil { return nil, err } // profile should not be nil at this point - return &authtypes.Profile{ - Localpart: localpart, - DisplayName: res.DisplayName, - AvatarURL: res.AvatarURL, - }, nil + return profile, nil } diff --git a/appservice/appservice.go b/appservice/appservice.go index 5b1b93de2..d94a483e0 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -16,20 +16,17 @@ package appservice import ( "context" - "crypto/tls" - "net/http" "sync" - "time" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" - "github.com/matrix-org/gomatrixserverlib" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/query" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -37,39 +34,31 @@ import ( // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, + processContext *process.ProcessContext, + cfg *config.Dendrite, + natsInstance *jetstream.NATSInstance, userAPI userapi.AppserviceUserAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) appserviceAPI.AppServiceInternalAPI { - client := &http.Client{ - Timeout: time.Second * 30, - Transport: &http.Transport{ - DisableKeepAlives: true, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: base.Cfg.AppServiceAPI.DisableTLSValidation, - }, - Proxy: http.ProxyFromEnvironment, - }, - } + // Create appserivce query API with an HTTP client that will be used for all // outbound and inbound requests (inbound only for the internal API) appserviceQueryAPI := &query.AppServiceQueryAPI{ - HTTPClient: client, - Cfg: &base.Cfg.AppServiceAPI, + Cfg: &cfg.AppServiceAPI, ProtocolCache: map[string]appserviceAPI.ASProtocolResponse{}, CacheMu: sync.Mutex{}, } - if len(base.Cfg.Derived.ApplicationServices) == 0 { + if len(cfg.Derived.ApplicationServices) == 0 { return appserviceQueryAPI } // Wrap application services in a type that relates the application service and // a sync.Cond object that can be used to notify workers when there are new // events to be sent out. - for _, appservice := range base.Cfg.Derived.ApplicationServices { + for _, appservice := range cfg.Derived.ApplicationServices { // Create bot account for this AS if it doesn't already exist - if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil { + if err := generateAppServiceAccount(userAPI, appservice, cfg.Global.ServerName); err != nil { logrus.WithFields(logrus.Fields{ "appservice": appservice.ID, }).WithError(err).Panicf("failed to generate bot account for appservice") @@ -78,10 +67,10 @@ func NewInternalAPI( // Only consume if we actually have ASes to track, else we'll just chew cycles needlessly. // We can't add ASes at runtime so this is safe to do. - js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + js, _ := natsInstance.Prepare(processContext, &cfg.Global.JetStream) consumer := consumers.NewOutputRoomEventConsumer( - base.ProcessContext, &base.Cfg.AppServiceAPI, - client, js, rsAPI, + processContext, &cfg.AppServiceAPI, + js, rsAPI, ) if err := consumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start appservice roomserver consumer") @@ -96,7 +85,7 @@ func NewInternalAPI( func generateAppServiceAccount( userAPI userapi.AppserviceUserAPI, as config.ApplicationService, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index de9f5aaf1..878ca5666 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -3,19 +3,31 @@ package appservice_test import ( "context" "encoding/json" + "fmt" + "net" "net/http" "net/http/httptest" + "path" "reflect" "regexp" "strings" "testing" + "time" + + "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/appservice/consumers" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver" + rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/test/testrig" ) @@ -104,34 +116,138 @@ func TestAppserviceInternalAPI(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, closeBase := testrig.CreateBaseDendrite(t, dbType) - defer closeBase() + cfg, ctx, close := testrig.CreateConfig(t, dbType) + defer close() // Create a dummy application service - base.Cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ - { - ID: "someID", - URL: srv.URL, - ASToken: "", - HSToken: "", - SenderLocalpart: "senderLocalPart", - NamespaceMap: map[string][]config.ApplicationServiceNamespace{ - "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, - "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, - }, - Protocols: []string{existingProtocol}, + as := &config.ApplicationService{ + ID: "someID", + URL: srv.URL, + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, + "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, }, + Protocols: []string{existingProtocol}, } + as.CreateHTTPClient(cfg.AppServiceAPI.DisableTLSValidation) + cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{*as} + t.Cleanup(func() { + ctx.ShutdownDendrite() + ctx.WaitForShutdown() + }) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) // Create required internal APIs - rsAPI := roomserver.NewInternalAPI(base) - usrAPI := userapi.NewInternalAPI(base, rsAPI, nil) - asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) + natsInstance := jetstream.NATSInstance{} + cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(ctx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil) + asAPI := appservice.NewInternalAPI(ctx, cfg, &natsInstance, usrAPI, rsAPI) runCases(t, asAPI) }) } +func TestAppserviceInternalAPI_UnixSocket_Simple(t *testing.T) { + + // Set expected results + existingProtocol := "irc" + wantLocationResponse := []api.ASLocationResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} + wantUserResponse := []api.ASUserResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} + wantProtocolResponse := api.ASProtocolResponse{Instances: []api.ProtocolInstance{{Fields: []byte("{}")}}} + + // create a dummy AS url, handling some cases + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "location"): + // Check if we've got an existing protocol, if so, return a proper response. + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantLocationResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode([]api.ASLocationResponse{}); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + case strings.Contains(r.URL.Path, "user"): + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantUserResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode([]api.UserResponse{}); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + case strings.Contains(r.URL.Path, "protocol"): + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantProtocolResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode(nil); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + default: + t.Logf("hit location: %s", r.URL.Path) + } + })) + + tmpDir := t.TempDir() + socket := path.Join(tmpDir, "socket") + l, err := net.Listen("unix", socket) + assert.NoError(t, err) + _ = srv.Listener.Close() + srv.Listener = l + srv.Start() + defer srv.Close() + + cfg, ctx, tearDown := testrig.CreateConfig(t, test.DBTypeSQLite) + defer tearDown() + + // Create a dummy application service + as := &config.ApplicationService{ + ID: "someID", + URL: fmt.Sprintf("unix://%s", socket), + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, + "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, + }, + Protocols: []string{existingProtocol}, + } + as.CreateHTTPClient(cfg.AppServiceAPI.DisableTLSValidation) + cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{*as} + + t.Cleanup(func() { + ctx.ShutdownDendrite() + ctx.WaitForShutdown() + }) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + // Create required internal APIs + natsInstance := jetstream.NATSInstance{} + cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(ctx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil) + asAPI := appservice.NewInternalAPI(ctx, cfg, &natsInstance, usrAPI, rsAPI) + + t.Run("UserIDExists", func(t *testing.T) { + testUserIDExists(t, asAPI, "@as-testing:test", true) + testUserIDExists(t, asAPI, "@as1-testing:test", false) + }) + +} + func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) { ctx := context.Background() userResp := &api.UserIDExistsResponse{} @@ -201,3 +317,87 @@ func testProtocol(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, w t.Errorf("unexpected result for Protocols(%s): %+v, expected %+v", proto, protoResp.Protocols[proto], wantResult) } } + +// Tests that the roomserver consumer only receives one invite +func TestRoomserverConsumerOneInvite(t *testing.T) { + + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice) + + // Invite Bob + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + + evChan := make(chan struct{}) + // create a dummy AS url, handling the events + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var txn consumers.ApplicationServiceTransaction + err := json.NewDecoder(r.Body).Decode(&txn) + if err != nil { + t.Fatal(err) + } + for _, ev := range txn.Events { + if ev.Type != spec.MRoomMember { + continue + } + // Usually we would check the event content for the membership, but since + // we only invited bob, this should be fine for this test. + if ev.StateKey != nil && *ev.StateKey == bob.ID { + evChan <- struct{}{} + } + } + })) + defer srv.Close() + + as := &config.ApplicationService{ + ID: "someID", + URL: srv.URL, + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile(bob.ID)}}, + "aliases": {{RegexpObject: regexp.MustCompile(room.ID)}}, + }, + } + as.CreateHTTPClient(cfg.AppServiceAPI.DisableTLSValidation) + + // Create a dummy application service + cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{*as} + + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + // Create required internal APIs + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil) + // start the consumer + appservice.NewInternalAPI(processCtx, cfg, natsInstance, usrAPI, rsAPI) + + // Create the room + if err := rsapi.SendEvents(context.Background(), rsAPI, rsapi.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + var seenInvitesForBob int + waitLoop: + for { + select { + case <-time.After(time.Millisecond * 50): // wait for the AS to process the events + break waitLoop + case <-evChan: + seenInvitesForBob++ + if seenInvitesForBob != 1 { + t.Fatalf("received unexpected invites: %d", seenInvitesForBob) + } + } + } + close(evChan) + }) +} diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 528de63e8..8ca8220d0 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -26,21 +26,28 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/synctypes" log "github.com/sirupsen/logrus" ) +// ApplicationServiceTransaction is the transaction that is sent off to an +// application service. +type ApplicationServiceTransaction struct { + Events []synctypes.ClientEvent `json:"events"` +} + // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { ctx context.Context cfg *config.AppServiceAPI - client *http.Client jetstream nats.JetStreamContext topic string rsAPI api.AppserviceRoomserverAPI @@ -56,14 +63,12 @@ type appserviceState struct { func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.AppServiceAPI, - client *http.Client, js nats.JetStreamContext, rsAPI api.AppserviceRoomserverAPI, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), cfg: cfg, - client: client, jetstream: js, topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), rsAPI: rsAPI, @@ -140,12 +145,6 @@ func (s *OutputRoomEventConsumer) onMessage( } } - case api.OutputTypeNewInviteEvent: - if output.NewInviteEvent == nil || !s.appserviceIsInterestedInEvent(ctx, output.NewInviteEvent.Event, state.ApplicationService) { - continue - } - events = append(events, output.NewInviteEvent.Event) - default: continue } @@ -180,8 +179,8 @@ func (s *OutputRoomEventConsumer) sendEvents( ) error { // Create the transaction body. transaction, err := json.Marshal( - gomatrixserverlib.ApplicationServiceTransaction{ - Events: gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll), + ApplicationServiceTransaction{ + Events: synctypes.HeaderedToClientEvents(events, synctypes.FormatAll), }, ) if err != nil { @@ -195,13 +194,13 @@ func (s *OutputRoomEventConsumer) sendEvents( // Send the transaction to the appservice. // https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid - address := fmt.Sprintf("%s/transactions/%s?access_token=%s", state.URL, txnID, url.QueryEscape(state.HSToken)) + address := fmt.Sprintf("%s/transactions/%s?access_token=%s", state.RequestUrl(), txnID, url.QueryEscape(state.HSToken)) req, err := http.NewRequestWithContext(ctx, "PUT", address, bytes.NewBuffer(transaction)) if err != nil { return err } req.Header.Set("Content-Type", "application/json") - resp, err := s.client.Do(req) + resp, err := state.HTTPClient.Do(req) if err != nil { return state.backoffAndPause(err) } @@ -212,7 +211,7 @@ func (s *OutputRoomEventConsumer) sendEvents( case http.StatusOK: state.backoff = 0 default: - return state.backoffAndPause(fmt.Errorf("received HTTP status code %d from appservice", resp.StatusCode)) + return state.backoffAndPause(fmt.Errorf("received HTTP status code %d from appservice url %s", resp.StatusCode, address)) } return nil } @@ -242,7 +241,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont return true } - if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil { + if event.Type() == spec.MRoomMember && event.StateKey() != nil { if appservice.IsInterestedInUserID(*event.StateKey()) { return true } @@ -288,7 +287,7 @@ func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, e switch { case ev.StateKey == nil: continue - case ev.Type != gomatrixserverlib.MRoomMember: + case ev.Type != spec.MRoomMember: continue } var membership gomatrixserverlib.MemberContent @@ -296,7 +295,7 @@ func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, e switch { case err != nil: continue - case membership.Membership == gomatrixserverlib.Join: + case membership.Membership == spec.Join: if appservice.IsInterestedInUserID(*ev.StateKey) { return true } diff --git a/appservice/query/query.go b/appservice/query/query.go index 2348eab4b..ca8d7b3a3 100644 --- a/appservice/query/query.go +++ b/appservice/query/query.go @@ -25,10 +25,10 @@ import ( "strings" "sync" - "github.com/opentracing/opentracing-go" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" ) @@ -37,7 +37,6 @@ const userIDExistsPath = "/users/" // AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI type AppServiceQueryAPI struct { - HTTPClient *http.Client Cfg *config.AppServiceAPI ProtocolCache map[string]api.ASProtocolResponse CacheMu sync.Mutex @@ -50,14 +49,14 @@ func (a *AppServiceQueryAPI) RoomAliasExists( request *api.RoomAliasExistsRequest, response *api.RoomAliasExistsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "ApplicationServiceRoomAlias") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "ApplicationServiceRoomAlias") + defer trace.EndRegion() // Determine which application service should handle this request for _, appservice := range a.Cfg.Derived.ApplicationServices { if appservice.URL != "" && appservice.IsInterestedInRoomAlias(request.Alias) { // The full path to the rooms API, includes hs token - URL, err := url.Parse(appservice.URL + roomAliasExistsPath) + URL, err := url.Parse(appservice.RequestUrl() + roomAliasExistsPath) if err != nil { return err } @@ -73,7 +72,7 @@ func (a *AppServiceQueryAPI) RoomAliasExists( } req = req.WithContext(ctx) - resp, err := a.HTTPClient.Do(req) + resp, err := appservice.HTTPClient.Do(req) if resp != nil { defer func() { err = resp.Body.Close() @@ -117,14 +116,14 @@ func (a *AppServiceQueryAPI) UserIDExists( request *api.UserIDExistsRequest, response *api.UserIDExistsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "ApplicationServiceUserID") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "ApplicationServiceUserID") + defer trace.EndRegion() // Determine which application service should handle this request for _, appservice := range a.Cfg.Derived.ApplicationServices { if appservice.URL != "" && appservice.IsInterestedInUserID(request.UserID) { // The full path to the rooms API, includes hs token - URL, err := url.Parse(appservice.URL + userIDExistsPath) + URL, err := url.Parse(appservice.RequestUrl() + userIDExistsPath) if err != nil { return err } @@ -137,7 +136,7 @@ func (a *AppServiceQueryAPI) UserIDExists( if err != nil { return err } - resp, err := a.HTTPClient.Do(req.WithContext(ctx)) + resp, err := appservice.HTTPClient.Do(req.WithContext(ctx)) if resp != nil { defer func() { err = resp.Body.Close() @@ -212,12 +211,12 @@ func (a *AppServiceQueryAPI) Locations( var asLocations []api.ASLocationResponse params.Set("access_token", as.HSToken) - url := as.URL + api.ASLocationPath + url := as.RequestUrl() + api.ASLocationPath if req.Protocol != "" { url += "/" + req.Protocol } - if err := requestDo[[]api.ASLocationResponse](a.HTTPClient, url+"?"+params.Encode(), &asLocations); err != nil { + if err := requestDo[[]api.ASLocationResponse](as.HTTPClient, url+"?"+params.Encode(), &asLocations); err != nil { log.WithError(err).Error("unable to get 'locations' from application service") continue } @@ -247,12 +246,12 @@ func (a *AppServiceQueryAPI) User( var asUsers []api.ASUserResponse params.Set("access_token", as.HSToken) - url := as.URL + api.ASUserPath + url := as.RequestUrl() + api.ASUserPath if req.Protocol != "" { url += "/" + req.Protocol } - if err := requestDo[[]api.ASUserResponse](a.HTTPClient, url+"?"+params.Encode(), &asUsers); err != nil { + if err := requestDo[[]api.ASUserResponse](as.HTTPClient, url+"?"+params.Encode(), &asUsers); err != nil { log.WithError(err).Error("unable to get 'user' from application service") continue } @@ -290,7 +289,7 @@ func (a *AppServiceQueryAPI) Protocols( response := api.ASProtocolResponse{} for _, as := range a.Cfg.Derived.ApplicationServices { var proto api.ASProtocolResponse - if err := requestDo[api.ASProtocolResponse](a.HTTPClient, as.URL+api.ASProtocolPath+req.Protocol, &proto); err != nil { + if err := requestDo[api.ASProtocolResponse](as.HTTPClient, as.RequestUrl()+api.ASProtocolPath+req.Protocol, &proto); err != nil { log.WithError(err).Error("unable to get 'protocol' from application service") continue } @@ -320,7 +319,7 @@ func (a *AppServiceQueryAPI) Protocols( for _, as := range a.Cfg.Derived.ApplicationServices { for _, p := range as.Protocols { var proto api.ASProtocolResponse - if err := requestDo[api.ASProtocolResponse](a.HTTPClient, as.URL+api.ASProtocolPath+p, &proto); err != nil { + if err := requestDo[api.ASProtocolResponse](as.HTTPClient, as.RequestUrl()+api.ASProtocolPath+p, &proto); err != nil { log.WithError(err).Error("unable to get 'protocol' from application service") continue } diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go index 44e52286f..61baed902 100644 --- a/build/dendritejs-pinecone/main.go +++ b/build/dendritejs-pinecone/main.go @@ -29,11 +29,14 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" @@ -157,9 +160,8 @@ func startup() { pManager.AddPeer("wss://pinecone.matrix.org/public") cfg := &config.Dendrite{} - cfg.Defaults(true) + cfg.Defaults(config.DefaultOpts{Generate: true, SingleDatabase: false}) cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" - cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" @@ -169,35 +171,37 @@ func startup() { cfg.Global.TrustedIDServers = []string{} cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.PrivateKey = sk - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + cfg.Global.ServerName = spec.ServerName(hex.EncodeToString(pk)) cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true if err := cfg.Derive(); err != nil { logrus.Fatalf("Failed to derive values from config: %s", err) } - base := base.NewBaseDendrite(cfg) - defer base.Close() // nolint: errcheck + natsInstance := jetstream.NATSInstance{} + processCtx := process.NewProcessContext() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() + caches := caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, caching.EnableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.EnableMetrics) - rsAPI := roomserver.NewInternalAPI(base) - - federation := conn.CreateFederationClient(base, pSessions) + federation := conn.CreateFederationClient(cfg, pSessions) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(base, rsAPI, federation) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation) asQuery := appservice.NewInternalAPI( - base, userAPI, rsAPI, + processCtx, cfg, &natsInstance, userAPI, rsAPI, ) rsAPI.SetAppserviceAPI(asQuery) - fedSenderAPI := federationapi.NewInternalAPI(base, federation, rsAPI, base.Caches, keyRing, true) + fedSenderAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, federation, rsAPI, caches, keyRing, true) rsAPI.SetFederationAPI(fedSenderAPI, keyRing) monolith := setup.Monolith{ - Config: base.Cfg, - Client: conn.CreateClient(base, pSessions), + Config: cfg, + Client: conn.CreateClient(pSessions), FedClient: federation, KeyRing: keyRing, @@ -208,15 +212,15 @@ func startup() { //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: rooms.NewPineconeRoomProvider(pRouter, pSessions, fedSenderAPI, federation), } - monolith.AddAllPublicRoutes(base) + monolith.AddAllPublicRoutes(processCtx, cfg, routers, cm, &natsInstance, caches, caching.EnableMetrics) httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) - httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(routers.Client) + httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(routers.Media) p2pRouter := pSessions.Protocol("matrix").HTTP().Mux() - p2pRouter.Handle(httputil.PublicFederationPathPrefix, base.PublicFederationAPIMux) - p2pRouter.Handle(httputil.PublicMediaPathPrefix, base.PublicMediaAPIMux) + p2pRouter.Handle(httputil.PublicFederationPathPrefix, routers.Federation) + p2pRouter.Handle(httputil.PublicMediaPathPrefix, routers.Media) // Expose the matrix APIs via fetch - for local traffic go func() { diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 16797eec0..8718c71fd 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -30,8 +30,12 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/process" userapiAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/pinecone/types" "github.com/sirupsen/logrus" @@ -137,9 +141,9 @@ func (m *DendriteMonolith) SetStaticPeer(uri string) { } } -func getServerKeyFromString(nodeID string) (gomatrixserverlib.ServerName, error) { - var nodeKey gomatrixserverlib.ServerName - if userID, err := gomatrixserverlib.NewUserID(nodeID, false); err == nil { +func getServerKeyFromString(nodeID string) (spec.ServerName, error) { + var nodeKey spec.ServerName + if userID, err := spec.NewUserID(nodeID, false); err == nil { hexKey, decodeErr := hex.DecodeString(string(userID.Domain())) if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { return "", fmt.Errorf("UserID domain is not a valid ed25519 public key: %v", userID.Domain()) @@ -151,7 +155,7 @@ func getServerKeyFromString(nodeID string) (gomatrixserverlib.ServerName, error) if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { return "", fmt.Errorf("Relay server uri is not a valid ed25519 public key: %v", nodeID) } else { - nodeKey = gomatrixserverlib.ServerName(nodeID) + nodeKey = spec.ServerName(nodeID) } } @@ -159,7 +163,7 @@ func getServerKeyFromString(nodeID string) (gomatrixserverlib.ServerName, error) } func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { - relays := []gomatrixserverlib.ServerName{} + relays := []spec.ServerName{} for _, uri := range strings.Split(uris, ",") { uri = strings.TrimSpace(uri) if len(uri) == 0 { @@ -185,9 +189,9 @@ func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { m.p2pMonolith.RelayRetriever.SetRelayServers(relays) } else { relay.UpdateNodeRelayServers( - gomatrixserverlib.ServerName(nodeKey), + spec.ServerName(nodeKey), relays, - m.p2pMonolith.BaseDendrite.Context(), + m.p2pMonolith.ProcessCtx.Context(), m.p2pMonolith.GetFederationAPI(), ) } @@ -212,9 +216,9 @@ func (m *DendriteMonolith) GetRelayServers(nodeID string) string { relaysString += string(relay) } } else { - request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(nodeKey)} + request := api.P2PQueryRelayServersRequest{Server: spec.ServerName(nodeKey)} response := api.P2PQueryRelayServersResponse{} - err := m.p2pMonolith.GetFederationAPI().P2PQueryRelayServers(m.p2pMonolith.BaseDendrite.Context(), &request, &response) + err := m.p2pMonolith.GetFederationAPI().P2PQueryRelayServers(m.p2pMonolith.ProcessCtx.Context(), &request, &response) if err != nil { logrus.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) return "" @@ -288,7 +292,7 @@ func (m *DendriteMonolith) RegisterUser(localpart, password string) (string, err pubkey := m.p2pMonolith.Router.PublicKey() userID := userutil.MakeUserID( localpart, - gomatrixserverlib.ServerName(hex.EncodeToString(pubkey[:])), + spec.ServerName(hex.EncodeToString(pubkey[:])), ) userReq := &userapiAPI.PerformAccountCreationRequest{ AccountType: userapiAPI.AccountTypeUser, @@ -339,17 +343,21 @@ func (m *DendriteMonolith) Start() { prefix := hex.EncodeToString(pk) cfg := monolith.GenerateDefaultConfig(sk, m.StorageDirectory, m.CacheDirectory, prefix) - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + cfg.Global.ServerName = spec.ServerName(hex.EncodeToString(pk)) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.InMemory = false // NOTE : disabled for now since there is a 64 bit alignment panic on 32 bit systems // This isn't actually fixed: https://github.com/blevesearch/zapx/pull/147 cfg.SyncAPI.Fulltext.Enabled = false + processCtx := process.NewProcessContext() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() + enableRelaying := false enableMetrics := false enableWebsockets := false - m.p2pMonolith.SetupDendrite(cfg, 65432, enableRelaying, enableMetrics, enableWebsockets) + m.p2pMonolith.SetupDendrite(processCtx, cfg, cm, routers, 65432, enableRelaying, enableMetrics, enableWebsockets) m.p2pMonolith.StartMonolith() } diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go index 434e07ef2..7a7e36c7e 100644 --- a/build/gobind-pinecone/monolith_test.go +++ b/build/gobind-pinecone/monolith_test.go @@ -18,7 +18,7 @@ import ( "strings" "testing" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func TestMonolithStarts(t *testing.T) { @@ -110,7 +110,7 @@ func TestParseServerKey(t *testing.T) { name string serverKey string expectedErr bool - expectedKey gomatrixserverlib.ServerName + expectedKey spec.ServerName }{ { name: "valid userid as key", diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 32af611ae..720ce37eb 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -12,6 +12,7 @@ import ( "path/filepath" "time" + "github.com/getsentry/sentry-go" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" @@ -19,15 +20,20 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/yggrooms" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" + basepkg "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" _ "golang.org/x/mobile/bind" @@ -129,7 +135,7 @@ func (m *DendriteMonolith) Start() { Generate: true, SingleDatabase: true, }) - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + cfg.Global.ServerName = spec.ServerName(hex.EncodeToString(pk)) cfg.Global.PrivateKey = sk cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory)) @@ -148,25 +154,71 @@ func (m *DendriteMonolith) Start() { panic(err) } - base := base.NewBaseDendrite(cfg) - base.ConfigureAdminEndpoints() - m.processContext = base.ProcessContext - defer base.Close() // nolint: errcheck + configErrors := &config.ConfigErrors{} + cfg.Verify(configErrors) + if len(*configErrors) > 0 { + for _, err := range *configErrors { + logrus.Errorf("Configuration error: %s", err) + } + logrus.Fatalf("Failed to start due to configuration errors") + } - federation := ygg.CreateFederationClient(base) + internal.SetupStdLogging() + internal.SetupHookLogging(cfg.Logging) + internal.SetupPprof() + + logrus.Infof("Dendrite version %s", internal.VersionString()) + + if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled { + logrus.Warn("Open registration is enabled") + } + + closer, err := cfg.SetupTracing() + if err != nil { + logrus.WithError(err).Panicf("failed to start opentracing") + } + defer closer.Close() + + if cfg.Global.Sentry.Enabled { + logrus.Info("Setting up Sentry for debugging...") + err = sentry.Init(sentry.ClientOptions{ + Dsn: cfg.Global.Sentry.DSN, + Environment: cfg.Global.Sentry.Environment, + Debug: true, + ServerName: string(cfg.Global.ServerName), + Release: "dendrite@" + internal.VersionString(), + AttachStacktrace: true, + }) + if err != nil { + logrus.WithError(err).Panic("failed to start Sentry") + } + } + processCtx := process.NewProcessContext() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() + basepkg.ConfigureAdminEndpoints(processCtx, routers) + m.processContext = processCtx + defer func() { + processCtx.ShutdownDendrite() + processCtx.WaitForShutdown() + }() // nolint: errcheck + + federation := ygg.CreateFederationClient(cfg) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - rsAPI := roomserver.NewInternalAPI(base) + caches := caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, caching.EnableMetrics) + natsInstance := jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.EnableMetrics) fsAPI := federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, keyRing, true, + processCtx, cfg, cm, &natsInstance, federation, rsAPI, caches, keyRing, true, ) - userAPI := userapi.NewInternalAPI(base, rsAPI, federation) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation) - asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) // The underlying roomserver implementation needs to be able to call the fedsender. @@ -174,8 +226,8 @@ func (m *DendriteMonolith) Start() { rsAPI.SetFederationAPI(fsAPI, keyRing) monolith := setup.Monolith{ - Config: base.Cfg, - Client: ygg.CreateClient(base), + Config: cfg, + Client: ygg.CreateClient(), FedClient: federation, KeyRing: keyRing, @@ -187,17 +239,17 @@ func (m *DendriteMonolith) Start() { ygg, fsAPI, federation, ), } - monolith.AddAllPublicRoutes(base) + monolith.AddAllPublicRoutes(processCtx, cfg, routers, cm, &natsInstance, caches, caching.EnableMetrics) httpRouter := mux.NewRouter() - httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) - httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) - httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) + httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(routers.Client) + httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(routers.Media) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(routers.DendriteAdmin) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(routers.SynapseAdmin) yggRouter := mux.NewRouter() - yggRouter.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) - yggRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + yggRouter.PathPrefix(httputil.PublicFederationPathPrefix).Handler(routers.Federation) + yggRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(routers.Media) // Build both ends of a HTTP multiplex. m.httpServer = &http.Server{ diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 70bbe8f95..453d89765 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -1,6 +1,6 @@ #syntax=docker/dockerfile:1.2 -FROM golang:1.18-stretch as build +FROM golang:1.20-bullseye as build RUN apt-get update && apt-get install -y sqlite3 WORKDIR /build @@ -17,7 +17,7 @@ RUN --mount=target=. \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ CGO_ENABLED=${CGO} go build -o /dendrite/dendrite ./cmd/dendrite && \ - CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite && \ + CGO_ENABLED=${CGO} go build -cover -covermode=atomic -o /dendrite/dendrite-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite && \ cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index d4b6d3f75..77071b450 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -1,19 +1,19 @@ #syntax=docker/dockerfile:1.2 -FROM golang:1.18-stretch as build +FROM golang:1.20-bullseye as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build # No password when connecting over localhost -RUN sed -i "s%127.0.0.1/32 md5%127.0.0.1/32 trust%g" /etc/postgresql/9.6/main/pg_hba.conf && \ +RUN sed -i "s%127.0.0.1/32 md5%127.0.0.1/32 trust%g" /etc/postgresql/13/main/pg_hba.conf && \ # Bump up max conns for moar concurrency - sed -i 's/max_connections = 100/max_connections = 2000/g' /etc/postgresql/9.6/main/postgresql.conf + sed -i 's/max_connections = 100/max_connections = 2000/g' /etc/postgresql/13/main/postgresql.conf # This entry script starts postgres, waits for it to be up then starts dendrite RUN echo '\ #!/bin/bash -eu \n\ pg_lsclusters \n\ - pg_ctlcluster 9.6 main start \n\ + pg_ctlcluster 13 main start \n\ \n\ until pg_isready \n\ do \n\ @@ -35,7 +35,7 @@ RUN --mount=target=. \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ CGO_ENABLED=${CGO} go build -o /dendrite/dendrite ./cmd/dendrite && \ - CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite && \ + CGO_ENABLED=${CGO} go build -cover -covermode=atomic -o /dendrite/dendrite-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite && \ cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite diff --git a/build/scripts/complement-cmd.sh b/build/scripts/complement-cmd.sh index 52b063d01..3f6ed070c 100755 --- a/build/scripts/complement-cmd.sh +++ b/build/scripts/complement-cmd.sh @@ -2,14 +2,15 @@ # This script is intended to be used inside a docker container for Complement +export GOCOVERDIR=/tmp/covdatafiles +mkdir -p "${GOCOVERDIR}" if [[ "${COVER}" -eq 1 ]]; then echo "Running with coverage" exec /dendrite/dendrite-cover \ --really-enable-open-registration \ --tls-cert server.crt \ --tls-key server.key \ - --config dendrite.yaml \ - --test.coverprofile=complementcover.log + --config dendrite.yaml else echo "Not running with coverage" exec /dendrite/dendrite \ diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 300d3a88a..cf1c4ef30 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -4,15 +4,22 @@ import ( "context" "net/http" "net/http/httptest" + "reflect" "testing" + "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" + basepkg "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/tidwall/gjson" @@ -29,54 +36,30 @@ func TestAdminResetPassword(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, baseClose := testrig.CreateBaseDendrite(t, dbType) - defer baseClose() - + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} // add a vhost - base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, }) - rsAPI := roomserver.NewInternalAPI(base) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) // Needed for changing the password/login - userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) // Create the users in the userapi and login - accessTokens := map[*test.User]string{ - aliceAdmin: "", - bob: "", - vhUser: "", - } - for u := range accessTokens { - localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) - userRes := &uapi.PerformAccountCreationResponse{} - password := util.RandomString(8) - if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ - AccountType: u.AccountType, - Localpart: localpart, - ServerName: serverName, - Password: password, - }, userRes); err != nil { - t.Errorf("failed to create account: %s", err) - } - - req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ - "type": authtypes.LoginTypePassword, - "identifier": map[string]interface{}{ - "type": "m.id.user", - "user": u.ID, - }, - "password": password, - })) - rec := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("failed to login: %s", rec.Body.String()) - } - accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + vhUser: {}, } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) testCases := []struct { name string @@ -120,11 +103,11 @@ func TestAdminResetPassword(t *testing.T) { } if tc.withHeader { - req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser]) + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) } rec := httptest.NewRecorder() - base.DendriteAdminMux.ServeHTTP(rec, req) + routers.DendriteAdmin.ServeHTTP(rec, req) t.Logf("%s", rec.Body.String()) if tc.wantOK && rec.Code != http.StatusOK { t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) @@ -140,23 +123,26 @@ func TestPurgeRoom(t *testing.T) { room := test.NewRoom(t, aliceAdmin, test.RoomPreset(test.PresetTrustedPrivateChat)) // Invite Bob - room.CreateAndInsert(t, aliceAdmin, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, aliceAdmin, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, baseClose := testrig.CreateBaseDendrite(t, dbType) - defer baseClose() + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() - fedClient := base.CreateFederationClient() - rsAPI := roomserver.NewInternalAPI(base) - userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) // this starts the JetStream consumers - syncapi.AddPublicRoutes(base, userAPI, rsAPI) - federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) + federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) rsAPI.SetFederationAPI(nil, nil) // Create the room @@ -165,40 +151,13 @@ func TestPurgeRoom(t *testing.T) { } // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) // Create the users in the userapi and login - accessTokens := map[*test.User]string{ - aliceAdmin: "", - } - for u := range accessTokens { - localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) - userRes := &uapi.PerformAccountCreationResponse{} - password := util.RandomString(8) - if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ - AccountType: u.AccountType, - Localpart: localpart, - ServerName: serverName, - Password: password, - }, userRes); err != nil { - t.Errorf("failed to create account: %s", err) - } - - req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ - "type": authtypes.LoginTypePassword, - "identifier": map[string]interface{}{ - "type": "m.id.user", - "user": u.ID, - }, - "password": password, - })) - rec := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("failed to login: %s", rec.Body.String()) - } - accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) testCases := []struct { name string @@ -215,10 +174,10 @@ func TestPurgeRoom(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/purgeRoom/"+tc.roomID) - req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin].accessToken) rec := httptest.NewRecorder() - base.DendriteAdminMux.ServeHTTP(rec, req) + routers.DendriteAdmin.ServeHTTP(rec, req) t.Logf("%s", rec.Body.String()) if tc.wantOK && rec.Code != http.StatusOK { t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) @@ -228,3 +187,256 @@ func TestPurgeRoom(t *testing.T) { }) } + +func TestAdminEvacuateRoom(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t) + room := test.NewRoom(t, aliceAdmin) + + // Join Bob + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // this starts the JetStream consumers + fsAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) + rsAPI.SetFederationAPI(fsAPI, nil) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", api.DoNotSendToOtherServers, nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + testCases := []struct { + name string + roomID string + wantOK bool + wantAffected []string + }{ + {name: "Can evacuate existing room", wantOK: true, roomID: room.ID, wantAffected: []string{aliceAdmin.ID, bob.ID}}, + {name: "Can not evacuate non-existent room", wantOK: false, roomID: "!doesnotexist:localhost", wantAffected: []string{}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/evacuateRoom/"+tc.roomID) + + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin].accessToken) + + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + affectedArr := gjson.GetBytes(rec.Body.Bytes(), "affected").Array() + affected := make([]string, 0, len(affectedArr)) + for _, x := range affectedArr { + affected = append(affected, x.Str) + } + if !reflect.DeepEqual(affected, tc.wantAffected) { + t.Fatalf("expected affected %#v, but got %#v", tc.wantAffected, affected) + } + }) + } + + // Wait for the FS API to have consumed every message + js, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + timeout := time.After(time.Second) + for { + select { + case <-timeout: + t.Fatalf("FS API didn't process all events in time") + default: + } + info, err := js.ConsumerInfo(cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent), cfg.Global.JetStream.Durable("FederationAPIRoomServerConsumer")+"Pull") + if err != nil { + time.Sleep(time.Millisecond * 10) + continue + } + if info.NumPending == 0 && info.NumAckPending == 0 { + break + } + } + }) +} + +func TestAdminEvacuateUser(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t) + room := test.NewRoom(t, aliceAdmin) + room2 := test.NewRoom(t, aliceAdmin) + + // Join Bob + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room2.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // this starts the JetStream consumers + fsAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, basepkg.CreateFederationClient(cfg, nil), rsAPI, caches, nil, true) + rsAPI.SetFederationAPI(fsAPI, nil) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", api.DoNotSendToOtherServers, nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room2.Events(), "test", "test", api.DoNotSendToOtherServers, nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + testCases := []struct { + name string + userID string + wantOK bool + wantAffectedRooms []string + }{ + {name: "Can evacuate existing user", wantOK: true, userID: bob.ID, wantAffectedRooms: []string{room.ID, room2.ID}}, + {name: "invalid userID is rejected", wantOK: false, userID: "!notauserid:test", wantAffectedRooms: []string{}}, + {name: "Can not evacuate user from different server", wantOK: false, userID: "@doesnotexist:localhost", wantAffectedRooms: []string{}}, + {name: "Can not evacuate non-existent user", wantOK: false, userID: "@doesnotexist:test", wantAffectedRooms: []string{}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/evacuateUser/"+tc.userID) + + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin].accessToken) + + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + affectedArr := gjson.GetBytes(rec.Body.Bytes(), "affected").Array() + affected := make([]string, 0, len(affectedArr)) + for _, x := range affectedArr { + affected = append(affected, x.Str) + } + if !reflect.DeepEqual(affected, tc.wantAffectedRooms) { + t.Fatalf("expected affected %#v, but got %#v", tc.wantAffectedRooms, affected) + } + + }) + } + // Wait for the FS API to have consumed every message + js, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + timeout := time.After(time.Second) + for { + select { + case <-timeout: + t.Fatalf("FS API didn't process all events in time") + default: + } + info, err := js.ConsumerInfo(cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent), cfg.Global.JetStream.Durable("FederationAPIRoomServerConsumer")+"Pull") + if err != nil { + time.Sleep(time.Millisecond * 10) + continue + } + if info.NumPending == 0 && info.NumAckPending == 0 { + break + } + } + }) +} + +func TestAdminMarkAsStale(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + testCases := []struct { + name string + userID string + wantOK bool + }{ + {name: "local user is not allowed", userID: aliceAdmin.ID}, + {name: "invalid userID", userID: "!notvalid:test"}, + {name: "remote user is allowed", userID: "@alice:localhost", wantOK: true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/refreshDevices/"+tc.userID) + + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin].accessToken) + + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} diff --git a/clientapi/api/api.go b/clientapi/api/api.go index d96b032f0..23974c865 100644 --- a/clientapi/api/api.go +++ b/clientapi/api/api.go @@ -14,10 +14,10 @@ package api -import "github.com/matrix-org/gomatrixserverlib" +import "github.com/matrix-org/gomatrixserverlib/fclient" // ExtraPublicRoomsProvider provides a way to inject extra published rooms into /publicRooms requests. type ExtraPublicRoomsProvider interface { // Rooms returns the extra rooms. This is called on-demand by clients, so cache appropriately. - Rooms() []gomatrixserverlib.PublicRoom + Rooms() []fclient.PublicRoom } diff --git a/clientapi/auth/authtypes/threepid.go b/clientapi/auth/authtypes/threepid.go index 60d77dc6f..ebafbe6aa 100644 --- a/clientapi/auth/authtypes/threepid.go +++ b/clientapi/auth/authtypes/threepid.go @@ -16,6 +16,8 @@ package authtypes // ThreePID represents a third-party identifier type ThreePID struct { - Address string `json:"address"` - Medium string `json:"medium"` + Address string `json:"address"` + Medium string `json:"medium"` + AddedAt int64 `json:"added_at"` + ValidatedAt int64 `json:"validated_at"` } diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index 044062c42..c91cba241 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -25,7 +25,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" ) @@ -68,7 +68,7 @@ func TestLoginFromJSONReader(t *testing.T) { var userAPI fakeUserInternalAPI cfg := &config.ClientAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: serverName, }, }, @@ -148,7 +148,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { var userAPI fakeUserInternalAPI cfg := &config.ClientAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: serverName, }, }, diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 5d97b31ce..4003e9647 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -8,13 +8,14 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) var ( ctx = context.Background() - serverName = gomatrixserverlib.ServerName("example.com") + serverName = spec.ServerName("example.com") // space separated localpart+password -> account lookup = make(map[string]*api.Account) device = &api.Device{ @@ -47,7 +48,7 @@ func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *a func setup() *UserInteractive { cfg := &config.ClientAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: serverName, }, }, diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index e9985d43f..b57c8061e 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,8 +15,11 @@ package clientapi import ( + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" @@ -25,41 +28,41 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/transactions" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. func AddPublicRoutes( - base *base.BaseDendrite, - federation *gomatrixserverlib.FederationClient, + processContext *process.ProcessContext, + routers httputil.Routers, + cfg *config.Dendrite, + natsInstance *jetstream.NATSInstance, + federation *fclient.FederationClient, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, transactionsCache *transactions.Cache, fsAPI federationAPI.ClientFederationAPI, userAPI userapi.ClientUserAPI, userDirectoryProvider userapi.QuerySearchProfilesAPI, - extRoomsProvider api.ExtraPublicRoomsProvider, + extRoomsProvider api.ExtraPublicRoomsProvider, enableMetrics bool, ) { - cfg := &base.Cfg.ClientAPI - mscCfg := &base.Cfg.MSCs - js, natsClient := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, natsClient := natsInstance.Prepare(processContext, &cfg.Global.JetStream) syncProducer := &producers.SyncAPIProducer{ JetStream: js, - TopicReceiptEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), - TopicSendToDeviceEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), - TopicTypingEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent), - TopicPresenceEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicReceiptEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), UserAPI: userAPI, - ServerName: cfg.Matrix.ServerName, + ServerName: cfg.Global.ServerName, } routing.Setup( - base, + routers, cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, syncProducer, transactionsCache, fsAPI, - extRoomsProvider, mscCfg, natsClient, + extRoomsProvider, natsClient, enableMetrics, ) } diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go new file mode 100644 index 000000000..0be262735 --- /dev/null +++ b/clientapi/clientapi_test.go @@ -0,0 +1,1632 @@ +package clientapi + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/pushrules" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/routing" + "github.com/matrix-org/dendrite/clientapi/threepid" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/version" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +type userDevice struct { + accessToken string + deviceID string + password string +} + +func TestGetPutDevices(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + testCases := []struct { + name string + requestUser *test.User + deviceUser *test.User + request *http.Request + wantStatusCode int + validateFunc func(t *testing.T, device userDevice, routers httputil.Routers) + }{ + { + name: "can get all devices", + requestUser: alice, + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices", strings.NewReader("")), + wantStatusCode: http.StatusOK, + }, + { + name: "can get specific own device", + requestUser: alice, + deviceUser: alice, + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + }, + { + name: "can not get device for different user", + requestUser: alice, + deviceUser: bob, + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices/", strings.NewReader("")), + wantStatusCode: http.StatusNotFound, + }, + { + name: "can update own device", + requestUser: alice, + deviceUser: alice, + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/devices/", strings.NewReader(`{"display_name":"my new displayname"}`)), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, device userDevice, routers httputil.Routers) { + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices/"+device.deviceID, strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+device.accessToken) + rec := httptest.NewRecorder() + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "display_name").Str + if gotDisplayName != "my new displayname" { + t.Fatalf("expected displayname '%s', got '%s'", "my new displayname", gotDisplayName) + } + }, + }, + { + // this should return "device does not exist" + name: "can not update device for different user", + requestUser: alice, + deviceUser: bob, + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/devices/", strings.NewReader(`{"display_name":"my new displayname"}`)), + wantStatusCode: http.StatusNotFound, + }, + } + + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dev := accessTokens[tc.requestUser] + if tc.deviceUser != nil { + tc.request = httptest.NewRequest(tc.request.Method, tc.request.RequestURI+accessTokens[tc.deviceUser].deviceID, tc.request.Body) + } + tc.request.Header.Set("Authorization", "Bearer "+dev.accessToken) + rec := httptest.NewRecorder() + routers.Client.ServeHTTP(rec, tc.request) + if rec.Code != tc.wantStatusCode { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + if tc.wantStatusCode != http.StatusOK && rec.Code != http.StatusOK { + return + } + if tc.validateFunc != nil { + tc.validateFunc(t, dev, routers) + } + }) + } + }) +} + +// Deleting devices requires the UIA dance, so do this in a different test +func TestDeleteDevice(t *testing.T) { + alice := test.NewUser(t) + localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + + // create the account and an initial device + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + // create some more devices + accessToken := util.RandomString(8) + devRes := &uapi.PerformDeviceCreationResponse{} + if err := userAPI.PerformDeviceCreation(processCtx.Context(), &uapi.PerformDeviceCreationRequest{ + Localpart: localpart, + ServerName: serverName, + AccessToken: accessToken, + NoDeviceListUpdate: true, + }, devRes); err != nil { + t.Fatal(err) + } + if !devRes.DeviceCreated { + t.Fatalf("failed to create device") + } + secondDeviceID := devRes.Device.ID + + // initiate UIA for the second device + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/devices/"+secondDeviceID, strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected HTTP 401, got %d: %s", rec.Code, rec.Body.String()) + } + // get the session ID + sessionID := gjson.GetBytes(rec.Body.Bytes(), "session").Str + + // prepare UIA request body + reqBody := bytes.Buffer{} + if err := json.NewEncoder(&reqBody).Encode(map[string]interface{}{ + "auth": map[string]string{ + "session": sessionID, + "type": authtypes.LoginTypePassword, + "user": alice.ID, + "password": accessTokens[alice].password, + }, + }); err != nil { + t.Fatal(err) + } + + // copy the request body, so we can use it again for the successful delete + reqBody2 := reqBody + + // do the same request again, this time with our UIA, but for a different device ID, this should fail + rec = httptest.NewRecorder() + + req = httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/devices/"+accessTokens[alice].deviceID, &reqBody) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusForbidden { + t.Fatalf("expected HTTP 403, got %d: %s", rec.Code, rec.Body.String()) + } + + // do the same request again, this time with our UIA, but for the correct device ID, this should be fine + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/devices/"+secondDeviceID, &reqBody2) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + // verify devices are deleted + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + for _, device := range gjson.GetBytes(rec.Body.Bytes(), "devices.#.device_id").Array() { + if device.Str == secondDeviceID { + t.Fatalf("expected device %s to be deleted, but wasn't", secondDeviceID) + } + } + }) +} + +// Deleting devices requires the UIA dance, so do this in a different test +func TestDeleteDevices(t *testing.T) { + alice := test.NewUser(t) + localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + + // create the account and an initial device + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + // create some more devices + var devices []string + for i := 0; i < 10; i++ { + accessToken := util.RandomString(8) + devRes := &uapi.PerformDeviceCreationResponse{} + if err := userAPI.PerformDeviceCreation(processCtx.Context(), &uapi.PerformDeviceCreationRequest{ + Localpart: localpart, + ServerName: serverName, + AccessToken: accessToken, + NoDeviceListUpdate: true, + }, devRes); err != nil { + t.Fatal(err) + } + if !devRes.DeviceCreated { + t.Fatalf("failed to create device") + } + devices = append(devices, devRes.Device.ID) + } + + // initiate UIA + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/delete_devices", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected HTTP 401, got %d: %s", rec.Code, rec.Body.String()) + } + // get the session ID + sessionID := gjson.GetBytes(rec.Body.Bytes(), "session").Str + + // prepare UIA request body + reqBody := bytes.Buffer{} + if err := json.NewEncoder(&reqBody).Encode(map[string]interface{}{ + "auth": map[string]string{ + "session": sessionID, + "type": authtypes.LoginTypePassword, + "user": alice.ID, + "password": accessTokens[alice].password, + }, + "devices": devices[5:], + }); err != nil { + t.Fatal(err) + } + + // do the same request again, this time with our UIA, + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/delete_devices", &reqBody) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + // verify devices are deleted + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + for _, device := range gjson.GetBytes(rec.Body.Bytes(), "devices.#.device_id").Array() { + for _, deletedDevice := range devices[5:] { + if device.Str == deletedDevice { + t.Fatalf("expected device %s to be deleted, but wasn't", deletedDevice) + } + } + } + }) +} + +func createAccessTokens(t *testing.T, accessTokens map[*test.User]userDevice, userAPI uapi.UserInternalAPI, ctx context.Context, routers httputil.Routers) { + t.Helper() + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = userDevice{ + accessToken: gjson.GetBytes(rec.Body.Bytes(), "access_token").String(), + deviceID: gjson.GetBytes(rec.Body.Bytes(), "device_id").String(), + password: password, + } + } +} + +func TestSetDisplayname(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + notLocalUser := &test.User{ID: "@charlie:localhost", Localpart: "charlie"} + changeDisplayName := "my new display name" + + testCases := []struct { + name string + user *test.User + wantOK bool + changeReq io.Reader + wantDisplayName string + }{ + { + name: "invalid user", + user: &test.User{ID: "!notauser"}, + }, + { + name: "non-existent user", + user: &test.User{ID: "@doesnotexist:test"}, + }, + { + name: "non-local user is not allowed", + user: notLocalUser, + }, + { + name: "existing user is allowed to change own name", + user: alice, + wantOK: true, + wantDisplayName: changeDisplayName, + }, + { + name: "existing user is not allowed to change own name if name is empty", + user: bob, + wantOK: false, + wantDisplayName: "", + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil) + asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI) + + AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + bob: {}, + } + + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wantDisplayName := tc.user.Localpart + if tc.changeReq == nil { + tc.changeReq = strings.NewReader("") + } + + // check profile after initial account creation + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID, strings.NewReader("")) + t.Logf("%s", req.URL.String()) + routers.Client.ServeHTTP(rec, req) + + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d", rec.Code) + } + + if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "displayname").Str; tc.wantOK && gotDisplayName != wantDisplayName { + t.Fatalf("expected displayname to be '%s', but got '%s'", wantDisplayName, gotDisplayName) + } + + // now set the new display name + wantDisplayName = tc.wantDisplayName + tc.changeReq = strings.NewReader(fmt.Sprintf(`{"displayname":"%s"}`, tc.wantDisplayName)) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/profile/"+tc.user.ID+"/displayname", tc.changeReq) + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.user].accessToken) + + routers.Client.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + // now only get the display name + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID+"/displayname", strings.NewReader("")) + + routers.Client.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "displayname").Str; tc.wantOK && gotDisplayName != wantDisplayName { + t.Fatalf("expected displayname to be '%s', but got '%s'", wantDisplayName, gotDisplayName) + } + }) + } + }) +} + +func TestSetAvatarURL(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + notLocalUser := &test.User{ID: "@charlie:localhost", Localpart: "charlie"} + changeDisplayName := "mxc://newMXID" + + testCases := []struct { + name string + user *test.User + wantOK bool + changeReq io.Reader + avatar_url string + }{ + { + name: "invalid user", + user: &test.User{ID: "!notauser"}, + }, + { + name: "non-existent user", + user: &test.User{ID: "@doesnotexist:test"}, + }, + { + name: "non-local user is not allowed", + user: notLocalUser, + }, + { + name: "existing user is allowed to change own avatar", + user: alice, + wantOK: true, + avatar_url: changeDisplayName, + }, + { + name: "existing user is not allowed to change own avatar if avatar is empty", + user: bob, + wantOK: false, + avatar_url: "", + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil) + asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI) + + AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + bob: {}, + } + + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wantAvatarURL := "" + if tc.changeReq == nil { + tc.changeReq = strings.NewReader("") + } + + // check profile after initial account creation + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID, strings.NewReader("")) + t.Logf("%s", req.URL.String()) + routers.Client.ServeHTTP(rec, req) + + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d", rec.Code) + } + + if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "avatar_url").Str; tc.wantOK && gotDisplayName != wantAvatarURL { + t.Fatalf("expected displayname to be '%s', but got '%s'", wantAvatarURL, gotDisplayName) + } + + // now set the new display name + wantAvatarURL = tc.avatar_url + tc.changeReq = strings.NewReader(fmt.Sprintf(`{"avatar_url":"%s"}`, tc.avatar_url)) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/profile/"+tc.user.ID+"/avatar_url", tc.changeReq) + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.user].accessToken) + + routers.Client.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + // now only get the display name + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID+"/avatar_url", strings.NewReader("")) + + routers.Client.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "avatar_url").Str; tc.wantOK && gotDisplayName != wantAvatarURL { + t.Fatalf("expected displayname to be '%s', but got '%s'", wantAvatarURL, gotDisplayName) + } + }) + } + }) +} + +func TestTyping(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + // Needed to create accounts + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + typingForUser string + roomID string + requestBody io.Reader + wantOK bool + }{ + { + name: "can not set typing for different user", + typingForUser: "@notourself:test", + roomID: room.ID, + requestBody: strings.NewReader(""), + }, + { + name: "invalid request body", + typingForUser: alice.ID, + roomID: room.ID, + requestBody: strings.NewReader(""), + }, + { + name: "non-existent room", + typingForUser: alice.ID, + roomID: "!doesnotexist:test", + }, + { + name: "invalid room ID", + typingForUser: alice.ID, + roomID: "@notaroomid:test", + }, + { + name: "allowed to set own typing status", + typingForUser: alice.ID, + roomID: room.ID, + requestBody: strings.NewReader(`{"typing":true}`), + wantOK: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/rooms/"+tc.roomID+"/typing/"+tc.typingForUser, tc.requestBody) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestMembership(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RateLimiting.Enabled = false + defer close() + natsInstance := jetstream.NATSInstance{} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + // Needed to create accounts + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + rsAPI.SetUserAPI(userAPI) + // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + alice: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatal(err) + } + + invalidBodyRequest := func(roomID, membershipType string) *http.Request { + return httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", roomID, membershipType), strings.NewReader("")) + } + + missingUserIDRequest := func(roomID, membershipType string) *http.Request { + return httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", roomID, membershipType), strings.NewReader("{}")) + } + + testCases := []struct { + name string + roomID string + request *http.Request + wantOK bool + asUser *test.User + }{ + { + name: "ban - invalid request body", + request: invalidBodyRequest(room.ID, "ban"), + }, + { + name: "kick - invalid request body", + request: invalidBodyRequest(room.ID, "kick"), + }, + { + name: "unban - invalid request body", + request: invalidBodyRequest(room.ID, "unban"), + }, + { + name: "invite - invalid request body", + request: invalidBodyRequest(room.ID, "invite"), + }, + { + name: "ban - missing user_id body", + request: missingUserIDRequest(room.ID, "ban"), + }, + { + name: "kick - missing user_id body", + request: missingUserIDRequest(room.ID, "kick"), + }, + { + name: "unban - missing user_id body", + request: missingUserIDRequest(room.ID, "unban"), + }, + { + name: "invite - missing user_id body", + request: missingUserIDRequest(room.ID, "invite"), + }, + { + name: "Bob forgets invalid room", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", "!doesnotexist", "forget"), strings.NewReader("")), + asUser: bob, + }, + { + name: "Alice can not ban Bob in non-existent room", // fails because "not joined" + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", "!doesnotexist:test", "ban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + }, + { + name: "Alice can not kick Bob in non-existent room", // fails because "not joined" + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", "!doesnotexist:test", "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + }, + // the following must run in sequence, as they build up on each other + { + name: "Alice invites Bob", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "invite"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: true, + }, + { + name: "Bob accepts invite", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "join"), strings.NewReader("")), + wantOK: true, + asUser: bob, + }, + { + name: "Alice verifies that Bob is joined", // returns an error if no membership event can be found + request: httptest.NewRequest(http.MethodGet, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s/m.room.member/%s", room.ID, "state", bob.ID), strings.NewReader("")), + wantOK: true, + }, + { + name: "Bob forgets the room but is still a member", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "forget"), strings.NewReader("")), + wantOK: false, // user is still in the room + asUser: bob, + }, + { + name: "Bob can not kick Alice", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, alice.ID))), + wantOK: false, // powerlevel too low + asUser: bob, + }, + { + name: "Bob can not ban Alice", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "ban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, alice.ID))), + wantOK: false, // powerlevel too low + asUser: bob, + }, + { + name: "Alice can kick Bob", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: true, + }, + { + name: "Alice can ban Bob", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "ban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: true, + }, + { + name: "Alice can not kick Bob again", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: false, // can not kick banned/left user + }, + { + name: "Bob can not unban himself", // mostly because of not being a member of the room + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "unban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + asUser: bob, + }, + { + name: "Alice can not invite Bob again", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "invite"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: false, // user still banned + }, + { + name: "Alice can unban Bob", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "unban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: true, + }, + { + name: "Alice can not unban Bob again", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "unban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: false, + }, + { + name: "Alice can invite Bob again", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "invite"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: true, + }, + { + name: "Bob can reject the invite by leaving", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "leave"), strings.NewReader("")), + wantOK: true, + asUser: bob, + }, + { + name: "Bob can forget the room", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "forget"), strings.NewReader("")), + wantOK: true, + asUser: bob, + }, + { + name: "Bob can forget the room again", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "forget"), strings.NewReader("")), + wantOK: true, + asUser: bob, + }, + // END must run in sequence + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.asUser == nil { + tc.asUser = alice + } + rec := httptest.NewRecorder() + tc.request.Header.Set("Authorization", "Bearer "+accessTokens[tc.asUser].accessToken) + routers.Client.ServeHTTP(rec, tc.request) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + if !tc.wantOK && rec.Code == http.StatusOK { + t.Fatalf("expected request to fail, but didn't: %s", rec.Body.String()) + } + t.Logf("%s", rec.Body.String()) + }) + } + }) +} + +func TestCapabilities(t *testing.T) { + alice := test.NewUser(t) + ctx := context.Background() + + // construct the expected result + versionsMap := map[gomatrixserverlib.RoomVersion]string{} + for v, desc := range version.SupportedRoomVersions() { + if desc.Stable { + versionsMap[v] = "stable" + } else { + versionsMap[v] = "unstable" + } + } + + expectedMap := map[string]interface{}{ + "capabilities": map[string]interface{}{ + "m.change_password": map[string]bool{ + "enabled": true, + }, + "m.room_versions": map[string]interface{}{ + "default": version.DefaultRoomVersion(), + "available": versionsMap, + }, + }, + } + + expectedBuf := &bytes.Buffer{} + err := json.NewEncoder(expectedBuf).Encode(expectedMap) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RateLimiting.Enabled = false + defer close() + natsInstance := jetstream.NATSInstance{} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + + // Needed to create accounts + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + testCases := []struct { + name string + request *http.Request + }{ + { + name: "can get capabilities", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/capabilities", strings.NewReader("")), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, tc.request) + assert.Equal(t, http.StatusOK, rec.Code) + assert.ObjectsAreEqual(expectedBuf.Bytes(), rec.Body.Bytes()) + }) + } + }) +} + +func TestTurnserver(t *testing.T) { + alice := test.NewUser(t) + ctx := context.Background() + + cfg, processCtx, close := testrig.CreateConfig(t, test.DBTypeSQLite) + cfg.ClientAPI.RateLimiting.Enabled = false + defer close() + natsInstance := jetstream.NATSInstance{} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + + // Needed to create accounts + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + //rsAPI.SetUserAPI(userAPI) + // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + testCases := []struct { + name string + turnConfig config.TURN + wantEmptyResponse bool + }{ + { + name: "no turn server configured", + wantEmptyResponse: true, + }, + { + name: "servers configured but not userLifeTime", + wantEmptyResponse: true, + turnConfig: config.TURN{URIs: []string{""}}, + }, + { + name: "missing sharedSecret/username/password", + wantEmptyResponse: true, + turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m"}, + }, + { + name: "with shared secret", + turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", SharedSecret: "iAmSecret"}, + }, + { + name: "with username/password secret", + turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", Username: "username", Password: "iAmSecret"}, + }, + { + name: "only username set", + turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", Username: "username"}, + wantEmptyResponse: true, + }, + { + name: "only password set", + turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", Username: "username"}, + wantEmptyResponse: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/voip/turnServer", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + cfg.ClientAPI.TURN = tc.turnConfig + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + + if tc.wantEmptyResponse && rec.Body.String() != "{}" { + t.Fatalf("expected an empty response, but got %s", rec.Body.String()) + } + if !tc.wantEmptyResponse { + assert.NotEqual(t, "{}", rec.Body.String()) + + resp := gomatrix.RespTurnServer{} + err := json.NewDecoder(rec.Body).Decode(&resp) + assert.NoError(t, err) + + duration, _ := time.ParseDuration(tc.turnConfig.UserLifetime) + assert.Equal(t, tc.turnConfig.URIs, resp.URIs) + assert.Equal(t, int(duration.Seconds()), resp.TTL) + if tc.turnConfig.Username != "" && tc.turnConfig.Password != "" { + assert.Equal(t, tc.turnConfig.Username, resp.Username) + assert.Equal(t, tc.turnConfig.Password, resp.Password) + } + } + }) + } +} + +func Test3PID(t *testing.T) { + alice := test.NewUser(t) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RateLimiting.Enabled = false + cfg.FederationAPI.DisableTLSValidation = true // needed to be able to connect to our identityServer below + defer close() + natsInstance := jetstream.NATSInstance{} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + + // Needed to create accounts + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + identityServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.String(), "getValidated3pid"): + resp := threepid.GetValidatedResponse{} + switch r.URL.Query().Get("client_secret") { + case "fail": + resp.ErrCode = "M_SESSION_NOT_VALIDATED" + case "fail2": + resp.ErrCode = "some other error" + case "fail3": + _, _ = w.Write([]byte("{invalidJson")) + return + case "success": + resp.Medium = "email" + case "success2": + resp.Medium = "email" + resp.Address = "somerandom@address.com" + } + _ = json.NewEncoder(w).Encode(resp) + case strings.Contains(r.URL.String(), "requestToken"): + resp := threepid.SID{SID: "randomSID"} + _ = json.NewEncoder(w).Encode(resp) + } + })) + defer identityServer.Close() + + identityServerBase := strings.TrimPrefix(identityServer.URL, "https://") + + testCases := []struct { + name string + request *http.Request + wantOK bool + setTrustedServer bool + wantLen3PIDs int + }{ + { + name: "can get associated threepid info", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/account/3pid", strings.NewReader("")), + wantOK: true, + }, + { + name: "can not set threepid info with invalid JSON", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader("")), + }, + { + name: "can not set threepid info with untrusted server", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader("{}")), + }, + { + name: "can check threepid info with trusted server, but unverified", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"fail"}}`, identityServerBase))), + setTrustedServer: true, + wantOK: false, + }, + { + name: "can check threepid info with trusted server, but fails for some other reason", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"fail2"}}`, identityServerBase))), + setTrustedServer: true, + wantOK: false, + }, + { + name: "can check threepid info with trusted server, but fails because of invalid json", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"fail3"}}`, identityServerBase))), + setTrustedServer: true, + wantOK: false, + }, + { + name: "can save threepid info with trusted server", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"success"}}`, identityServerBase))), + setTrustedServer: true, + wantOK: true, + }, + { + name: "can save threepid info with trusted server using bind=true", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"success2"},"bind":true}`, identityServerBase))), + setTrustedServer: true, + wantOK: true, + }, + { + name: "can get associated threepid info again", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/account/3pid", strings.NewReader("")), + wantOK: true, + wantLen3PIDs: 2, + }, + { + name: "can delete associated threepid info", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/delete", strings.NewReader(`{"medium":"email","address":"somerandom@address.com"}`)), + wantOK: true, + }, + { + name: "can get associated threepid after deleting association", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/account/3pid", strings.NewReader("")), + wantOK: true, + wantLen3PIDs: 1, + }, + { + name: "can not request emailToken with invalid request body", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/email/requestToken", strings.NewReader("")), + }, + { + name: "can not request emailToken for in use address", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/email/requestToken", strings.NewReader(fmt.Sprintf(`{"client_secret":"somesecret","email":"","send_attempt":1,"id_server":"%s"}`, identityServerBase))), + }, + { + name: "can request emailToken", + request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/email/requestToken", strings.NewReader(fmt.Sprintf(`{"client_secret":"somesecret","email":"somerandom@address.com","send_attempt":1,"id_server":"%s"}`, identityServerBase))), + wantOK: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + if tc.setTrustedServer { + cfg.Global.TrustedIDServers = []string{identityServerBase} + } + + rec := httptest.NewRecorder() + tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + + routers.Client.ServeHTTP(rec, tc.request) + t.Logf("Response: %s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + if !tc.wantOK && rec.Code == http.StatusOK { + t.Fatalf("expected request to fail, but didn't: %s", rec.Body.String()) + } + if tc.wantLen3PIDs > 0 { + var resp routing.ThreePIDsResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + if len(resp.ThreePIDs) != tc.wantLen3PIDs { + t.Fatalf("expected %d threepids, got %d", tc.wantLen3PIDs, len(resp.ThreePIDs)) + } + } + }) + } + }) +} + +func TestPushRules(t *testing.T) { + alice := test.NewUser(t) + + // create the default push rules, used when validating responses + localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) + defaultRules, err := json.Marshal(pushRuleSets) + assert.NoError(t, err) + + ruleID1 := "myrule" + ruleID2 := "myrule2" + ruleID3 := "myrule3" + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RateLimiting.Enabled = false + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + testCases := []struct { + name string + request *http.Request + wantStatusCode int + validateFunc func(t *testing.T, respBody *bytes.Buffer) // used when updating rules, otherwise wantStatusCode should be enough + queryAttr map[string]string + }{ + { + name: "can not get rules without trailing slash", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can get default rules", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.Equal(t, defaultRules, respBody.Bytes()) + }, + }, + { + name: "can get rules by scope", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.Equal(t, gjson.GetBytes(defaultRules, "global").Raw, respBody.String()) + }, + }, + { + name: "can not get invalid rules by scope", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get rules for invalid scope and kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/invalid/", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get rules for invalid kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can get rules by scope and kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.Equal(t, gjson.GetBytes(defaultRules, "global.override").Raw, respBody.String()) + }, + }, + { + name: "can get rules by scope and content kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.Equal(t, gjson.GetBytes(defaultRules, "global.content").Raw, respBody.String()) + }, + }, + { + name: "can not get rules by scope and room kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/room/", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get rules by scope and sender kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/sender/", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can get rules by scope and underride kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/underride/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.Equal(t, gjson.GetBytes(defaultRules, "global.underride").Raw, respBody.String()) + }, + }, + { + name: "can not get rules by scope, kind and ID for invalid scope", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/doesnotexist/.m.rule.master", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get rules by scope, kind and ID for invalid kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/doesnotexist/.m.rule.master", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can get rules by scope, kind and ID", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master", strings.NewReader("")), + wantStatusCode: http.StatusOK, + }, + { + name: "can not get rules by scope, kind and ID for invalid ID", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.doesnotexist", strings.NewReader("")), + wantStatusCode: http.StatusNotFound, + }, + { + name: "can not get status for invalid attribute", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/invalid", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get status for invalid kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get enabled status for invalid scope", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get enabled status for invalid rule", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/doesnotexist/enabled", strings.NewReader("")), + wantStatusCode: http.StatusNotFound, + }, + { + name: "can get enabled rules by scope, kind and ID", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.False(t, gjson.GetBytes(respBody.Bytes(), "enabled").Bool(), "expected master rule to be disabled") + }, + }, + { + name: "can get actions scope, kind and ID", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + actions := gjson.GetBytes(respBody.Bytes(), "actions").Array() + // only a basic check + assert.Equal(t, 1, len(actions)) + }, + }, + { + name: "can not set enabled status with invalid JSON", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not set attribute for invalid attribute", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/doesnotexist", strings.NewReader("{}")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not set attribute for invalid scope", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("{}")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not set attribute for invalid kind", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("{}")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not set attribute for invalid rule", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/invalid/enabled", strings.NewReader("{}")), + wantStatusCode: http.StatusNotFound, + }, + { + name: "can set enabled status with valid JSON", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(`{"enabled":true}`)), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + assert.True(t, gjson.GetBytes(rec.Body.Bytes(), "enabled").Bool(), "expected master rule to be enabled: %s", rec.Body.String()) + }, + }, + { + name: "can set actions with valid JSON", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(`{"actions":["dont_notify","notify"]}`)), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + assert.Equal(t, 2, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 2 actions %s", rec.Body.String()) + }, + }, + { + name: "can not create new push rule with invalid JSON", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not create new push rule with invalid rule content", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("{}")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not create new push rule with invalid scope", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can create new push rule with valid rule content", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule/actions", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + assert.Equal(t, 1, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 1 action %s", rec.Body.String()) + }, + }, + { + name: "can not create new push starting with a dot", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/.myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can create new push rule after existing", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)), + queryAttr: map[string]string{ + "after": ruleID1, + }, + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + rules := gjson.ParseBytes(rec.Body.Bytes()) + for i, rule := range rules.Array() { + if rule.Get("rule_id").Str == ruleID1 && i != 0 { + t.Fatalf("expected '%s' to be the first, but wasn't", ruleID1) + } + if rule.Get("rule_id").Str == ruleID2 && i != 1 { + t.Fatalf("expected '%s' to be the second, but wasn't", ruleID2) + } + } + }, + }, + { + name: "can create new push rule before existing", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule3", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)), + queryAttr: map[string]string{ + "before": ruleID1, + }, + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + rules := gjson.ParseBytes(rec.Body.Bytes()) + for i, rule := range rules.Array() { + if rule.Get("rule_id").Str == ruleID3 && i != 0 { + t.Fatalf("expected '%s' to be the first, but wasn't", ruleID3) + } + if rule.Get("rule_id").Str == ruleID1 && i != 1 { + t.Fatalf("expected '%s' to be the second, but wasn't", ruleID1) + } + if rule.Get("rule_id").Str == ruleID2 && i != 2 { + t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1) + } + } + }, + }, + { + name: "can modify existing push rule", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule2/actions", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + actions := gjson.GetBytes(rec.Body.Bytes(), "actions").Array() + // there should only be one action + assert.Equal(t, "dont_notify", actions[0].Str) + }, + }, + { + name: "can move existing push rule to the front", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)), + queryAttr: map[string]string{ + "before": ruleID3, + }, + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + rules := gjson.ParseBytes(rec.Body.Bytes()) + for i, rule := range rules.Array() { + if rule.Get("rule_id").Str == ruleID2 && i != 0 { + t.Fatalf("expected '%s' to be the first, but wasn't", ruleID2) + } + if rule.Get("rule_id").Str == ruleID3 && i != 1 { + t.Fatalf("expected '%s' to be the second, but wasn't", ruleID3) + } + if rule.Get("rule_id").Str == ruleID1 && i != 2 { + t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1) + } + } + }, + }, + { + name: "can not delete push rule with invalid scope", + request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/invalid/content/myrule2", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not delete push rule with invalid kind", + request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/invalid/myrule2", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not delete push rule with non-existent rule", + request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/doesnotexist", strings.NewReader("")), + wantStatusCode: http.StatusNotFound, + }, + { + name: "can delete existing push rule", + request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader("")), + wantStatusCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + + if tc.queryAttr != nil { + params := url.Values{} + for k, v := range tc.queryAttr { + params.Set(k, v) + } + + tc.request = httptest.NewRequest(tc.request.Method, tc.request.URL.String()+"?"+params.Encode(), tc.request.Body) + } + + tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + + routers.Client.ServeHTTP(rec, tc.request) + assert.Equal(t, tc.wantStatusCode, rec.Code, rec.Body.String()) + if tc.validateFunc != nil { + tc.validateFunc(t, rec.Body) + } + t.Logf("%s", rec.Body.String()) + }) + } + }) +} diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index 2dc0c4843..905ecce4d 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -22,6 +22,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -37,13 +38,13 @@ type SyncAPIProducer struct { TopicTypingEvent string TopicPresenceEvent string JetStream nats.JetStreamContext - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName UserAPI userapi.ClientUserAPI } func (p *SyncAPIProducer) SendReceipt( ctx context.Context, - userID, roomID, eventID, receiptType string, timestamp gomatrixserverlib.Timestamp, + userID, roomID, eventID, receiptType string, timestamp spec.Timestamp, ) error { m := &nats.Msg{ Subject: p.TopicReceiptEvent, @@ -154,7 +155,7 @@ func (p *SyncAPIProducer) SendPresence( m.Header.Set("status_msg", *statusMsg) } - m.Header.Set("last_active_ts", strconv.Itoa(int(gomatrixserverlib.AsTimestamp(time.Now())))) + m.Header.Set("last_active_ts", strconv.Itoa(int(spec.AsTimestamp(time.Now())))) _, err := p.JetStream.PublishMsg(m, nats.Context(ctx)) return err diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index a01f6b944..809d486d2 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -22,23 +23,16 @@ import ( "github.com/matrix-org/dendrite/userapi/api" ) -func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - roomID, ok := vars["roomID"] - if !ok { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting room ID."), - } - } res := &roomserverAPI.PerformAdminEvacuateRoomResponse{} if err := rsAPI.PerformAdminEvacuateRoom( req.Context(), &roomserverAPI.PerformAdminEvacuateRoomRequest{ - RoomID: roomID, + RoomID: vars["roomID"], }, res, ); err != nil { @@ -55,18 +49,13 @@ func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *api.Dev } } -func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - userID, ok := vars["userID"] - if !ok { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting user ID."), - } - } + userID := vars["userID"] + _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) @@ -103,13 +92,8 @@ func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device if err != nil { return util.ErrorResponse(err) } - roomID, ok := vars["roomID"] - if !ok { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting room ID."), - } - } + roomID := vars["roomID"] + res := &roomserverAPI.PerformAdminPurgeRoomResponse{} if err := rsAPI.PerformAdminPurgeRoom( context.Background(), @@ -279,7 +263,7 @@ func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.De &roomserverAPI.PerformAdminDownloadStateRequest{ UserID: device.UserID, RoomID: roomID, - ServerName: gomatrixserverlib.ServerName(serverName), + ServerName: spec.ServerName(serverName), }, res, ); err != nil { diff --git a/clientapi/routing/aliases.go b/clientapi/routing/aliases.go index 68d0f4195..90621dafa 100644 --- a/clientapi/routing/aliases.go +++ b/clientapi/routing/aliases.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -31,7 +32,7 @@ func GetAliases( req *http.Request, rsAPI api.ClientRoomserverAPI, device *userapi.Device, roomID string, ) util.JSONResponse { stateTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomHistoryVisibility, + EventType: spec.MRoomHistoryVisibility, StateKey: "", } stateReq := &api.QueryCurrentStateRequest{ @@ -53,7 +54,7 @@ func GetAliases( return util.ErrorResponse(fmt.Errorf("historyVisEvent.HistoryVisibility: %w", err)) } } - if visibility != gomatrixserverlib.WorldReadable { + if visibility != spec.WorldReadable { queryReq := api.QueryMembershipForUserRequest{ RoomID: roomID, UserID: device.UserID, diff --git a/clientapi/routing/auth_fallback_test.go b/clientapi/routing/auth_fallback_test.go index 534581bdd..afeca051b 100644 --- a/clientapi/routing/auth_fallback_test.go +++ b/clientapi/routing/auth_fallback_test.go @@ -10,30 +10,28 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/test/testrig" ) func Test_AuthFallback(t *testing.T) { - base, _, _ := testrig.Base(nil) - defer base.Close() - + cfg := config.Dendrite{} + cfg.Defaults(config.DefaultOpts{Generate: true, SingleDatabase: true}) for _, useHCaptcha := range []bool{false, true} { for _, recaptchaEnabled := range []bool{false, true} { for _, wantErr := range []bool{false, true} { t.Run(fmt.Sprintf("useHCaptcha(%v) - recaptchaEnabled(%v) - wantErr(%v)", useHCaptcha, recaptchaEnabled, wantErr), func(t *testing.T) { // Set the defaults for each test - base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, SingleDatabase: true}) - base.Cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled - base.Cfg.ClientAPI.RecaptchaPublicKey = "pub" - base.Cfg.ClientAPI.RecaptchaPrivateKey = "priv" + cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, SingleDatabase: true}) + cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled + cfg.ClientAPI.RecaptchaPublicKey = "pub" + cfg.ClientAPI.RecaptchaPrivateKey = "priv" if useHCaptcha { - base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = "https://hcaptcha.com/siteverify" - base.Cfg.ClientAPI.RecaptchaApiJsUrl = "https://js.hcaptcha.com/1/api.js" - base.Cfg.ClientAPI.RecaptchaFormField = "h-captcha-response" - base.Cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha" + cfg.ClientAPI.RecaptchaSiteVerifyAPI = "https://hcaptcha.com/siteverify" + cfg.ClientAPI.RecaptchaApiJsUrl = "https://js.hcaptcha.com/1/api.js" + cfg.ClientAPI.RecaptchaFormField = "h-captcha-response" + cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha" } cfgErrs := &config.ConfigErrors{} - base.Cfg.ClientAPI.Verify(cfgErrs) + cfg.ClientAPI.Verify(cfgErrs) if len(*cfgErrs) > 0 { t.Fatalf("(hCaptcha=%v) unexpected config errors: %s", useHCaptcha, cfgErrs.Error()) } @@ -41,7 +39,7 @@ func Test_AuthFallback(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/?session=1337", nil) rec := httptest.NewRecorder() - AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &cfg.ClientAPI) if !recaptchaEnabled { if rec.Code != http.StatusBadRequest { t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest) @@ -50,8 +48,8 @@ func Test_AuthFallback(t *testing.T) { t.Fatalf("unexpected response body: %s", rec.Body.String()) } } else { - if !strings.Contains(rec.Body.String(), base.Cfg.ClientAPI.RecaptchaSitekeyClass) { - t.Fatalf("body does not contain %s: %s", base.Cfg.ClientAPI.RecaptchaSitekeyClass, rec.Body.String()) + if !strings.Contains(rec.Body.String(), cfg.ClientAPI.RecaptchaSitekeyClass) { + t.Fatalf("body does not contain %s: %s", cfg.ClientAPI.RecaptchaSitekeyClass, rec.Body.String()) } } @@ -64,14 +62,14 @@ func Test_AuthFallback(t *testing.T) { })) defer srv.Close() // nolint: errcheck - base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL + cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL // check the result after sending the captcha req = httptest.NewRequest(http.MethodPost, "/?session=1337", nil) req.Form = url.Values{} - req.Form.Add(base.Cfg.ClientAPI.RecaptchaFormField, "someRandomValue") + req.Form.Add(cfg.ClientAPI.RecaptchaFormField, "someRandomValue") rec = httptest.NewRecorder() - AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &cfg.ClientAPI) if recaptchaEnabled { if !wantErr { if rec.Code != http.StatusOK { @@ -105,7 +103,7 @@ func Test_AuthFallback(t *testing.T) { t.Run("unknown fallbacks are handled correctly", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) rec := httptest.NewRecorder() - AuthFallback(rec, req, "DoesNotExist", &base.Cfg.ClientAPI) + AuthFallback(rec, req, "DoesNotExist", &cfg.ClientAPI) if rec.Code != http.StatusNotImplemented { t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusNotImplemented) } @@ -114,7 +112,7 @@ func Test_AuthFallback(t *testing.T) { t.Run("unknown methods are handled correctly", func(t *testing.T) { req := httptest.NewRequest(http.MethodDelete, "/?session=1337", nil) rec := httptest.NewRecorder() - AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &cfg.ClientAPI) if rec.Code != http.StatusMethodNotAllowed { t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusMethodNotAllowed) } @@ -123,7 +121,7 @@ func Test_AuthFallback(t *testing.T) { t.Run("missing session parameter is handled correctly", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &cfg.ClientAPI) if rec.Code != http.StatusBadRequest { t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) } @@ -132,7 +130,7 @@ func Test_AuthFallback(t *testing.T) { t.Run("missing session parameter is handled correctly", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &cfg.ClientAPI) if rec.Code != http.StatusBadRequest { t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) } @@ -141,7 +139,7 @@ func Test_AuthFallback(t *testing.T) { t.Run("missing 'response' is handled correctly", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) rec := httptest.NewRecorder() - AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &cfg.ClientAPI) if rec.Code != http.StatusBadRequest { t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) } diff --git a/clientapi/routing/capabilities.go b/clientapi/routing/capabilities.go index b7d47e916..e6c1a9b8c 100644 --- a/clientapi/routing/capabilities.go +++ b/clientapi/routing/capabilities.go @@ -17,26 +17,21 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - + "github.com/matrix-org/dendrite/roomserver/version" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // GetCapabilities returns information about the server's supported feature set // and other relevant capabilities to an authenticated user. -func GetCapabilities( - req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, -) util.JSONResponse { - roomVersionsQueryReq := roomserverAPI.QueryRoomVersionCapabilitiesRequest{} - roomVersionsQueryRes := roomserverAPI.QueryRoomVersionCapabilitiesResponse{} - if err := rsAPI.QueryRoomVersionCapabilities( - req.Context(), - &roomVersionsQueryReq, - &roomVersionsQueryRes, - ); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryRoomVersionCapabilities failed") - return jsonerror.InternalServerError() +func GetCapabilities() util.JSONResponse { + versionsMap := map[gomatrixserverlib.RoomVersion]string{} + for v, desc := range version.SupportedRoomVersions() { + if desc.Stable { + versionsMap[v] = "stable" + } else { + versionsMap[v] = "unstable" + } } response := map[string]interface{}{ @@ -44,7 +39,10 @@ func GetCapabilities( "m.change_password": map[string]bool{ "enabled": true, }, - "m.room_versions": roomVersionsQueryRes, + "m.room_versions": map[string]interface{}{ + "default": version.DefaultRoomVersion(), + "available": versionsMap, + }, }, } diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index a0d80903d..a07ef2f50 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -26,6 +26,8 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverVersion "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -233,7 +235,7 @@ func createRoom( createContent["room_version"] = roomVersion powerLevelContent := eventutil.InitialPowerLevelsContent(userID) joinRuleContent := gomatrixserverlib.JoinRuleContent{ - JoinRule: gomatrixserverlib.Invite, + JoinRule: spec.Invite, } historyVisibilityContent := gomatrixserverlib.HistoryVisibilityContent{ HistoryVisibility: historyVisibilityShared, @@ -253,40 +255,40 @@ func createRoom( switch r.Preset { case presetPrivateChat: - joinRuleContent.JoinRule = gomatrixserverlib.Invite + joinRuleContent.JoinRule = spec.Invite historyVisibilityContent.HistoryVisibility = historyVisibilityShared case presetTrustedPrivateChat: - joinRuleContent.JoinRule = gomatrixserverlib.Invite + joinRuleContent.JoinRule = spec.Invite historyVisibilityContent.HistoryVisibility = historyVisibilityShared for _, invitee := range r.Invite { powerLevelContent.Users[invitee] = 100 } case presetPublicChat: - joinRuleContent.JoinRule = gomatrixserverlib.Public + joinRuleContent.JoinRule = spec.Public historyVisibilityContent.HistoryVisibility = historyVisibilityShared } createEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomCreate, + Type: spec.MRoomCreate, Content: createContent, } powerLevelEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomPowerLevels, + Type: spec.MRoomPowerLevels, Content: powerLevelContent, } joinRuleEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomJoinRules, + Type: spec.MRoomJoinRules, Content: joinRuleContent, } historyVisibilityEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomHistoryVisibility, + Type: spec.MRoomHistoryVisibility, Content: historyVisibilityContent, } membershipEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomMember, + Type: spec.MRoomMember, StateKey: userID, Content: gomatrixserverlib.MemberContent{ - Membership: gomatrixserverlib.Join, + Membership: spec.Join, DisplayName: profile.DisplayName, AvatarURL: profile.AvatarURL, }, @@ -299,7 +301,7 @@ func createRoom( if r.Name != "" { nameEvent = &fledglingEvent{ - Type: gomatrixserverlib.MRoomName, + Type: spec.MRoomName, Content: eventutil.NameContent{ Name: r.Name, }, @@ -308,7 +310,7 @@ func createRoom( if r.Topic != "" { topicEvent = &fledglingEvent{ - Type: gomatrixserverlib.MRoomTopic, + Type: spec.MRoomTopic, Content: eventutil.TopicContent{ Topic: r.Topic, }, @@ -317,7 +319,7 @@ func createRoom( if r.GuestCanJoin { guestAccessEvent = &fledglingEvent{ - Type: gomatrixserverlib.MRoomGuestAccess, + Type: spec.MRoomGuestAccess, Content: eventutil.GuestAccessContent{ GuestAccess: "can_join", }, @@ -347,7 +349,7 @@ func createRoom( } aliasEvent = &fledglingEvent{ - Type: gomatrixserverlib.MRoomCanonicalAlias, + Type: spec.MRoomCanonicalAlias, Content: eventutil.CanonicalAlias{ Alias: roomAlias, }, @@ -362,25 +364,25 @@ func createRoom( } switch r.InitialState[i].Type { - case gomatrixserverlib.MRoomCreate: + case spec.MRoomCreate: continue - case gomatrixserverlib.MRoomPowerLevels: + case spec.MRoomPowerLevels: powerLevelEvent = r.InitialState[i] - case gomatrixserverlib.MRoomJoinRules: + case spec.MRoomJoinRules: joinRuleEvent = r.InitialState[i] - case gomatrixserverlib.MRoomHistoryVisibility: + case spec.MRoomHistoryVisibility: historyVisibilityEvent = r.InitialState[i] - case gomatrixserverlib.MRoomGuestAccess: + case spec.MRoomGuestAccess: guestAccessEvent = &r.InitialState[i] - case gomatrixserverlib.MRoomName: + case spec.MRoomName: nameEvent = &r.InitialState[i] - case gomatrixserverlib.MRoomTopic: + case spec.MRoomTopic: topicEvent = &r.InitialState[i] default: @@ -448,7 +450,7 @@ func createRoom( builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} } var ev *gomatrixserverlib.Event - ev, err = buildEvent(&builder, userDomain, &authEvents, cfg, evTime, roomVersion) + ev, err = builder.AddAuthEventsAndBuild(userDomain, &authEvents, evTime, roomVersion, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildEvent failed") return jsonerror.InternalServerError() @@ -510,30 +512,30 @@ func createRoom( // If this is a direct message then we should invite the participants. if len(r.Invite) > 0 { // Build some stripped state for the invite. - var globalStrippedState []gomatrixserverlib.InviteV2StrippedState + var globalStrippedState []fclient.InviteV2StrippedState for _, event := range builtEvents { // Chosen events from the spec: // https://spec.matrix.org/v1.3/client-server-api/#stripped-state switch event.Type() { - case gomatrixserverlib.MRoomCreate: + case spec.MRoomCreate: fallthrough - case gomatrixserverlib.MRoomName: + case spec.MRoomName: fallthrough - case gomatrixserverlib.MRoomAvatar: + case spec.MRoomAvatar: fallthrough - case gomatrixserverlib.MRoomTopic: + case spec.MRoomTopic: fallthrough - case gomatrixserverlib.MRoomCanonicalAlias: + case spec.MRoomCanonicalAlias: fallthrough - case gomatrixserverlib.MRoomEncryption: + case spec.MRoomEncryption: fallthrough - case gomatrixserverlib.MRoomMember: + case spec.MRoomMember: fallthrough - case gomatrixserverlib.MRoomJoinRules: + case spec.MRoomJoinRules: ev := event.Event globalStrippedState = append( globalStrippedState, - gomatrixserverlib.NewInviteV2StrippedState(ev), + fclient.NewInviteV2StrippedState(ev), ) } } @@ -542,7 +544,7 @@ func createRoom( for _, invitee := range r.Invite { // Build the invite event. inviteEvent, err := buildMembershipEvent( - ctx, invitee, "", profileAPI, device, gomatrixserverlib.Invite, + ctx, invitee, "", profileAPI, device, spec.Invite, roomID, r.IsDirect, cfg, evTime, rsAPI, asAPI, ) if err != nil { @@ -551,7 +553,7 @@ func createRoom( } inviteStrippedState := append( globalStrippedState, - gomatrixserverlib.NewInviteV2StrippedState(inviteEvent.Event), + fclient.NewInviteV2StrippedState(inviteEvent.Event), ) // Send the invite event to the roomserver. var inviteRes roomserverAPI.PerformInviteResponse @@ -599,31 +601,3 @@ func createRoom( JSON: response, } } - -// buildEvent fills out auth_events for the builder then builds the event -func buildEvent( - builder *gomatrixserverlib.EventBuilder, - serverName gomatrixserverlib.ServerName, - provider gomatrixserverlib.AuthEventProvider, - cfg *config.ClientAPI, - evTime time.Time, - roomVersion gomatrixserverlib.RoomVersion, -) (*gomatrixserverlib.Event, error) { - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) - if err != nil { - return nil, err - } - refs, err := eventsNeeded.AuthEventReferences(provider) - if err != nil { - return nil, err - } - builder.AuthEvents = refs - event, err := builder.Build( - evTime, serverName, cfg.Matrix.KeyID, - cfg.Matrix.PrivateKey, roomVersion, - ) - if err != nil { - return nil, fmt.Errorf("cannot build event %s : Builder failed to build. %w", builder.Type, err) - } - return event, nil -} diff --git a/clientapi/routing/deactivate.go b/clientapi/routing/deactivate.go index f213db7f3..3f4f539f6 100644 --- a/clientapi/routing/deactivate.go +++ b/clientapi/routing/deactivate.go @@ -33,7 +33,7 @@ func Deactivate( return *errRes } - localpart, _, err := gomatrixserverlib.SplitID('@', login.Username()) + localpart, serverName, err := gomatrixserverlib.SplitID('@', login.Username()) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() @@ -41,7 +41,8 @@ func Deactivate( var res api.PerformAccountDeactivationResponse err = accountAPI.PerformAccountDeactivation(ctx, &api.PerformAccountDeactivationRequest{ - Localpart: localpart, + Localpart: localpart, + ServerName: serverName, }, &res) if err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.PerformAccountDeactivation failed") diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index e3a02661c..331bacc3c 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -15,6 +15,7 @@ package routing import ( + "encoding/json" "io" "net" "net/http" @@ -146,12 +147,6 @@ func UpdateDeviceByID( JSON: jsonerror.Forbidden("device does not exist"), } } - if performRes.Forbidden { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("device not owned by current user"), - } - } return util.JSONResponse{ Code: http.StatusOK, @@ -189,7 +184,7 @@ func DeleteDeviceById( if dev != deviceID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("session & device mismatch"), + JSON: jsonerror.Forbidden("session and device mismatch"), } } } @@ -242,16 +237,37 @@ func DeleteDeviceById( // DeleteDevices handles POST requests to /delete_devices func DeleteDevices( - req *http.Request, userAPI api.ClientUserAPI, device *api.Device, + req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.ClientUserAPI, device *api.Device, ) util.JSONResponse { ctx := req.Context() - payload := devicesDeleteJSON{} - if resErr := httputil.UnmarshalJSONRequest(req, &payload); resErr != nil { - return *resErr + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), + } + } + defer req.Body.Close() // nolint:errcheck + + // initiate UIA + login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device) + if errRes != nil { + return *errRes } - defer req.Body.Close() // nolint: errcheck + if login.Username() != device.UserID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("unable to delete devices for other user"), + } + } + + payload := devicesDeleteJSON{} + if err = json.Unmarshal(bodyBytes, &payload); err != nil { + util.GetLogger(ctx).WithError(err).Error("unable to unmarshal device deletion request") + return jsonerror.InternalServerError() + } var res api.PerformDeviceDeletionResponse if err := userAPI.PerformDeviceDeletion(ctx, &api.PerformDeviceDeletionRequest{ diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index b3c5aae45..e108f3cf8 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -19,6 +19,8 @@ import ( "net/http" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -34,7 +36,7 @@ type roomDirectoryResponse struct { Servers []string `json:"servers"` } -func (r *roomDirectoryResponse) fillServers(servers []gomatrixserverlib.ServerName) { +func (r *roomDirectoryResponse) fillServers(servers []spec.ServerName) { r.Servers = make([]string, len(servers)) for i, s := range servers { r.Servers[i] = string(s) @@ -45,7 +47,7 @@ func (r *roomDirectoryResponse) fillServers(servers []gomatrixserverlib.ServerNa func DirectoryRoom( req *http.Request, roomAlias string, - federation *gomatrixserverlib.FederationClient, + federation *fclient.FederationClient, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, fedSenderAPI federationAPI.ClientFederationAPI, @@ -252,7 +254,7 @@ func GetVisibility( var v roomVisibility if len(res.RoomIDs) == 1 { - v.Visibility = gomatrixserverlib.Public + v.Visibility = spec.Public } else { v.Visibility = "private" } @@ -277,7 +279,7 @@ func SetVisibility( queryEventsReq := roomserverAPI.QueryLatestEventsAndStateRequest{ RoomID: roomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{{ - EventType: gomatrixserverlib.MRoomPowerLevels, + EventType: spec.MRoomPowerLevels, StateKey: "", }}, } @@ -290,7 +292,7 @@ func SetVisibility( // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].Event) - if power.UserLevel(dev.UserID) < power.EventLevel(gomatrixserverlib.MRoomCanonicalAlias, true) { + if power.UserLevel(dev.UserID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("userID doesn't have power level to change visibility"), diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 606744767..0616e79ee 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -23,7 +23,8 @@ import ( "strings" "sync" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/api" @@ -35,7 +36,7 @@ import ( var ( cacheMu sync.Mutex - publicRoomsCache []gomatrixserverlib.PublicRoom + publicRoomsCache []fclient.PublicRoom ) type PublicRoomReq struct { @@ -56,7 +57,7 @@ type filter struct { func GetPostPublicRooms( req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, extRoomsProvider api.ExtraPublicRoomsProvider, - federation *gomatrixserverlib.FederationClient, + federation *fclient.FederationClient, cfg *config.ClientAPI, ) util.JSONResponse { var request PublicRoomReq @@ -71,7 +72,7 @@ func GetPostPublicRooms( } } - serverName := gomatrixserverlib.ServerName(request.Server) + serverName := spec.ServerName(request.Server) if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( req.Context(), cfg.Matrix.ServerName, serverName, @@ -102,10 +103,10 @@ func GetPostPublicRooms( func publicRooms( ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.ClientRoomserverAPI, extRoomsProvider api.ExtraPublicRoomsProvider, -) (*gomatrixserverlib.RespPublicRooms, error) { +) (*fclient.RespPublicRooms, error) { - response := gomatrixserverlib.RespPublicRooms{ - Chunk: []gomatrixserverlib.PublicRoom{}, + response := fclient.RespPublicRooms{ + Chunk: []fclient.PublicRoom{}, } var limit int64 var offset int64 @@ -122,7 +123,7 @@ func publicRooms( } err = nil - var rooms []gomatrixserverlib.PublicRoom + var rooms []fclient.PublicRoom if request.Since == "" { rooms = refreshPublicRoomCache(ctx, rsAPI, extRoomsProvider, request) } else { @@ -146,14 +147,14 @@ func publicRooms( return &response, err } -func filterRooms(rooms []gomatrixserverlib.PublicRoom, searchTerm string) []gomatrixserverlib.PublicRoom { +func filterRooms(rooms []fclient.PublicRoom, searchTerm string) []fclient.PublicRoom { if searchTerm == "" { return rooms } normalizedTerm := strings.ToLower(searchTerm) - result := make([]gomatrixserverlib.PublicRoom, 0) + result := make([]fclient.PublicRoom, 0) for _, room := range rooms { if strings.Contains(strings.ToLower(room.Name), normalizedTerm) || strings.Contains(strings.ToLower(room.Topic), normalizedTerm) || @@ -214,7 +215,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO // limit=3&since=6 => G (prev='3', next='') // // A value of '-1' for prev/next indicates no position. -func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int64) (subset []gomatrixserverlib.PublicRoom, prev, next int) { +func sliceInto(slice []fclient.PublicRoom, since int64, limit int64) (subset []fclient.PublicRoom, prev, next int) { prev = -1 next = -1 @@ -241,10 +242,10 @@ func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int64) ( func refreshPublicRoomCache( ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, extRoomsProvider api.ExtraPublicRoomsProvider, request PublicRoomReq, -) []gomatrixserverlib.PublicRoom { +) []fclient.PublicRoom { cacheMu.Lock() defer cacheMu.Unlock() - var extraRooms []gomatrixserverlib.PublicRoom + var extraRooms []fclient.PublicRoom if extRoomsProvider != nil { extraRooms = extRoomsProvider.Rooms() } @@ -269,7 +270,7 @@ func refreshPublicRoomCache( util.GetLogger(ctx).WithError(err).Error("PopulatePublicRooms failed") return publicRoomsCache } - publicRoomsCache = []gomatrixserverlib.PublicRoom{} + publicRoomsCache = []fclient.PublicRoom{} publicRoomsCache = append(publicRoomsCache, pubRooms...) publicRoomsCache = append(publicRoomsCache, extraRooms...) publicRoomsCache = dedupeAndShuffle(publicRoomsCache) @@ -281,16 +282,16 @@ func refreshPublicRoomCache( return publicRoomsCache } -func getPublicRoomsFromCache() []gomatrixserverlib.PublicRoom { +func getPublicRoomsFromCache() []fclient.PublicRoom { cacheMu.Lock() defer cacheMu.Unlock() return publicRoomsCache } -func dedupeAndShuffle(in []gomatrixserverlib.PublicRoom) []gomatrixserverlib.PublicRoom { +func dedupeAndShuffle(in []fclient.PublicRoom) []fclient.PublicRoom { // de-duplicate rooms with the same room ID. We can join the room via any of these aliases as we know these servers // are alive and well, so we arbitrarily pick one (purposefully shuffling them to spread the load a bit) - var publicRooms []gomatrixserverlib.PublicRoom + var publicRooms []fclient.PublicRoom haveRoomIDs := make(map[string]bool) rand.Shuffle(len(in), func(i, j int) { in[i], in[j] = in[j], in[i] diff --git a/clientapi/routing/directory_public_test.go b/clientapi/routing/directory_public_test.go index 65ad392c2..de2f01b19 100644 --- a/clientapi/routing/directory_public_test.go +++ b/clientapi/routing/directory_public_test.go @@ -4,17 +4,17 @@ import ( "reflect" "testing" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" ) -func pubRoom(name string) gomatrixserverlib.PublicRoom { - return gomatrixserverlib.PublicRoom{ +func pubRoom(name string) fclient.PublicRoom { + return fclient.PublicRoom{ Name: name, } } func TestSliceInto(t *testing.T) { - slice := []gomatrixserverlib.PublicRoom{ + slice := []fclient.PublicRoom{ pubRoom("a"), pubRoom("b"), pubRoom("c"), pubRoom("d"), pubRoom("e"), pubRoom("f"), pubRoom("g"), } limit := int64(3) @@ -22,7 +22,7 @@ func TestSliceInto(t *testing.T) { since int64 wantPrev int wantNext int - wantSubset []gomatrixserverlib.PublicRoom + wantSubset []fclient.PublicRoom }{ { since: 0, diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index e371d9214..d736a9588 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -18,11 +18,12 @@ import ( "net/http" "time" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -48,7 +49,7 @@ func JoinRoomByIDOrAlias( for _, serverName := range serverNames { joinReq.ServerNames = append( joinReq.ServerNames, - gomatrixserverlib.ServerName(serverName), + spec.ServerName(serverName), ) } } @@ -61,21 +62,19 @@ func JoinRoomByIDOrAlias( // Work out our localpart for the client profile request. // Request our profile content to populate the request content with. - res := &api.QueryProfileResponse{} - err := profileAPI.QueryProfile(req.Context(), &api.QueryProfileRequest{UserID: device.UserID}, res) - if err != nil || !res.UserExists { - if !res.UserExists { - util.GetLogger(req.Context()).Error("Unable to query user profile, no profile found.") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Unable to query user profile, no profile found."), - } - } + profile, err := profileAPI.QueryProfile(req.Context(), device.UserID) - util.GetLogger(req.Context()).WithError(err).Error("UserProfileAPI.QueryProfile failed") - } else { - joinReq.Content["displayname"] = res.DisplayName - joinReq.Content["avatar_url"] = res.AvatarURL + switch err { + case nil: + joinReq.Content["displayname"] = profile.DisplayName + joinReq.Content["avatar_url"] = profile.AvatarURL + case appserviceAPI.ErrProfileNotExists: + util.GetLogger(req.Context()).Error("Unable to query user profile, no profile found.") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("Unable to query user profile, no profile found."), + } + default: } // Ask the roomserver to perform the join. diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go index 1450ef4bd..fd58ff5d5 100644 --- a/clientapi/routing/joinroom_test.go +++ b/clientapi/routing/joinroom_test.go @@ -7,6 +7,9 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/appservice" @@ -24,12 +27,15 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, baseClose := testrig.CreateBaseDendrite(t, dbType) - defer baseClose() + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() - rsAPI := roomserver.NewInternalAPI(base) - userAPI := userapi.NewInternalAPI(base, rsAPI, nil) - asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc // Create the users in the userapi @@ -61,7 +67,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { RoomAliasName: "alias", Invite: []string{bob.ID}, GuestCanJoin: false, - }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + }, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) crResp, ok := resp.JSON.(createRoomResponse) if !ok { t.Fatalf("response is not a createRoomResponse: %+v", resp) @@ -76,7 +82,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { Preset: presetPublicChat, Invite: []string{charlie.ID}, GuestCanJoin: true, - }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + }, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) if !ok { t.Fatalf("response is not a createRoomResponse: %+v", resp) diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go index b72db9d8b..bff676826 100644 --- a/clientapi/routing/login_test.go +++ b/clientapi/routing/login_test.go @@ -7,11 +7,17 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/test" @@ -28,20 +34,24 @@ func TestLogin(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, baseClose := testrig.CreateBaseDendrite(t, dbType) - defer baseClose() - base.Cfg.ClientAPI.RateLimiting.Enabled = false + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + cfg.ClientAPI.RateLimiting.Enabled = false + natsInstance := jetstream.NATSInstance{} // add a vhost - base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, }) - rsAPI := roomserver.NewInternalAPI(base) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) // Needed for /login - userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, &base.Cfg.MSCs, nil) + Setup(routers, cfg, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, nil, caching.DisableMetrics) // Create password password := util.RandomString(8) @@ -114,7 +124,7 @@ func TestLogin(t *testing.T) { "password": password, })) rec := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(rec, req) + routers.Client.ServeHTTP(rec, req) if tc.wantOK && rec.Code != http.StatusOK { t.Fatalf("failed to login: %s", rec.Body.String()) } diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 482c1f5f7..13de4d008 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -16,11 +16,13 @@ package routing import ( "context" - "errors" "fmt" "net/http" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -31,76 +33,56 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -var errMissingUserID = errors.New("'user_id' must be supplied") - func SendBan( req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { - body, evTime, roomVer, reqErr := extractRequestData(req, roomID, rsAPI) + body, evTime, reqErr := extractRequestData(req) if reqErr != nil { return *reqErr } + if body.UserID == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("missing user_id"), + } + } + errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) if errRes != nil { return *errRes } - plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }) - if plEvent == nil { - return util.JSONResponse{ - Code: 403, - JSON: jsonerror.Forbidden("You don't have permission to ban this user, no power_levels event in this room."), - } - } - pl, err := plEvent.PowerLevels() - if err != nil { - return util.JSONResponse{ - Code: 403, - JSON: jsonerror.Forbidden("You don't have permission to ban this user, the power_levels event for this room is malformed so auth checks cannot be performed."), - } + pl, errRes := getPowerlevels(req, rsAPI, roomID) + if errRes != nil { + return *errRes } allowedToBan := pl.UserLevel(device.UserID) >= pl.Ban if !allowedToBan { return util.JSONResponse{ - Code: 403, + Code: http.StatusForbidden, JSON: jsonerror.Forbidden("You don't have permission to ban this user, power level too low."), } } - return sendMembership(req.Context(), profileAPI, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) + return sendMembership(req.Context(), profileAPI, device, roomID, spec.Ban, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) } func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, device *userapi.Device, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, - roomVer gomatrixserverlib.RoomVersion, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI) util.JSONResponse { event, err := buildMembershipEvent( ctx, targetUserID, reason, profileAPI, device, membership, roomID, false, cfg, evTime, rsAPI, asAPI, ) - if err == errMissingUserID { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), - } - } else if err == eventutil.ErrRoomNoExists { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound(err.Error()), - } - } else if err != nil { + if err != nil { util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") return jsonerror.InternalServerError() } @@ -109,7 +91,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic if err = roomserverAPI.SendEvents( ctx, rsAPI, roomserverAPI.KindNew, - []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, + []*gomatrixserverlib.HeaderedEvent{event}, device.UserDomain(), serverName, serverName, @@ -131,13 +113,65 @@ func SendKick( roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { - body, evTime, roomVer, reqErr := extractRequestData(req, roomID, rsAPI) + body, evTime, reqErr := extractRequestData(req) if reqErr != nil { return *reqErr } if body.UserID == "" { return util.JSONResponse{ - Code: 400, + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("missing user_id"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + if errRes != nil { + return *errRes + } + + pl, errRes := getPowerlevels(req, rsAPI, roomID) + if errRes != nil { + return *errRes + } + allowedToKick := pl.UserLevel(device.UserID) >= pl.Kick + if !allowedToKick { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("You don't have permission to kick this user, power level too low."), + } + } + + var queryRes roomserverAPI.QueryMembershipForUserResponse + err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ + RoomID: roomID, + UserID: body.UserID, + }, &queryRes) + if err != nil { + return util.ErrorResponse(err) + } + // kick is only valid if the user is not currently banned or left (that is, they are joined or invited) + if queryRes.Membership != spec.Join && queryRes.Membership != spec.Invite { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Unknown("cannot /kick banned or left users"), + } + } + // TODO: should we be using SendLeave instead? + return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) +} + +func SendUnban( + req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, + roomID string, cfg *config.ClientAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, +) util.JSONResponse { + body, evTime, reqErr := extractRequestData(req) + if reqErr != nil { + return *reqErr + } + if body.UserID == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("missing user_id"), } } @@ -155,56 +189,16 @@ func SendKick( if err != nil { return util.ErrorResponse(err) } - // kick is only valid if the user is not currently banned or left (that is, they are joined or invited) - if queryRes.Membership != "join" && queryRes.Membership != "invite" { - return util.JSONResponse{ - Code: 403, - JSON: jsonerror.Unknown("cannot /kick banned or left users"), - } - } - // TODO: should we be using SendLeave instead? - return sendMembership(req.Context(), profileAPI, device, roomID, "leave", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) -} -func SendUnban( - req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, - roomID string, cfg *config.ClientAPI, - rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, -) util.JSONResponse { - body, evTime, roomVer, reqErr := extractRequestData(req, roomID, rsAPI) - if reqErr != nil { - return *reqErr - } - if body.UserID == "" { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON("missing user_id"), - } - } - - var queryRes roomserverAPI.QueryMembershipForUserResponse - err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: body.UserID, - }, &queryRes) - if err != nil { - return util.ErrorResponse(err) - } - if !queryRes.RoomExists { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("room does not exist"), - } - } // unban is only valid if the user is currently banned - if queryRes.Membership != "ban" { + if queryRes.Membership != spec.Ban { return util.JSONResponse{ - Code: 400, + Code: http.StatusBadRequest, JSON: jsonerror.Unknown("can only /unban users that are banned"), } } // TODO: should we be using SendLeave instead? - return sendMembership(req.Context(), profileAPI, device, roomID, "leave", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) + return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) } func SendInvite( @@ -212,7 +206,7 @@ func SendInvite( roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { - body, evTime, _, reqErr := extractRequestData(req, roomID, rsAPI) + body, evTime, reqErr := extractRequestData(req) if reqErr != nil { return *reqErr } @@ -234,6 +228,18 @@ func SendInvite( } } + if body.UserID == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("missing user_id"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + if errRes != nil { + return *errRes + } + // We already received the return value, so no need to check for an error here. response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime) return response @@ -250,20 +256,10 @@ func sendInvite( asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, ) (util.JSONResponse, error) { event, err := buildMembershipEvent( - ctx, userID, reason, profileAPI, device, "invite", + ctx, userID, reason, profileAPI, device, spec.Invite, roomID, false, cfg, evTime, rsAPI, asAPI, ) - if err == errMissingUserID { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), - }, err - } else if err == eventutil.ErrRoomNoExists { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound(err.Error()), - }, err - } else if err != nil { + if err != nil { util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") return jsonerror.InternalServerError(), err } @@ -357,19 +353,7 @@ func loadProfile( return profile, err } -func extractRequestData(req *http.Request, roomID string, rsAPI roomserverAPI.ClientRoomserverAPI) ( - body *threepid.MembershipRequest, evTime time.Time, roomVer gomatrixserverlib.RoomVersion, resErr *util.JSONResponse, -) { - verReq := roomserverAPI.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := roomserverAPI.QueryRoomVersionForRoomResponse{} - if err := rsAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil { - resErr = &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(err.Error()), - } - return - } - roomVer = verRes.RoomVersion +func extractRequestData(req *http.Request) (body *threepid.MembershipRequest, evTime time.Time, resErr *util.JSONResponse) { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { resErr = reqErr @@ -432,34 +416,17 @@ func checkAndProcessThreepid( } func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID, roomID string) *util.JSONResponse { - tuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomMember, - StateKey: userID, - } - var membershipRes roomserverAPI.QueryCurrentStateResponse - err := rsAPI.QueryCurrentState(ctx, &roomserverAPI.QueryCurrentStateRequest{ - RoomID: roomID, - StateTuples: []gomatrixserverlib.StateKeyTuple{tuple}, + var membershipRes roomserverAPI.QueryMembershipForUserResponse + err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{ + RoomID: roomID, + UserID: userID, }, &membershipRes) if err != nil { - util.GetLogger(ctx).WithError(err).Error("QueryCurrentState: could not query membership for user") + util.GetLogger(ctx).WithError(err).Error("QueryMembershipForUser: could not query membership for user") e := jsonerror.InternalServerError() return &e } - ev := membershipRes.StateEvents[tuple] - if ev == nil { - return &util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("user does not belong to room"), - } - } - membership, err := ev.Membership() - if err != nil { - util.GetLogger(ctx).WithError(err).Error("Member event isn't valid") - e := jsonerror.InternalServerError() - return &e - } - if membership != gomatrixserverlib.Join { + if !membershipRes.IsInRoom { return &util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("user does not belong to room"), @@ -511,3 +478,24 @@ func SendForget( JSON: struct{}{}, } } + +func getPowerlevels(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, roomID string) (*gomatrixserverlib.PowerLevelContent, *util.JSONResponse) { + plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ + EventType: spec.MRoomPowerLevels, + StateKey: "", + }) + if plEvent == nil { + return nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("You don't have permission to perform this action, no power_levels event in this room."), + } + } + pl, err := plEvent.PowerLevels() + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("You don't have permission to perform this action, the power_levels event for this room is malformed so auth checks cannot be performed."), + } + } + return pl, nil +} diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index 9b2592eb5..9a8f4378a 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -49,7 +49,7 @@ func PeekRoomByIDOrAlias( for _, serverName := range serverNames { peekReq.ServerNames = append( peekReq.ServerNames, - gomatrixserverlib.ServerName(serverName), + spec.ServerName(serverName), ) } } diff --git a/clientapi/routing/presence.go b/clientapi/routing/presence.go index 093a62464..c50b09434 100644 --- a/clientapi/routing/presence.go +++ b/clientapi/routing/presence.go @@ -27,7 +27,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -123,7 +123,7 @@ func GetPresence( } } - p := types.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(lastActive)} + p := types.PresenceInternal{LastActiveTS: spec.Timestamp(lastActive)} currentlyActive := p.CurrentlyActive() return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 92a75fc78..6d1a08244 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -20,6 +20,8 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -36,14 +38,14 @@ import ( // GetProfile implements GET /profile/{userID} func GetProfile( - req *http.Request, profileAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + req *http.Request, profileAPI userapi.ProfileAPI, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceInternalAPI, - federation *gomatrixserverlib.FederationClient, + federation *fclient.FederationClient, ) util.JSONResponse { profile, err := getProfile(req.Context(), profileAPI, cfg, userID, asAPI, federation) if err != nil { - if err == eventutil.ErrProfileNoExists { + if err == appserviceAPI.ErrProfileNotExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), @@ -56,7 +58,7 @@ func GetProfile( return util.JSONResponse{ Code: http.StatusOK, - JSON: eventutil.ProfileResponse{ + JSON: eventutil.UserProfile{ AvatarURL: profile.AvatarURL, DisplayName: profile.DisplayName, }, @@ -65,34 +67,28 @@ func GetProfile( // GetAvatarURL implements GET /profile/{userID}/avatar_url func GetAvatarURL( - req *http.Request, profileAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + req *http.Request, profileAPI userapi.ProfileAPI, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceInternalAPI, - federation *gomatrixserverlib.FederationClient, + federation *fclient.FederationClient, ) util.JSONResponse { - profile, err := getProfile(req.Context(), profileAPI, cfg, userID, asAPI, federation) - if err != nil { - if err == eventutil.ErrProfileNoExists { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), - } - } - - util.GetLogger(req.Context()).WithError(err).Error("getProfile failed") - return jsonerror.InternalServerError() + profile := GetProfile(req, profileAPI, cfg, userID, asAPI, federation) + p, ok := profile.JSON.(eventutil.UserProfile) + // not a profile response, so most likely an error, return that + if !ok { + return profile } return util.JSONResponse{ Code: http.StatusOK, - JSON: eventutil.AvatarURL{ - AvatarURL: profile.AvatarURL, + JSON: eventutil.UserProfile{ + AvatarURL: p.AvatarURL, }, } } // SetAvatarURL implements PUT /profile/{userID}/avatar_url func SetAvatarURL( - req *http.Request, profileAPI userapi.ClientUserAPI, + req *http.Request, profileAPI userapi.ProfileAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI, ) util.JSONResponse { if userID != device.UserID { @@ -102,7 +98,7 @@ func SetAvatarURL( } } - var r eventutil.AvatarURL + var r eventutil.UserProfile if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr } @@ -134,24 +130,20 @@ func SetAvatarURL( } } - setRes := &userapi.PerformSetAvatarURLResponse{} - if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ - Localpart: localpart, - ServerName: domain, - AvatarURL: r.AvatarURL, - }, setRes); err != nil { + profile, changed, err := profileAPI.SetAvatarURL(req.Context(), localpart, domain, r.AvatarURL) + if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") return jsonerror.InternalServerError() } // No need to build new membership events, since nothing changed - if !setRes.Changed { + if !changed { return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, } } - response, err := updateProfile(req.Context(), rsAPI, device, setRes.Profile, userID, cfg, evTime) + response, err := updateProfile(req.Context(), rsAPI, device, profile, userID, cfg, evTime) if err != nil { return response } @@ -164,34 +156,28 @@ func SetAvatarURL( // GetDisplayName implements GET /profile/{userID}/displayname func GetDisplayName( - req *http.Request, profileAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + req *http.Request, profileAPI userapi.ProfileAPI, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceInternalAPI, - federation *gomatrixserverlib.FederationClient, + federation *fclient.FederationClient, ) util.JSONResponse { - profile, err := getProfile(req.Context(), profileAPI, cfg, userID, asAPI, federation) - if err != nil { - if err == eventutil.ErrProfileNoExists { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), - } - } - - util.GetLogger(req.Context()).WithError(err).Error("getProfile failed") - return jsonerror.InternalServerError() + profile := GetProfile(req, profileAPI, cfg, userID, asAPI, federation) + p, ok := profile.JSON.(eventutil.UserProfile) + // not a profile response, so most likely an error, return that + if !ok { + return profile } return util.JSONResponse{ Code: http.StatusOK, - JSON: eventutil.DisplayName{ - DisplayName: profile.DisplayName, + JSON: eventutil.UserProfile{ + DisplayName: p.DisplayName, }, } } // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( - req *http.Request, profileAPI userapi.ClientUserAPI, + req *http.Request, profileAPI userapi.ProfileAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI, ) util.JSONResponse { if userID != device.UserID { @@ -201,7 +187,7 @@ func SetDisplayName( } } - var r eventutil.DisplayName + var r eventutil.UserProfile if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr } @@ -233,25 +219,20 @@ func SetDisplayName( } } - profileRes := &userapi.PerformUpdateDisplayNameResponse{} - err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ - Localpart: localpart, - ServerName: domain, - DisplayName: r.DisplayName, - }, profileRes) + profile, changed, err := profileAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") return jsonerror.InternalServerError() } // No need to build new membership events, since nothing changed - if !profileRes.Changed { + if !changed { return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, } } - response, err := updateProfile(req.Context(), rsAPI, device, profileRes.Profile, userID, cfg, evTime) + response, err := updateProfile(req.Context(), rsAPI, device, profile, userID, cfg, evTime) if err != nil { return response } @@ -308,12 +289,12 @@ func updateProfile( // getProfile gets the full profile of a user by querying the database or a // remote homeserver. // Returns an error when something goes wrong or specifically -// eventutil.ErrProfileNoExists when the profile doesn't exist. +// eventutil.ErrProfileNotExists when the profile doesn't exist. func getProfile( - ctx context.Context, profileAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + ctx context.Context, profileAPI userapi.ProfileAPI, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceInternalAPI, - federation *gomatrixserverlib.FederationClient, + federation *fclient.FederationClient, ) (*authtypes.Profile, error) { localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { @@ -325,7 +306,7 @@ func getProfile( if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { - return nil, eventutil.ErrProfileNoExists + return nil, appserviceAPI.ErrProfileNotExists } } @@ -371,7 +352,7 @@ func buildMembershipEvents( } content := gomatrixserverlib.MemberContent{ - Membership: gomatrixserverlib.Join, + Membership: spec.Join, } content.DisplayName = newProfile.DisplayName diff --git a/clientapi/routing/pushrules.go b/clientapi/routing/pushrules.go index 856f52c75..f1a539adf 100644 --- a/clientapi/routing/pushrules.go +++ b/clientapi/routing/pushrules.go @@ -31,7 +31,7 @@ func errorResponse(ctx context.Context, err error, msg string, args ...interface } func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRulesJSON failed") } @@ -42,7 +42,7 @@ func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userap } func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRulesJSON failed") } @@ -57,7 +57,7 @@ func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Devi } func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } @@ -66,7 +66,8 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) - if rulesPtr == nil { + // Even if rulesPtr is not nil, there may not be any rules for this kind + if rulesPtr == nil || (rulesPtr != nil && len(*rulesPtr) == 0) { return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") } return util.JSONResponse{ @@ -76,7 +77,7 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi } func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } @@ -101,7 +102,10 @@ func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { var newRule pushrules.Rule if err := json.NewDecoder(body).Decode(&newRule); err != nil { - return errorResponse(ctx, err, "JSON Decode failed") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(err.Error()), + } } newRule.RuleID = ruleID @@ -110,7 +114,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs) } - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } @@ -120,6 +124,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { + // while this should be impossible (ValidateRule would already return an error), better keep it around return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) @@ -144,7 +149,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, } // Add new rule. - i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID) + i, err = findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID) if err != nil { return errorResponse(ctx, err, "findPushRuleInsertionIndex failed") } @@ -153,7 +158,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i) } - if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil { return errorResponse(ctx, err, "putPushRules failed") } @@ -161,7 +166,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, } func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } @@ -180,7 +185,7 @@ func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, dev *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) - if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil { return errorResponse(ctx, err, "putPushRules failed") } @@ -192,7 +197,7 @@ func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri if err != nil { return errorResponse(ctx, err, "pushRuleAttrGetter failed") } - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } @@ -238,7 +243,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri return errorResponse(ctx, err, "pushRuleAttrSetter failed") } - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } @@ -258,7 +263,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) { attrSet((*rulesPtr)[i], &newPartialRule) - if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil { return errorResponse(ctx, err, "putPushRules failed") } } @@ -266,28 +271,6 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} } -func queryPushRules(ctx context.Context, userID string, userAPI userapi.ClientUserAPI) (*pushrules.AccountRuleSets, error) { - var res userapi.QueryPushRulesResponse - if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil { - util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed") - return nil, err - } - return res.RuleSets, nil -} - -func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.ClientUserAPI) error { - req := userapi.PerformPushRulesPutRequest{ - UserID: userID, - RuleSets: ruleSets, - } - var res struct{} - if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil { - util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed") - return err - } - return nil -} - func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet { switch scope { case pushrules.GlobalScope: diff --git a/clientapi/routing/receipt.go b/clientapi/routing/receipt.go index 99217a780..634b60b71 100644 --- a/clientapi/routing/receipt.go +++ b/clientapi/routing/receipt.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -31,7 +31,7 @@ import ( ) func SetReceipt(req *http.Request, userAPI api.ClientUserAPI, syncProducer *producers.SyncAPIProducer, device *userapi.Device, roomID, receiptType, eventID string) util.JSONResponse { - timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + timestamp := spec.AsTimestamp(time.Now()) logrus.WithFields(logrus.Fields{ "roomID": roomID, "receiptType": receiptType, diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index f86bbc8fd..2bdf6705e 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -77,7 +78,7 @@ func SendRedaction( allowedToRedact := ev.Sender() == device.UserID if !allowedToRedact { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, + EventType: spec.MRoomPowerLevels, StateKey: "", }) if plEvent == nil { @@ -114,7 +115,7 @@ func SendRedaction( builder := gomatrixserverlib.EventBuilder{ Sender: device.UserID, RoomID: roomID, - Type: gomatrixserverlib.MRoomRedaction, + Type: spec.MRoomRedaction, Redacts: eventID, } err := builder.SetContent(r) diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index ff6a0900e..8e7f77586 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -37,6 +37,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -206,10 +207,10 @@ var ( // previous parameters with the ones supplied. This mean you cannot "build up" request params. type registerRequest struct { // registration parameters - Password string `json:"password"` - Username string `json:"username"` - ServerName gomatrixserverlib.ServerName `json:"-"` - Admin bool `json:"admin"` + Password string `json:"password"` + Username string `json:"username"` + ServerName spec.ServerName `json:"-"` + Admin bool `json:"admin"` // user-interactive auth params Auth authDict `json:"auth"` @@ -478,7 +479,7 @@ func Register( } var r registerRequest - host := gomatrixserverlib.ServerName(req.Host) + host := spec.ServerName(req.Host) if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil { r.ServerName = v.ServerName } else { @@ -824,7 +825,7 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.ClientUserAPI, - username string, serverName gomatrixserverlib.ServerName, displayName string, + username string, serverName spec.ServerName, displayName string, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, deviceDisplayName, deviceID *string, @@ -888,13 +889,7 @@ func completeRegistration( } if displayName != "" { - nameReq := userapi.PerformUpdateDisplayNameRequest{ - Localpart: username, - ServerName: serverName, - DisplayName: displayName, - } - var nameRes userapi.PerformUpdateDisplayNameResponse - err = userAPI.SetDisplayName(ctx, &nameReq, &nameRes) + _, _, err = userAPI.SetDisplayName(ctx, username, serverName, displayName) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -1000,7 +995,7 @@ func RegisterAvailable( // Squash username to all lowercase letters username = strings.ToLower(username) domain := cfg.Matrix.ServerName - host := gomatrixserverlib.ServerName(req.Host) + host := spec.ServerName(req.Host) if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil { domain = v.ServerName } diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 651e3d3d6..b07f636dd 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -30,8 +30,11 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" @@ -404,11 +407,15 @@ func Test_register(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, baseClose := testrig.CreateBaseDendrite(t, dbType) - defer baseClose() + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() - rsAPI := roomserver.NewInternalAPI(base) - userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -430,16 +437,16 @@ func Test_register(t *testing.T) { } })) defer srv.Close() - base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL + cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL } - if err := base.Cfg.Derive(); err != nil { + if err := cfg.Derive(); err != nil { t.Fatalf("failed to derive config: %s", err) } - base.Cfg.ClientAPI.RecaptchaEnabled = tc.enableRecaptcha - base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled - base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled + cfg.ClientAPI.RecaptchaEnabled = tc.enableRecaptcha + cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled + cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled if tc.kind == "" { tc.kind = "user" @@ -467,15 +474,15 @@ func Test_register(t *testing.T) { req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body) - resp := Register(req, userAPI, &base.Cfg.ClientAPI) + resp := Register(req, userAPI, &cfg.ClientAPI) t.Logf("Resp: %+v", resp) // The first request should return a userInteractiveResponse switch r := resp.JSON.(type) { case userInteractiveResponse: // Check that the flows are the ones we configured - if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) { - t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows) + if !reflect.DeepEqual(r.Flows, cfg.Derived.Registration.Flows) { + t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, cfg.Derived.Registration.Flows) } case *jsonerror.MatrixError: if !reflect.DeepEqual(tc.wantResponse, resp) { @@ -531,7 +538,7 @@ func Test_register(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/", body) - resp = Register(req, userAPI, &base.Cfg.ClientAPI) + resp = Register(req, userAPI, &cfg.ClientAPI) switch resp.JSON.(type) { case *jsonerror.MatrixError: @@ -574,16 +581,19 @@ func Test_register(t *testing.T) { func TestRegisterUserWithDisplayName(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, baseClose := testrig.CreateBaseDendrite(t, dbType) - defer baseClose() - base.Cfg.Global.ServerName = "server" + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + cfg.Global.ServerName = "server" - rsAPI := roomserver.NewInternalAPI(base) - userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) deviceName, deviceID := "deviceName", "deviceID" expectedDisplayName := "DisplayName" response := completeRegistration( - base.Context(), + processCtx.Context(), userAPI, "user", "server", @@ -601,24 +611,25 @@ func TestRegisterUserWithDisplayName(t *testing.T) { assert.Equal(t, http.StatusOK, response.Code) - req := api.QueryProfileRequest{UserID: "@user:server"} - var res api.QueryProfileResponse - err := userAPI.QueryProfile(base.Context(), &req, &res) + profile, err := userAPI.QueryProfile(processCtx.Context(), "@user:server") assert.NoError(t, err) - assert.Equal(t, expectedDisplayName, res.DisplayName) + assert.Equal(t, expectedDisplayName, profile.DisplayName) }) } func TestRegisterAdminUsingSharedSecret(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, baseClose := testrig.CreateBaseDendrite(t, dbType) - defer baseClose() - base.Cfg.Global.ServerName = "server" + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + cfg.Global.ServerName = "server" sharedSecret := "dendritetest" - base.Cfg.ClientAPI.RegistrationSharedSecret = sharedSecret + cfg.ClientAPI.RegistrationSharedSecret = sharedSecret - rsAPI := roomserver.NewInternalAPI(base) - userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) expectedDisplayName := "rabbit" jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) @@ -642,17 +653,15 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) { ssrr := httptest.NewRequest(http.MethodPost, "/", body) response := handleSharedSecretRegistration( - &base.Cfg.ClientAPI, + &cfg.ClientAPI, userAPI, r, ssrr, ) assert.Equal(t, http.StatusOK, response.Code) - profilReq := api.QueryProfileRequest{UserID: "@alice:server"} - var profileRes api.QueryProfileResponse - err = userAPI.QueryProfile(base.Context(), &profilReq, &profileRes) + profile, err := userAPI.QueryProfile(processCtx.Context(), "@alice:server") assert.NoError(t, err) - assert.Equal(t, expectedDisplayName, profileRes.DisplayName) + assert.Equal(t, expectedDisplayName, profile.DisplayName) }) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 028d02e97..9bf9ac80c 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -18,12 +18,12 @@ import ( "context" "net/http" "strings" - "sync" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/setup/base" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/prometheus/client_golang/prometheus" @@ -50,25 +50,27 @@ import ( // applied: // nolint: gocyclo func Setup( - base *base.BaseDendrite, - cfg *config.ClientAPI, + routers httputil.Routers, + dendriteCfg *config.Dendrite, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, userAPI userapi.ClientUserAPI, userDirectoryProvider userapi.QuerySearchProfilesAPI, - federation *gomatrixserverlib.FederationClient, + federation *fclient.FederationClient, syncProducer *producers.SyncAPIProducer, transactionsCache *transactions.Cache, federationSender federationAPI.ClientFederationAPI, extRoomsProvider api.ExtraPublicRoomsProvider, - mscCfg *config.MSCs, natsClient *nats.Conn, + natsClient *nats.Conn, enableMetrics bool, ) { - publicAPIMux := base.PublicClientAPIMux - wkMux := base.PublicWellKnownAPIMux - synapseAdminRouter := base.SynapseAdminMux - dendriteAdminRouter := base.DendriteAdminMux + cfg := &dendriteCfg.ClientAPI + mscCfg := &dendriteCfg.MSCs + publicAPIMux := routers.Client + wkMux := routers.WellKnown + synapseAdminRouter := routers.SynapseAdmin + dendriteAdminRouter := routers.DendriteAdmin - if base.EnableMetrics { + if enableMetrics { prometheus.MustRegister(amtRegUsers, sendEventDuration) } @@ -154,15 +156,15 @@ func Setup( dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminEvacuateRoom(req, cfg, device, rsAPI) + return AdminEvacuateRoom(req, rsAPI) }), - ).Methods(http.MethodGet, http.MethodOptions) + ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}", httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminEvacuateUser(req, cfg, device, rsAPI) + return AdminEvacuateUser(req, cfg, rsAPI) }), - ).Methods(http.MethodGet, http.MethodOptions) + ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -197,18 +199,13 @@ func Setup( // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") - var serverNotificationSender *userapi.Device - var err error - notificationSenderOnce := &sync.Once{} + serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg) + if err != nil { + logrus.WithError(err).Fatal("unable to get account for sending sending server notices") + } synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}", httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - notificationSenderOnce.Do(func() { - serverNotificationSender, err = getSenderDevice(context.Background(), rsAPI, userAPI, cfg) - if err != nil { - logrus.WithError(err).Fatal("unable to get account for sending sending server notices") - } - }) // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r @@ -230,12 +227,6 @@ func Setup( synapseAdminRouter.Handle("/admin/v1/send_server_notice", httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - notificationSenderOnce.Do(func() { - serverNotificationSender, err = getSenderDevice(context.Background(), rsAPI, userAPI, cfg) - if err != nil { - logrus.WithError(err).Fatal("unable to get account for sending sending server notices") - } - }) // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r @@ -266,7 +257,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/join/{roomIDOrAlias}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -282,7 +273,7 @@ func Setup( if mscCfg.Enabled("msc2753") { v3mux.Handle("/peek/{roomIDOrAlias}", - httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -302,7 +293,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/join", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -656,7 +647,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) { + httputil.MakeHTMLAPI("auth_fallback", enableMetrics, func(w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) AuthFallback(w, req, vars["authType"], cfg) }), @@ -860,6 +851,8 @@ func Setup( // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method + threePIDClient := base.CreateClient(dendriteCfg, nil) // TODO: Move this somewhere else, e.g. pass in as parameter + v3mux.Handle("/account/3pid", httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetAssociated3PIDs(req, userAPI, device) @@ -868,11 +861,11 @@ func Setup( v3mux.Handle("/account/3pid", httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return CheckAndSave3PIDAssociation(req, userAPI, device, cfg) + return CheckAndSave3PIDAssociation(req, userAPI, device, cfg, threePIDClient) }), ).Methods(http.MethodPost, http.MethodOptions) - unstableMux.Handle("/account/3pid/delete", + v3mux.Handle("/account/3pid/delete", httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Forget3PID(req, userAPI) }), @@ -880,7 +873,7 @@ func Setup( v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { - return RequestEmailToken(req, userAPI, cfg) + return RequestEmailToken(req, userAPI, cfg, threePIDClient) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -1114,7 +1107,7 @@ func Setup( v3mux.Handle("/delete_devices", httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return DeleteDevices(req, userAPI, device) + return DeleteDevices(req, userInteractiveAuth, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -1193,7 +1186,7 @@ func Setup( if r := rateLimits.Limit(req, device); r != nil { return *r } - return GetCapabilities(req, rsAPI) + return GetCapabilities() }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) @@ -1405,7 +1398,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 90af9ac4d..b1f8fa039 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -24,6 +24,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" @@ -117,7 +118,7 @@ func SendEvent( // If we're sending a membership update, make sure to strip the authorised // via key if it is present, otherwise other servers won't be able to auth // the event if the room is set to the "restricted" join rule. - if eventType == gomatrixserverlib.MRoomMember { + if eventType == spec.MRoomMember { delete(r, "join_authorised_via_users_server") } @@ -136,7 +137,7 @@ func SendEvent( timeToGenerateEvent := time.Since(startedGeneratingEvent) // validate that the aliases exists - if eventType == gomatrixserverlib.MRoomCanonicalAlias && stateKey != nil && *stateKey == "" { + if eventType == spec.MRoomCanonicalAlias && stateKey != nil && *stateKey == "" { aliasReq := api.AliasEvent{} if err = json.Unmarshal(e.Content(), &aliasReq); err != nil { return util.ErrorResponse(fmt.Errorf("unable to parse alias event: %w", err)) diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 3f92e4227..9dc884d62 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -15,12 +15,13 @@ package routing import ( "net/http" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/util" ) type typingContentJSON struct { diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index fb93d8783..d6191f3b4 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -295,30 +295,28 @@ func getSenderDevice( } // Set the avatarurl for the user - avatarRes := &userapi.PerformSetAvatarURLResponse{} - if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ - Localpart: cfg.Matrix.ServerNotices.LocalPart, - ServerName: cfg.Matrix.ServerName, - AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, - }, avatarRes); err != nil { + profile, avatarChanged, err := userAPI.SetAvatarURL(ctx, + cfg.Matrix.ServerNotices.LocalPart, + cfg.Matrix.ServerName, + cfg.Matrix.ServerNotices.AvatarURL, + ) + if err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") return nil, err } - profile := avatarRes.Profile - // Set the displayname for the user - displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} - if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ - Localpart: cfg.Matrix.ServerNotices.LocalPart, - ServerName: cfg.Matrix.ServerName, - DisplayName: cfg.Matrix.ServerNotices.DisplayName, - }, displayNameRes); err != nil { + _, displayNameChanged, err := userAPI.SetDisplayName(ctx, + cfg.Matrix.ServerNotices.LocalPart, + cfg.Matrix.ServerName, + cfg.Matrix.ServerNotices.DisplayName, + ) + if err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") return nil, err } - if displayNameRes.Changed { + if displayNameChanged { profile.DisplayName = cfg.Matrix.ServerNotices.DisplayName } @@ -334,7 +332,7 @@ func getSenderDevice( // We've got an existing account, return the first device of it if len(deviceRes.Devices) > 0 { // If there were changes to the profile, create a new membership event - if displayNameRes.Changed || avatarRes.Changed { + if displayNameChanged || avatarChanged { _, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, cfg, time.Now()) if err != nil { return nil, err diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 12984c39a..3fb157660 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -22,14 +22,16 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/synctypes" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) type stateEventInStateResp struct { - gomatrixserverlib.ClientEvent + synctypes.ClientEvent PrevContent json.RawMessage `json:"prev_content,omitempty"` ReplacesState string `json:"replaces_state,omitempty"` } @@ -67,7 +69,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a // that marks the room as world-readable. If we don't then we assume that // the room is not world-readable. for _, ev := range stateRes.StateEvents { - if ev.Type() == gomatrixserverlib.MRoomHistoryVisibility { + if ev.Type() == spec.MRoomHistoryVisibility { content := map[string]string{} if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for history visibility failed") @@ -122,7 +124,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a "state_at_event": !wantLatestState, }).Info("Fetching all state") - stateEvents := []gomatrixserverlib.ClientEvent{} + stateEvents := []synctypes.ClientEvent{} if wantLatestState { // If we are happy to use the latest state, either because the user is // still in the room, or because the room is world-readable, then just @@ -131,7 +133,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateRes.StateEvents { stateEvents = append( stateEvents, - gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll), + synctypes.HeaderedToClientEvent(ev, synctypes.FormatAll), ) } } else { @@ -150,7 +152,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateAfterRes.StateEvents { stateEvents = append( stateEvents, - gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll), + synctypes.HeaderedToClientEvent(ev, synctypes.FormatAll), ) } } @@ -184,9 +186,9 @@ func OnIncomingStateTypeRequest( StateKey: stateKey, }, } - if evType != gomatrixserverlib.MRoomHistoryVisibility && stateKey != "" { + if evType != spec.MRoomHistoryVisibility && stateKey != "" { stateToFetch = append(stateToFetch, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomHistoryVisibility, + EventType: spec.MRoomHistoryVisibility, StateKey: "", }) } @@ -207,7 +209,7 @@ func OnIncomingStateTypeRequest( // that marks the room as world-readable. If we don't then we assume that // the room is not world-readable. for _, ev := range stateRes.StateEvents { - if ev.Type() == gomatrixserverlib.MRoomHistoryVisibility { + if ev.Type() == spec.MRoomHistoryVisibility { content := map[string]string{} if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for history visibility failed") @@ -241,7 +243,7 @@ func OnIncomingStateTypeRequest( } // If the user has never been in the room then stop at this point. // We won't tell the user about a room they have never joined. - if !membershipRes.HasBeenInRoom || membershipRes.Membership == gomatrixserverlib.Ban { + if !membershipRes.HasBeenInRoom || membershipRes.Membership == spec.Ban { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), @@ -309,7 +311,7 @@ func OnIncomingStateTypeRequest( } stateEvent := stateEventInStateResp{ - ClientEvent: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll), + ClientEvent: synctypes.HeaderedToClientEvent(event, synctypes.FormatAll), } var res interface{} diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index 971bfcad3..102b1d1cb 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" userdb "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -33,7 +34,7 @@ type reqTokenResponse struct { SID string `json:"sid"` } -type threePIDsResponse struct { +type ThreePIDsResponse struct { ThreePIDs []authtypes.ThreePID `json:"threepids"` } @@ -41,7 +42,7 @@ type threePIDsResponse struct { // // POST /account/3pid/email/requestToken // POST /register/email/requestToken -func RequestEmailToken(req *http.Request, threePIDAPI api.ClientUserAPI, cfg *config.ClientAPI) util.JSONResponse { +func RequestEmailToken(req *http.Request, threePIDAPI api.ClientUserAPI, cfg *config.ClientAPI, client *fclient.Client) util.JSONResponse { var body threepid.EmailAssociationRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr @@ -72,7 +73,7 @@ func RequestEmailToken(req *http.Request, threePIDAPI api.ClientUserAPI, cfg *co } } - resp.SID, err = threepid.CreateSession(req.Context(), body, cfg) + resp.SID, err = threepid.CreateSession(req.Context(), body, cfg, client) if err == threepid.ErrNotTrusted { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -92,7 +93,7 @@ func RequestEmailToken(req *http.Request, threePIDAPI api.ClientUserAPI, cfg *co // CheckAndSave3PIDAssociation implements POST /account/3pid func CheckAndSave3PIDAssociation( req *http.Request, threePIDAPI api.ClientUserAPI, device *api.Device, - cfg *config.ClientAPI, + cfg *config.ClientAPI, client *fclient.Client, ) util.JSONResponse { var body threepid.EmailAssociationCheckRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { @@ -100,7 +101,7 @@ func CheckAndSave3PIDAssociation( } // Check if the association has been validated - verified, address, medium, err := threepid.CheckAssociation(req.Context(), body.Creds, cfg) + verified, address, medium, err := threepid.CheckAssociation(req.Context(), body.Creds, cfg, client) if err == threepid.ErrNotTrusted { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -123,13 +124,8 @@ func CheckAndSave3PIDAssociation( if body.Bind { // Publish the association on the identity server if requested - err = threepid.PublishAssociation(body.Creds, device.UserID, cfg) - if err == threepid.ErrNotTrusted { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.NotTrusted(body.Creds.IDServer), - } - } else if err != nil { + err = threepid.PublishAssociation(req.Context(), body.Creds, device.UserID, cfg, client) + if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepid.PublishAssociation failed") return jsonerror.InternalServerError() } @@ -180,7 +176,7 @@ func GetAssociated3PIDs( return util.JSONResponse{ Code: http.StatusOK, - JSON: threePIDsResponse{res.ThreePIDs}, + JSON: ThreePIDsResponse{res.ThreePIDs}, } } @@ -191,7 +187,10 @@ func Forget3PID(req *http.Request, threepidAPI api.ClientUserAPI) util.JSONRespo return *reqErr } - if err := threepidAPI.PerformForgetThreePID(req.Context(), &api.PerformForgetThreePIDRequest{}, &struct{}{}); err != nil { + if err := threepidAPI.PerformForgetThreePID(req.Context(), &api.PerformForgetThreePIDRequest{ + ThreePID: body.Address, + Medium: body.Medium, + }, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.PerformForgetThreePID failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go index 62af9efa4..0d021c7aa 100644 --- a/clientapi/routing/userdirectory.go +++ b/clientapi/routing/userdirectory.go @@ -26,6 +26,8 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -41,8 +43,8 @@ func SearchUserDirectory( provider userapi.QuerySearchProfilesAPI, searchString string, limit int, - federation *gomatrixserverlib.FederationClient, - localServerName gomatrixserverlib.ServerName, + federation *fclient.FederationClient, + localServerName spec.ServerName, ) util.JSONResponse { if limit < 10 { limit = 10 diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 1f294a032..34fa54152 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // MembershipRequest represents the body of an incoming POST request @@ -209,24 +210,17 @@ func queryIDServerStoreInvite( body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { // Retrieve the sender's profile to get their display name - localpart, serverName, err := gomatrixserverlib.SplitID('@', device.UserID) + _, serverName, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { return nil, err } var profile *authtypes.Profile if cfg.Matrix.IsLocalServerName(serverName) { - res := &userapi.QueryProfileResponse{} - err = userAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: device.UserID}, res) + profile, err = userAPI.QueryProfile(ctx, device.UserID) if err != nil { return nil, err } - profile = &authtypes.Profile{ - Localpart: localpart, - DisplayName: res.DisplayName, - AvatarURL: res.AvatarURL, - } - } else { profile = &authtypes.Profile{} } @@ -285,7 +279,7 @@ func queryIDServerPubKey(ctx context.Context, idServerName string, keyID string) } var pubKeyRes struct { - PublicKey gomatrixserverlib.Base64Bytes `json:"public_key"` + PublicKey spec.Base64Bytes `json:"public_key"` } if resp.StatusCode != http.StatusOK { diff --git a/clientapi/threepid/threepid.go b/clientapi/threepid/threepid.go index 1e64e3034..1fe573b1b 100644 --- a/clientapi/threepid/threepid.go +++ b/clientapi/threepid/threepid.go @@ -25,6 +25,7 @@ import ( "strings" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/fclient" ) // EmailAssociationRequest represents the request defined at https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-register-email-requesttoken @@ -37,7 +38,7 @@ type EmailAssociationRequest struct { // EmailAssociationCheckRequest represents the request defined at https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-account-3pid type EmailAssociationCheckRequest struct { - Creds Credentials `json:"threePidCreds"` + Creds Credentials `json:"three_pid_creds"` Bind bool `json:"bind"` } @@ -48,12 +49,16 @@ type Credentials struct { Secret string `json:"client_secret"` } +type SID struct { + SID string `json:"sid"` +} + // CreateSession creates a session on an identity server. // Returns the session's ID. // Returns an error if there was a problem sending the request or decoding the // response, or if the identity server responded with a non-OK status. func CreateSession( - ctx context.Context, req EmailAssociationRequest, cfg *config.ClientAPI, + ctx context.Context, req EmailAssociationRequest, cfg *config.ClientAPI, client *fclient.Client, ) (string, error) { if err := isTrusted(req.IDServer, cfg); err != nil { return "", err @@ -73,8 +78,7 @@ func CreateSession( } request.Header.Add("Content-Type", "application/x-www-form-urlencoded") - client := http.Client{} - resp, err := client.Do(request.WithContext(ctx)) + resp, err := client.DoHTTPRequest(ctx, request) if err != nil { return "", err } @@ -85,14 +89,20 @@ func CreateSession( } // Extract the SID from the response and return it - var sid struct { - SID string `json:"sid"` - } + var sid SID err = json.NewDecoder(resp.Body).Decode(&sid) return sid.SID, err } +type GetValidatedResponse struct { + Medium string `json:"medium"` + ValidatedAt int64 `json:"validated_at"` + Address string `json:"address"` + ErrCode string `json:"errcode"` + Error string `json:"error"` +} + // CheckAssociation checks the status of an ongoing association validation on an // identity server. // Returns a boolean set to true if the association has been validated, false if not. @@ -102,6 +112,7 @@ func CreateSession( // response, or if the identity server responded with a non-OK status. func CheckAssociation( ctx context.Context, creds Credentials, cfg *config.ClientAPI, + client *fclient.Client, ) (bool, string, string, error) { if err := isTrusted(creds.IDServer, cfg); err != nil { return false, "", "", err @@ -112,19 +123,12 @@ func CheckAssociation( if err != nil { return false, "", "", err } - resp, err := http.DefaultClient.Do(req.WithContext(ctx)) + resp, err := client.DoHTTPRequest(ctx, req) if err != nil { return false, "", "", err } - var respBody struct { - Medium string `json:"medium"` - ValidatedAt int64 `json:"validated_at"` - Address string `json:"address"` - ErrCode string `json:"errcode"` - Error string `json:"error"` - } - + var respBody GetValidatedResponse if err = json.NewDecoder(resp.Body).Decode(&respBody); err != nil { return false, "", "", err } @@ -142,7 +146,7 @@ func CheckAssociation( // identifier and a Matrix ID. // Returns an error if there was a problem sending the request or decoding the // response, or if the identity server responded with a non-OK status. -func PublishAssociation(creds Credentials, userID string, cfg *config.ClientAPI) error { +func PublishAssociation(ctx context.Context, creds Credentials, userID string, cfg *config.ClientAPI, client *fclient.Client) error { if err := isTrusted(creds.IDServer, cfg); err != nil { return err } @@ -160,8 +164,7 @@ func PublishAssociation(creds Credentials, userID string, cfg *config.ClientAPI) } request.Header.Add("Content-Type", "application/x-www-form-urlencoded") - client := http.Client{} - resp, err := client.Do(request) + resp, err := client.DoHTTPRequest(ctx, request) if err != nil { return err } diff --git a/clientapi/userutil/userutil.go b/clientapi/userutil/userutil.go index 9be1e9b31..26237142b 100644 --- a/clientapi/userutil/userutil.go +++ b/clientapi/userutil/userutil.go @@ -19,13 +19,14 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // ParseUsernameParam extracts localpart from usernameParam. // usernameParam can either be a user ID or just the localpart/username. // If serverName is passed, it is verified against the domain obtained from usernameParam (if present) // Returns error in case of invalid usernameParam. -func ParseUsernameParam(usernameParam string, cfg *config.Global) (string, gomatrixserverlib.ServerName, error) { +func ParseUsernameParam(usernameParam string, cfg *config.Global) (string, spec.ServerName, error) { localpart := usernameParam if strings.HasPrefix(usernameParam, "@") { @@ -45,6 +46,6 @@ func ParseUsernameParam(usernameParam string, cfg *config.Global) (string, gomat } // MakeUserID generates user ID from localpart & server name -func MakeUserID(localpart string, server gomatrixserverlib.ServerName) string { +func MakeUserID(localpart string, server spec.ServerName) string { return fmt.Sprintf("@%s:%s", localpart, string(server)) } diff --git a/clientapi/userutil/userutil_test.go b/clientapi/userutil/userutil_test.go index ee6bf8a01..cdda0a88a 100644 --- a/clientapi/userutil/userutil_test.go +++ b/clientapi/userutil/userutil_test.go @@ -16,21 +16,22 @@ import ( "testing" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) var ( - localpart = "somelocalpart" - serverName gomatrixserverlib.ServerName = "someservername" - invalidServerName gomatrixserverlib.ServerName = "invalidservername" - goodUserID = "@" + localpart + ":" + string(serverName) - badUserID = "@bad:user:name@noservername:" + localpart = "somelocalpart" + serverName spec.ServerName = "someservername" + invalidServerName spec.ServerName = "invalidservername" + goodUserID = "@" + localpart + ":" + string(serverName) + badUserID = "@bad:user:name@noservername:" ) // TestGoodUserID checks that correct localpart is returned for a valid user ID. func TestGoodUserID(t *testing.T) { cfg := &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: serverName, }, } @@ -49,7 +50,7 @@ func TestGoodUserID(t *testing.T) { // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. func TestWithLocalpartOnly(t *testing.T) { cfg := &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: serverName, }, } @@ -68,7 +69,7 @@ func TestWithLocalpartOnly(t *testing.T) { // TestIncorrectDomain checks for error when there's server name mismatch. func TestIncorrectDomain(t *testing.T) { cfg := &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: invalidServerName, }, } @@ -83,7 +84,7 @@ func TestIncorrectDomain(t *testing.T) { // TestBadUserID checks that ParseUsernameParam fails for invalid user ID func TestBadUserID(t *testing.T) { cfg := &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: serverName, }, } diff --git a/cmd/dendrite-demo-pinecone/conn/client.go b/cmd/dendrite-demo-pinecone/conn/client.go index a91434f62..4571de157 100644 --- a/cmd/dendrite-demo-pinecone/conn/client.go +++ b/cmd/dendrite-demo-pinecone/conn/client.go @@ -21,8 +21,8 @@ import ( "net/http" "strings" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/fclient" "nhooyr.io/websocket" pineconeRouter "github.com/matrix-org/pinecone/router" @@ -90,18 +90,18 @@ func createTransport(s *pineconeSessions.Sessions) *http.Transport { } func CreateClient( - base *base.BaseDendrite, s *pineconeSessions.Sessions, -) *gomatrixserverlib.Client { - return gomatrixserverlib.NewClient( - gomatrixserverlib.WithTransport(createTransport(s)), + s *pineconeSessions.Sessions, +) *fclient.Client { + return fclient.NewClient( + fclient.WithTransport(createTransport(s)), ) } func CreateFederationClient( - base *base.BaseDendrite, s *pineconeSessions.Sessions, -) *gomatrixserverlib.FederationClient { - return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.SigningIdentities(), - gomatrixserverlib.WithTransport(createTransport(s)), + cfg *config.Dendrite, s *pineconeSessions.Sessions, +) *fclient.FederationClient { + return fclient.NewFederationClient( + cfg.Global.SigningIdentities(), + fclient.WithTransport(createTransport(s)), ) } diff --git a/cmd/dendrite-demo-pinecone/defaults/defaults.go b/cmd/dendrite-demo-pinecone/defaults/defaults.go index c92493137..9da54d5f5 100644 --- a/cmd/dendrite-demo-pinecone/defaults/defaults.go +++ b/cmd/dendrite-demo-pinecone/defaults/defaults.go @@ -14,8 +14,8 @@ package defaults -import "github.com/matrix-org/gomatrixserverlib" +import "github.com/matrix-org/gomatrixserverlib/spec" -var DefaultServerNames = map[gomatrixserverlib.ServerName]struct{}{ +var DefaultServerNames = map[spec.ServerName]struct{}{ "3bf0258d23c60952639cc4c69c71d1508a7d43a0475d9000ff900a1848411ec7": {}, } diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 7706a73f5..18d1dd31f 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -27,9 +27,13 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/monolith" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" pineconeRouter "github.com/matrix-org/pinecone/router" @@ -74,7 +78,7 @@ func main() { cfg = monolith.GenerateDefaultConfig(sk, *instanceDir, *instanceDir, *instanceName) } - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + cfg.Global.ServerName = spec.ServerName(hex.EncodeToString(pk)) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) p2pMonolith := monolith.P2PMonolith{} @@ -87,9 +91,13 @@ func main() { } } + processCtx := process.NewProcessContext() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() + enableMetrics := true enableWebsockets := true - p2pMonolith.SetupDendrite(cfg, *instancePort, *instanceRelayingEnabled, enableMetrics, enableWebsockets) + p2pMonolith.SetupDendrite(processCtx, cfg, cm, routers, *instancePort, *instanceRelayingEnabled, enableMetrics, enableWebsockets) p2pMonolith.StartMonolith() p2pMonolith.WaitForShutdown() diff --git a/cmd/dendrite-demo-pinecone/monolith/monolith.go b/cmd/dendrite-demo-pinecone/monolith/monolith.go index ea8e985c7..397473865 100644 --- a/cmd/dendrite-demo-pinecone/monolith/monolith.go +++ b/cmd/dendrite-demo-pinecone/monolith/monolith.go @@ -23,6 +23,7 @@ import ( "net" "net/http" "path/filepath" + "sync" "time" "github.com/gorilla/mux" @@ -37,7 +38,9 @@ import ( "github.com/matrix-org/dendrite/federationapi" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi" relayAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" @@ -45,9 +48,10 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/userapi" userAPI "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" pineconeConnections "github.com/matrix-org/pinecone/connections" @@ -60,13 +64,13 @@ import ( const SessionProtocol = "matrix" type P2PMonolith struct { - BaseDendrite *base.BaseDendrite Sessions *pineconeSessions.Sessions Multicast *pineconeMulticast.Multicast ConnManager *pineconeConnections.ConnectionManager Router *pineconeRouter.Router EventChannel chan pineconeEvents.Event RelayRetriever relay.RelayServerRetriever + ProcessCtx *process.ProcessContext dendrite setup.Monolith port int @@ -76,6 +80,7 @@ type P2PMonolith struct { listener net.Listener httpListenAddr string stopHandlingEvents chan bool + httpServerMu sync.Mutex } func GenerateDefaultConfig(sk ed25519.PrivateKey, storageDir string, cacheDir string, dbPrefix string) *config.Dendrite { @@ -120,53 +125,52 @@ func (p *P2PMonolith) SetupPinecone(sk ed25519.PrivateKey) { p.ConnManager = pineconeConnections.NewConnectionManager(p.Router, nil) } -func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelaying bool, enableMetrics bool, enableWebsockets bool) { - if enableMetrics { - p.BaseDendrite = base.NewBaseDendrite(cfg) - } else { - p.BaseDendrite = base.NewBaseDendrite(cfg, base.DisableMetrics) - } - p.port = port - p.BaseDendrite.ConfigureAdminEndpoints() +func (p *P2PMonolith) SetupDendrite( + processCtx *process.ProcessContext, cfg *config.Dendrite, cm sqlutil.Connections, routers httputil.Routers, + port int, enableRelaying bool, enableMetrics bool, enableWebsockets bool) { - federation := conn.CreateFederationClient(p.BaseDendrite, p.Sessions) + p.port = port + base.ConfigureAdminEndpoints(processCtx, routers) + + federation := conn.CreateFederationClient(cfg, p.Sessions) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - rsComponent := roomserver.NewInternalAPI(p.BaseDendrite) - rsAPI := rsComponent + caches := caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, enableMetrics) + natsInstance := jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, enableMetrics) fsAPI := federationapi.NewInternalAPI( - p.BaseDendrite, federation, rsAPI, p.BaseDendrite.Caches, keyRing, true, + processCtx, cfg, cm, &natsInstance, federation, rsAPI, caches, keyRing, true, ) - userAPI := userapi.NewInternalAPI(p.BaseDendrite, rsAPI, federation) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation) - asAPI := appservice.NewInternalAPI(p.BaseDendrite, userAPI, rsAPI) + asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) - rsComponent.SetFederationAPI(fsAPI, keyRing) + rsAPI.SetFederationAPI(fsAPI, keyRing) userProvider := users.NewPineconeUserProvider(p.Router, p.Sessions, userAPI, federation) roomProvider := rooms.NewPineconeRoomProvider(p.Router, p.Sessions, fsAPI, federation) - js, _ := p.BaseDendrite.NATS.Prepare(p.BaseDendrite.ProcessContext, &p.BaseDendrite.Cfg.Global.JetStream) + js, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) producer := &producers.SyncAPIProducer{ JetStream: js, - TopicReceiptEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), - TopicSendToDeviceEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), - TopicTypingEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), - TopicPresenceEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), - TopicDeviceListUpdate: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - TopicSigningKeyUpdate: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - Config: &p.BaseDendrite.Cfg.FederationAPI, + TopicReceiptEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &cfg.FederationAPI, UserAPI: userAPI, } - relayAPI := relayapi.NewRelayInternalAPI(p.BaseDendrite, federation, rsAPI, keyRing, producer, enableRelaying) + relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, federation, rsAPI, keyRing, producer, enableRelaying, caches) logrus.Infof("Relaying enabled: %v", relayAPI.RelayingEnabled()) p.dendrite = setup.Monolith{ - Config: p.BaseDendrite.Cfg, - Client: conn.CreateClient(p.BaseDendrite, p.Sessions), + Config: cfg, + Client: conn.CreateClient(p.Sessions), FedClient: federation, KeyRing: keyRing, @@ -178,9 +182,10 @@ func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelayi ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } - p.dendrite.AddAllPublicRoutes(p.BaseDendrite) + p.ProcessCtx = processCtx + p.dendrite.AddAllPublicRoutes(processCtx, cfg, routers, cm, &natsInstance, caches, enableMetrics) - p.setupHttpServers(userProvider, enableWebsockets) + p.setupHttpServers(userProvider, routers, enableWebsockets) } func (p *P2PMonolith) GetFederationAPI() federationAPI.FederationInternalAPI { @@ -202,20 +207,22 @@ func (p *P2PMonolith) StartMonolith() { func (p *P2PMonolith) Stop() { logrus.Info("Stopping monolith") - _ = p.BaseDendrite.Close() + p.ProcessCtx.ShutdownDendrite() p.WaitForShutdown() logrus.Info("Stopped monolith") } func (p *P2PMonolith) WaitForShutdown() { - p.BaseDendrite.WaitForShutdown() + base.WaitForShutdown(p.ProcessCtx) p.closeAllResources() } func (p *P2PMonolith) closeAllResources() { logrus.Info("Closing monolith resources") + p.httpServerMu.Lock() if p.httpServer != nil { _ = p.httpServer.Shutdown(context.Background()) + p.httpServerMu.Unlock() } select { @@ -245,12 +252,12 @@ func (p *P2PMonolith) Addr() string { return p.httpListenAddr } -func (p *P2PMonolith) setupHttpServers(userProvider *users.PineconeUserProvider, enableWebsockets bool) { +func (p *P2PMonolith) setupHttpServers(userProvider *users.PineconeUserProvider, routers httputil.Routers, enableWebsockets bool) { p.httpMux = mux.NewRouter().SkipClean(true).UseEncodedPath() - p.httpMux.PathPrefix(httputil.PublicClientPathPrefix).Handler(p.BaseDendrite.PublicClientAPIMux) - p.httpMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(p.BaseDendrite.PublicMediaAPIMux) - p.httpMux.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(p.BaseDendrite.DendriteAdminMux) - p.httpMux.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(p.BaseDendrite.SynapseAdminMux) + p.httpMux.PathPrefix(httputil.PublicClientPathPrefix).Handler(routers.Client) + p.httpMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(routers.Media) + p.httpMux.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(routers.DendriteAdmin) + p.httpMux.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(routers.SynapseAdmin) if enableWebsockets { wsUpgrader := websocket.Upgrader{ @@ -283,8 +290,8 @@ func (p *P2PMonolith) setupHttpServers(userProvider *users.PineconeUserProvider, p.pineconeMux = mux.NewRouter().SkipClean(true).UseEncodedPath() p.pineconeMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) - p.pineconeMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(p.BaseDendrite.PublicFederationAPIMux) - p.pineconeMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(p.BaseDendrite.PublicMediaAPIMux) + p.pineconeMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(routers.Federation) + p.pineconeMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(routers.Media) pHTTP := p.Sessions.Protocol(SessionProtocol).HTTP() pHTTP.Mux().Handle(users.PublicURL, p.pineconeMux) @@ -294,6 +301,7 @@ func (p *P2PMonolith) setupHttpServers(userProvider *users.PineconeUserProvider, func (p *P2PMonolith) startHTTPServers() { go func() { + p.httpServerMu.Lock() // Build both ends of a HTTP multiplex. p.httpServer = &http.Server{ Addr: ":0", @@ -306,7 +314,7 @@ func (p *P2PMonolith) startHTTPServers() { }, Handler: p.pineconeMux, } - + p.httpServerMu.Unlock() pubkey := p.Router.PublicKey() pubkeyString := hex.EncodeToString(pubkey[:]) logrus.Info("Listening on ", pubkeyString) @@ -339,7 +347,7 @@ func (p *P2PMonolith) startEventHandler() { eLog := logrus.WithField("pinecone", "events") p.RelayRetriever = relay.NewRelayServerRetriever( context.Background(), - gomatrixserverlib.ServerName(p.Router.PublicKey().String()), + spec.ServerName(p.Router.PublicKey().String()), p.dendrite.FederationAPI, p.dendrite.RelayAPI, stopRelayServerSync, @@ -365,10 +373,10 @@ func (p *P2PMonolith) startEventHandler() { // eLog.Info("Broadcast received from: ", e.PeerID) req := &federationAPI.PerformWakeupServersRequest{ - ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + ServerNames: []spec.ServerName{spec.ServerName(e.PeerID)}, } res := &federationAPI.PerformWakeupServersResponse{} - if err := p.dendrite.FederationAPI.PerformWakeupServers(p.BaseDendrite.Context(), req, res); err != nil { + if err := p.dendrite.FederationAPI.PerformWakeupServers(p.ProcessCtx.Context(), req, res); err != nil { eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) } } diff --git a/cmd/dendrite-demo-pinecone/relay/retriever.go b/cmd/dendrite-demo-pinecone/relay/retriever.go index 6b34f6416..3c76ad600 100644 --- a/cmd/dendrite-demo-pinecone/relay/retriever.go +++ b/cmd/dendrite-demo-pinecone/relay/retriever.go @@ -21,7 +21,7 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "go.uber.org/atomic" ) @@ -32,10 +32,10 @@ const ( type RelayServerRetriever struct { ctx context.Context - serverName gomatrixserverlib.ServerName + serverName spec.ServerName federationAPI federationAPI.FederationInternalAPI relayAPI relayServerAPI.RelayInternalAPI - relayServersQueried map[gomatrixserverlib.ServerName]bool + relayServersQueried map[spec.ServerName]bool queriedServersMutex sync.Mutex running atomic.Bool quit chan bool @@ -43,7 +43,7 @@ type RelayServerRetriever struct { func NewRelayServerRetriever( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, federationAPI federationAPI.FederationInternalAPI, relayAPI relayServerAPI.RelayInternalAPI, quit chan bool, @@ -53,14 +53,14 @@ func NewRelayServerRetriever( serverName: serverName, federationAPI: federationAPI, relayAPI: relayAPI, - relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + relayServersQueried: make(map[spec.ServerName]bool), running: *atomic.NewBool(false), quit: quit, } } func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { - request := federationAPI.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(r.serverName)} + request := federationAPI.P2PQueryRelayServersRequest{Server: spec.ServerName(r.serverName)} response := federationAPI.P2PQueryRelayServersResponse{} err := r.federationAPI.P2PQueryRelayServers(r.ctx, &request, &response) if err != nil { @@ -76,13 +76,13 @@ func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { eLog.Infof("Registered relay servers: %v", response.RelayServers) } -func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.ServerName) { +func (r *RelayServerRetriever) SetRelayServers(servers []spec.ServerName) { UpdateNodeRelayServers(r.serverName, servers, r.ctx, r.federationAPI) // Replace list of servers to sync with and mark them all as unsynced. r.queriedServersMutex.Lock() defer r.queriedServersMutex.Unlock() - r.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool) + r.relayServersQueried = make(map[spec.ServerName]bool) for _, server := range servers { r.relayServersQueried[server] = false } @@ -90,10 +90,10 @@ func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.Serve r.StartSync() } -func (r *RelayServerRetriever) GetRelayServers() []gomatrixserverlib.ServerName { +func (r *RelayServerRetriever) GetRelayServers() []spec.ServerName { r.queriedServersMutex.Lock() defer r.queriedServersMutex.Unlock() - relayServers := []gomatrixserverlib.ServerName{} + relayServers := []spec.ServerName{} for server := range r.relayServersQueried { relayServers = append(relayServers, server) } @@ -101,11 +101,11 @@ func (r *RelayServerRetriever) GetRelayServers() []gomatrixserverlib.ServerName return relayServers } -func (r *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { +func (r *RelayServerRetriever) GetQueriedServerStatus() map[spec.ServerName]bool { r.queriedServersMutex.Lock() defer r.queriedServersMutex.Unlock() - result := map[gomatrixserverlib.ServerName]bool{} + result := map[spec.ServerName]bool{} for server, queried := range r.relayServersQueried { result[server] = queried } @@ -128,7 +128,7 @@ func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { t := time.NewTimer(relayServerRetryInterval) for { - relayServersToQuery := []gomatrixserverlib.ServerName{} + relayServersToQuery := []spec.ServerName{} func() { r.queriedServersMutex.Lock() defer r.queriedServersMutex.Unlock() @@ -158,10 +158,10 @@ func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { } } -func (r *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { +func (r *RelayServerRetriever) queryRelayServers(relayServers []spec.ServerName) { logrus.Info("Querying relay servers for any available transactions") for _, server := range relayServers { - userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.serverName), false) + userID, err := spec.NewUserID("@user:"+string(r.serverName), false) if err != nil { return } @@ -187,8 +187,8 @@ func (r *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverli } func UpdateNodeRelayServers( - node gomatrixserverlib.ServerName, - relays []gomatrixserverlib.ServerName, + node spec.ServerName, + relays []spec.ServerName, ctx context.Context, fedAPI federationAPI.FederationInternalAPI, ) { @@ -201,7 +201,7 @@ func UpdateNodeRelayServers( } // Remove old, non-matching relays - var serversToRemove []gomatrixserverlib.ServerName + var serversToRemove []spec.ServerName for _, existingServer := range response.RelayServers { shouldRemove := true for _, newServer := range relays { diff --git a/cmd/dendrite-demo-pinecone/relay/retriever_test.go b/cmd/dendrite-demo-pinecone/relay/retriever_test.go index 6c4c3a529..6ec9aaf5d 100644 --- a/cmd/dendrite-demo-pinecone/relay/retriever_test.go +++ b/cmd/dendrite-demo-pinecone/relay/retriever_test.go @@ -21,13 +21,13 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "gotest.tools/v3/poll" ) -var testRelayServers = []gomatrixserverlib.ServerName{"relay1", "relay2"} +var testRelayServers = []spec.ServerName{"relay1", "relay2"} type FakeFedAPI struct { federationAPI.FederationInternalAPI @@ -48,8 +48,8 @@ type FakeRelayAPI struct { func (r *FakeRelayAPI) PerformRelayServerSync( ctx context.Context, - userID gomatrixserverlib.UserID, - relayServer gomatrixserverlib.ServerName, + userID spec.UserID, + relayServer spec.ServerName, ) error { return nil } diff --git a/cmd/dendrite-demo-pinecone/rooms/rooms.go b/cmd/dendrite-demo-pinecone/rooms/rooms.go index 0ac705cc1..73423ef65 100644 --- a/cmd/dendrite-demo-pinecone/rooms/rooms.go +++ b/cmd/dendrite-demo-pinecone/rooms/rooms.go @@ -21,7 +21,8 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/defaults" "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" pineconeRouter "github.com/matrix-org/pinecone/router" @@ -32,14 +33,14 @@ type PineconeRoomProvider struct { r *pineconeRouter.Router s *pineconeSessions.Sessions fedSender api.FederationInternalAPI - fedClient *gomatrixserverlib.FederationClient + fedClient *fclient.FederationClient } func NewPineconeRoomProvider( r *pineconeRouter.Router, s *pineconeSessions.Sessions, fedSender api.FederationInternalAPI, - fedClient *gomatrixserverlib.FederationClient, + fedClient *fclient.FederationClient, ) *PineconeRoomProvider { p := &PineconeRoomProvider{ r: r, @@ -50,31 +51,31 @@ func NewPineconeRoomProvider( return p } -func (p *PineconeRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { - list := map[gomatrixserverlib.ServerName]struct{}{} +func (p *PineconeRoomProvider) Rooms() []fclient.PublicRoom { + list := map[spec.ServerName]struct{}{} for k := range defaults.DefaultServerNames { list[k] = struct{}{} } for _, k := range p.r.Peers() { - list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} + list[spec.ServerName(k.PublicKey)] = struct{}{} } return bulkFetchPublicRoomsFromServers( context.Background(), p.fedClient, - gomatrixserverlib.ServerName(p.r.PublicKey().String()), list, + spec.ServerName(p.r.PublicKey().String()), list, ) } // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( - ctx context.Context, fedClient *gomatrixserverlib.FederationClient, - origin gomatrixserverlib.ServerName, - homeservers map[gomatrixserverlib.ServerName]struct{}, -) (publicRooms []gomatrixserverlib.PublicRoom) { + ctx context.Context, fedClient *fclient.FederationClient, + origin spec.ServerName, + homeservers map[spec.ServerName]struct{}, +) (publicRooms []fclient.PublicRoom) { limit := 200 // follow pipeline semantics, see https://blog.golang.org/pipelines for more info. // goroutines send rooms to this channel - roomCh := make(chan gomatrixserverlib.PublicRoom, int(limit)) + roomCh := make(chan fclient.PublicRoom, int(limit)) // signalling channel to tell goroutines to stop sending rooms and quit done := make(chan bool) // signalling to say when we can close the room channel @@ -83,7 +84,7 @@ func bulkFetchPublicRoomsFromServers( // concurrently query for public rooms reqctx, reqcancel := context.WithTimeout(ctx, time.Second*5) for hs := range homeservers { - go func(homeserverDomain gomatrixserverlib.ServerName) { + go func(homeserverDomain spec.ServerName) { defer wg.Done() util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") fres, err := fedClient.GetPublicRooms(reqctx, origin, homeserverDomain, int(limit), "", false, "") diff --git a/cmd/dendrite-demo-pinecone/users/users.go b/cmd/dendrite-demo-pinecone/users/users.go index fc66bf299..9c844ab67 100644 --- a/cmd/dendrite-demo-pinecone/users/users.go +++ b/cmd/dendrite-demo-pinecone/users/users.go @@ -27,7 +27,8 @@ import ( clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/defaults" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" pineconeRouter "github.com/matrix-org/pinecone/router" @@ -38,7 +39,7 @@ type PineconeUserProvider struct { r *pineconeRouter.Router s *pineconeSessions.Sessions userAPI userapi.QuerySearchProfilesAPI - fedClient *gomatrixserverlib.FederationClient + fedClient *fclient.FederationClient } const PublicURL = "/_matrix/p2p/profiles" @@ -47,7 +48,7 @@ func NewPineconeUserProvider( r *pineconeRouter.Router, s *pineconeSessions.Sessions, userAPI userapi.QuerySearchProfilesAPI, - fedClient *gomatrixserverlib.FederationClient, + fedClient *fclient.FederationClient, ) *PineconeUserProvider { p := &PineconeUserProvider{ r: r, @@ -79,12 +80,12 @@ func (p *PineconeUserProvider) FederatedUserProfiles(w http.ResponseWriter, r *h } func (p *PineconeUserProvider) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error { - list := map[gomatrixserverlib.ServerName]struct{}{} + list := map[spec.ServerName]struct{}{} for k := range defaults.DefaultServerNames { list[k] = struct{}{} } for _, k := range p.r.Peers() { - list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} + list[spec.ServerName(k.PublicKey)] = struct{}{} } res.Profiles = bulkFetchUserDirectoriesFromServers(context.Background(), req, p.fedClient, list) return nil @@ -94,8 +95,8 @@ func (p *PineconeUserProvider) QuerySearchProfiles(ctx context.Context, req *use // Returns a list of user profiles. func bulkFetchUserDirectoriesFromServers( ctx context.Context, req *userapi.QuerySearchProfilesRequest, - fedClient *gomatrixserverlib.FederationClient, - homeservers map[gomatrixserverlib.ServerName]struct{}, + fedClient *fclient.FederationClient, + homeservers map[spec.ServerName]struct{}, ) (profiles []authtypes.Profile) { jsonBody, err := json.Marshal(req) if err != nil { @@ -114,7 +115,7 @@ func bulkFetchUserDirectoriesFromServers( // concurrently query for public rooms reqctx, reqcancel := context.WithTimeout(ctx, time.Second*5) for hs := range homeservers { - go func(homeserverDomain gomatrixserverlib.ServerName) { + go func(homeserverDomain spec.ServerName) { defer wg.Done() util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for users") diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index d759c6a73..25c1475cb 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -27,7 +27,13 @@ import ( "path/filepath" "time" + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice" @@ -41,7 +47,7 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" + basepkg "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/mscs" "github.com/matrix-org/dendrite/test" @@ -57,6 +63,7 @@ var ( instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)") ) +// nolint: gocyclo func main() { flag.Parse() internal.SetupPprof() @@ -139,40 +146,86 @@ func main() { } } - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + cfg.Global.ServerName = spec.ServerName(hex.EncodeToString(pk)) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - base := base.NewBaseDendrite(cfg) - base.ConfigureAdminEndpoints() - defer base.Close() // nolint: errcheck + configErrors := &config.ConfigErrors{} + cfg.Verify(configErrors) + if len(*configErrors) > 0 { + for _, err := range *configErrors { + logrus.Errorf("Configuration error: %s", err) + } + logrus.Fatalf("Failed to start due to configuration errors") + } + + internal.SetupStdLogging() + internal.SetupHookLogging(cfg.Logging) + internal.SetupPprof() + + logrus.Infof("Dendrite version %s", internal.VersionString()) + + if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled { + logrus.Warn("Open registration is enabled") + } + + closer, err := cfg.SetupTracing() + if err != nil { + logrus.WithError(err).Panicf("failed to start opentracing") + } + defer closer.Close() // nolint: errcheck + + if cfg.Global.Sentry.Enabled { + logrus.Info("Setting up Sentry for debugging...") + err = sentry.Init(sentry.ClientOptions{ + Dsn: cfg.Global.Sentry.DSN, + Environment: cfg.Global.Sentry.Environment, + Debug: true, + ServerName: string(cfg.Global.ServerName), + Release: "dendrite@" + internal.VersionString(), + AttachStacktrace: true, + }) + if err != nil { + logrus.WithError(err).Panic("failed to start Sentry") + } + } + + processCtx := process.NewProcessContext() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() + + basepkg.ConfigureAdminEndpoints(processCtx, routers) + defer func() { + processCtx.ShutdownDendrite() + processCtx.WaitForShutdown() + }() // nolint: errcheck ygg, err := yggconn.Setup(sk, *instanceName, ".", *instancePeer, *instanceListen) if err != nil { panic(err) } - federation := ygg.CreateFederationClient(base) + federation := ygg.CreateFederationClient(cfg) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - rsAPI := roomserver.NewInternalAPI( - base, - ) + caches := caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, caching.EnableMetrics) + natsInstance := jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.EnableMetrics) - userAPI := userapi.NewInternalAPI(base, rsAPI, federation) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation) - asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) fsAPI := federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, keyRing, true, + processCtx, cfg, cm, &natsInstance, federation, rsAPI, caches, keyRing, true, ) rsAPI.SetFederationAPI(fsAPI, keyRing) monolith := setup.Monolith{ - Config: base.Cfg, - Client: ygg.CreateClient(base), + Config: cfg, + Client: ygg.CreateClient(), FedClient: federation, KeyRing: keyRing, @@ -184,21 +237,21 @@ func main() { ygg, fsAPI, federation, ), } - monolith.AddAllPublicRoutes(base) - if err := mscs.Enable(base, &monolith); err != nil { + monolith.AddAllPublicRoutes(processCtx, cfg, routers, cm, &natsInstance, caches, caching.EnableMetrics) + if err := mscs.Enable(cfg, cm, routers, &monolith, caches); err != nil { logrus.WithError(err).Fatalf("Failed to enable MSCs") } httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) - httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) - httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) + httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(routers.Client) + httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(routers.Media) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(routers.DendriteAdmin) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(routers.SynapseAdmin) embed.Embed(httpRouter, *instancePort, "Yggdrasil Demo") yggRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - yggRouter.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) - yggRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + yggRouter.PathPrefix(httputil.PublicFederationPathPrefix).Handler(routers.Federation) + yggRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(routers.Media) // Build both ends of a HTTP multiplex. httpServer := &http.Server{ @@ -232,5 +285,5 @@ func main() { }() // We want to block forever to let the HTTP and HTTPS handler serve the APIs - base.WaitForShutdown() + basepkg.WaitForShutdown(processCtx) } diff --git a/cmd/dendrite-demo-yggdrasil/signing/fetcher.go b/cmd/dendrite-demo-yggdrasil/signing/fetcher.go index bcec0cbec..aeaa2022e 100644 --- a/cmd/dendrite-demo-yggdrasil/signing/fetcher.go +++ b/cmd/dendrite-demo-yggdrasil/signing/fetcher.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const KeyID = "ed25519:dendrite-demo-yggdrasil" @@ -36,7 +37,7 @@ func (f *YggdrasilKeys) KeyRing() *gomatrixserverlib.KeyRing { func (f *YggdrasilKeys) FetchKeys( ctx context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { res := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) for req := range requests { @@ -54,7 +55,7 @@ func (f *YggdrasilKeys) FetchKeys( Key: hexkey, }, ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, - ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(24 * time.Hour * 365)), + ValidUntilTS: spec.AsTimestamp(time.Now().Add(24 * time.Hour * 365)), } } return res, nil diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/client.go b/cmd/dendrite-demo-yggdrasil/yggconn/client.go index 41a9ec123..c25acf2ec 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/client.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/client.go @@ -4,8 +4,8 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/fclient" ) type yggroundtripper struct { @@ -17,9 +17,7 @@ func (y *yggroundtripper) RoundTrip(req *http.Request) (*http.Response, error) { return y.inner.RoundTrip(req) } -func (n *Node) CreateClient( - base *base.BaseDendrite, -) *gomatrixserverlib.Client { +func (n *Node) CreateClient() *fclient.Client { tr := &http.Transport{} tr.RegisterProtocol( "matrix", &yggroundtripper{ @@ -33,14 +31,14 @@ func (n *Node) CreateClient( }, }, ) - return gomatrixserverlib.NewClient( - gomatrixserverlib.WithTransport(tr), + return fclient.NewClient( + fclient.WithTransport(tr), ) } func (n *Node) CreateFederationClient( - base *base.BaseDendrite, -) *gomatrixserverlib.FederationClient { + cfg *config.Dendrite, +) *fclient.FederationClient { tr := &http.Transport{} tr.RegisterProtocol( "matrix", &yggroundtripper{ @@ -54,8 +52,8 @@ func (n *Node) CreateFederationClient( }, }, ) - return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.SigningIdentities(), - gomatrixserverlib.WithTransport(tr), + return fclient.NewFederationClient( + cfg.Global.SigningIdentities(), + fclient.WithTransport(tr), ) } diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/node.go b/cmd/dendrite-demo-yggdrasil/yggconn/node.go index 6df5fa879..26c30e489 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/node.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/node.go @@ -23,7 +23,7 @@ import ( "regexp" "strings" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/neilalexander/utp" "github.com/sirupsen/logrus" @@ -134,14 +134,14 @@ func (n *Node) PeerCount() int { return len(n.core.GetPeers()) } -func (n *Node) KnownNodes() []gomatrixserverlib.ServerName { +func (n *Node) KnownNodes() []spec.ServerName { nodemap := map[string]struct{}{} for _, peer := range n.core.GetPeers() { nodemap[hex.EncodeToString(peer.Key)] = struct{}{} } - var nodes []gomatrixserverlib.ServerName + var nodes []spec.ServerName for node := range nodemap { - nodes = append(nodes, gomatrixserverlib.ServerName(node)) + nodes = append(nodes, spec.ServerName(node)) } return nodes } diff --git a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go index 0de64755e..f838dea0b 100644 --- a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go +++ b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go @@ -21,18 +21,19 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/yggconn" "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) type YggdrasilRoomProvider struct { node *yggconn.Node fedSender api.FederationInternalAPI - fedClient *gomatrixserverlib.FederationClient + fedClient *fclient.FederationClient } func NewYggdrasilRoomProvider( - node *yggconn.Node, fedSender api.FederationInternalAPI, fedClient *gomatrixserverlib.FederationClient, + node *yggconn.Node, fedSender api.FederationInternalAPI, fedClient *fclient.FederationClient, ) *YggdrasilRoomProvider { p := &YggdrasilRoomProvider{ node: node, @@ -42,10 +43,10 @@ func NewYggdrasilRoomProvider( return p } -func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { +func (p *YggdrasilRoomProvider) Rooms() []fclient.PublicRoom { return bulkFetchPublicRoomsFromServers( context.Background(), p.fedClient, - gomatrixserverlib.ServerName(p.node.DerivedServerName()), + spec.ServerName(p.node.DerivedServerName()), p.node.KnownNodes(), ) } @@ -53,14 +54,14 @@ func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( - ctx context.Context, fedClient *gomatrixserverlib.FederationClient, - origin gomatrixserverlib.ServerName, - homeservers []gomatrixserverlib.ServerName, -) (publicRooms []gomatrixserverlib.PublicRoom) { + ctx context.Context, fedClient *fclient.FederationClient, + origin spec.ServerName, + homeservers []spec.ServerName, +) (publicRooms []fclient.PublicRoom) { limit := 200 // follow pipeline semantics, see https://blog.golang.org/pipelines for more info. // goroutines send rooms to this channel - roomCh := make(chan gomatrixserverlib.PublicRoom, int(limit)) + roomCh := make(chan fclient.PublicRoom, int(limit)) // signalling channel to tell goroutines to stop sending rooms and quit done := make(chan bool) // signalling to say when we can close the room channel @@ -68,7 +69,7 @@ func bulkFetchPublicRoomsFromServers( wg.Add(len(homeservers)) // concurrently query for public rooms for _, hs := range homeservers { - go func(homeserverDomain gomatrixserverlib.ServerName) { + go func(homeserverDomain spec.ServerName) { defer wg.Done() util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") fres, err := fedClient.GetPublicRooms(ctx, origin, homeserverDomain, int(limit), "", false, "") diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 174a80a3e..6a0e21799 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -259,10 +259,20 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir func getAndSortVersionsFromGithub(httpClient *http.Client) (semVers []*semver.Version, err error) { u := "https://api.github.com/repos/matrix-org/dendrite/tags" - res, err := httpClient.Get(u) - if err != nil { - return nil, err + + var res *http.Response + for i := 0; i < 3; i++ { + res, err = httpClient.Get(u) + if err != nil { + return nil, err + } + if res.StatusCode == 200 { + break + } + log.Printf("Github API returned HTTP %d, retrying\n", res.StatusCode) + time.Sleep(time.Second * 5) } + if res.StatusCode != 200 { return nil, fmt.Errorf("%s returned HTTP %d", u, res.StatusCode) } diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go index 1ae348cfa..66eb88f87 100644 --- a/cmd/dendrite/main.go +++ b/cmd/dendrite/main.go @@ -16,8 +16,16 @@ package main import ( "flag" - "io/fs" + "time" + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/appservice" @@ -34,8 +42,8 @@ var ( unixSocket = flag.String("unix-socket", "", "EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)", ) - unixSocketPermission = flag.Int("unix-socket-permission", 0755, - "EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server", + unixSocketPermission = flag.String("unix-socket-permission", "755", + "EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server (in chmod format like 755)", ) httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") @@ -59,27 +67,97 @@ func main() { } httpsAddr = https } else { - httpAddr = config.UnixSocketAddress(*unixSocket, fs.FileMode(*unixSocketPermission)) + socket, err := config.UnixSocketAddress(*unixSocket, *unixSocketPermission) + if err != nil { + logrus.WithError(err).Fatalf("Failed to parse unix socket") + } + httpAddr = socket } - options := []basepkg.BaseDendriteOptions{} + configErrors := &config.ConfigErrors{} + cfg.Verify(configErrors) + if len(*configErrors) > 0 { + for _, err := range *configErrors { + logrus.Errorf("Configuration error: %s", err) + } + logrus.Fatalf("Failed to start due to configuration errors") + } + processCtx := process.NewProcessContext() - base := basepkg.NewBaseDendrite(cfg, options...) - defer base.Close() // nolint: errcheck + internal.SetupStdLogging() + internal.SetupHookLogging(cfg.Logging) + internal.SetupPprof() - federation := base.CreateFederationClient() + basepkg.PlatformSanityChecks() - rsAPI := roomserver.NewInternalAPI(base) + logrus.Infof("Dendrite version %s", internal.VersionString()) + if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled { + logrus.Warn("Open registration is enabled") + } + // create DNS cache + var dnsCache *fclient.DNSCache + if cfg.Global.DNSCache.Enabled { + dnsCache = fclient.NewDNSCache( + cfg.Global.DNSCache.CacheSize, + cfg.Global.DNSCache.CacheLifetime, + ) + logrus.Infof( + "DNS cache enabled (size %d, lifetime %s)", + cfg.Global.DNSCache.CacheSize, + cfg.Global.DNSCache.CacheLifetime, + ) + } + + // setup tracing + closer, err := cfg.SetupTracing() + if err != nil { + logrus.WithError(err).Panicf("failed to start opentracing") + } + defer closer.Close() // nolint: errcheck + + // setup sentry + if cfg.Global.Sentry.Enabled { + logrus.Info("Setting up Sentry for debugging...") + err = sentry.Init(sentry.ClientOptions{ + Dsn: cfg.Global.Sentry.DSN, + Environment: cfg.Global.Sentry.Environment, + Debug: true, + ServerName: string(cfg.Global.ServerName), + Release: "dendrite@" + internal.VersionString(), + AttachStacktrace: true, + }) + if err != nil { + logrus.WithError(err).Panic("failed to start Sentry") + } + go func() { + processCtx.ComponentStarted() + <-processCtx.WaitForShutdown() + if !sentry.Flush(time.Second * 5) { + logrus.Warnf("failed to flush all Sentry events!") + } + processCtx.ComponentFinished() + }() + } + + federationClient := basepkg.CreateFederationClient(cfg, dnsCache) + httpClient := basepkg.CreateClient(cfg, dnsCache) + + // prepare required dependencies + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() + + caches := caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, caching.EnableMetrics) + natsInstance := jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.EnableMetrics) fsAPI := federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, nil, false, + processCtx, cfg, cm, &natsInstance, federationClient, rsAPI, caches, nil, false, ) keyRing := fsAPI.KeyRing() - userAPI := userapi.NewInternalAPI(base, rsAPI, federation) - - asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federationClient) + asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) // The underlying roomserver implementation needs to be able to call the fedsender. // This is different to rsAPI which can be the http client which doesn't need this @@ -89,9 +167,9 @@ func main() { rsAPI.SetUserAPI(userAPI) monolith := setup.Monolith{ - Config: base.Cfg, - Client: base.CreateClient(), - FedClient: federation, + Config: cfg, + Client: httpClient, + FedClient: federationClient, KeyRing: keyRing, AppserviceAPI: asAPI, @@ -101,25 +179,25 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, } - monolith.AddAllPublicRoutes(base) + monolith.AddAllPublicRoutes(processCtx, cfg, routers, cm, &natsInstance, caches, caching.EnableMetrics) - if len(base.Cfg.MSCs.MSCs) > 0 { - if err := mscs.Enable(base, &monolith); err != nil { + if len(cfg.MSCs.MSCs) > 0 { + if err := mscs.Enable(cfg, cm, routers, &monolith, caches); err != nil { logrus.WithError(err).Fatalf("Failed to enable MSCs") } } // Expose the matrix APIs directly rather than putting them under a /api path. go func() { - base.SetupAndServeHTTP(httpAddr, nil, nil) + basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpAddr, nil, nil) }() // Handle HTTPS if certificate and key are provided if *unixSocket == "" && *certFile != "" && *keyFile != "" { go func() { - base.SetupAndServeHTTP(httpsAddr, certFile, keyFile) + basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpsAddr, certFile, keyFile) }() } // We want to block forever to let the HTTP and HTTPS handler serve the APIs - base.WaitForShutdown() + basepkg.WaitForShutdown(processCtx) } diff --git a/cmd/furl/main.go b/cmd/furl/main.go index b208ba868..cdfef09f7 100644 --- a/cmd/furl/main.go +++ b/cmd/furl/main.go @@ -13,6 +13,8 @@ import ( "os" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) var requestFrom = flag.String("from", "", "the server name that the request should originate from") @@ -48,9 +50,9 @@ func main() { panic("unexpected key block") } - serverName := gomatrixserverlib.ServerName(*requestFrom) - client := gomatrixserverlib.NewFederationClient( - []*gomatrixserverlib.SigningIdentity{ + serverName := spec.ServerName(*requestFrom) + client := fclient.NewFederationClient( + []*fclient.SigningIdentity{ { ServerName: serverName, KeyID: gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), @@ -82,10 +84,10 @@ func main() { } } - req := gomatrixserverlib.NewFederationRequest( + req := fclient.NewFederationRequest( method, serverName, - gomatrixserverlib.ServerName(u.Host), + spec.ServerName(u.Host), u.RequestURI(), ) @@ -96,7 +98,7 @@ func main() { } if err = req.Sign( - gomatrixserverlib.ServerName(*requestFrom), + spec.ServerName(*requestFrom), gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), privateKey, ); err != nil { diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index b0707be11..cb57ed78f 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -5,11 +5,11 @@ import ( "fmt" "path/filepath" - "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" "gopkg.in/yaml.v2" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/spec" ) func main() { @@ -30,7 +30,7 @@ func main() { SingleDatabase: true, }) if *serverName != "" { - cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName) + cfg.Global.ServerName = spec.ServerName(*serverName) } uri := config.DataSource(*dbURI) if uri.IsSQLite() || uri == "" { diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index a9cc80cb7..09c0e6907 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -10,12 +10,13 @@ import ( "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" ) @@ -40,7 +41,7 @@ func main() { Level: "error", }) cfg.ClientAPI.RegistrationDisabled = true - base := base.NewBaseDendrite(cfg, base.DisableMetrics) + args := flag.Args() fmt.Println("Room version", *roomVersion) @@ -54,8 +55,10 @@ func main() { fmt.Println("Fetching", len(snapshotNIDs), "snapshot NIDs") + processCtx := process.NewProcessContext() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) roomserverDB, err := storage.Open( - base, &cfg.RoomServer.Database, + processCtx.Context(), cm, &cfg.RoomServer.Database, caching.NewRistrettoCache(128*1024*1024, time.Hour, true), ) if err != nil { diff --git a/dendrite-sample.yaml b/dendrite-sample.yaml index 39be8d8b8..6b3ea74f2 100644 --- a/dendrite-sample.yaml +++ b/dendrite-sample.yaml @@ -95,7 +95,7 @@ global: # We use this information to understand how Dendrite is being used in the wild. report_stats: enabled: false - endpoint: https://matrix.org/report-usage-stats/push + endpoint: https://panopticon.matrix.org/push # Server notices allows server admins to send messages to all users on the server. server_notices: diff --git a/docs/FAQ.md b/docs/FAQ.md index ecf43d693..f1a14c613 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -18,7 +18,7 @@ Mostly, although there are still bugs and missing features. If you are a confide ## Is Dendrite feature-complete? -No, although a good portion of the Matrix specification has been implemented. Mostly missing are client features - see the [readme](../README.md) at the root of the repository for more information. +No, although a good portion of the Matrix specification has been implemented. Mostly missing are client features - see the [readme](https://github.com/matrix-org/dendrite/blob/main/README.md) at the root of the repository for more information. ## Why doesn't Dendrite have "x" yet? diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index a61786c1d..0901965c1 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -14,7 +14,7 @@ GEM execjs coffee-script-source (1.11.1) colorator (1.1.0) - commonmarker (0.23.7) + commonmarker (0.23.9) concurrent-ruby (1.2.0) dnsruby (1.61.9) simpleidn (~> 0.1) @@ -231,9 +231,9 @@ GEM jekyll-seo-tag (~> 2.1) minitest (5.17.0) multipart-post (2.1.1) - nokogiri (1.13.10-arm64-darwin) + nokogiri (1.14.3-arm64-darwin) racc (~> 1.4) - nokogiri (1.13.10-x86_64-linux) + nokogiri (1.14.3-x86_64-linux) racc (~> 1.4) octokit (4.22.0) faraday (>= 0.9) @@ -241,7 +241,7 @@ GEM pathutil (0.16.2) forwardable-extended (~> 2.6) public_suffix (4.0.7) - racc (1.6.1) + racc (1.6.2) rb-fsevent (0.11.1) rb-inotify (0.10.1) ffi (~> 1.0) diff --git a/docs/administration/4_adminapi.md b/docs/administration/4_adminapi.md index 46cfac220..b11aeb1a6 100644 --- a/docs/administration/4_adminapi.md +++ b/docs/administration/4_adminapi.md @@ -32,7 +32,7 @@ UPDATE userapi_accounts SET account_type = 3 WHERE localpart = '$localpart'; Where `$localpart` is the username only (e.g. `alice`). -## GET `/_dendrite/admin/evacuateRoom/{roomID}` +## POST `/_dendrite/admin/evacuateRoom/{roomID}` This endpoint will instruct Dendrite to part all local users from the given `roomID` in the URL. It may take some time to complete. A JSON body will be returned containing @@ -41,7 +41,7 @@ the user IDs of all affected users. If the room has an alias set (e.g. is published), the room's ID will not be visible in the URL, but it can be found as the room's "internal ID" in Element Web (Settings -> Advanced) -## GET `/_dendrite/admin/evacuateUser/{userID}` +## POST `/_dendrite/admin/evacuateUser/{userID}` This endpoint will instruct Dendrite to part the given local `userID` in the URL from all rooms which they are currently joined. A JSON body will be returned containing diff --git a/docs/development/sytest.md b/docs/development/sytest.md index 3cfb99e60..4fae2ea3d 100644 --- a/docs/development/sytest.md +++ b/docs/development/sytest.md @@ -61,7 +61,6 @@ When debugging, the following Docker `run` options may also be useful: * `-e "DENDRITE_TRACE_HTTP=1"`: Adds HTTP tracing to server logs. * `-e "DENDRITE_TRACE_INTERNAL=1"`: Adds roomserver internal API tracing to server logs. -* `-e "DENDRITE_TRACE_SQL=1"`: Adds tracing to all SQL statements to server logs. The docker command also supports a single positional argument for the test file to run, so you can run a single `.pl` file rather than the whole test suite. For example: diff --git a/docs/installation/7_configuration.md b/docs/installation/7_configuration.md index 5f123bfca..0cc67b156 100644 --- a/docs/installation/7_configuration.md +++ b/docs/installation/7_configuration.md @@ -10,7 +10,7 @@ permalink: /installation/configuration A YAML configuration file is used to configure Dendrite. A sample configuration file is present in the top level of the Dendrite repository: -* [`dendrite-sample.monolith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.monolith.yaml) +* [`dendrite-sample.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.yaml) You will need to duplicate the sample, calling it `dendrite.yaml` for example, and then tailor it to your installation. At a minimum, you will need to populate the following @@ -86,9 +86,8 @@ that you chose. ### Global connection pool -If you are running a monolith deployment and want to use a single connection pool to a -single PostgreSQL database, then you must uncomment and configure the `database` section -within the `global` section: +If you want to use a single connection pool to a single PostgreSQL database, then you must +uncomment and configure the `database` section within the `global` section: ```yaml global: @@ -102,15 +101,15 @@ global: **You must then remove or comment out** the `database` sections from other areas of the configuration file, e.g. under the `app_service_api`, `federation_api`, `key_server`, -`media_api`, `mscs`, `room_server`, `sync_api` and `user_api` blocks, otherwise these will -override the `global` database configuration. +`media_api`, `mscs`, `relay_api`, `room_server`, `sync_api` and `user_api` blocks, otherwise +these will override the `global` database configuration. ### Per-component connections (all other configurations) If you are are using SQLite databases or separate PostgreSQL databases per component, then you must instead configure the `database` sections under each of the component blocks ,e.g. under the `app_service_api`, `federation_api`, `key_server`, -`media_api`, `mscs`, `room_server`, `sync_api` and `user_api` blocks. +`media_api`, `mscs`, `relay_api`, `room_server`, `sync_api` and `user_api` blocks. For example, with PostgreSQL: diff --git a/federationapi/api/api.go b/federationapi/api/api.go index e4c0b2714..6b09b0cd6 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -7,6 +7,8 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/federationapi/types" ) @@ -21,9 +23,9 @@ type FederationInternalAPI interface { P2PFederationAPI QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error - LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) - MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + LookupServerKeys(ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp) ([]gomatrixserverlib.ServerKeys, error) + MSC2836EventRelationships(ctx context.Context, origin, dst spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res fclient.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst spec.ServerName, roomID string, suggestedOnly bool) (res fclient.MSC2946SpacesResponse, err error) // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. PerformBroadcastEDU( @@ -66,9 +68,9 @@ type RoomserverFederationAPI interface { // containing only the server names (without information for membership events). // The response will include this server if they are joined to the room. QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error - GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) + GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error) + GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error) } type P2PFederationAPI interface { @@ -98,45 +100,45 @@ type P2PFederationAPI interface { // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError type KeyserverFederationAPI interface { - GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) - ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) - QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) + GetUserDevices(ctx context.Context, origin, s spec.ServerName, userID string) (res fclient.RespUserDevices, err error) + ClaimKeys(ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string) (res fclient.RespClaimKeys, err error) + QueryKeys(ctx context.Context, origin, s spec.ServerName, keys map[string][]string) (res fclient.RespQueryKeys, err error) } // an interface for gmsl.FederationClient - contains functions called by federationapi only. type FederationClient interface { P2PFederationClient gomatrixserverlib.KeyClient - SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) + SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res fclient.RespSend, err error) // Perform operations - LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) - Peek(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) - MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) - SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) - MakeLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) - SendLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) - SendInviteV2(ctx context.Context, origin, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) + LookupRoomAlias(ctx context.Context, origin, s spec.ServerName, roomAlias string) (res fclient.RespDirectory, err error) + Peek(ctx context.Context, origin, s spec.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res fclient.RespPeek, err error) + MakeJoin(ctx context.Context, origin, s spec.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res fclient.RespMakeJoin, err error) + SendJoin(ctx context.Context, origin, s spec.ServerName, event *gomatrixserverlib.Event) (res fclient.RespSendJoin, err error) + MakeLeave(ctx context.Context, origin, s spec.ServerName, roomID, userID string) (res fclient.RespMakeLeave, err error) + SendLeave(ctx context.Context, origin, s spec.ServerName, event *gomatrixserverlib.Event) (err error) + SendInviteV2(ctx context.Context, origin, s spec.ServerName, request fclient.InviteV2Request) (res fclient.RespInviteV2, err error) - GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) - ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) - QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) - Backfill(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) - MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error) + GetUserDevices(ctx context.Context, origin, s spec.ServerName, userID string) (fclient.RespUserDevices, error) + ClaimKeys(ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string) (fclient.RespClaimKeys, error) + QueryKeys(ctx context.Context, origin, s spec.ServerName, keys map[string][]string) (fclient.RespQueryKeys, error) + Backfill(ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) + MSC2836EventRelationships(ctx context.Context, origin, dst spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res fclient.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst spec.ServerName, roomID string, suggestedOnly bool) (res fclient.MSC2946SpacesResponse, err error) - ExchangeThirdPartyInvite(ctx context.Context, origin, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) - LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) - LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) + ExchangeThirdPartyInvite(ctx context.Context, origin, s spec.ServerName, builder gomatrixserverlib.EventBuilder) (err error) + LookupState(ctx context.Context, origin, s spec.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespState, err error) + LookupStateIDs(ctx context.Context, origin, s spec.ServerName, roomID string, eventID string) (res fclient.RespStateIDs, err error) + LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error) } type P2PFederationClient interface { - P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) - P2PGetTransactionFromRelay(ctx context.Context, u gomatrixserverlib.UserID, prev gomatrixserverlib.RelayEntry, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetRelayTransaction, err error) + P2PSendTransactionToRelay(ctx context.Context, u spec.UserID, t gomatrixserverlib.Transaction, forwardingServer spec.ServerName) (res fclient.EmptyResp, err error) + P2PGetTransactionFromRelay(ctx context.Context, u spec.UserID, prev fclient.RelayEntry, relayServer spec.ServerName) (res fclient.RespGetRelayTransaction, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. @@ -152,7 +154,7 @@ func (e FederationClientError) Error() string { } type QueryServerKeysRequest struct { - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName KeyIDToCriteria map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria } @@ -171,7 +173,7 @@ type QueryServerKeysResponse struct { } type QueryPublicKeysRequest struct { - Requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp `json:"requests"` + Requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp `json:"requests"` } type QueryPublicKeysResponse struct { @@ -179,13 +181,13 @@ type QueryPublicKeysResponse struct { } type PerformDirectoryLookupRequest struct { - RoomAlias string `json:"room_alias"` - ServerName gomatrixserverlib.ServerName `json:"server_name"` + RoomAlias string `json:"room_alias"` + ServerName spec.ServerName `json:"server_name"` } type PerformDirectoryLookupResponse struct { - RoomID string `json:"room_id"` - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` + RoomID string `json:"room_id"` + ServerNames []spec.ServerName `json:"server_names"` } type PerformJoinRequest struct { @@ -198,7 +200,7 @@ type PerformJoinRequest struct { } type PerformJoinResponse struct { - JoinedVia gomatrixserverlib.ServerName + JoinedVia spec.ServerName LastError *gomatrix.HTTPError } @@ -222,9 +224,9 @@ type PerformLeaveResponse struct { } type PerformInviteRequest struct { - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - Event *gomatrixserverlib.HeaderedEvent `json:"event"` - InviteRoomState []gomatrixserverlib.InviteV2StrippedState `json:"invite_room_state"` + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + Event *gomatrixserverlib.HeaderedEvent `json:"event"` + InviteRoomState []fclient.InviteV2StrippedState `json:"invite_room_state"` } type PerformInviteResponse struct { @@ -240,7 +242,7 @@ type QueryJoinedHostServerNamesInRoomRequest struct { // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomResponse struct { - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` + ServerNames []spec.ServerName `json:"server_names"` } type PerformBroadcastEDURequest struct { @@ -250,7 +252,7 @@ type PerformBroadcastEDUResponse struct { } type PerformWakeupServersRequest struct { - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` + ServerNames []spec.ServerName `json:"server_names"` } type PerformWakeupServersResponse struct { @@ -264,24 +266,24 @@ type InputPublicKeysResponse struct { } type P2PQueryRelayServersRequest struct { - Server gomatrixserverlib.ServerName + Server spec.ServerName } type P2PQueryRelayServersResponse struct { - RelayServers []gomatrixserverlib.ServerName + RelayServers []spec.ServerName } type P2PAddRelayServersRequest struct { - Server gomatrixserverlib.ServerName - RelayServers []gomatrixserverlib.ServerName + Server spec.ServerName + RelayServers []spec.ServerName } type P2PAddRelayServersResponse struct { } type P2PRemoveRelayServersRequest struct { - Server gomatrixserverlib.ServerName - RelayServers []gomatrixserverlib.ServerName + Server spec.ServerName + RelayServers []spec.ServerName } type P2PRemoveRelayServersResponse struct { diff --git a/federationapi/api/servers.go b/federationapi/api/servers.go index 6bb15763d..ff4dc6c99 100644 --- a/federationapi/api/servers.go +++ b/federationapi/api/servers.go @@ -4,8 +4,9 @@ import ( "context" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type ServersInRoomProvider interface { - GetServersForRoom(ctx context.Context, roomID string, event *gomatrixserverlib.Event) []gomatrixserverlib.ServerName + GetServersForRoom(ctx context.Context, roomID string, event *gomatrixserverlib.Event) []spec.ServerName } diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 7d9df3d78..3fdc835bb 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -20,6 +20,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -40,7 +41,7 @@ type KeyChangeConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool rsAPI roomserverAPI.FederationRoomserverAPI topic string } @@ -140,7 +141,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MDeviceListUpdate, + Type: spec.MDeviceListUpdate, Origin: string(originServerName), } event := gomatrixserverlib.DeviceListUpdateEvent{ diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index 29b16f373..e751b65d4 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" ) @@ -39,7 +40,7 @@ type OutputPresenceConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool rsAPI roomserverAPI.FederationRoomserverAPI topic string outboundPresenceEnabled bool @@ -127,7 +128,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg statusMsg = &status } - p := types.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(ts)} + p := types.PresenceInternal{LastActiveTS: spec.Timestamp(ts)} content := fedTypes.Presence{ Push: []fedTypes.PresenceContent{ @@ -142,7 +143,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg } edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MPresence, + Type: spec.MPresence, Origin: string(serverName), } if edu.Content, err = json.Marshal(content); err != nil { diff --git a/federationapi/consumers/receipts.go b/federationapi/consumers/receipts.go index 200c06e6c..1407a88b7 100644 --- a/federationapi/consumers/receipts.go +++ b/federationapi/consumers/receipts.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" syncTypes "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" ) @@ -39,7 +40,7 @@ type OutputReceiptConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool topic string } @@ -107,7 +108,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) return true } - receipt.Timestamp = gomatrixserverlib.Timestamp(timestamp) + receipt.Timestamp = spec.Timestamp(timestamp) joined, err := t.db.GetJoinedHosts(ctx, receipt.RoomID) if err != nil { @@ -115,7 +116,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) return false } - names := make([]gomatrixserverlib.ServerName, len(joined)) + names := make([]spec.ServerName, len(joined)) for i := range joined { names[i] = joined[i].ServerName } @@ -133,7 +134,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) } edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MReceipt, + Type: spec.MReceipt, Origin: string(receiptServerName), } if edu.Content, err = json.Marshal(content); err != nil { diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 378b96ba0..5ef65ee5b 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -22,6 +22,7 @@ import ( "time" syncAPITypes "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -207,9 +208,9 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } // If we added new hosts, inform them about our known presence events for this room - if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { + if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == spec.MRoomMember && ore.Event.StateKey() != nil { membership, _ := ore.Event.Membership() - if membership == gomatrixserverlib.Join { + if membership == spec.Join { s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) } } @@ -239,12 +240,12 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew // Send the event. return s.queues.SendEvent( - ore.Event, gomatrixserverlib.ServerName(ore.SendAsServer), joinedHostsAtEvent, + ore.Event, spec.ServerName(ore.SendAsServer), joinedHostsAtEvent, ) } func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) { - joined := make([]gomatrixserverlib.ServerName, 0, len(addedJoined)) + joined := make([]spec.ServerName, 0, len(addedJoined)) for _, added := range addedJoined { joined = append(joined, added.ServerName) } @@ -285,7 +286,7 @@ func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []type continue } - p := syncAPITypes.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(lastActive)} + p := syncAPITypes.PresenceInternal{LastActiveTS: spec.Timestamp(lastActive)} content.Push = append(content.Push, types.PresenceContent{ CurrentlyActive: p.CurrentlyActive(), @@ -301,7 +302,7 @@ func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []type } edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MPresence, + Type: spec.MPresence, Origin: string(s.cfg.Matrix.ServerName), } if edu.Content, err = json.Marshal(content); err != nil { @@ -326,7 +327,7 @@ func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []type // Returns an error if there was a problem talking to the room server. func (s *OutputRoomEventConsumer) joinedHostsAtEvent( ore api.OutputNewRoomEvent, oldJoinedHosts []types.JoinedHost, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { // Combine the delta into a single delta so that the adds and removes can // cancel each other out. This should reduce the number of times we need // to fetch a state event from the room server. @@ -349,7 +350,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( removed[eventID] = true } - joined := map[gomatrixserverlib.ServerName]bool{} + joined := map[spec.ServerName]bool{} for _, joinedHost := range oldJoinedHosts { if removed[joinedHost.MemberEventID] { // This m.room.member event is part of the current state of the @@ -376,7 +377,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( joined[inboundPeek.ServerName] = true } - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for serverName, include := range joined { if include { result = append(result, serverName) @@ -398,7 +399,7 @@ func JoinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, if err != nil { return nil, err } - if membership != gomatrixserverlib.Join { + if membership != spec.Join { continue } _, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) diff --git a/federationapi/consumers/sendtodevice.go b/federationapi/consumers/sendtodevice.go index 9620d1612..91b28cdbf 100644 --- a/federationapi/consumers/sendtodevice.go +++ b/federationapi/consumers/sendtodevice.go @@ -20,6 +20,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -39,7 +40,7 @@ type OutputSendToDeviceConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool topic string } @@ -107,7 +108,7 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MDirectToDevice, + Type: spec.MDirectToDevice, Origin: string(originServerName), } tdm := gomatrixserverlib.ToDeviceMessage{ @@ -127,7 +128,7 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats } log.Debugf("Sending send-to-device message into %q destination queue", destServerName) - if err := t.queues.SendEDU(edu, originServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { + if err := t.queues.SendEDU(edu, originServerName, []spec.ServerName{destServerName}); err != nil { log.WithError(err).Error("failed to send EDU") return false } diff --git a/federationapi/consumers/typing.go b/federationapi/consumers/typing.go index c66f97519..134f2174f 100644 --- a/federationapi/consumers/typing.go +++ b/federationapi/consumers/typing.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" ) @@ -36,7 +37,7 @@ type OutputTypingConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool topic string } @@ -97,7 +98,7 @@ func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) return false } - names := make([]gomatrixserverlib.ServerName, len(joined)) + names := make([]spec.ServerName, len(joined)) for i := range joined { names[i] = joined[i].ServerName } diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index ec482659a..c64fa550d 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -17,6 +17,11 @@ package federationapi import ( "time" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/api" @@ -29,7 +34,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/internal/caching" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -40,17 +44,21 @@ import ( // AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. func AddPublicRoutes( - base *base.BaseDendrite, + processContext *process.ProcessContext, + routers httputil.Routers, + dendriteConfig *config.Dendrite, + natsInstance *jetstream.NATSInstance, userAPI userapi.FederationUserAPI, - federation *gomatrixserverlib.FederationClient, + federation *fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, rsAPI roomserverAPI.FederationRoomserverAPI, fedAPI federationAPI.FederationInternalAPI, servers federationAPI.ServersInRoomProvider, + enableMetrics bool, ) { - cfg := &base.Cfg.FederationAPI - mscCfg := &base.Cfg.MSCs - js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + cfg := &dendriteConfig.FederationAPI + mscCfg := &dendriteConfig.MSCs + js, _ := natsInstance.Prepare(processContext, &cfg.Matrix.JetStream) producer := &producers.SyncAPIProducer{ JetStream: js, TopicReceiptEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), @@ -75,26 +83,30 @@ func AddPublicRoutes( } routing.Setup( - base, + routers, + dendriteConfig, rsAPI, f, keyRing, federation, userAPI, mscCfg, - servers, producer, + servers, producer, enableMetrics, ) } // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, + processContext *process.ProcessContext, + dendriteCfg *config.Dendrite, + cm sqlutil.Connections, + natsInstance *jetstream.NATSInstance, federation api.FederationClient, rsAPI roomserverAPI.FederationRoomserverAPI, caches *caching.Caches, keyRing *gomatrixserverlib.KeyRing, resetBlacklist bool, ) api.FederationInternalAPI { - cfg := &base.Cfg.FederationAPI + cfg := &dendriteCfg.FederationAPI - federationDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) + federationDB, err := storage.NewDatabase(processContext.Context(), cm, &cfg.Database, caches, dendriteCfg.Global.IsLocalServerName) if err != nil { logrus.WithError(err).Panic("failed to connect to federation sender db") } @@ -108,51 +120,51 @@ func NewInternalAPI( cfg.FederationMaxRetries+1, cfg.P2PFederationRetriesUntilAssumedOffline+1) - js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, nats := natsInstance.Prepare(processContext, &cfg.Matrix.JetStream) - signingInfo := base.Cfg.Global.SigningIdentities() + signingInfo := dendriteCfg.Global.SigningIdentities() queues := queue.NewOutgoingQueues( - federationDB, base.ProcessContext, + federationDB, processContext, cfg.Matrix.DisableFederation, cfg.Matrix.ServerName, federation, rsAPI, &stats, signingInfo, ) rsConsumer := consumers.NewOutputRoomEventConsumer( - base.ProcessContext, cfg, js, nats, queues, + processContext, cfg, js, nats, queues, federationDB, rsAPI, ) if err = rsConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start room server consumer") } tsConsumer := consumers.NewOutputSendToDeviceConsumer( - base.ProcessContext, cfg, js, queues, federationDB, + processContext, cfg, js, queues, federationDB, ) if err = tsConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start send-to-device consumer") } receiptConsumer := consumers.NewOutputReceiptConsumer( - base.ProcessContext, cfg, js, queues, federationDB, + processContext, cfg, js, queues, federationDB, ) if err = receiptConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start receipt consumer") } typingConsumer := consumers.NewOutputTypingConsumer( - base.ProcessContext, cfg, js, queues, federationDB, + processContext, cfg, js, queues, federationDB, ) if err = typingConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start typing consumer") } keyConsumer := consumers.NewKeyChangeConsumer( - base.ProcessContext, &base.Cfg.KeyServer, js, queues, federationDB, rsAPI, + processContext, &dendriteCfg.KeyServer, js, queues, federationDB, rsAPI, ) if err = keyConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start key server consumer") } presenceConsumer := consumers.NewOutputPresenceConsumer( - base.ProcessContext, cfg, js, queues, federationDB, rsAPI, + processContext, cfg, js, queues, federationDB, rsAPI, ) if err = presenceConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start presence consumer") @@ -161,7 +173,7 @@ func NewInternalAPI( var cleanExpiredEDUs func() cleanExpiredEDUs = func() { logrus.Infof("Cleaning expired EDUs") - if err := federationDB.DeleteExpiredEDUs(base.Context()); err != nil { + if err := federationDB.DeleteExpiredEDUs(processContext.Context()); err != nil { logrus.WithError(err).Error("Failed to clean expired EDUs") } time.AfterFunc(time.Hour, cleanExpiredEDUs) diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index bb6ee8935..50691334d 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -12,22 +12,26 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/routing" "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) type server struct { - name gomatrixserverlib.ServerName // server name - validity time.Duration // key validity duration from now - config *config.FederationAPI // skeleton config, from TestMain - fedclient *gomatrixserverlib.FederationClient // uses MockRoundTripper - cache *caching.Caches // server-specific cache - api api.FederationInternalAPI // server-specific server key API + name spec.ServerName // server name + validity time.Duration // key validity duration from now + config *config.FederationAPI // skeleton config, from TestMain + fedclient *fclient.FederationClient // uses MockRoundTripper + cache *caching.Caches // server-specific cache + api api.FederationInternalAPI // server-specific server key API } func (s *server) renew() { @@ -65,7 +69,7 @@ func TestMain(m *testing.M) { // Create a new cache but don't enable prometheus! s.cache = caching.NewRistrettoCache(8*1024*1024, time.Hour, false) - + natsInstance := jetstream.NATSInstance{} // Create a temporary directory for JetStream. d, err := os.MkdirTemp("./", "jetstream*") if err != nil { @@ -80,7 +84,7 @@ func TestMain(m *testing.M) { Generate: true, SingleDatabase: false, }) - cfg.Global.ServerName = gomatrixserverlib.ServerName(s.name) + cfg.Global.ServerName = spec.ServerName(s.name) cfg.Global.PrivateKey = testPriv cfg.Global.JetStream.InMemory = true cfg.Global.JetStream.TopicPrefix = string(s.name[:1]) @@ -103,14 +107,15 @@ func TestMain(m *testing.M) { transport.RegisterProtocol("matrix", &MockRoundTripper{}) // Create the federation client. - s.fedclient = gomatrixserverlib.NewFederationClient( + s.fedclient = fclient.NewFederationClient( s.config.Matrix.SigningIdentities(), - gomatrixserverlib.WithTransport(transport), + fclient.WithTransport(transport), ) // Finally, build the server key APIs. - sbase := base.NewBaseDendrite(cfg, base.DisableMetrics) - s.api = NewInternalAPI(sbase, s.fedclient, nil, s.cache, nil, true) + processCtx := process.NewProcessContext() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + s.api = NewInternalAPI(processCtx, cfg, cm, &natsInstance, s.fedclient, nil, s.cache, nil, true) } // Now that we have built our server key APIs, start the @@ -137,7 +142,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err } // Get the keys and JSON-ify them. - keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host)) + keys := routing.LocalKeys(s.config, spec.ServerName(req.Host)) body, err := json.MarshalIndent(keys.JSON, "", " ") if err != nil { return nil, err @@ -162,8 +167,8 @@ func TestServersRequestOwnKeys(t *testing.T) { } res, err := s.api.FetchKeys( context.Background(), - map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ - req: gomatrixserverlib.AsTimestamp(time.Now()), + map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{ + req: spec.AsTimestamp(time.Now()), }, ) if err != nil { @@ -188,8 +193,8 @@ func TestRenewalBehaviour(t *testing.T) { res, err := serverA.api.FetchKeys( context.Background(), - map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ - req: gomatrixserverlib.AsTimestamp(time.Now()), + map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{ + req: spec.AsTimestamp(time.Now()), }, ) if err != nil { @@ -212,8 +217,8 @@ func TestRenewalBehaviour(t *testing.T) { res, err = serverA.api.FetchKeys( context.Background(), - map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ - req: gomatrixserverlib.AsTimestamp(time.Now()), + map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{ + req: spec.AsTimestamp(time.Now()), }, ) if err != nil { diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 57d4b9644..a8af00006 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -10,16 +10,19 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/internal" rsapi "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" @@ -54,7 +57,7 @@ type fedClient struct { fedClientMutex sync.Mutex api.FederationClient allowJoins []*test.Room - keys map[gomatrixserverlib.ServerName]struct { + keys map[spec.ServerName]struct { key ed25519.PrivateKey keyID gomatrixserverlib.KeyID } @@ -62,7 +65,7 @@ type fedClient struct { sentTxn bool } -func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) { +func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer spec.ServerName) (gomatrixserverlib.ServerKeys, error) { f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() fmt.Println("GetServerKeys:", matrixServer) @@ -81,11 +84,11 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv } keys.ServerName = matrixServer - keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(10 * time.Hour)) + keys.ValidUntilTS = spec.AsTimestamp(time.Now().Add(10 * time.Hour)) publicKey := pkey.Public().(ed25519.PublicKey) keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ keyID: { - Key: gomatrixserverlib.Base64Bytes(publicKey), + Key: spec.Base64Bytes(publicKey), }, } toSign, err := json.Marshal(keys.ServerKeyFields) @@ -103,7 +106,7 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv return keys, nil } -func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { +func (f *fedClient) MakeJoin(ctx context.Context, origin, s spec.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res fclient.RespMakeJoin, err error) { for _, r := range f.allowJoins { if r.ID == roomID { res.RoomVersion = r.Version @@ -112,7 +115,7 @@ func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.Se RoomID: roomID, Type: "m.room.member", StateKey: &userID, - Content: gomatrixserverlib.RawJSON([]byte(`{"membership":"join"}`)), + Content: spec.RawJSON([]byte(`{"membership":"join"}`)), PrevEvents: r.ForwardExtremities(), } var needed gomatrixserverlib.StateNeeded @@ -127,7 +130,7 @@ func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.Se } return } -func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { +func (f *fedClient) SendJoin(ctx context.Context, origin, s spec.ServerName, event *gomatrixserverlib.Event) (res fclient.RespSendJoin, err error) { f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { @@ -141,11 +144,11 @@ func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.Se return } -func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { +func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res fclient.RespSend, err error) { f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() for _, edu := range t.EDUs { - if edu.Type == gomatrixserverlib.MDeviceListUpdate { + if edu.Type == spec.MDeviceListUpdate { f.sentTxn = true } } @@ -162,21 +165,24 @@ func TestFederationAPIJoinThenKeyUpdate(t *testing.T) { } func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) - base.Cfg.FederationAPI.PreferDirectFetch = true - base.Cfg.FederationAPI.KeyPerspectives = nil + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + cfg.FederationAPI.PreferDirectFetch = true + cfg.FederationAPI.KeyPerspectives = nil defer close() - jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) - serverA := gomatrixserverlib.ServerName("server.a") + serverA := spec.ServerName("server.a") serverAKeyID := gomatrixserverlib.KeyID("ed25519:servera") serverAPrivKey := test.PrivateKeyA creator := test.NewUser(t, test.WithSigningServer(serverA, serverAKeyID, serverAPrivKey)) - myServer := base.Cfg.Global.ServerName - myServerKeyID := base.Cfg.Global.KeyID - myServerPrivKey := base.Cfg.Global.PrivateKey + myServer := cfg.Global.ServerName + myServerKeyID := cfg.Global.KeyID + myServerPrivKey := cfg.Global.PrivateKey joiningUser := test.NewUser(t, test.WithSigningServer(myServer, myServerKeyID, myServerPrivKey)) fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID) room := test.NewRoom(t, creator) @@ -198,7 +204,7 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { fc := &fedClient{ allowJoins: []*test.Room{room}, t: t, - keys: map[gomatrixserverlib.ServerName]struct { + keys: map[spec.ServerName]struct { key ed25519.PrivateKey keyID gomatrixserverlib.KeyID }{ @@ -212,13 +218,13 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { }, }, } - fsapi := federationapi.NewInternalAPI(base, fc, rsapi, base.Caches, nil, false) + fsapi := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, fc, rsapi, caches, nil, false) var resp api.PerformJoinResponse fsapi.PerformJoin(context.Background(), &api.PerformJoinRequest{ RoomID: room.ID, UserID: joiningUser.ID, - ServerNames: []gomatrixserverlib.ServerName{serverA}, + ServerNames: []spec.ServerName{serverA}, }, &resp) if resp.JoinedVia != serverA { t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, serverA) @@ -245,7 +251,7 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { } msg := &nats.Msg{ - Subject: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + Subject: cfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), Header: nats.Header{}, Data: b, } @@ -263,30 +269,6 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { // Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404. // Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated. func TestRoomsV3URLEscapeDoNot404(t *testing.T) { - _, privKey, _ := ed25519.GenerateKey(nil) - cfg := &config.Dendrite{} - cfg.Defaults(config.DefaultOpts{ - Generate: true, - SingleDatabase: false, - }) - cfg.Global.KeyID = gomatrixserverlib.KeyID("ed25519:auto") - cfg.Global.ServerName = gomatrixserverlib.ServerName("localhost") - cfg.Global.PrivateKey = privKey - cfg.Global.JetStream.InMemory = true - b := base.NewBaseDendrite(cfg, base.DisableMetrics) - keyRing := &test.NopJSONVerifier{} - // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. - // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil) - baseURL, cancel := test.ListenAndServe(t, b.PublicFederationAPIMux, true) - defer cancel() - serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) - - fedCli := gomatrixserverlib.NewFederationClient( - cfg.Global.SigningIdentities(), - gomatrixserverlib.WithSkipVerify(true), - ) - testCases := []struct { roomVer gomatrixserverlib.RoomVersion eventJSON string @@ -315,13 +297,36 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { }, } + cfg, processCtx, close := testrig.CreateConfig(t, test.DBTypeSQLite) + defer close() + routers := httputil.NewRouters() + + _, privKey, _ := ed25519.GenerateKey(nil) + cfg.Global.KeyID = gomatrixserverlib.KeyID("ed25519:auto") + cfg.Global.ServerName = spec.ServerName("localhost") + cfg.Global.PrivateKey = privKey + cfg.Global.JetStream.InMemory = true + keyRing := &test.NopJSONVerifier{} + natsInstance := jetstream.NATSInstance{} + // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. + // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. + federationapi.AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, caching.DisableMetrics) + baseURL, cancel := test.ListenAndServe(t, routers.Federation, true) + defer cancel() + serverName := spec.ServerName(strings.TrimPrefix(baseURL, "https://")) + + fedCli := fclient.NewFederationClient( + cfg.Global.SigningIdentities(), + fclient.WithSkipVerify(true), + ) + for _, tc := range testCases { ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(tc.eventJSON), false, tc.roomVer) if err != nil { t.Errorf("failed to parse event: %s", err) } he := ev.Headered(tc.roomVer) - invReq, err := gomatrixserverlib.NewInviteV2Request(he, nil) + invReq, err := fclient.NewInviteV2Request(he, nil) if err != nil { t.Errorf("failed to create invite v2 request: %s", err) continue diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 99773a750..529cf777d 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -17,6 +17,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -107,7 +108,7 @@ func NewFederationInternalAPI( } } -func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) { +func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) { stats := a.statistics.ForServer(s) if stats.Blacklisted() { return stats, &api.FederationClientError{ @@ -144,7 +145,7 @@ func failBlacklistableError(err error, stats *statistics.ServerStatistics) (unti } func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted( - s gomatrixserverlib.ServerName, request func() (interface{}, error), + s spec.ServerName, request func() (interface{}, error), ) (interface{}, error) { stats, err := a.isBlacklistedOrBackingOff(s) if err != nil { @@ -169,7 +170,7 @@ func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted( } func (a *FederationInternalAPI) doRequestIfNotBlacklisted( - s gomatrixserverlib.ServerName, request func() (interface{}, error), + s spec.ServerName, request func() (interface{}, error), ) (interface{}, error) { stats := a.statistics.ForServer(s) if blacklisted := stats.Blacklisted(); blacklisted { diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index db6348ec1..e4288a20c 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -5,68 +5,70 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Functions here are "proxying" calls to the gomatrixserverlib federation // client. func (a *FederationInternalAPI) GetEventAuth( - ctx context.Context, origin, s gomatrixserverlib.ServerName, + ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, -) (res gomatrixserverlib.RespEventAuth, err error) { +) (res fclient.RespEventAuth, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) }) if err != nil { - return gomatrixserverlib.RespEventAuth{}, err + return fclient.RespEventAuth{}, err } - return ires.(gomatrixserverlib.RespEventAuth), nil + return ires.(fclient.RespEventAuth), nil } func (a *FederationInternalAPI) GetUserDevices( - ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, -) (gomatrixserverlib.RespUserDevices, error) { + ctx context.Context, origin, s spec.ServerName, userID string, +) (fclient.RespUserDevices, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetUserDevices(ctx, origin, s, userID) }) if err != nil { - return gomatrixserverlib.RespUserDevices{}, err + return fclient.RespUserDevices{}, err } - return ires.(gomatrixserverlib.RespUserDevices), nil + return ires.(fclient.RespUserDevices), nil } func (a *FederationInternalAPI) ClaimKeys( - ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, -) (gomatrixserverlib.RespClaimKeys, error) { + ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string, +) (fclient.RespClaimKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) }) if err != nil { - return gomatrixserverlib.RespClaimKeys{}, err + return fclient.RespClaimKeys{}, err } - return ires.(gomatrixserverlib.RespClaimKeys), nil + return ires.(fclient.RespClaimKeys), nil } func (a *FederationInternalAPI) QueryKeys( - ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, -) (gomatrixserverlib.RespQueryKeys, error) { + ctx context.Context, origin, s spec.ServerName, keys map[string][]string, +) (fclient.RespQueryKeys, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { return a.federation.QueryKeys(ctx, origin, s, keys) }) if err != nil { - return gomatrixserverlib.RespQueryKeys{}, err + return fclient.RespQueryKeys{}, err } - return ires.(gomatrixserverlib.RespQueryKeys), nil + return ires.(fclient.RespQueryKeys), nil } func (a *FederationInternalAPI) Backfill( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string, ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() @@ -80,50 +82,51 @@ func (a *FederationInternalAPI) Backfill( } func (a *FederationInternalAPI) LookupState( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, -) (res gomatrixserverlib.RespState, err error) { + ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, +) (res gomatrixserverlib.StateResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) }) if err != nil { - return gomatrixserverlib.RespState{}, err + return &fclient.RespState{}, err } - return ires.(gomatrixserverlib.RespState), nil + r := ires.(fclient.RespState) + return &r, nil } func (a *FederationInternalAPI) LookupStateIDs( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, -) (res gomatrixserverlib.RespStateIDs, err error) { + ctx context.Context, origin, s spec.ServerName, roomID, eventID string, +) (res gomatrixserverlib.StateIDResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) }) if err != nil { - return gomatrixserverlib.RespStateIDs{}, err + return fclient.RespStateIDs{}, err } - return ires.(gomatrixserverlib.RespStateIDs), nil + return ires.(fclient.RespStateIDs), nil } func (a *FederationInternalAPI) LookupMissingEvents( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, - missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, -) (res gomatrixserverlib.RespMissingEvents, err error) { + ctx context.Context, origin, s spec.ServerName, roomID string, + missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, +) (res fclient.RespMissingEvents, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) }) if err != nil { - return gomatrixserverlib.RespMissingEvents{}, err + return fclient.RespMissingEvents{}, err } - return ires.(gomatrixserverlib.RespMissingEvents), nil + return ires.(fclient.RespMissingEvents), nil } func (a *FederationInternalAPI) GetEvent( - ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, + ctx context.Context, origin, s spec.ServerName, eventID string, ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() @@ -137,7 +140,7 @@ func (a *FederationInternalAPI) GetEvent( } func (a *FederationInternalAPI) LookupServerKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) ([]gomatrixserverlib.ServerKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() @@ -151,9 +154,9 @@ func (a *FederationInternalAPI) LookupServerKeys( } func (a *FederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + ctx context.Context, origin, s spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, -) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { +) (res fclient.MSC2836EventRelationshipsResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -162,12 +165,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( if err != nil { return res, err } - return ires.(gomatrixserverlib.MSC2836EventRelationshipsResponse), nil + return ires.(fclient.MSC2836EventRelationshipsResponse), nil } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, -) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { + ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool, +) (res fclient.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -176,5 +179,5 @@ func (a *FederationInternalAPI) MSC2946Spaces( if err != nil { return res, err } - return ires.(gomatrixserverlib.MSC2946SpacesResponse), nil + return ires.(fclient.MSC2946SpacesResponse), nil } diff --git a/federationapi/internal/federationclient_test.go b/federationapi/internal/federationclient_test.go index 49137e2d8..8c562dd61 100644 --- a/federationapi/internal/federationclient_test.go +++ b/federationapi/internal/federationclient_test.go @@ -24,7 +24,8 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -33,20 +34,20 @@ const ( FailuresUntilBlacklist = 8 ) -func (t *testFedClient) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) { +func (t *testFedClient) QueryKeys(ctx context.Context, origin, s spec.ServerName, keys map[string][]string) (fclient.RespQueryKeys, error) { t.queryKeysCalled = true if t.shouldFail { - return gomatrixserverlib.RespQueryKeys{}, fmt.Errorf("Failure") + return fclient.RespQueryKeys{}, fmt.Errorf("Failure") } - return gomatrixserverlib.RespQueryKeys{}, nil + return fclient.RespQueryKeys{}, nil } -func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) { +func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string) (fclient.RespClaimKeys, error) { t.claimKeysCalled = true if t.shouldFail { - return gomatrixserverlib.RespClaimKeys{}, fmt.Errorf("Failure") + return fclient.RespClaimKeys{}, fmt.Errorf("Failure") } - return gomatrixserverlib.RespClaimKeys{}, nil + return fclient.RespClaimKeys{}, nil } func TestFederationClientQueryKeys(t *testing.T) { @@ -54,7 +55,7 @@ func TestFederationClientQueryKeys(t *testing.T) { cfg := config.FederationAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: "server", }, }, @@ -85,7 +86,7 @@ func TestFederationClientQueryKeysBlacklisted(t *testing.T) { cfg := config.FederationAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: "server", }, }, @@ -115,7 +116,7 @@ func TestFederationClientQueryKeysFailure(t *testing.T) { cfg := config.FederationAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: "server", }, }, @@ -145,7 +146,7 @@ func TestFederationClientClaimKeys(t *testing.T) { cfg := config.FederationAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: "server", }, }, @@ -176,7 +177,7 @@ func TestFederationClientClaimKeysBlacklisted(t *testing.T) { cfg := config.FederationAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: "server", }, }, diff --git a/federationapi/internal/keys.go b/federationapi/internal/keys.go index 258bd88bf..00e78a1c1 100644 --- a/federationapi/internal/keys.go +++ b/federationapi/internal/keys.go @@ -7,6 +7,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -31,14 +32,14 @@ func (s *FederationInternalAPI) StoreKeys( func (s *FederationInternalAPI) FetchKeys( _ context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { // Run in a background context - we don't want to stop this work just // because the caller gives up waiting. ctx := context.Background() - now := gomatrixserverlib.AsTimestamp(time.Now()) + now := spec.AsTimestamp(time.Now()) results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} - origRequests := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{} + origRequests := map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{} for k, v := range requests { origRequests[k] = v } @@ -95,7 +96,7 @@ func (s *FederationInternalAPI) FetcherName() string { // a request for our own server keys, either current or old. func (s *FederationInternalAPI) handleLocalKeys( _ context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) { for req := range requests { @@ -111,10 +112,10 @@ func (s *FederationInternalAPI) handleLocalKeys( // Insert our own key into the response. results[req] = gomatrixserverlib.PublicKeyLookupResult{ VerifyKey: gomatrixserverlib.VerifyKey{ - Key: gomatrixserverlib.Base64Bytes(s.cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)), + Key: spec.Base64Bytes(s.cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)), }, ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, - ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(s.cfg.Matrix.KeyValidityPeriod)), + ValidUntilTS: spec.AsTimestamp(time.Now().Add(s.cfg.Matrix.KeyValidityPeriod)), } } else { // The key request doesn't match our current key. Let's see @@ -128,7 +129,7 @@ func (s *FederationInternalAPI) handleLocalKeys( // Insert our own key into the response. results[req] = gomatrixserverlib.PublicKeyLookupResult{ VerifyKey: gomatrixserverlib.VerifyKey{ - Key: gomatrixserverlib.Base64Bytes(oldVerifyKey.PrivateKey.Public().(ed25519.PublicKey)), + Key: spec.Base64Bytes(oldVerifyKey.PrivateKey.Public().(ed25519.PublicKey)), }, ExpiredTS: oldVerifyKey.ExpiredAt, ValidUntilTS: gomatrixserverlib.PublicKeyNotValid, @@ -146,8 +147,8 @@ func (s *FederationInternalAPI) handleLocalKeys( // satisfied from our local database/cache. func (s *FederationInternalAPI) handleDatabaseKeys( ctx context.Context, - now gomatrixserverlib.Timestamp, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + now spec.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) error { // Ask the database/cache for the keys. @@ -180,9 +181,9 @@ func (s *FederationInternalAPI) handleDatabaseKeys( // the remaining requests. func (s *FederationInternalAPI) handleFetcherKeys( ctx context.Context, - _ gomatrixserverlib.Timestamp, + _ spec.Timestamp, fetcher gomatrixserverlib.KeyFetcher, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) error { logrus.WithFields(logrus.Fields{ diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index dadb2b2b3..62a08e887 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -9,6 +9,8 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -80,8 +82,8 @@ func (r *FederationInternalAPI) PerformJoin( // Deduplicate the server names we were provided but keep the ordering // as this encodes useful information about which servers are most likely // to respond. - seenSet := make(map[gomatrixserverlib.ServerName]bool) - var uniqueList []gomatrixserverlib.ServerName + seenSet := make(map[spec.ServerName]bool) + var uniqueList []spec.ServerName for _, srv := range request.ServerNames { if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) { continue @@ -143,7 +145,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( ctx context.Context, roomID, userID string, content map[string]interface{}, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, supportedVersions []gomatrixserverlib.RoomVersion, unsigned map[string]interface{}, ) error { @@ -175,7 +177,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" - respMakeJoin.JoinEvent.Type = gomatrixserverlib.MRoomMember + respMakeJoin.JoinEvent.Type = spec.MRoomMember respMakeJoin.JoinEvent.Sender = userID respMakeJoin.JoinEvent.StateKey = &userID respMakeJoin.JoinEvent.RoomID = roomID @@ -184,7 +186,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( content = map[string]interface{}{} } _ = json.Unmarshal(respMakeJoin.JoinEvent.Content, &content) - content["membership"] = gomatrixserverlib.Join + content["membership"] = spec.Join if err = respMakeJoin.JoinEvent.SetContent(content); err != nil { return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err) } @@ -233,7 +235,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // contain signatures that we don't know about. if len(respSendJoin.Event) > 0 { var remoteEvent *gomatrixserverlib.Event - remoteEvent, err = respSendJoin.Event.UntrustedEvent(respMakeJoin.RoomVersion) + remoteEvent, err = gomatrixserverlib.UntrustedEvent(respSendJoin.Event, respMakeJoin.RoomVersion) if err == nil && isWellFormedMembershipEvent( remoteEvent, roomID, userID, ) { @@ -255,10 +257,10 @@ func (r *FederationInternalAPI) performJoinUsingServer( // waste the effort. // TODO: Can we expand Check here to return a list of missing auth // events rather than failing one at a time? - var respState *gomatrixserverlib.RespState - respState, err = respSendJoin.Check( + var respState gomatrixserverlib.StateResponse + respState, err = gomatrixserverlib.CheckSendJoinResponse( context.Background(), - respMakeJoin.RoomVersion, + respMakeJoin.RoomVersion, &respSendJoin, r.keyRing, event, federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName), @@ -273,7 +275,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // joining a room, waiting for 200 OK then changing device keys and have those keys not be sent // to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers") // The events are trusted now as we performed auth checks above. - joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false)) + joinedHosts, err := consumers.JoinedHostsFromEvents(respState.GetStateEvents().TrustedEvents(respMakeJoin.RoomVersion, false)) if err != nil { return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err) } @@ -315,7 +317,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID string) bool { if membership, err := event.Membership(); err != nil { return false - } else if membership != gomatrixserverlib.Join { + } else if membership != spec.Join { return false } if event.RoomID() != roomID { @@ -342,8 +344,8 @@ func (r *FederationInternalAPI) PerformOutboundPeek( // Deduplicate the server names we were provided but keep the ordering // as this encodes useful information about which servers are most likely // to respond. - seenSet := make(map[gomatrixserverlib.ServerName]bool) - var uniqueList []gomatrixserverlib.ServerName + seenSet := make(map[spec.ServerName]bool) + var uniqueList []spec.ServerName for _, srv := range request.ServerNames { if seenSet[srv] { continue @@ -409,7 +411,7 @@ func (r *FederationInternalAPI) PerformOutboundPeek( func (r *FederationInternalAPI) performOutboundPeekUsingServer( ctx context.Context, roomID string, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, supportedVersions []gomatrixserverlib.RoomVersion, ) error { if !r.shouldAttemptDirectFederation(serverName) { @@ -468,11 +470,12 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // we have the peek state now so let's process regardless of whether upstream gives up ctx = context.Background() - respState := respPeek.ToRespState() // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName)) + authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( + ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName), + ) if err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) } @@ -496,7 +499,11 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( if err = roomserverAPI.SendEventWithState( ctx, r.rsAPI, r.cfg.Matrix.ServerName, roomserverAPI.KindNew, - &respState, + // use the authorized state from CheckStateResponse + &fclient.RespState{ + StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(stateEvents), + AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(authEvents), + }, respPeek.LatestEvent.Headered(respPeek.RoomVersion), serverName, nil, @@ -547,7 +554,7 @@ func (r *FederationInternalAPI) PerformLeave( // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" - respMakeLeave.LeaveEvent.Type = gomatrixserverlib.MRoomMember + respMakeLeave.LeaveEvent.Type = spec.MRoomMember respMakeLeave.LeaveEvent.Sender = request.UserID respMakeLeave.LeaveEvent.StateKey = &request.UserID respMakeLeave.LeaveEvent.RoomID = request.RoomID @@ -643,7 +650,7 @@ func (r *FederationInternalAPI) PerformInvite( "destination": destination, }).Info("Sending invite") - inviteReq, err := gomatrixserverlib.NewInviteV2Request(request.Event, request.InviteRoomState) + inviteReq, err := fclient.NewInviteV2Request(request.Event, request.InviteRoomState) if err != nil { return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } @@ -653,7 +660,7 @@ func (r *FederationInternalAPI) PerformInvite( return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } - inviteEvent, err := inviteRes.Event.UntrustedEvent(request.RoomVersion) + inviteEvent, err := gomatrixserverlib.UntrustedEvent(inviteRes.Event, request.RoomVersion) if err != nil { return fmt.Errorf("r.federation.SendInviteV2 failed to decode event response: %w", err) } @@ -699,7 +706,7 @@ func (r *FederationInternalAPI) PerformWakeupServers( return nil } -func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { +func (r *FederationInternalAPI) MarkServersAlive(destinations []spec.ServerName) { for _, srv := range destinations { wasBlacklisted := r.statistics.ForServer(srv).MarkServerAlive() r.queues.RetryServer(srv, wasBlacklisted) @@ -709,7 +716,7 @@ func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverli func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { // sanity check we have a create event and it has a known room version for _, ev := range authChain { - if ev.Type() == gomatrixserverlib.MRoomCreate && ev.StateKeyEquals("") { + if ev.Type() == spec.MRoomCreate && ev.StateKeyEquals("") { // make sure the room version is known content := ev.Content() verBody := struct { @@ -761,7 +768,7 @@ func setDefaultRoomVersionFromJoinEvent( // FederatedAuthProvider is an auth chain provider which fetches events from the server provided func federatedAuthProvider( ctx context.Context, federation api.FederationClient, - keyRing gomatrixserverlib.JSONVerifier, origin, server gomatrixserverlib.ServerName, + keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName, ) gomatrixserverlib.AuthChainProvider { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. @@ -867,7 +874,7 @@ func (r *FederationInternalAPI) P2PRemoveRelayServers( } func (r *FederationInternalAPI) shouldAttemptDirectFederation( - destination gomatrixserverlib.ServerName, + destination spec.ServerName, ) bool { var shouldRelay bool stats := r.statistics.ForServer(destination) diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go index e6e366f99..4f8362425 100644 --- a/federationapi/internal/perform_test.go +++ b/federationapi/internal/perform_test.go @@ -24,7 +24,8 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -35,14 +36,14 @@ type testFedClient struct { shouldFail bool } -func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { - return gomatrixserverlib.RespDirectory{}, nil +func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s spec.ServerName, roomAlias string) (res fclient.RespDirectory, err error) { + return fclient.RespDirectory{}, nil } func TestPerformWakeupServers(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() - server := gomatrixserverlib.ServerName("wakeup") + server := spec.ServerName("wakeup") testDB.AddServerToBlacklist(server) testDB.SetServerAssumedOffline(context.Background(), server) blacklisted, err := testDB.IsServerBlacklisted(server) @@ -54,7 +55,7 @@ func TestPerformWakeupServers(t *testing.T) { cfg := config.FederationAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: "relay", }, }, @@ -72,7 +73,7 @@ func TestPerformWakeupServers(t *testing.T) { ) req := api.PerformWakeupServersRequest{ - ServerNames: []gomatrixserverlib.ServerName{server}, + ServerNames: []spec.ServerName{server}, } res := api.PerformWakeupServersResponse{} err = fedAPI.PerformWakeupServers(context.Background(), &req, &res) @@ -89,14 +90,14 @@ func TestPerformWakeupServers(t *testing.T) { func TestQueryRelayServers(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() - server := gomatrixserverlib.ServerName("wakeup") - relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + server := spec.ServerName("wakeup") + relayServers := []spec.ServerName{"relay1", "relay2"} err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) assert.NoError(t, err) cfg := config.FederationAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: "relay", }, }, @@ -126,14 +127,14 @@ func TestQueryRelayServers(t *testing.T) { func TestRemoveRelayServers(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() - server := gomatrixserverlib.ServerName("wakeup") - relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + server := spec.ServerName("wakeup") + relayServers := []spec.ServerName{"relay1", "relay2"} err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) assert.NoError(t, err) cfg := config.FederationAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: "relay", }, }, @@ -152,7 +153,7 @@ func TestRemoveRelayServers(t *testing.T) { req := api.P2PRemoveRelayServersRequest{ Server: server, - RelayServers: []gomatrixserverlib.ServerName{"relay1"}, + RelayServers: []spec.ServerName{"relay1"}, } res := api.P2PRemoveRelayServersResponse{} err = fedAPI.P2PRemoveRelayServers(context.Background(), &req, &res) @@ -161,7 +162,7 @@ func TestRemoveRelayServers(t *testing.T) { finalRelays, err := testDB.P2PGetRelayServersForServer(context.Background(), server) assert.NoError(t, err) assert.Equal(t, 1, len(finalRelays)) - assert.Equal(t, gomatrixserverlib.ServerName("relay2"), finalRelays[0]) + assert.Equal(t, spec.ServerName("relay2"), finalRelays[0]) } func TestPerformDirectoryLookup(t *testing.T) { @@ -169,7 +170,7 @@ func TestPerformDirectoryLookup(t *testing.T) { cfg := config.FederationAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: "relay", }, }, @@ -198,13 +199,13 @@ func TestPerformDirectoryLookup(t *testing.T) { func TestPerformDirectoryLookupRelaying(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() - server := gomatrixserverlib.ServerName("wakeup") + server := spec.ServerName("wakeup") testDB.SetServerAssumedOffline(context.Background(), server) - testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"}) + testDB.P2PAddRelayServersForServer(context.Background(), server, []spec.ServerName{"relay"}) cfg := config.FederationAPI{ Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: server, }, }, diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index 688afa8ea..e53f19ff8 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -7,6 +7,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -25,7 +26,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom( return } -func (a *FederationInternalAPI) fetchServerKeysDirectly(ctx context.Context, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { +func (a *FederationInternalAPI) fetchServerKeysDirectly(ctx context.Context, serverName spec.ServerName) (*gomatrixserverlib.ServerKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBackingOffOrBlacklisted(serverName, func() (interface{}, error) { diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 6bcfafa39..ede56436a 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -22,6 +22,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -46,7 +47,7 @@ type SyncAPIProducer struct { func (p *SyncAPIProducer) SendReceipt( ctx context.Context, - userID, roomID, eventID, receiptType string, timestamp gomatrixserverlib.Timestamp, + userID, roomID, eventID, receiptType string, timestamp spec.Timestamp, ) error { m := &nats.Msg{ Subject: p.TopicReceiptEvent, @@ -155,7 +156,7 @@ func (p *SyncAPIProducer) SendPresence( if statusMsg != nil { m.Header.Set("status_msg", *statusMsg) } - lastActiveTS := gomatrixserverlib.AsTimestamp(time.Now().Add(-(time.Duration(lastActiveAgo) * time.Millisecond))) + lastActiveTS := spec.AsTimestamp(time.Now().Add(-(time.Duration(lastActiveAgo) * time.Millisecond))) m.Header.Set("last_active_ts", strconv.Itoa(int(lastActiveTS))) log.Tracef("Sending presence to syncAPI: %+v", m.Header) @@ -164,7 +165,7 @@ func (p *SyncAPIProducer) SendPresence( } func (p *SyncAPIProducer) SendDeviceListUpdate( - ctx context.Context, deviceListUpdate gomatrixserverlib.RawJSON, origin gomatrixserverlib.ServerName, + ctx context.Context, deviceListUpdate spec.RawJSON, origin spec.ServerName, ) (err error) { m := nats.NewMsg(p.TopicDeviceListUpdate) m.Header.Set("origin", string(origin)) @@ -175,7 +176,7 @@ func (p *SyncAPIProducer) SendDeviceListUpdate( } func (p *SyncAPIProducer) SendSigningKeyUpdate( - ctx context.Context, data gomatrixserverlib.RawJSON, origin gomatrixserverlib.ServerName, + ctx context.Context, data spec.RawJSON, origin spec.ServerName, ) (err error) { m := nats.NewMsg(p.TopicSigningKeyUpdate) m.Header.Set("origin", string(origin)) diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 12e6db9fa..c136fe751 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -23,6 +23,8 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "go.uber.org/atomic" @@ -50,11 +52,11 @@ type destinationQueue struct { queues *OutgoingQueues db storage.Database process *process.ProcessContext - signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity + signing map[spec.ServerName]*fclient.SigningIdentity rsAPI api.FederationRoomserverAPI client fedapi.FederationClient // federation client - origin gomatrixserverlib.ServerName // origin of requests - destination gomatrixserverlib.ServerName // destination of requests + origin spec.ServerName // origin of requests + destination spec.ServerName // destination of requests running atomic.Bool // is the queue worker running? backingOff atomic.Bool // true if we're backing off overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more @@ -425,7 +427,7 @@ func (oq *destinationQueue) nextTransaction( relaySuccess := false logrus.Infof("Sending %q to relay servers: %v", t.TransactionID, relayServers) // TODO : how to pass through actual userID here?!?!?!?! - userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) + userID, userErr := spec.NewUserID("@user:"+string(oq.destination), false) if userErr != nil { return userErr, sendMethod } @@ -506,7 +508,7 @@ func (oq *destinationQueue) createTransaction( // it so that we retry with the same transaction ID. oq.transactionIDMutex.Lock() if oq.transactionID == "" { - now := gomatrixserverlib.AsTimestamp(time.Now()) + now := spec.AsTimestamp(time.Now()) oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) } oq.transactionIDMutex.Unlock() @@ -517,7 +519,7 @@ func (oq *destinationQueue) createTransaction( } t.Origin = oq.origin t.Destination = oq.destination - t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) + t.OriginServerTS = spec.AsTimestamp(time.Now()) t.TransactionID = oq.transactionID var pduReceipts []*receipt.Receipt diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 5d6b8d44c..860446a05 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -22,6 +22,8 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" @@ -42,12 +44,12 @@ type OutgoingQueues struct { process *process.ProcessContext disabled bool rsAPI api.FederationRoomserverAPI - origin gomatrixserverlib.ServerName + origin spec.ServerName client fedapi.FederationClient statistics *statistics.Statistics - signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity + signing map[spec.ServerName]*fclient.SigningIdentity queuesMutex sync.Mutex // protects the below - queues map[gomatrixserverlib.ServerName]*destinationQueue + queues map[spec.ServerName]*destinationQueue } func init() { @@ -86,11 +88,11 @@ func NewOutgoingQueues( db storage.Database, process *process.ProcessContext, disabled bool, - origin gomatrixserverlib.ServerName, + origin spec.ServerName, client fedapi.FederationClient, rsAPI api.FederationRoomserverAPI, statistics *statistics.Statistics, - signing []*gomatrixserverlib.SigningIdentity, + signing []*fclient.SigningIdentity, ) *OutgoingQueues { queues := &OutgoingQueues{ disabled: disabled, @@ -100,15 +102,15 @@ func NewOutgoingQueues( origin: origin, client: client, statistics: statistics, - signing: map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity{}, - queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, + signing: map[spec.ServerName]*fclient.SigningIdentity{}, + queues: map[spec.ServerName]*destinationQueue{}, } for _, identity := range signing { queues.signing[identity.ServerName] = identity } // Look up which servers we have pending items for and then rehydrate those queues. if !disabled { - serverNames := map[gomatrixserverlib.ServerName]struct{}{} + serverNames := map[spec.ServerName]struct{}{} if names, err := db.GetPendingPDUServerNames(process.Context()); err == nil { for _, serverName := range names { serverNames[serverName] = struct{}{} @@ -147,7 +149,7 @@ type queuedEDU struct { edu *gomatrixserverlib.EDU } -func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { +func (oqs *OutgoingQueues) getQueue(destination spec.ServerName) *destinationQueue { if oqs.statistics.ForServer(destination).Blacklisted() { return nil } @@ -186,8 +188,8 @@ func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) { // SendEvent sends an event to the destinations func (oqs *OutgoingQueues) SendEvent( - ev *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, - destinations []gomatrixserverlib.ServerName, + ev *gomatrixserverlib.HeaderedEvent, origin spec.ServerName, + destinations []spec.ServerName, ) error { if oqs.disabled { log.Trace("Federation is disabled, not sending event") @@ -202,7 +204,7 @@ func (oqs *OutgoingQueues) SendEvent( // Deduplicate destinations and remove the origin from the list of // destinations just to be sure. - destmap := map[gomatrixserverlib.ServerName]struct{}{} + destmap := map[spec.ServerName]struct{}{} for _, d := range destinations { destmap[d] = struct{}{} } @@ -276,8 +278,8 @@ func (oqs *OutgoingQueues) SendEvent( // SendEDU sends an EDU event to the destinations. func (oqs *OutgoingQueues) SendEDU( - e *gomatrixserverlib.EDU, origin gomatrixserverlib.ServerName, - destinations []gomatrixserverlib.ServerName, + e *gomatrixserverlib.EDU, origin spec.ServerName, + destinations []spec.ServerName, ) error { if oqs.disabled { log.Trace("Federation is disabled, not sending EDU") @@ -292,7 +294,7 @@ func (oqs *OutgoingQueues) SendEDU( // Deduplicate destinations and remove the origin from the list of // destinations just to be sure. - destmap := map[gomatrixserverlib.ServerName]struct{}{} + destmap := map[spec.ServerName]struct{}{} for _, d := range destinations { destmap[d] = struct{}{} } @@ -375,7 +377,7 @@ func (oqs *OutgoingQueues) SendEDU( } // RetryServer attempts to resend events to the given server if we had given up. -func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName, wasBlacklisted bool) { +func (oqs *OutgoingQueues) RetryServer(srv spec.ServerName, wasBlacklisted bool) { if oqs.disabled { return } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index bccfb3428..b153cba07 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -21,6 +21,11 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "go.uber.org/atomic" "gotest.tools/v3/poll" @@ -34,31 +39,29 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) { if realDatabase { // Real Database/s - b, baseClose := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) connStr, dbClose := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(b, &config.DatabaseOptions{ + db, err := storage.NewDatabase(processCtx.Context(), cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, b.Caches, b.Cfg.Global.IsLocalServerName) + }, caches, cfg.Global.IsLocalServerName) if err != nil { t.Fatalf("NewDatabase returned %s", err) } - return db, b.ProcessContext, func() { + return db, processCtx, func() { + close() dbClose() - baseClose() } } else { // Fake Database db := test.NewInMemoryFederationDatabase() - b := struct { - ProcessContext *process.ProcessContext - }{ProcessContext: process.NewProcessContext()} - return db, b.ProcessContext, func() {} + return db, process.NewProcessContext(), func() {} } } @@ -79,24 +82,24 @@ type stubFederationClient struct { txRelayCount atomic.Uint32 } -func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { +func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res fclient.RespSend, err error) { var result error if !f.shouldTxSucceed { result = fmt.Errorf("transaction failed") } f.txCount.Add(1) - return gomatrixserverlib.RespSend{}, result + return fclient.RespSend{}, result } -func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) { +func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u spec.UserID, t gomatrixserverlib.Transaction, forwardingServer spec.ServerName) (res fclient.EmptyResp, err error) { var result error if !f.shouldTxRelaySucceed { result = fmt.Errorf("relay transaction failed") } f.txRelayCount.Add(1) - return gomatrixserverlib.EmptyResp{}, result + return fclient.EmptyResp{}, result } func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { @@ -111,7 +114,7 @@ func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { t.Helper() - return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} + return &gomatrixserverlib.EDU{Type: spec.MTyping} } func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32, shouldTxSucceed bool, shouldTxRelaySucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { @@ -126,7 +129,7 @@ func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32 rs := &stubFederationRoomServerAPI{} stats := statistics.NewStatistics(db, failuresUntilBlacklist, failuresUntilAssumedOffline) - signingInfo := []*gomatrixserverlib.SigningIdentity{ + signingInfo := []*fclient.SigningIdentity{ { KeyID: "ed21019:auto", PrivateKey: test.PrivateKeyA, @@ -141,7 +144,7 @@ func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32 func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -150,7 +153,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { }() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -170,7 +173,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -179,7 +182,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { }() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -199,7 +202,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { func TestSendPDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -208,7 +211,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) { }() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -229,7 +232,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) { func TestSendEDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -238,7 +241,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) { }() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -259,7 +262,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) { func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -268,7 +271,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { }() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -287,7 +290,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { fc.shouldTxSucceed = true ev = mustCreatePDU(t) - err = queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err = queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) pollEnd := time.Now().Add(1 * time.Second) @@ -310,7 +313,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -319,7 +322,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { }() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -338,7 +341,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { fc.shouldTxSucceed = true ev = mustCreateEDU(t) - err = queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err = queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) pollEnd := time.Now().Add(1 * time.Second) @@ -361,7 +364,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -370,7 +373,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { }() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -393,7 +396,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -402,7 +405,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { }() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -425,7 +428,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -436,7 +439,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { queues.statistics.ForServer(destination).Failure() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -459,7 +462,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -470,7 +473,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { queues.statistics.ForServer(destination).Failure() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -493,7 +496,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { func TestRetryServerSendsPDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -505,7 +508,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { // before it is blacklisted and deleted. dest := queues.getQueue(destination) ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) checkBlacklisted := func(log poll.LogT) poll.Result { @@ -544,7 +547,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { func TestRetryServerSendsEDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -556,7 +559,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { // before it is blacklisted and deleted. dest := queues.getQueue(destination) ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) checkBlacklisted := func(log poll.LogT) poll.Result { @@ -595,7 +598,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { func TestSendPDUBatches(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) @@ -606,7 +609,7 @@ func TestSendPDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() - destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + destinations := map[spec.ServerName]struct{}{destination: {}} // Populate database with > maxPDUsPerTransaction pduMultiplier := uint32(3) for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ { @@ -618,7 +621,7 @@ func TestSendPDUBatches(t *testing.T) { } ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -639,7 +642,7 @@ func TestSendPDUBatches(t *testing.T) { func TestSendEDUBatches(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) @@ -650,7 +653,7 @@ func TestSendEDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() - destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + destinations := map[spec.ServerName]struct{}{destination: {}} // Populate database with > maxEDUsPerTransaction eduMultiplier := uint32(3) for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ { @@ -662,7 +665,7 @@ func TestSendEDUBatches(t *testing.T) { } ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -683,7 +686,7 @@ func TestSendEDUBatches(t *testing.T) { func TestSendPDUAndEDUBatches(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) @@ -694,7 +697,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() - destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + destinations := map[spec.ServerName]struct{}{destination: {}} // Populate database with > maxEDUsPerTransaction multiplier := uint32(3) for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ { @@ -714,7 +717,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { } ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -737,7 +740,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -747,7 +750,7 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { dest := queues.getQueue(destination) queues.statistics.ForServer(destination).Failure() - destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + destinations := map[spec.ServerName]struct{}{destination: {}} ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) @@ -773,8 +776,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { // NOTE : Only one test case against real databases can be run at a time. t.Parallel() failuresUntilBlacklist := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") - destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + destination := spec.ServerName("remotehost") + destinations := map[spec.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. @@ -788,7 +791,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { // before it is blacklisted and deleted. dest := queues.getQueue(destination) ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) // NOTE : The server can be blacklisted before this, so manually inject the event @@ -841,7 +844,7 @@ func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(7) failuresUntilAssumedOffline := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -850,7 +853,7 @@ func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { }() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -874,7 +877,7 @@ func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(7) failuresUntilAssumedOffline := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -883,7 +886,7 @@ func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) { }() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -907,7 +910,7 @@ func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) failuresUntilAssumedOffline := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -915,11 +918,11 @@ func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { <-pc.WaitForShutdown() }() - relayServers := []gomatrixserverlib.ServerName{"relayserver"} + relayServers := []spec.ServerName{"relayserver"} queues.statistics.ForServer(destination).AddRelayServers(relayServers) ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -946,7 +949,7 @@ func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) failuresUntilAssumedOffline := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -954,11 +957,11 @@ func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { <-pc.WaitForShutdown() }() - relayServers := []gomatrixserverlib.ServerName{"relayserver"} + relayServers := []spec.ServerName{"relayserver"} queues.statistics.ForServer(destination).AddRelayServers(relayServers) ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 272f5e9d8..e99d54a63 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -25,6 +25,8 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -32,7 +34,7 @@ import ( // https://matrix.org/docs/spec/server_server/unstable.html#get-matrix-federation-v1-backfill-roomid func Backfill( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, cfg *config.FederationAPI, @@ -126,7 +128,7 @@ func Backfill( txn := gomatrixserverlib.Transaction{ Origin: request.Destination(), PDUs: eventJSONs, - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + OriginServerTS: spec.AsTimestamp(time.Now()), } // Send the events to the client. diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 871d26cd4..6a2ef1527 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -19,6 +19,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -53,10 +55,10 @@ func GetUserDevices( return jsonerror.InternalAPIError(req.Context(), err) } - response := gomatrixserverlib.RespUserDevices{ + response := fclient.RespUserDevices{ UserID: userID, StreamID: res.StreamID, - Devices: []gomatrixserverlib.RespUserDevice{}, + Devices: []fclient.RespUserDevice{}, } if masterKey, ok := sigRes.MasterKeys[userID]; ok { @@ -67,7 +69,7 @@ func GetUserDevices( } for _, dev := range res.Devices { - var key gomatrixserverlib.RespUserDeviceKeys + var key fclient.RespUserDeviceKeys err := json.Unmarshal(dev.DeviceKeys.KeyJSON, &key) if err != nil { util.GetLogger(req.Context()).WithError(err).Warnf("malformed device key: %s", string(dev.DeviceKeys.KeyJSON)) @@ -79,7 +81,7 @@ func GetUserDevices( displayName = gjson.GetBytes(dev.DeviceKeys.KeyJSON, "unsigned.device_display_name").Str } - device := gomatrixserverlib.RespUserDevice{ + device := fclient.RespUserDevice{ DeviceID: dev.DeviceID, DisplayName: displayName, Keys: key, @@ -90,10 +92,10 @@ func GetUserDevices( for sourceUserID, forSourceUser := range targetKey { for sourceKeyID, sourceKey := range forSourceUser { if device.Keys.Signatures == nil { - device.Keys.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + device.Keys.Signatures = map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } if _, ok := device.Keys.Signatures[sourceUserID]; !ok { - device.Keys.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + device.Keys.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } device.Keys.Signatures[sourceUserID][sourceKeyID] = sourceKey } diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index 2f1f3baf6..89db3e98e 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -19,13 +19,14 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" ) // GetEventAuth returns event auth for the roomID and eventID func GetEventAuth( ctx context.Context, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, eventID string, @@ -70,7 +71,7 @@ func GetEventAuth( return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespEventAuth{ + JSON: fclient.RespEventAuth{ AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), }, } diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index b41292415..e25473890 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -21,6 +21,8 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -30,10 +32,10 @@ import ( // GetEvent returns the requested event func GetEvent( ctx context.Context, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, eventID string, - origin gomatrixserverlib.ServerName, + origin spec.ServerName, ) util.JSONResponse { err := allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) if err != nil { @@ -48,7 +50,7 @@ func GetEvent( return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{ Origin: origin, - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + OriginServerTS: spec.AsTimestamp(time.Now()), PDUs: []json.RawMessage{ event.JSON(), }, @@ -59,7 +61,7 @@ func GetEvent( // otherwise it returns an error response which can be sent to the client. func allowedToSeeEvent( ctx context.Context, - origin gomatrixserverlib.ServerName, + origin spec.ServerName, rsAPI api.FederationRoomserverAPI, eventID string, ) *util.JSONResponse { diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index f424fcacd..b188041ed 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -25,20 +25,21 @@ import ( roomserverVersion "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" ) // InviteV2 implements /_matrix/federation/v2/invite/{roomID}/{eventID} func InviteV2( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, roomID string, eventID string, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, ) util.JSONResponse { - inviteReq := gomatrixserverlib.InviteV2Request{} + inviteReq := fclient.InviteV2Request{} err := json.Unmarshal(request.Content(), &inviteReq) switch e := err.(type) { case gomatrixserverlib.UnsupportedRoomVersionError: @@ -68,7 +69,7 @@ func InviteV2( // InviteV1 implements /_matrix/federation/v1/invite/{roomID}/{eventID} func InviteV1( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, roomID string, eventID string, cfg *config.FederationAPI, @@ -91,7 +92,7 @@ func InviteV1( JSON: jsonerror.NotJSON("The request body could not be decoded into an invite v1 request. " + err.Error()), } } - var strippedState []gomatrixserverlib.InviteV2StrippedState + var strippedState []fclient.InviteV2StrippedState if err := json.Unmarshal(event.Unsigned(), &strippedState); err != nil { // just warn, they may not have added any. util.GetLogger(httpReq.Context()).Warnf("failed to extract stripped state from invite event") @@ -106,7 +107,7 @@ func processInvite( isInviteV2 bool, event *gomatrixserverlib.Event, roomVer gomatrixserverlib.RoomVersion, - strippedState []gomatrixserverlib.InviteV2StrippedState, + strippedState []fclient.InviteV2StrippedState, roomID string, eventID string, cfg *config.FederationAPI, @@ -218,12 +219,12 @@ func processInvite( if isInviteV2 { return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespInviteV2{Event: signedEvent.JSON()}, + JSON: fclient.RespInviteV2{Event: signedEvent.JSON()}, } } else { return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespInvite{Event: signedEvent.JSON()}, + JSON: fclient.RespInvite{Event: signedEvent.JSON()}, } } } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 03809df75..750d5ce9d 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -22,6 +22,8 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -34,7 +36,7 @@ import ( // MakeJoin implements the /make_join API func MakeJoin( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, roomID, userID string, @@ -123,7 +125,7 @@ func MakeJoin( StateKey: &userID, } content := gomatrixserverlib.MemberContent{ - Membership: gomatrixserverlib.Join, + Membership: spec.Join, AuthorisedVia: authorisedVia, } if err = builder.SetContent(content); err != nil { @@ -189,7 +191,7 @@ func MakeJoin( // nolint:gocyclo func SendJoin( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, @@ -230,7 +232,7 @@ func SendJoin( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -277,7 +279,7 @@ func SendJoin( JSON: jsonerror.BadJSON("missing content.membership key"), } } - if membership != gomatrixserverlib.Join { + if membership != spec.Join { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("membership must be 'join'"), @@ -348,8 +350,8 @@ func SendJoin( continue } if membership, merr := se.Membership(); merr == nil { - alreadyJoined = (membership == gomatrixserverlib.Join) - isBanned = (membership == gomatrixserverlib.Ban) + alreadyJoined = (membership == spec.Join) + isBanned = (membership == spec.Ban) break } } @@ -433,7 +435,7 @@ func SendJoin( // https://matrix.org/docs/spec/server_server/latest#put-matrix-federation-v1-send-join-roomid-eventid return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespSendJoin{ + JSON: fclient.RespSendJoin{ StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents), AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents), Origin: cfg.Matrix.ServerName, diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 2885cc916..6c30e5b06 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -25,6 +25,8 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "golang.org/x/crypto/ed25519" @@ -37,7 +39,7 @@ type queryKeysRequest struct { // QueryDeviceKeys returns device keys for users on this server. // https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-query func QueryDeviceKeys( - httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.FederationKeyAPI, thisServer gomatrixserverlib.ServerName, + httpReq *http.Request, request *fclient.FederationRequest, keyAPI api.FederationKeyAPI, thisServer spec.ServerName, ) util.JSONResponse { var qkr queryKeysRequest err := json.Unmarshal(request.Content(), &qkr) @@ -91,7 +93,7 @@ type claimOTKsRequest struct { // ClaimOneTimeKeys claims OTKs for users on this server. // https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-claim func ClaimOneTimeKeys( - httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.FederationKeyAPI, thisServer gomatrixserverlib.ServerName, + httpReq *http.Request, request *fclient.FederationRequest, keyAPI api.FederationKeyAPI, thisServer spec.ServerName, ) util.JSONResponse { var cor claimOTKsRequest err := json.Unmarshal(request.Content(), &cor) @@ -134,7 +136,7 @@ func ClaimOneTimeKeys( // LocalKeys returns the local keys for the server. // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys -func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse { +func LocalKeys(cfg *config.FederationAPI, serverName spec.ServerName) util.JSONResponse { keys, err := localKeys(cfg, serverName) if err != nil { return util.MessageResponse(http.StatusNotFound, err.Error()) @@ -142,9 +144,9 @@ func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam return util.JSONResponse{Code: http.StatusOK, JSON: keys} } -func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { +func localKeys(cfg *config.FederationAPI, serverName spec.ServerName) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys - var identity *gomatrixserverlib.SigningIdentity + var identity *fclient.SigningIdentity var err error if virtualHost := cfg.Matrix.VirtualHostForHTTPHost(serverName); virtualHost == nil { if identity, err = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName); err != nil { @@ -152,10 +154,10 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam } publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) keys.ServerName = cfg.Matrix.ServerName - keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) + keys.ValidUntilTS = spec.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ cfg.Matrix.KeyID: { - Key: gomatrixserverlib.Base64Bytes(publicKey), + Key: spec.Base64Bytes(publicKey), }, } keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} @@ -173,10 +175,10 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam } publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey) keys.ServerName = virtualHost.ServerName - keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod)) + keys.ValidUntilTS = spec.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod)) keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ virtualHost.KeyID: { - Key: gomatrixserverlib.Base64Bytes(publicKey), + Key: spec.Base64Bytes(publicKey), }, } // TODO: Virtual hosts probably want to be able to specify old signing @@ -199,7 +201,7 @@ func NotaryKeys( fsAPI federationAPI.FederationInternalAPI, req *gomatrixserverlib.PublicKeyNotaryLookupRequest, ) util.JSONResponse { - serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal + serverName := spec.ServerName(httpReq.Host) // TODO: this is not ideal if !cfg.Matrix.IsLocalServerName(serverName) { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index f1e9f49ba..afd4b24ac 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -22,6 +22,8 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -29,7 +31,7 @@ import ( // MakeLeave implements the /make_leave API func MakeLeave( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, roomID, userID string, @@ -55,7 +57,7 @@ func MakeLeave( Type: "m.room.member", StateKey: &userID, } - err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Leave}) + err = builder.SetContent(map[string]interface{}{"membership": spec.Leave}) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("builder.SetContent failed") return jsonerror.InternalServerError() @@ -95,7 +97,7 @@ func MakeLeave( if !state.StateKeyEquals(userID) { continue } - if mem, merr := state.Membership(); merr == nil && mem == gomatrixserverlib.Leave { + if mem, merr := state.Membership(); merr == nil && mem == spec.Leave { return util.JSONResponse{ Code: http.StatusOK, JSON: map[string]interface{}{ @@ -132,7 +134,7 @@ func MakeLeave( // nolint:gocyclo func SendLeave( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, @@ -195,7 +197,7 @@ func SendLeave( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -213,7 +215,7 @@ func SendLeave( RoomID: roomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{ { - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: *event.StateKey(), }, }, @@ -242,7 +244,7 @@ func SendLeave( // We are/were joined/invited/banned or something. Check if // we can no-op here. if len(queryRes.StateEvents) == 1 { - if mem, merr := queryRes.StateEvents[0].Membership(); merr == nil && mem == gomatrixserverlib.Leave { + if mem, merr := queryRes.StateEvents[0].Membership(); merr == nil && mem == spec.Leave { return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, @@ -286,7 +288,7 @@ func SendLeave( JSON: jsonerror.BadJSON("missing content.membership key"), } } - if mem != gomatrixserverlib.Leave { + if mem != spec.Leave { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("The membership in the event content must be set to leave"), diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index 531cb9e28..28dc31da6 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -19,6 +19,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" ) @@ -33,7 +34,7 @@ type getMissingEventRequest struct { // Events are fetched from room DAG starting from latest_events until we reach earliest_events or the limit. func GetMissingEvents( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, ) util.JSONResponse { @@ -67,7 +68,7 @@ func GetMissingEvents( eventsResponse.Events = filterEvents(eventsResponse.Events, roomID) - resp := gomatrixserverlib.RespMissingEvents{ + resp := fclient.RespMissingEvents{ Events: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(eventsResponse.Events), } diff --git a/federationapi/routing/peek.go b/federationapi/routing/peek.go index bc4dac90f..2ccf7cfc4 100644 --- a/federationapi/routing/peek.go +++ b/federationapi/routing/peek.go @@ -21,13 +21,14 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" ) // Peek implements the SS /peek API, handling inbound peeks func Peek( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, roomID, peekID string, @@ -87,7 +88,7 @@ func Peek( return util.JSONResponse{Code: http.StatusNotFound, JSON: nil} } - respPeek := gomatrixserverlib.RespPeek{ + respPeek := fclient.RespPeek{ StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.StateEvents), AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), RoomVersion: response.RoomVersion, diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index e4d2230ad..55641b216 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -50,10 +50,7 @@ func GetProfile( } } - var profileRes userapi.QueryProfileResponse - err = userAPI.QueryProfile(httpReq.Context(), &userapi.QueryProfileRequest{ - UserID: userID, - }, &profileRes) + profile, err := userAPI.QueryProfile(httpReq.Context(), userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryProfile failed") return jsonerror.InternalServerError() @@ -65,21 +62,21 @@ func GetProfile( if field != "" { switch field { case "displayname": - res = eventutil.DisplayName{ - DisplayName: profileRes.DisplayName, + res = eventutil.UserProfile{ + DisplayName: profile.DisplayName, } case "avatar_url": - res = eventutil.AvatarURL{ - AvatarURL: profileRes.AvatarURL, + res = eventutil.UserProfile{ + AvatarURL: profile.AvatarURL, } default: code = http.StatusBadRequest res = jsonerror.InvalidArgumentValue("The request body did not contain an allowed value of argument 'field'. Allowed values are either: 'avatar_url', 'displayname'.") } } else { - res = eventutil.ProfileResponse{ - AvatarURL: profileRes.AvatarURL, - DisplayName: profileRes.DisplayName, + res = eventutil.UserProfile{ + AvatarURL: profile.AvatarURL, + DisplayName: profile.DisplayName, } } diff --git a/federationapi/routing/profile_test.go b/federationapi/routing/profile_test.go index 3b9d576bf..18a908e4c 100644 --- a/federationapi/routing/profile_test.go +++ b/federationapi/routing/profile_test.go @@ -23,15 +23,21 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" fedAPI "github.com/matrix-org/dendrite/federationapi" fedInternal "github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" userAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ed25519" ) @@ -40,36 +46,39 @@ type fakeUserAPI struct { userAPI.FederationUserAPI } -func (u *fakeUserAPI) QueryProfile(ctx context.Context, req *userAPI.QueryProfileRequest, res *userAPI.QueryProfileResponse) error { - return nil +func (u *fakeUserAPI) QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error) { + return &authtypes.Profile{}, nil } func TestHandleQueryProfile(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() defer close() fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() - base.PublicFederationAPIMux = fedMux - base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin - base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + natsInstance := jetstream.NATSInstance{} + routers.Federation = fedMux + cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + cfg.FederationAPI.Matrix.Metrics.Enabled = false fedClient := fakeFedClient{} serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + fedapi := fedAPI.NewInternalAPI(processCtx, cfg, cm, &natsInstance, &fedClient, nil, nil, keyRing, true) userapi := fakeUserAPI{} r, ok := fedapi.(*fedInternal.FederationInternalAPI) if !ok { panic("This is a programming error.") } - routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil) + routing.Setup(routers, cfg, nil, r, keyRing, &fedClient, &userapi, &cfg.MSCs, nil, nil, caching.DisableMetrics) handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) keyID := signing.KeyID pk := sk.Public().(ed25519.PublicKey) - serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin))) + serverName := spec.ServerName(hex.EncodeToString(pk)) + req := fclient.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin))) type queryContent struct{} content := queryContent{} err := req.SetContent(content) diff --git a/federationapi/routing/publicrooms.go b/federationapi/routing/publicrooms.go index 34025932a..80343d93a 100644 --- a/federationapi/routing/publicrooms.go +++ b/federationapi/routing/publicrooms.go @@ -7,6 +7,8 @@ import ( "strconv" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -48,9 +50,9 @@ func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.FederationRoomser func publicRooms( ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.FederationRoomserverAPI, -) (*gomatrixserverlib.RespPublicRooms, error) { +) (*fclient.RespPublicRooms, error) { - var response gomatrixserverlib.RespPublicRooms + var response fclient.RespPublicRooms var limit int16 var offset int64 limit = request.Limit @@ -122,14 +124,14 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO } // due to lots of switches -func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.FederationRoomserverAPI) ([]gomatrixserverlib.PublicRoom, error) { +func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.FederationRoomserverAPI) ([]fclient.PublicRoom, error) { avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} - canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} + canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCanonicalAlias, StateKey: ""} topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""} guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""} - visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} - joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} + visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomHistoryVisibility, StateKey: ""} + joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""} var stateRes roomserverAPI.QueryBulkStateContentResponse err := rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ @@ -137,23 +139,23 @@ func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.Fede AllowWildcards: true, StateTuples: []gomatrixserverlib.StateKeyTuple{ nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple, - {EventType: gomatrixserverlib.MRoomMember, StateKey: "*"}, + {EventType: spec.MRoomMember, StateKey: "*"}, }, }, &stateRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed") return nil, err } - chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs)) + chunk := make([]fclient.PublicRoom, len(roomIDs)) i := 0 for roomID, data := range stateRes.Rooms { - pub := gomatrixserverlib.PublicRoom{ + pub := fclient.PublicRoom{ RoomID: roomID, } joinCount := 0 var joinRule, guestAccess string for tuple, contentVal := range data { - if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" { + if tuple.EventType == spec.MRoomMember && contentVal == "join" { joinCount++ continue } @@ -177,7 +179,7 @@ func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.Fede guestAccess = contentVal } } - if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" { + if joinRule == spec.Public && guestAccess == "can_join" { pub.GuestCanJoin = true } pub.JoinedMembersCount = joinCount diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index e6dc52601..6b1c371ec 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" ) @@ -50,7 +51,7 @@ func RoomAliasToID( } } - var resp gomatrixserverlib.RespDirectory + var resp fclient.RespDirectory if domain == cfg.Matrix.ServerName { queryReq := &roomserverAPI.GetRoomIDForAliasRequest{ @@ -71,7 +72,7 @@ func RoomAliasToID( return jsonerror.InternalServerError() } - resp = gomatrixserverlib.RespDirectory{ + resp = fclient.RespDirectory{ RoomID: queryRes.RoomID, Servers: serverQueryRes.ServerNames, } diff --git a/federationapi/routing/query_test.go b/federationapi/routing/query_test.go index d839a16b8..1004fe1ea 100644 --- a/federationapi/routing/query_test.go +++ b/federationapi/routing/query_test.go @@ -28,10 +28,15 @@ import ( fedclient "github.com/matrix-org/dendrite/federationapi/api" fedInternal "github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ed25519" ) @@ -40,36 +45,39 @@ type fakeFedClient struct { fedclient.FederationClient } -func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { +func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s spec.ServerName, roomAlias string) (res fclient.RespDirectory, err error) { return } func TestHandleQueryDirectory(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() defer close() fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() - base.PublicFederationAPIMux = fedMux - base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin - base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + natsInstance := jetstream.NATSInstance{} + routers.Federation = fedMux + cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + cfg.FederationAPI.Matrix.Metrics.Enabled = false fedClient := fakeFedClient{} serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + fedapi := fedAPI.NewInternalAPI(processCtx, cfg, cm, &natsInstance, &fedClient, nil, nil, keyRing, true) userapi := fakeUserAPI{} r, ok := fedapi.(*fedInternal.FederationInternalAPI) if !ok { panic("This is a programming error.") } - routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil) + routing.Setup(routers, cfg, nil, r, keyRing, &fedClient, &userapi, &cfg.MSCs, nil, nil, caching.DisableMetrics) handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) keyID := signing.KeyID pk := sk.Public().(ed25519.PublicKey) - serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server")) + serverName := spec.ServerName(hex.EncodeToString(pk)) + req := fclient.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server")) type queryContent struct{} content := queryContent{} err := req.SetContent(content) diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 324740ddc..41af5e7d1 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -31,10 +31,11 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" @@ -55,7 +56,8 @@ const ( // applied: // nolint: gocyclo func Setup( - base *base.BaseDendrite, + routers httputil.Routers, + dendriteCfg *config.Dendrite, rsAPI roomserverAPI.FederationRoomserverAPI, fsAPI *fedInternal.FederationInternalAPI, keys gomatrixserverlib.JSONVerifier, @@ -63,14 +65,14 @@ func Setup( userAPI userapi.FederationUserAPI, mscCfg *config.MSCs, servers federationAPI.ServersInRoomProvider, - producer *producers.SyncAPIProducer, + producer *producers.SyncAPIProducer, enableMetrics bool, ) { - fedMux := base.PublicFederationAPIMux - keyMux := base.PublicKeyAPIMux - wkMux := base.PublicWellKnownAPIMux - cfg := &base.Cfg.FederationAPI + fedMux := routers.Federation + keyMux := routers.Keys + wkMux := routers.WellKnown + cfg := &dendriteCfg.FederationAPI - if base.EnableMetrics { + if enableMetrics { prometheus.MustRegister( internal.PDUCountTotal, internal.EDUCountTotal, ) @@ -85,7 +87,7 @@ func Setup( } localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { - return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host)) + return LocalKeys(cfg, spec.ServerName(req.Host)) }) notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse { @@ -94,11 +96,11 @@ func Setup( return util.ErrorResponse(err) } var pkReq *gomatrixserverlib.PublicKeyNotaryLookupRequest - serverName := gomatrixserverlib.ServerName(vars["serverName"]) + serverName := spec.ServerName(vars["serverName"]) keyID := gomatrixserverlib.KeyID(vars["keyID"]) if serverName != "" && keyID != "" { pkReq = &gomatrixserverlib.PublicKeyNotaryLookupRequest{ - ServerKeys: map[gomatrixserverlib.ServerName]map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria{ + ServerKeys: map[spec.ServerName]map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria{ serverName: { keyID: gomatrixserverlib.PublicKeyNotaryQueryCriteria{}, }, @@ -136,7 +138,7 @@ func Setup( mu := internal.NewMutexByRoom() v1fedmux.Handle("/send/{txnID}", MakeFedAPI( "federation_send", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), cfg, rsAPI, userAPI, keys, federation, mu, servers, producer, @@ -146,7 +148,7 @@ func Setup( v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -162,7 +164,7 @@ func Setup( v2fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -184,7 +186,7 @@ func Setup( v1fedmux.Handle("/exchange_third_party_invite/{roomID}", MakeFedAPI( "exchange_third_party_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return ExchangeThirdPartyInvite( httpReq, request, vars["roomID"], rsAPI, cfg, federation, ) @@ -193,7 +195,7 @@ func Setup( v1fedmux.Handle("/event/{eventID}", MakeFedAPI( "federation_get_event", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return GetEvent( httpReq.Context(), request, rsAPI, vars["eventID"], cfg.Matrix.ServerName, ) @@ -202,7 +204,7 @@ func Setup( v1fedmux.Handle("/state/{roomID}", MakeFedAPI( "federation_get_state", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -217,7 +219,7 @@ func Setup( v1fedmux.Handle("/state_ids/{roomID}", MakeFedAPI( "federation_get_state_ids", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -232,7 +234,7 @@ func Setup( v1fedmux.Handle("/event_auth/{roomID}/{eventID}", MakeFedAPI( "federation_get_event_auth", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -247,7 +249,7 @@ func Setup( v1fedmux.Handle("/query/directory", MakeFedAPI( "federation_query_room_alias", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return RoomAliasToID( httpReq, federation, cfg, rsAPI, fsAPI, ) @@ -256,7 +258,7 @@ func Setup( v1fedmux.Handle("/query/profile", MakeFedAPI( "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return GetProfile( httpReq, userAPI, cfg, ) @@ -265,7 +267,7 @@ func Setup( v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( httpReq, userAPI, vars["userID"], ) @@ -275,7 +277,7 @@ func Setup( if mscCfg.Enabled("msc2444") { v1fedmux.Handle("/peek/{roomID}/{peekID}", MakeFedAPI( "federation_peek", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -306,7 +308,7 @@ func Setup( v1fedmux.Handle("/make_join/{roomID}/{userID}", MakeFedAPI( "federation_make_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -337,7 +339,7 @@ func Setup( v1fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -369,7 +371,7 @@ func Setup( v2fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -386,7 +388,7 @@ func Setup( v1fedmux.Handle("/make_leave/{roomID}/{eventID}", MakeFedAPI( "federation_make_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -403,7 +405,7 @@ func Setup( v1fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -435,7 +437,7 @@ func Setup( v2fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -459,7 +461,7 @@ func Setup( v1fedmux.Handle("/get_missing_events/{roomID}", MakeFedAPI( "federation_get_missing_events", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -472,7 +474,7 @@ func Setup( v1fedmux.Handle("/backfill/{roomID}", MakeFedAPI( "federation_backfill", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, @@ -491,14 +493,14 @@ func Setup( v1fedmux.Handle("/user/keys/claim", MakeFedAPI( "federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return ClaimOneTimeKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) v1fedmux.Handle("/user/keys/query", MakeFedAPI( "federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return QueryDeviceKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) @@ -536,14 +538,14 @@ func ErrorIfLocalServerNotInRoom( // MakeFedAPI makes an http.Handler that checks matrix federation authentication. func MakeFedAPI( - metricsName string, serverName gomatrixserverlib.ServerName, - isLocalServerName func(gomatrixserverlib.ServerName) bool, + metricsName string, serverName spec.ServerName, + isLocalServerName func(spec.ServerName) bool, keyRing gomatrixserverlib.JSONVerifier, wakeup *FederationWakeups, - f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, + f func(*http.Request, *fclient.FederationRequest, map[string]string) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { - fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + fedReq, errResp := fclient.VerifyHTTPRequest( req, time.Now(), serverName, isLocalServerName, keyRing, ) if fedReq == nil { @@ -586,7 +588,7 @@ type FederationWakeups struct { origins sync.Map } -func (f *FederationWakeups) Wakeup(ctx context.Context, origin gomatrixserverlib.ServerName) { +func (f *FederationWakeups) Wakeup(ctx context.Context, origin spec.ServerName) { key, keyok := f.origins.Load(origin) if keyok { lastTime, ok := key.(time.Time) @@ -594,6 +596,6 @@ func (f *FederationWakeups) Wakeup(ctx context.Context, origin gomatrixserverlib return } } - f.FsAPI.MarkServersAlive([]gomatrixserverlib.ServerName{origin}) + f.FsAPI.MarkServersAlive([]spec.ServerName{origin}) f.origins.Store(origin, time.Now()) } diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 82651719f..deefe5f1a 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -22,6 +22,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -55,7 +56,7 @@ var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse // Send implements /_matrix/federation/v1/send/{txnID} func Send( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, txnID gomatrixserverlib.TransactionID, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index eed4e7e69..55b156e50 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -25,16 +25,21 @@ import ( fedAPI "github.com/matrix-org/dendrite/federationapi" fedInternal "github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ed25519" ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testOrigin = spec.ServerName("kaer.morhen") ) type sendContent struct { @@ -44,28 +49,31 @@ type sendContent struct { func TestHandleSend(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() defer close() fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() - base.PublicFederationAPIMux = fedMux - base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin - base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false - fedapi := fedAPI.NewInternalAPI(base, nil, nil, nil, nil, true) + natsInstance := jetstream.NATSInstance{} + routers.Federation = fedMux + cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedapi := fedAPI.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, nil, nil, true) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() r, ok := fedapi.(*fedInternal.FederationInternalAPI) if !ok { panic("This is a programming error.") } - routing.Setup(base, nil, r, keyRing, nil, nil, &base.Cfg.MSCs, nil, nil) + routing.Setup(routers, cfg, nil, r, keyRing, nil, nil, &cfg.MSCs, nil, nil, caching.DisableMetrics) handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) keyID := signing.KeyID pk := sk.Public().(ed25519.PublicKey) - serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - req := gomatrixserverlib.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234") + serverName := spec.ServerName(hex.EncodeToString(pk)) + req := fclient.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234") content := sendContent{} err := req.SetContent(content) if err != nil { diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 1120cf260..52a8c2b17 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -20,13 +20,14 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" ) // GetState returns state events & auth events for the roomID, eventID func GetState( ctx context.Context, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, ) util.JSONResponse { @@ -40,7 +41,7 @@ func GetState( return *err } - return util.JSONResponse{Code: http.StatusOK, JSON: &gomatrixserverlib.RespState{ + return util.JSONResponse{Code: http.StatusOK, JSON: &fclient.RespState{ AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(authChain), StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateEvents), }} @@ -49,7 +50,7 @@ func GetState( // GetStateIDs returns state event IDs & auth event IDs for the roomID, eventID func GetStateIDs( ctx context.Context, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, ) util.JSONResponse { @@ -66,7 +67,7 @@ func GetStateIDs( stateEventIDs := getIDsFromEvent(stateEvents) authEventIDs := getIDsFromEvent(authEvents) - return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.RespStateIDs{ + return util.JSONResponse{Code: http.StatusOK, JSON: fclient.RespStateIDs{ StateEventIDs: stateEventIDs, AuthEventIDs: authEventIDs, }, @@ -74,7 +75,7 @@ func GetStateIDs( } func parseEventIDParam( - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, ) (eventID string, resErr *util.JSONResponse) { URL, err := url.Parse(request.RequestURI()) if err != nil { @@ -96,7 +97,7 @@ func parseEventIDParam( func getState( ctx context.Context, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, eventID string, diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index d07faef39..ccdca4826 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -27,6 +27,8 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -114,7 +116,7 @@ func CreateInvitesFrom3PIDInvites( // ExchangeThirdPartyInvite implements PUT /_matrix/federation/v1/exchange_third_party_invite/{roomID} func ExchangeThirdPartyInvite( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, roomID string, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, @@ -184,7 +186,7 @@ func ExchangeThirdPartyInvite( // Ask the requesting server to sign the newly created event so we know it // acknowledged it - inviteReq, err := gomatrixserverlib.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil) + inviteReq, err := fclient.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") return jsonerror.InternalServerError() @@ -194,7 +196,7 @@ func ExchangeThirdPartyInvite( util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() } - inviteEvent, err := signedEvent.Event.UntrustedEvent(verRes.RoomVersion) + inviteEvent, err := gomatrixserverlib.UntrustedEvent(signedEvent.Event, verRes.RoomVersion) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() @@ -256,18 +258,15 @@ func createInviteFrom3PIDInvite( StateKey: &inv.MXID, } - var res userapi.QueryProfileResponse - err = userAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{ - UserID: inv.MXID, - }, &res) + profile, err := userAPI.QueryProfile(ctx, inv.MXID) if err != nil { return nil, err } content := gomatrixserverlib.MemberContent{ - AvatarURL: res.AvatarURL, - DisplayName: res.DisplayName, - Membership: gomatrixserverlib.Invite, + AvatarURL: profile.AvatarURL, + DisplayName: profile.DisplayName, + Membership: spec.Invite, ThirdPartyInvite: &gomatrixserverlib.MemberThirdPartyInvite{ Signed: inv.Signed, }, @@ -363,7 +362,7 @@ func sendToRemoteServer( federation federationAPI.FederationClient, cfg *config.FederationAPI, builder gomatrixserverlib.EventBuilder, ) (err error) { - remoteServers := make([]gomatrixserverlib.ServerName, 2) + remoteServers := make([]spec.ServerName, 2) _, remoteServers[0], err = gomatrixserverlib.SplitID('@', inv.Sender) if err != nil { return diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index e29e3b140..e5fc4b940 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -7,11 +7,11 @@ import ( "sync" "time" - "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "go.uber.org/atomic" "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Statistics contains information about all of the remote federated @@ -19,10 +19,10 @@ import ( // wrapper. type Statistics struct { DB storage.Database - servers map[gomatrixserverlib.ServerName]*ServerStatistics + servers map[spec.ServerName]*ServerStatistics mutex sync.RWMutex - backoffTimers map[gomatrixserverlib.ServerName]*time.Timer + backoffTimers map[spec.ServerName]*time.Timer backoffMutex sync.RWMutex // How many times should we tolerate consecutive failures before we @@ -45,14 +45,14 @@ func NewStatistics( DB: db, FailuresUntilBlacklist: failuresUntilBlacklist, FailuresUntilAssumedOffline: failuresUntilAssumedOffline, - backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), - servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), + backoffTimers: make(map[spec.ServerName]*time.Timer), + servers: make(map[spec.ServerName]*ServerStatistics), } } // ForServer returns server statistics for the given server name. If it // does not exist, it will create empty statistics and return those. -func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { +func (s *Statistics) ForServer(serverName spec.ServerName) *ServerStatistics { // Look up if we have statistics for this server already. s.mutex.RLock() server, found := s.servers[serverName] @@ -63,7 +63,7 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS server = &ServerStatistics{ statistics: s, serverName: serverName, - knownRelayServers: []gomatrixserverlib.ServerName{}, + knownRelayServers: []spec.ServerName{}, } s.servers[serverName] = server s.mutex.Unlock() @@ -104,17 +104,17 @@ const ( // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - assumedOffline atomic.Bool // is the node assumed to be offline - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - successCounter atomic.Uint32 // how many times have we succeeded? - backoffNotifier func() // notifies destination queue when backoff completes + statistics *Statistics // + serverName spec.ServerName // + blacklisted atomic.Bool // is the node blacklisted + assumedOffline atomic.Bool // is the node assumed to be offline + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes notifierMutex sync.Mutex - knownRelayServers []gomatrixserverlib.ServerName + knownRelayServers []spec.ServerName relayMutex sync.Mutex } @@ -307,15 +307,15 @@ func (s *ServerStatistics) SuccessCount() uint32 { // KnownRelayServers returns the list of relay servers associated with this // server. -func (s *ServerStatistics) KnownRelayServers() []gomatrixserverlib.ServerName { +func (s *ServerStatistics) KnownRelayServers() []spec.ServerName { s.relayMutex.Lock() defer s.relayMutex.Unlock() return s.knownRelayServers } -func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.ServerName) { - seenSet := make(map[gomatrixserverlib.ServerName]bool) - uniqueList := []gomatrixserverlib.ServerName{} +func (s *ServerStatistics) AddRelayServers(relayServers []spec.ServerName) { + seenSet := make(map[spec.ServerName]bool) + uniqueList := []spec.ServerName{} for _, srv := range relayServers { if seenSet[srv] { continue diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 183b9aa0c..a930bc3b0 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -6,7 +6,7 @@ import ( "time" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -108,10 +108,10 @@ func TestBackoff(t *testing.T) { func TestRelayServersListing(t *testing.T) { stats := NewStatistics(test.NewInMemoryFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline) server := ServerStatistics{statistics: &stats} - server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + server.AddRelayServers([]spec.ServerName{"relay1", "relay1", "relay2"}) relayServers := server.KnownRelayServers() - assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) - server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + assert.Equal(t, []spec.ServerName{"relay1", "relay2"}, relayServers) + server.AddRelayServers([]spec.ServerName{"relay1", "relay1", "relay2"}) relayServers = server.KnownRelayServers() - assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) + assert.Equal(t, []spec.ServerName{"relay1", "relay2"}, relayServers) } diff --git a/federationapi/storage/cache/keydb.go b/federationapi/storage/cache/keydb.go index 2063dfc55..b53695ca4 100644 --- a/federationapi/storage/cache/keydb.go +++ b/federationapi/storage/cache/keydb.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // A Database implements gomatrixserverlib.KeyDatabase and is used to store @@ -36,7 +37,7 @@ func (d KeyDatabase) FetcherName() string { // FetchKeys implements gomatrixserverlib.KeyDatabase func (d *KeyDatabase) FetchKeys( ctx context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) for req, ts := range requests { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 4f5300af1..7652efec0 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -19,6 +19,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/types" @@ -31,57 +32,57 @@ type Database interface { UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) - GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + GetAllJoinedHosts(ctx context.Context) ([]spec.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. - GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) + GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]spec.ServerName, error) StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error) - GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) - GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) + GetPendingPDUs(ctx context.Context, serverName spec.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) + GetPendingEDUs(ctx context.Context, serverName spec.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error - AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error - CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error - CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error + CleanPDUs(ctx context.Context, serverName spec.ServerName, receipts []*receipt.Receipt) error + CleanEDUs(ctx context.Context, serverName spec.ServerName, receipts []*receipt.Receipt) error - GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + GetPendingPDUServerNames(ctx context.Context) ([]spec.ServerName, error) + GetPendingEDUServerNames(ctx context.Context) ([]spec.ServerName, error) // these don't have contexts passed in as we want things to happen regardless of the request context - AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error - RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error + AddServerToBlacklist(serverName spec.ServerName) error + RemoveServerFromBlacklist(serverName spec.ServerName) error RemoveAllServersFromBlacklist() error - IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) + IsServerBlacklisted(serverName spec.ServerName) (bool, error) // Adds the server to the list of assumed offline servers. // If the server already exists in the table, nothing happens and returns success. - SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + SetServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error // Removes the server from the list of assumed offline servers. // If the server doesn't exist in the table, nothing happens and returns success. - RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + RemoveServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error // Purges all entries from the assumed offline table. RemoveAllServersAssumedOffline(ctx context.Context) error // Gets whether the provided server is present in the table. // If it is present, returns true. If not, returns false. - IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error) + IsServerAssumedOffline(ctx context.Context, serverName spec.ServerName) (bool, error) - AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error - RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error - GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) + AddOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error + RenewOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error + GetOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string) (*types.OutboundPeek, error) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) - AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error - RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error - GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) + AddInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error + RenewInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error + GetInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string) (*types.InboundPeek, error) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) // Update the notary with the given server keys from the given server name. - UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error + UpdateNotaryKeys(ctx context.Context, serverName spec.ServerName, serverKeys gomatrixserverlib.ServerKeys) error // Query the notary for the server keys for the given server. If `optKeyIDs` is not empty, multiple server keys may be returned (between 1 - len(optKeyIDs)) // such that the combination of all server keys will include all the `optKeyIDs`. - GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) + GetNotaryKeys(ctx context.Context, serverName spec.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) // DeleteExpiredEDUs cleans up expired EDUs DeleteExpiredEDUs(ctx context.Context) error @@ -91,17 +92,17 @@ type Database interface { type P2PDatabase interface { // Stores the given list of servers as relay servers for the provided destination server. // Providing duplicates will only lead to a single entry and won't lead to an error. - P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + P2PAddRelayServersForServer(ctx context.Context, serverName spec.ServerName, relayServers []spec.ServerName) error // Get the list of relay servers associated with the provided destination server. // If no entry exists in the table, an empty list is returned and does not result in an error. - P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + P2PGetRelayServersForServer(ctx context.Context, serverName spec.ServerName) ([]spec.ServerName, error) // Deletes any entries for the provided destination server that match the provided relayServers list. // If any of the provided servers don't match an entry, nothing happens and no error is returned. - P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + P2PRemoveRelayServersForServer(ctx context.Context, serverName spec.ServerName, relayServers []spec.ServerName) error // Deletes all entries for the provided destination server. // If the destination server doesn't exist in the table, nothing happens and no error is returned. - P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error + P2PRemoveAllRelayServersForServer(ctx context.Context, serverName spec.ServerName) error } diff --git a/federationapi/storage/postgres/assumed_offline_table.go b/federationapi/storage/postgres/assumed_offline_table.go index 5695d2e54..d8d389d86 100644 --- a/federationapi/storage/postgres/assumed_offline_table.go +++ b/federationapi/storage/postgres/assumed_offline_table.go @@ -19,7 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const assumedOfflineSchema = ` @@ -68,7 +68,7 @@ func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, er } func (s *assumedOfflineStatements) InsertAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) _, err := stmt.ExecContext(ctx, serverName) @@ -76,7 +76,7 @@ func (s *assumedOfflineStatements) InsertAssumedOffline( } func (s *assumedOfflineStatements) SelectAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (bool, error) { stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) res, err := stmt.QueryContext(ctx, serverName) @@ -91,7 +91,7 @@ func (s *assumedOfflineStatements) SelectAssumedOffline( } func (s *assumedOfflineStatements) DeleteAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) _, err := stmt.ExecContext(ctx, serverName) diff --git a/federationapi/storage/postgres/blacklist_table.go b/federationapi/storage/postgres/blacklist_table.go index eef37318b..48b6d72e1 100644 --- a/federationapi/storage/postgres/blacklist_table.go +++ b/federationapi/storage/postgres/blacklist_table.go @@ -19,7 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const blacklistSchema = ` @@ -60,23 +60,16 @@ func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { return } - if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil { - return - } - if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil { - return - } - if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil { - return - } - if s.deleteAllBlacklistStmt, err = db.Prepare(deleteAllBlacklistSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertBlacklistStmt, insertBlacklistSQL}, + {&s.selectBlacklistStmt, selectBlacklistSQL}, + {&s.deleteBlacklistStmt, deleteBlacklistSQL}, + {&s.deleteAllBlacklistStmt, deleteAllBlacklistSQL}, + }.Prepare(db) } func (s *blacklistStatements) InsertBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) _, err := stmt.ExecContext(ctx, serverName) @@ -84,7 +77,7 @@ func (s *blacklistStatements) InsertBlacklist( } func (s *blacklistStatements) SelectBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (bool, error) { stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) res, err := stmt.QueryContext(ctx, serverName) @@ -99,7 +92,7 @@ func (s *blacklistStatements) SelectBlacklist( } func (s *blacklistStatements) DeleteBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) _, err := stmt.ExecContext(ctx, serverName) diff --git a/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go b/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go index 53a7a025e..cf2d94b20 100644 --- a/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go +++ b/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go @@ -20,7 +20,7 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error { @@ -28,7 +28,7 @@ func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error { if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } - _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24))) + _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", spec.AsTimestamp(time.Now().Add(time.Hour*24))) if err != nil { return fmt.Errorf("failed to update queue_edus: %w", err) } diff --git a/federationapi/storage/postgres/inbound_peeks_table.go b/federationapi/storage/postgres/inbound_peeks_table.go index ad2afcb15..a6fffc0e1 100644 --- a/federationapi/storage/postgres/inbound_peeks_table.go +++ b/federationapi/storage/postgres/inbound_peeks_table.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const inboundPeeksSchema = ` @@ -86,7 +86,7 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er } func (s *inboundPeeksStatements) InsertInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt) @@ -95,7 +95,7 @@ func (s *inboundPeeksStatements) InsertInboundPeek( } func (s *inboundPeeksStatements) RenewInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) @@ -103,7 +103,7 @@ func (s *inboundPeeksStatements) RenewInboundPeek( } func (s *inboundPeeksStatements) SelectInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (*types.InboundPeek, error) { row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID) inboundPeek := types.InboundPeek{} @@ -152,7 +152,7 @@ func (s *inboundPeeksStatements) SelectInboundPeeks( } func (s *inboundPeeksStatements) DeleteInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) return diff --git a/federationapi/storage/postgres/joined_hosts_table.go b/federationapi/storage/postgres/joined_hosts_table.go index 9a3977560..2b0aebad1 100644 --- a/federationapi/storage/postgres/joined_hosts_table.go +++ b/federationapi/storage/postgres/joined_hosts_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const joinedHostsSchema = ` @@ -90,35 +90,22 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro if err != nil { return } - if s.insertJoinedHostsStmt, err = s.db.Prepare(insertJoinedHostsSQL); err != nil { - return - } - if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil { - return - } - if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil { - return - } - if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil { - return - } - if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil { - return - } - if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { - return - } - if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertJoinedHostsStmt, insertJoinedHostsSQL}, + {&s.deleteJoinedHostsStmt, deleteJoinedHostsSQL}, + {&s.deleteJoinedHostsForRoomStmt, deleteJoinedHostsForRoomSQL}, + {&s.selectJoinedHostsStmt, selectJoinedHostsSQL}, + {&s.selectAllJoinedHostsStmt, selectAllJoinedHostsSQL}, + {&s.selectJoinedHostsForRoomsStmt, selectJoinedHostsForRoomsSQL}, + {&s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, selectJoinedHostsForRoomsExcludingBlacklistedSQL}, + }.Prepare(db) } func (s *joinedHostsStatements) InsertJoinedHosts( ctx context.Context, txn *sql.Tx, roomID, eventID string, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) @@ -156,20 +143,20 @@ func (s *joinedHostsStatements) SelectJoinedHosts( func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var serverName string if err = rows.Scan(&serverName); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(serverName)) + result = append(result, spec.ServerName(serverName)) } return result, rows.Err() @@ -177,7 +164,7 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( func (s *joinedHostsStatements) SelectJoinedHostsForRooms( ctx context.Context, roomIDs []string, excludingBlacklisted bool, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { stmt := s.selectJoinedHostsForRoomsStmt if excludingBlacklisted { stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt @@ -188,13 +175,13 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms( } defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var serverName string if err = rows.Scan(&serverName); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(serverName)) + result = append(result, spec.ServerName(serverName)) } return result, rows.Err() @@ -217,7 +204,7 @@ func joinedHostsFromStmt( } result = append(result, types.JoinedHost{ MemberEventID: eventID, - ServerName: gomatrixserverlib.ServerName(serverName), + ServerName: spec.ServerName(serverName), }) } diff --git a/federationapi/storage/postgres/notary_server_keys_json_table.go b/federationapi/storage/postgres/notary_server_keys_json_table.go index 65221c088..af98a0d4e 100644 --- a/federationapi/storage/postgres/notary_server_keys_json_table.go +++ b/federationapi/storage/postgres/notary_server_keys_json_table.go @@ -19,7 +19,9 @@ import ( "database/sql" "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const notaryServerKeysJSONSchema = ` @@ -50,14 +52,13 @@ func NewPostgresNotaryServerKeysTable(db *sql.DB) (s *notaryServerKeysStatements return } - if s.insertServerKeysJSONStmt, err = db.Prepare(insertServerKeysJSONSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertServerKeysJSONStmt, insertServerKeysJSONSQL}, + }.Prepare(db) } func (s *notaryServerKeysStatements) InsertJSONResponse( - ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp, + ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName spec.ServerName, validUntil spec.Timestamp, ) (tables.NotaryID, error) { var notaryID tables.NotaryID return notaryID, txn.Stmt(s.insertServerKeysJSONStmt).QueryRowContext(ctx, string(keyQueryResponseJSON.Raw), serverName, validUntil).Scan(¬aryID) diff --git a/federationapi/storage/postgres/notary_server_keys_metadata_table.go b/federationapi/storage/postgres/notary_server_keys_metadata_table.go index 72faf4809..7a1ec4122 100644 --- a/federationapi/storage/postgres/notary_server_keys_metadata_table.go +++ b/federationapi/storage/postgres/notary_server_keys_metadata_table.go @@ -22,7 +22,9 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const notaryServerKeysMetadataSchema = ` @@ -91,31 +93,22 @@ func NewPostgresNotaryServerKeysMetadataTable(db *sql.DB) (s *notaryServerKeysMe return } - if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil { - return - } - if s.selectNotaryKeyResponsesStmt, err = db.Prepare(selectNotaryKeyResponsesSQL); err != nil { - return - } - if s.selectNotaryKeyResponsesWithKeyIDsStmt, err = db.Prepare(selectNotaryKeyResponsesWithKeyIDsSQL); err != nil { - return - } - if s.selectNotaryKeyMetadataStmt, err = db.Prepare(selectNotaryKeyMetadataSQL); err != nil { - return - } - if s.deleteUnusedServerKeysJSONStmt, err = db.Prepare(deleteUnusedServerKeysJSONSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.upsertServerKeysStmt, upsertServerKeysSQL}, + {&s.selectNotaryKeyResponsesStmt, selectNotaryKeyResponsesSQL}, + {&s.selectNotaryKeyResponsesWithKeyIDsStmt, selectNotaryKeyResponsesWithKeyIDsSQL}, + {&s.selectNotaryKeyMetadataStmt, selectNotaryKeyMetadataSQL}, + {&s.deleteUnusedServerKeysJSONStmt, deleteUnusedServerKeysJSONSQL}, + }.Prepare(db) } func (s *notaryServerKeysMetadataStatements) UpsertKey( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil gomatrixserverlib.Timestamp, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil spec.Timestamp, ) (tables.NotaryID, error) { notaryID := newNotaryID // see if the existing notary ID a) exists, b) has a longer valid_until var existingNotaryID tables.NotaryID - var existingValidUntil gomatrixserverlib.Timestamp + var existingValidUntil spec.Timestamp if err := txn.Stmt(s.selectNotaryKeyMetadataStmt).QueryRowContext(ctx, serverName, keyID).Scan(&existingNotaryID, &existingValidUntil); err != nil { if err != sql.ErrNoRows { return 0, err @@ -130,7 +123,7 @@ func (s *notaryServerKeysMetadataStatements) UpsertKey( return notaryID, err } -func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { +func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { var rows *sql.Rows var err error if len(keyIDs) == 0 { diff --git a/federationapi/storage/postgres/outbound_peeks_table.go b/federationapi/storage/postgres/outbound_peeks_table.go index 5df684318..bd2b10e67 100644 --- a/federationapi/storage/postgres/outbound_peeks_table.go +++ b/federationapi/storage/postgres/outbound_peeks_table.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const outboundPeeksSchema = ` @@ -85,7 +85,7 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err } func (s *outboundPeeksStatements) InsertOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt) @@ -94,7 +94,7 @@ func (s *outboundPeeksStatements) InsertOutboundPeek( } func (s *outboundPeeksStatements) RenewOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) @@ -102,7 +102,7 @@ func (s *outboundPeeksStatements) RenewOutboundPeek( } func (s *outboundPeeksStatements) SelectOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (*types.OutboundPeek, error) { row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) outboundPeek := types.OutboundPeek{} @@ -151,7 +151,7 @@ func (s *outboundPeeksStatements) SelectOutboundPeeks( } func (s *outboundPeeksStatements) DeleteOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) return diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go index 8870dc88d..7c57ed0cc 100644 --- a/federationapi/storage/postgres/queue_edus_table.go +++ b/federationapi/storage/postgres/queue_edus_table.go @@ -19,11 +19,11 @@ import ( "database/sql" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/federationapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" ) const queueEDUsSchema = ` @@ -121,9 +121,9 @@ func (s *queueEDUsStatements) InsertQueueEDU( ctx context.Context, txn *sql.Tx, eduType string, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, - expiresAt gomatrixserverlib.Timestamp, + expiresAt spec.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) _, err := stmt.ExecContext( @@ -138,7 +138,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( func (s *queueEDUsStatements) DeleteQueueEDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { stmt := sqlutil.TxStmt(txn, s.deleteQueueEDUStmt) @@ -148,7 +148,7 @@ func (s *queueEDUsStatements) DeleteQueueEDUs( func (s *queueEDUsStatements) SelectQueueEDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) @@ -182,16 +182,16 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName if err = rows.Scan(&serverName); err != nil { return nil, err } @@ -203,7 +203,7 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( func (s *queueEDUsStatements) SelectExpiredEDUs( ctx context.Context, txn *sql.Tx, - expiredBefore gomatrixserverlib.Timestamp, + expiredBefore spec.Timestamp, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) rows, err := stmt.QueryContext(ctx, expiredBefore) @@ -224,7 +224,7 @@ func (s *queueEDUsStatements) SelectExpiredEDUs( func (s *queueEDUsStatements) DeleteExpiredEDUs( ctx context.Context, txn *sql.Tx, - expiredBefore gomatrixserverlib.Timestamp, + expiredBefore spec.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt) _, err := stmt.ExecContext(ctx, expiredBefore) diff --git a/federationapi/storage/postgres/queue_json_table.go b/federationapi/storage/postgres/queue_json_table.go index e33074182..563738dd5 100644 --- a/federationapi/storage/postgres/queue_json_table.go +++ b/federationapi/storage/postgres/queue_json_table.go @@ -65,16 +65,11 @@ func NewPostgresQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { if err != nil { return } - if s.insertJSONStmt, err = s.db.Prepare(insertJSONSQL); err != nil { - return - } - if s.deleteJSONStmt, err = s.db.Prepare(deleteJSONSQL); err != nil { - return - } - if s.selectJSONStmt, err = s.db.Prepare(selectJSONSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertJSONSQL}, + {&s.deleteJSONStmt, deleteJSONSQL}, + {&s.selectJSONStmt, selectJSONSQL}, + }.Prepare(db) } func (s *queueJSONStatements) InsertQueueJSON( diff --git a/federationapi/storage/postgres/queue_pdus_table.go b/federationapi/storage/postgres/queue_pdus_table.go index 3b0bef9af..a767ec41d 100644 --- a/federationapi/storage/postgres/queue_pdus_table.go +++ b/federationapi/storage/postgres/queue_pdus_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const queuePDUsSchema = ` @@ -78,29 +79,20 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { if err != nil { return } - if s.insertQueuePDUStmt, err = s.db.Prepare(insertQueuePDUSQL); err != nil { - return - } - if s.deleteQueuePDUsStmt, err = s.db.Prepare(deleteQueuePDUSQL); err != nil { - return - } - if s.selectQueuePDUsStmt, err = s.db.Prepare(selectQueuePDUsSQL); err != nil { - return - } - if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { - return - } - if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertQueuePDUStmt, insertQueuePDUSQL}, + {&s.deleteQueuePDUsStmt, deleteQueuePDUSQL}, + {&s.selectQueuePDUsStmt, selectQueuePDUsSQL}, + {&s.selectQueuePDUReferenceJSONCountStmt, selectQueuePDUReferenceJSONCountSQL}, + {&s.selectQueuePDUServerNamesStmt, selectQueuePDUServerNamesSQL}, + }.Prepare(db) } func (s *queuePDUsStatements) InsertQueuePDU( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) @@ -115,7 +107,7 @@ func (s *queuePDUsStatements) InsertQueuePDU( func (s *queuePDUsStatements) DeleteQueuePDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { stmt := sqlutil.TxStmt(txn, s.deleteQueuePDUsStmt) @@ -140,7 +132,7 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt) @@ -163,16 +155,16 @@ func (s *queuePDUsStatements) SelectQueuePDUs( func (s *queuePDUsStatements) SelectQueuePDUServerNames( ctx context.Context, txn *sql.Tx, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectQueuePDUServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName if err = rows.Scan(&serverName); err != nil { return nil, err } diff --git a/federationapi/storage/postgres/relay_servers_table.go b/federationapi/storage/postgres/relay_servers_table.go index f7267978f..9e1bc5d40 100644 --- a/federationapi/storage/postgres/relay_servers_table.go +++ b/federationapi/storage/postgres/relay_servers_table.go @@ -21,7 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const relayServersSchema = ` @@ -78,8 +78,8 @@ func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err er func (s *relayServersStatements) InsertRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { for _, relayServer := range relayServers { stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) @@ -93,8 +93,8 @@ func (s *relayServersStatements) InsertRelayServers( func (s *relayServersStatements) SelectRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, -) ([]gomatrixserverlib.ServerName, error) { + serverName spec.ServerName, +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) rows, err := stmt.QueryContext(ctx, serverName) if err != nil { @@ -102,13 +102,13 @@ func (s *relayServersStatements) SelectRelayServers( } defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var relayServer string if err = rows.Scan(&relayServer); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(relayServer)) + result = append(result, spec.ServerName(relayServer)) } return result, nil } @@ -116,8 +116,8 @@ func (s *relayServersStatements) SelectRelayServers( func (s *relayServersStatements) DeleteRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) _, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers)) @@ -127,7 +127,7 @@ func (s *relayServersStatements) DeleteRelayServers( func (s *relayServersStatements) DeleteAllRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) if _, err := stmt.ExecContext(ctx, serverName); err != nil { diff --git a/federationapi/storage/postgres/server_key_table.go b/federationapi/storage/postgres/server_key_table.go index 16e294e05..c62446da5 100644 --- a/federationapi/storage/postgres/server_key_table.go +++ b/federationapi/storage/postgres/server_key_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const serverSigningKeysSchema = ` @@ -72,18 +73,15 @@ func NewPostgresServerSigningKeysTable(db *sql.DB) (s *serverSigningKeyStatement if err != nil { return } - if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerSigningKeysSQL); err != nil { - return - } - if s.upsertServerKeysStmt, err = db.Prepare(upsertServerSigningKeysSQL); err != nil { - return - } - return s, nil + return s, sqlutil.StatementList{ + {&s.bulkSelectServerKeysStmt, bulkSelectServerSigningKeysSQL}, + {&s.upsertServerKeysStmt, upsertServerSigningKeysSQL}, + }.Prepare(db) } func (s *serverSigningKeyStatements) BulkSelectServerKeys( ctx context.Context, txn *sql.Tx, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { var nameAndKeyIDs []string for request := range requests { @@ -106,7 +104,7 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys( return nil, err } r := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: gomatrixserverlib.ServerName(serverName), + ServerName: spec.ServerName(serverName), KeyID: gomatrixserverlib.KeyID(keyID), } vk := gomatrixserverlib.VerifyKey{} @@ -116,8 +114,8 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys( } results[r] = gomatrixserverlib.PublicKeyLookupResult{ VerifyKey: vk, - ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), - ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), + ValidUntilTS: spec.Timestamp(validUntilTS), + ExpiredTS: spec.Timestamp(expiredTS), } } return results, rows.Err() diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index b81f128e7..30665bc56 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -16,6 +16,7 @@ package postgres import ( + "context" "database/sql" "fmt" @@ -23,9 +24,8 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Database stores information needed by the federation sender @@ -36,10 +36,10 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) { var d Database var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { return nil, err } blacklist, err := NewPostgresBlacklistTable(d.db) @@ -95,7 +95,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Version: "federationsender: drop federationsender_rooms", Up: deltas.UpRemoveRoomsTable, }) - err = m.Up(base.Context()) + err = m.Up(ctx) if err != nil { return nil, err } diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 6769637bc..8c73967c6 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -26,11 +26,12 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Database struct { DB *sql.DB - IsLocalServerName func(gomatrixserverlib.ServerName) bool + IsLocalServerName func(spec.ServerName) bool Cache caching.FederationCache Writer sqlutil.Writer FederationQueuePDUs tables.FederationQueuePDUs @@ -102,7 +103,7 @@ func (d *Database) GetJoinedHosts( // Returns an error if something goes wrong. func (d *Database) GetAllJoinedHosts( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } @@ -111,7 +112,7 @@ func (d *Database) GetJoinedHostsForRooms( roomIDs []string, excludeSelf, excludeBlacklisted bool, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err @@ -148,7 +149,7 @@ func (d *Database) StoreJSON( } func (d *Database) AddServerToBlacklist( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) @@ -156,7 +157,7 @@ func (d *Database) AddServerToBlacklist( } func (d *Database) RemoveServerFromBlacklist( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName) @@ -170,14 +171,14 @@ func (d *Database) RemoveAllServersFromBlacklist() error { } func (d *Database) IsServerBlacklisted( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (bool, error) { return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } func (d *Database) SetServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName) @@ -186,7 +187,7 @@ func (d *Database) SetServerAssumedOffline( func (d *Database) RemoveServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName) @@ -203,15 +204,15 @@ func (d *Database) RemoveAllServersAssumedOffline( func (d *Database) IsServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (bool, error) { return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName) } func (d *Database) P2PAddRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers) @@ -220,15 +221,15 @@ func (d *Database) P2PAddRelayServersForServer( func (d *Database) P2PGetRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, -) ([]gomatrixserverlib.ServerName, error) { + serverName spec.ServerName, +) ([]spec.ServerName, error) { return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName) } func (d *Database) P2PRemoveRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers) @@ -237,7 +238,7 @@ func (d *Database) P2PRemoveRelayServersForServer( func (d *Database) P2PRemoveAllRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName) @@ -246,7 +247,7 @@ func (d *Database) P2PRemoveAllRelayServersForServer( func (d *Database) AddOutboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID string, peekID string, renewalInterval int64, @@ -258,7 +259,7 @@ func (d *Database) AddOutboundPeek( func (d *Database) RenewOutboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID string, peekID string, renewalInterval int64, @@ -270,7 +271,7 @@ func (d *Database) RenewOutboundPeek( func (d *Database) GetOutboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID, peekID string, ) (*types.OutboundPeek, error) { @@ -286,7 +287,7 @@ func (d *Database) GetOutboundPeeks( func (d *Database) AddInboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID string, peekID string, renewalInterval int64, @@ -298,7 +299,7 @@ func (d *Database) AddInboundPeek( func (d *Database) RenewInboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID string, peekID string, renewalInterval int64, @@ -310,7 +311,7 @@ func (d *Database) RenewInboundPeek( func (d *Database) GetInboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID string, peekID string, ) (*types.InboundPeek, error) { @@ -326,7 +327,7 @@ func (d *Database) GetInboundPeeks( func (d *Database) UpdateNotaryKeys( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, serverKeys gomatrixserverlib.ServerKeys, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -337,7 +338,7 @@ func (d *Database) UpdateNotaryKeys( // https://spec.matrix.org/unstable/server-server-api/#querying-keys-through-another-server weekIntoFuture := time.Now().Add(7 * 24 * time.Hour) if weekIntoFuture.Before(validUntil.Time()) { - validUntil = gomatrixserverlib.AsTimestamp(weekIntoFuture) + validUntil = spec.AsTimestamp(weekIntoFuture) } notaryID, err := d.NotaryServerKeysJSON.InsertJSONResponse(ctx, txn, serverKeys, serverName, validUntil) if err != nil { @@ -364,7 +365,7 @@ func (d *Database) UpdateNotaryKeys( func (d *Database) GetNotaryKeys( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, optKeyIDs []gomatrixserverlib.KeyID, ) (sks []gomatrixserverlib.ServerKeys, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index cff1ade6f..e8d1d3733 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // defaultExpiry for EDUs if not listed below @@ -32,8 +33,8 @@ var defaultExpiry = time.Hour * 24 // defaultExpireEDUTypes contains EDUs which can/should be expired after a given time // if the target server isn't reachable for some reason. var defaultExpireEDUTypes = map[string]time.Duration{ - gomatrixserverlib.MTyping: time.Minute, - gomatrixserverlib.MPresence: time.Minute * 10, + spec.MTyping: time.Minute, + spec.MPresence: time.Minute * 10, } // AssociateEDUWithDestination creates an association that the @@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{ // to which servers. func (d *Database) AssociateEDUWithDestinations( ctx context.Context, - destinations map[gomatrixserverlib.ServerName]struct{}, + destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, @@ -49,14 +50,14 @@ func (d *Database) AssociateEDUWithDestinations( if expireEDUTypes == nil { expireEDUTypes = defaultExpireEDUTypes } - expiresAt := gomatrixserverlib.AsTimestamp(time.Now().Add(defaultExpiry)) + expiresAt := spec.AsTimestamp(time.Now().Add(defaultExpiry)) if duration, ok := expireEDUTypes[eduType]; ok { // Keep EDUs for at least x minutes before deleting them - expiresAt = gomatrixserverlib.AsTimestamp(time.Now().Add(duration)) + expiresAt = spec.AsTimestamp(time.Now().Add(duration)) } // We forcibly set m.direct_to_device and m.device_list_update events // to 0, as we always want them to be delivered. (required for E2EE) - if eduType == gomatrixserverlib.MDirectToDevice || eduType == gomatrixserverlib.MDeviceListUpdate { + if eduType == spec.MDirectToDevice || eduType == spec.MDeviceListUpdate { expiresAt = 0 } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -79,7 +80,7 @@ func (d *Database) AssociateEDUWithDestinations( // the next pending transaction, up to the limit specified. func (d *Database) GetPendingEDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ( edus map[*receipt.Receipt]*gomatrixserverlib.EDU, @@ -126,7 +127,7 @@ func (d *Database) GetPendingEDUs( // transaction was sent successfully. func (d *Database) CleanEDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { @@ -169,7 +170,7 @@ func (d *Database) CleanEDUs( // waiting to be sent. func (d *Database) GetPendingEDUServerNames( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil) } @@ -177,7 +178,7 @@ func (d *Database) GetPendingEDUServerNames( func (d *Database) DeleteExpiredEDUs(ctx context.Context) error { var jsonNIDs []int64 err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) (err error) { - expiredBefore := gomatrixserverlib.AsTimestamp(time.Now()) + expiredBefore := spec.AsTimestamp(time.Now()) jsonNIDs, err = d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore) if err != nil { return err diff --git a/federationapi/storage/shared/storage_keys.go b/federationapi/storage/shared/storage_keys.go index 3222b1224..580cf1d84 100644 --- a/federationapi/storage/shared/storage_keys.go +++ b/federationapi/storage/shared/storage_keys.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // FetcherName implements KeyFetcher @@ -30,7 +31,7 @@ func (d Database) FetcherName() string { // FetchKeys implements gomatrixserverlib.KeyDatabase func (d *Database) FetchKeys( ctx context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { return d.ServerSigningKeys.BulkSelectServerKeys(ctx, nil, requests) } diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index 854e00553..b7fad06bd 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // AssociatePDUWithDestination creates an association that the @@ -30,7 +31,7 @@ import ( // to which servers. func (d *Database) AssociatePDUWithDestinations( ctx context.Context, - destinations map[gomatrixserverlib.ServerName]struct{}, + destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -52,7 +53,7 @@ func (d *Database) AssociatePDUWithDestinations( // the next pending transaction, up to the limit specified. func (d *Database) GetPendingPDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ( events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, @@ -105,7 +106,7 @@ func (d *Database) GetPendingPDUs( // successfully. func (d *Database) CleanPDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { @@ -148,6 +149,6 @@ func (d *Database) CleanPDUs( // waiting to be sent. func (d *Database) GetPendingPDUServerNames( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { return d.FederationQueuePDUs.SelectQueuePDUServerNames(ctx, nil) } diff --git a/federationapi/storage/sqlite3/assumed_offline_table.go b/federationapi/storage/sqlite3/assumed_offline_table.go index ff2afb4da..f8de7f0c5 100644 --- a/federationapi/storage/sqlite3/assumed_offline_table.go +++ b/federationapi/storage/sqlite3/assumed_offline_table.go @@ -19,7 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const assumedOfflineSchema = ` @@ -68,7 +68,7 @@ func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err } func (s *assumedOfflineStatements) InsertAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) _, err := stmt.ExecContext(ctx, serverName) @@ -76,7 +76,7 @@ func (s *assumedOfflineStatements) InsertAssumedOffline( } func (s *assumedOfflineStatements) SelectAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (bool, error) { stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) res, err := stmt.QueryContext(ctx, serverName) @@ -91,7 +91,7 @@ func (s *assumedOfflineStatements) SelectAssumedOffline( } func (s *assumedOfflineStatements) DeleteAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) _, err := stmt.ExecContext(ctx, serverName) diff --git a/federationapi/storage/sqlite3/blacklist_table.go b/federationapi/storage/sqlite3/blacklist_table.go index 2694e630d..2c65c487c 100644 --- a/federationapi/storage/sqlite3/blacklist_table.go +++ b/federationapi/storage/sqlite3/blacklist_table.go @@ -19,7 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const blacklistSchema = ` @@ -60,23 +60,16 @@ func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { return } - if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil { - return - } - if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil { - return - } - if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil { - return - } - if s.deleteAllBlacklistStmt, err = db.Prepare(deleteAllBlacklistSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertBlacklistStmt, insertBlacklistSQL}, + {&s.selectBlacklistStmt, selectBlacklistSQL}, + {&s.deleteBlacklistStmt, deleteBlacklistSQL}, + {&s.deleteAllBlacklistStmt, deleteAllBlacklistSQL}, + }.Prepare(db) } func (s *blacklistStatements) InsertBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) _, err := stmt.ExecContext(ctx, serverName) @@ -84,7 +77,7 @@ func (s *blacklistStatements) InsertBlacklist( } func (s *blacklistStatements) SelectBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (bool, error) { stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) res, err := stmt.QueryContext(ctx, serverName) @@ -99,7 +92,7 @@ func (s *blacklistStatements) SelectBlacklist( } func (s *blacklistStatements) DeleteBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) _, err := stmt.ExecContext(ctx, serverName) diff --git a/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go b/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go index c5030163b..d8be4695e 100644 --- a/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go +++ b/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go @@ -20,7 +20,7 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error { @@ -52,7 +52,7 @@ INSERT if err != nil { return fmt.Errorf("failed to update queue_edus: %w", err) } - _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24))) + _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", spec.AsTimestamp(time.Now().Add(time.Hour*24))) if err != nil { return fmt.Errorf("failed to update queue_edus: %w", err) } diff --git a/federationapi/storage/sqlite3/inbound_peeks_table.go b/federationapi/storage/sqlite3/inbound_peeks_table.go index 8c3567934..e58d53778 100644 --- a/federationapi/storage/sqlite3/inbound_peeks_table.go +++ b/federationapi/storage/sqlite3/inbound_peeks_table.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const inboundPeeksSchema = ` @@ -86,7 +86,7 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro } func (s *inboundPeeksStatements) InsertInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt) @@ -95,7 +95,7 @@ func (s *inboundPeeksStatements) InsertInboundPeek( } func (s *inboundPeeksStatements) RenewInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) @@ -103,7 +103,7 @@ func (s *inboundPeeksStatements) RenewInboundPeek( } func (s *inboundPeeksStatements) SelectInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (*types.InboundPeek, error) { row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID) inboundPeek := types.InboundPeek{} @@ -152,7 +152,7 @@ func (s *inboundPeeksStatements) SelectInboundPeeks( } func (s *inboundPeeksStatements) DeleteInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) return diff --git a/federationapi/storage/sqlite3/joined_hosts_table.go b/federationapi/storage/sqlite3/joined_hosts_table.go index 83112c150..2412cacdb 100644 --- a/federationapi/storage/sqlite3/joined_hosts_table.go +++ b/federationapi/storage/sqlite3/joined_hosts_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const joinedHostsSchema = ` @@ -90,29 +90,21 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) if err != nil { return } - if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil { - return - } - if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil { - return - } - if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil { - return - } - if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { - return - } - if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { - return - } - return + + return s, sqlutil.StatementList{ + {&s.insertJoinedHostsStmt, insertJoinedHostsSQL}, + {&s.deleteJoinedHostsStmt, deleteJoinedHostsSQL}, + {&s.deleteJoinedHostsForRoomStmt, deleteJoinedHostsForRoomSQL}, + {&s.selectJoinedHostsStmt, selectJoinedHostsSQL}, + {&s.selectAllJoinedHostsStmt, selectAllJoinedHostsSQL}, + }.Prepare(db) } func (s *joinedHostsStatements) InsertJoinedHosts( ctx context.Context, txn *sql.Tx, roomID, eventID string, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) @@ -154,20 +146,20 @@ func (s *joinedHostsStatements) SelectJoinedHosts( func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var serverName string if err = rows.Scan(&serverName); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(serverName)) + result = append(result, spec.ServerName(serverName)) } return result, rows.Err() @@ -175,7 +167,7 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( func (s *joinedHostsStatements) SelectJoinedHostsForRooms( ctx context.Context, roomIDs []string, excludingBlacklisted bool, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i := range roomIDs { iRoomIDs[i] = roomIDs[i] @@ -191,13 +183,13 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms( } defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var serverName string if err = rows.Scan(&serverName); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(serverName)) + result = append(result, spec.ServerName(serverName)) } return result, rows.Err() @@ -220,7 +212,7 @@ func joinedHostsFromStmt( } result = append(result, types.JoinedHost{ MemberEventID: eventID, - ServerName: gomatrixserverlib.ServerName(serverName), + ServerName: spec.ServerName(serverName), }) } diff --git a/federationapi/storage/sqlite3/notary_server_keys_json_table.go b/federationapi/storage/sqlite3/notary_server_keys_json_table.go index 4a028fa20..ad6d1b57f 100644 --- a/federationapi/storage/sqlite3/notary_server_keys_json_table.go +++ b/federationapi/storage/sqlite3/notary_server_keys_json_table.go @@ -19,7 +19,9 @@ import ( "database/sql" "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const notaryServerKeysJSONSchema = ` @@ -49,14 +51,13 @@ func NewSQLiteNotaryServerKeysTable(db *sql.DB) (s *notaryServerKeysStatements, return } - if s.insertServerKeysJSONStmt, err = db.Prepare(insertServerKeysJSONSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertServerKeysJSONStmt, insertServerKeysJSONSQL}, + }.Prepare(db) } func (s *notaryServerKeysStatements) InsertJSONResponse( - ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp, + ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName spec.ServerName, validUntil spec.Timestamp, ) (tables.NotaryID, error) { var notaryID tables.NotaryID return notaryID, txn.Stmt(s.insertServerKeysJSONStmt).QueryRowContext(ctx, string(keyQueryResponseJSON.Raw), serverName, validUntil).Scan(¬aryID) diff --git a/federationapi/storage/sqlite3/notary_server_keys_metadata_table.go b/federationapi/storage/sqlite3/notary_server_keys_metadata_table.go index 55709a962..2fd9ef211 100644 --- a/federationapi/storage/sqlite3/notary_server_keys_metadata_table.go +++ b/federationapi/storage/sqlite3/notary_server_keys_metadata_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const notaryServerKeysMetadataSchema = ` @@ -92,28 +93,21 @@ func NewSQLiteNotaryServerKeysMetadataTable(db *sql.DB) (s *notaryServerKeysMeta return } - if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil { - return - } - if s.selectNotaryKeyResponsesStmt, err = db.Prepare(selectNotaryKeyResponsesSQL); err != nil { - return - } - if s.selectNotaryKeyMetadataStmt, err = db.Prepare(selectNotaryKeyMetadataSQL); err != nil { - return - } - if s.deleteUnusedServerKeysJSONStmt, err = db.Prepare(deleteUnusedServerKeysJSONSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.upsertServerKeysStmt, upsertServerKeysSQL}, + {&s.selectNotaryKeyResponsesStmt, selectNotaryKeyResponsesSQL}, + {&s.selectNotaryKeyMetadataStmt, selectNotaryKeyMetadataSQL}, + {&s.deleteUnusedServerKeysJSONStmt, deleteUnusedServerKeysJSONSQL}, + }.Prepare(db) } func (s *notaryServerKeysMetadataStatements) UpsertKey( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil gomatrixserverlib.Timestamp, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil spec.Timestamp, ) (tables.NotaryID, error) { notaryID := newNotaryID // see if the existing notary ID a) exists, b) has a longer valid_until var existingNotaryID tables.NotaryID - var existingValidUntil gomatrixserverlib.Timestamp + var existingValidUntil spec.Timestamp if err := txn.Stmt(s.selectNotaryKeyMetadataStmt).QueryRowContext(ctx, serverName, keyID).Scan(&existingNotaryID, &existingValidUntil); err != nil { if err != sql.ErrNoRows { return 0, err @@ -128,7 +122,7 @@ func (s *notaryServerKeysMetadataStatements) UpsertKey( return notaryID, err } -func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { +func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { var rows *sql.Rows var err error if len(keyIDs) == 0 { diff --git a/federationapi/storage/sqlite3/outbound_peeks_table.go b/federationapi/storage/sqlite3/outbound_peeks_table.go index 33f452b68..b6684e9b3 100644 --- a/federationapi/storage/sqlite3/outbound_peeks_table.go +++ b/federationapi/storage/sqlite3/outbound_peeks_table.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const outboundPeeksSchema = ` @@ -85,7 +85,7 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er } func (s *outboundPeeksStatements) InsertOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt) @@ -94,7 +94,7 @@ func (s *outboundPeeksStatements) InsertOutboundPeek( } func (s *outboundPeeksStatements) RenewOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) @@ -102,7 +102,7 @@ func (s *outboundPeeksStatements) RenewOutboundPeek( } func (s *outboundPeeksStatements) SelectOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (*types.OutboundPeek, error) { row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) outboundPeek := types.OutboundPeek{} @@ -151,7 +151,7 @@ func (s *outboundPeeksStatements) SelectOutboundPeeks( } func (s *outboundPeeksStatements) DeleteOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) return diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go index 0dc914328..f500a6317 100644 --- a/federationapi/storage/sqlite3/queue_edus_table.go +++ b/federationapi/storage/sqlite3/queue_edus_table.go @@ -20,11 +20,10 @@ import ( "fmt" "strings" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" ) const queueEDUsSchema = ` @@ -121,9 +120,9 @@ func (s *queueEDUsStatements) InsertQueueEDU( ctx context.Context, txn *sql.Tx, eduType string, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, - expiresAt gomatrixserverlib.Timestamp, + expiresAt spec.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) _, err := stmt.ExecContext( @@ -138,7 +137,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( func (s *queueEDUsStatements) DeleteQueueEDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) @@ -160,7 +159,7 @@ func (s *queueEDUsStatements) DeleteQueueEDUs( func (s *queueEDUsStatements) SelectQueueEDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) @@ -194,16 +193,16 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName if err = rows.Scan(&serverName); err != nil { return nil, err } @@ -215,7 +214,7 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( func (s *queueEDUsStatements) SelectExpiredEDUs( ctx context.Context, txn *sql.Tx, - expiredBefore gomatrixserverlib.Timestamp, + expiredBefore spec.Timestamp, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) rows, err := stmt.QueryContext(ctx, expiredBefore) @@ -236,7 +235,7 @@ func (s *queueEDUsStatements) SelectExpiredEDUs( func (s *queueEDUsStatements) DeleteExpiredEDUs( ctx context.Context, txn *sql.Tx, - expiredBefore gomatrixserverlib.Timestamp, + expiredBefore spec.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt) _, err := stmt.ExecContext(ctx, expiredBefore) diff --git a/federationapi/storage/sqlite3/queue_json_table.go b/federationapi/storage/sqlite3/queue_json_table.go index fe5896c7f..0e2806d56 100644 --- a/federationapi/storage/sqlite3/queue_json_table.go +++ b/federationapi/storage/sqlite3/queue_json_table.go @@ -66,10 +66,10 @@ func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { if err != nil { return } - if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil { - return - } - return + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertJSONSQL}, + }.Prepare(db) } func (s *queueJSONStatements) InsertQueueJSON( diff --git a/federationapi/storage/sqlite3/queue_pdus_table.go b/federationapi/storage/sqlite3/queue_pdus_table.go index aee8b03d6..92075ff90 100644 --- a/federationapi/storage/sqlite3/queue_pdus_table.go +++ b/federationapi/storage/sqlite3/queue_pdus_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const queuePDUsSchema = ` @@ -87,32 +88,20 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { if err != nil { return } - if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil { - return - } - //if s.deleteQueuePDUsStmt, err = db.Prepare(deleteQueuePDUsSQL); err != nil { - // return - //} - if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil { - return - } - if s.selectQueuePDUsStmt, err = db.Prepare(selectQueuePDUsSQL); err != nil { - return - } - if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { - return - } - if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertQueuePDUStmt, insertQueuePDUSQL}, + {&s.selectQueueNextTransactionIDStmt, selectQueueNextTransactionIDSQL}, + {&s.selectQueuePDUsStmt, selectQueuePDUsSQL}, + {&s.selectQueueReferenceJSONCountStmt, selectQueuePDUsReferenceJSONCountSQL}, + {&s.selectQueueServerNamesStmt, selectQueuePDUsServerNamesSQL}, + }.Prepare(db) } func (s *queuePDUsStatements) InsertQueuePDU( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) @@ -127,7 +116,7 @@ func (s *queuePDUsStatements) InsertQueuePDU( func (s *queuePDUsStatements) DeleteQueuePDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) @@ -148,7 +137,7 @@ func (s *queuePDUsStatements) DeleteQueuePDUs( } func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (gomatrixserverlib.TransactionID, error) { var transactionID gomatrixserverlib.TransactionID stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) @@ -173,7 +162,7 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt) @@ -196,16 +185,16 @@ func (s *queuePDUsStatements) SelectQueuePDUs( func (s *queuePDUsStatements) SelectQueuePDUServerNames( ctx context.Context, txn *sql.Tx, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName if err = rows.Scan(&serverName); err != nil { return nil, err } diff --git a/federationapi/storage/sqlite3/relay_servers_table.go b/federationapi/storage/sqlite3/relay_servers_table.go index 27c3cca2c..36cabeb4d 100644 --- a/federationapi/storage/sqlite3/relay_servers_table.go +++ b/federationapi/storage/sqlite3/relay_servers_table.go @@ -21,7 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const relayServersSchema = ` @@ -77,8 +77,8 @@ func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err erro func (s *relayServersStatements) InsertRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { for _, relayServer := range relayServers { stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) @@ -92,8 +92,8 @@ func (s *relayServersStatements) InsertRelayServers( func (s *relayServersStatements) SelectRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, -) ([]gomatrixserverlib.ServerName, error) { + serverName spec.ServerName, +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) rows, err := stmt.QueryContext(ctx, serverName) if err != nil { @@ -101,13 +101,13 @@ func (s *relayServersStatements) SelectRelayServers( } defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var relayServer string if err = rows.Scan(&relayServer); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(relayServer)) + result = append(result, spec.ServerName(relayServer)) } return result, nil } @@ -115,8 +115,8 @@ func (s *relayServersStatements) SelectRelayServers( func (s *relayServersStatements) DeleteRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1) deleteStmt, err := s.db.Prepare(deleteSQL) @@ -138,7 +138,7 @@ func (s *relayServersStatements) DeleteRelayServers( func (s *relayServersStatements) DeleteAllRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) if _, err := stmt.ExecContext(ctx, serverName); err != nil { diff --git a/federationapi/storage/sqlite3/server_key_table.go b/federationapi/storage/sqlite3/server_key_table.go index 9b89649f6..f28b89940 100644 --- a/federationapi/storage/sqlite3/server_key_table.go +++ b/federationapi/storage/sqlite3/server_key_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const serverSigningKeysSchema = ` @@ -74,18 +75,15 @@ func NewSQLiteServerSigningKeysTable(db *sql.DB) (s *serverSigningKeyStatements, if err != nil { return } - if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerSigningKeysSQL); err != nil { - return - } - if s.upsertServerKeysStmt, err = db.Prepare(upsertServerSigningKeysSQL); err != nil { - return - } - return s, nil + return s, sqlutil.StatementList{ + {&s.bulkSelectServerKeysStmt, bulkSelectServerSigningKeysSQL}, + {&s.upsertServerKeysStmt, upsertServerSigningKeysSQL}, + }.Prepare(db) } func (s *serverSigningKeyStatements) BulkSelectServerKeys( ctx context.Context, txn *sql.Tx, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { nameAndKeyIDs := make([]string, 0, len(requests)) for request := range requests { @@ -110,7 +108,7 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys( return fmt.Errorf("bulkSelectServerKeys: %v", err) } r := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: gomatrixserverlib.ServerName(serverName), + ServerName: spec.ServerName(serverName), KeyID: gomatrixserverlib.KeyID(keyID), } vk := gomatrixserverlib.VerifyKey{} @@ -120,8 +118,8 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys( } results[r] = gomatrixserverlib.PublicKeyLookupResult{ VerifyKey: vk, - ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), - ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), + ValidUntilTS: spec.Timestamp(validUntilTS), + ExpiredTS: spec.Timestamp(expiredTS), } } return nil diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index 1e7e41a2c..00c8afa05 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -15,15 +15,15 @@ package sqlite3 import ( + "context" "database/sql" "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Database stores information needed by the federation sender @@ -34,10 +34,10 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) { var d Database var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { return nil, err } blacklist, err := NewSQLiteBlacklistTable(d.db) @@ -93,7 +93,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Version: "federationsender: drop federationsender_rooms", Up: deltas.UpRemoveRoomsTable, }) - err = m.Up(base.Context()) + err = m.Up(ctx) if err != nil { return nil, err } diff --git a/federationapi/storage/storage.go b/federationapi/storage/storage.go index 142e281ea..322a6c75b 100644 --- a/federationapi/storage/storage.go +++ b/federationapi/storage/storage.go @@ -18,23 +18,24 @@ package storage import ( + "context" "fmt" "github.com/matrix-org/dendrite/federationapi/storage/postgres" "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) + return sqlite3.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) + return postgres.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 1d2a13e81..db71f2c13 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -7,36 +7,39 @@ import ( "time" "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { - b, baseClose := testrig.CreateBaseDendrite(t, dbType) + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) connStr, dbClose := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(b, &config.DatabaseOptions{ + ctx := context.Background() + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.NewDatabase(ctx, cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, b.Caches, func(server gomatrixserverlib.ServerName) bool { return server == "localhost" }) + }, caches, func(server spec.ServerName) bool { return server == "localhost" }) if err != nil { t.Fatalf("NewDatabase returned %s", err) } return db, func() { dbClose() - baseClose() } } func TestExpireEDUs(t *testing.T) { var expireEDUTypes = map[string]time.Duration{ - gomatrixserverlib.MReceipt: 0, + spec.MReceipt: 0, } ctx := context.Background() - destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}} + destinations := map[spec.ServerName]struct{}{"localhost": {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateFederationDatabase(t, dbType) defer close() @@ -45,7 +48,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err := db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, spec.MReceipt, expireEDUTypes) assert.NoError(t, err) } // add data without expiry @@ -69,7 +72,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err = db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, spec.MDirectToDevice, expireEDUTypes) assert.NoError(t, err) err = db.DeleteExpiredEDUs(ctx) @@ -247,8 +250,8 @@ func TestInboundPeeking(t *testing.T) { } func TestServersAssumedOffline(t *testing.T) { - server1 := gomatrixserverlib.ServerName("server1") - server2 := gomatrixserverlib.ServerName("server2") + server1 := spec.ServerName("server1") + server2 := spec.ServerName("server2") test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, closeDB := mustCreateFederationDatabase(t, dbType) @@ -303,29 +306,29 @@ func TestServersAssumedOffline(t *testing.T) { } func TestRelayServersStored(t *testing.T) { - server := gomatrixserverlib.ServerName("server") - relayServer1 := gomatrixserverlib.ServerName("relayserver1") - relayServer2 := gomatrixserverlib.ServerName("relayserver2") + server := spec.ServerName("server") + relayServer1 := spec.ServerName("relayserver1") + relayServer2 := spec.ServerName("relayserver2") test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, closeDB := mustCreateFederationDatabase(t, dbType) defer closeDB() - err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + err := db.P2PAddRelayServersForServer(context.Background(), server, []spec.ServerName{relayServer1}) assert.Nil(t, err) relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Equal(t, relayServer1, relayServers[0]) - err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + err = db.P2PRemoveRelayServersForServer(context.Background(), server, []spec.ServerName{relayServer1}) assert.Nil(t, err) relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Zero(t, len(relayServers)) - err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2}) + err = db.P2PAddRelayServersForServer(context.Background(), server, []spec.ServerName{relayServer1, relayServer2}) assert.Nil(t, err) relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) diff --git a/federationapi/storage/storage_wasm.go b/federationapi/storage/storage_wasm.go index 84d5a3a4c..e19a45642 100644 --- a/federationapi/storage/storage_wasm.go +++ b/federationapi/storage/storage_wasm.go @@ -15,20 +15,21 @@ package storage import ( + "context" "fmt" "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, cache, serverName) + return sqlite3.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 762504e45..f8de42da7 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -20,26 +20,27 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type NotaryID int64 type FederationQueuePDUs interface { - InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error - DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName spec.ServerName, nid int64) error + DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, jsonNIDs []int64) error SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) - SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) + SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, limit int) ([]int64, error) + SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]spec.ServerName, error) } type FederationQueueEDUs interface { - InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64, expiresAt gomatrixserverlib.Timestamp) error - DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error - SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName spec.ServerName, nid int64, expiresAt spec.Timestamp) error + DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, jsonNIDs []int64) error + SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, limit int) ([]int64, error) SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) - SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error) - DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error + SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]spec.ServerName, error) + SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore spec.Timestamp) ([]int64, error) + DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore spec.Timestamp) error Prepare() error } @@ -50,10 +51,10 @@ type FederationQueueJSON interface { } type FederationQueueTransactions interface { - InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error - DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error - SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) - SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) + InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName spec.ServerName, nid int64) error + DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, jsonNIDs []int64) error + SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, limit int) ([]int64, error) + SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (int64, error) } type FederationTransactionJSON interface { @@ -63,51 +64,51 @@ type FederationTransactionJSON interface { } type FederationJoinedHosts interface { - InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error + InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName spec.ServerName) error DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error DeleteJoinedHostsForRoom(ctx context.Context, txn *sql.Tx, roomID string) error SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) - SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error) + SelectAllJoinedHosts(ctx context.Context) ([]spec.ServerName, error) + SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]spec.ServerName, error) } type FederationBlacklist interface { - InsertBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error - SelectBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) - DeleteBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + InsertBlacklist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error + SelectBlacklist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error) + DeleteBlacklist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } type FederationAssumedOffline interface { - InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error - SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) - DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error + SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error) + DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error } type FederationRelayServers interface { - InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error - SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) - DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error - DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, relayServers []spec.ServerName) error + SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) ([]spec.ServerName, error) + DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, relayServers []spec.ServerName) error + DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error } type FederationOutboundPeeks interface { - InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) - RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) - SelectOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (outboundPeek *types.OutboundPeek, err error) + InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) (err error) + RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) (err error) + SelectOutboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string) (outboundPeek *types.OutboundPeek, err error) SelectOutboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (outboundPeeks []types.OutboundPeek, err error) - DeleteOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (err error) + DeleteOutboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string) (err error) DeleteOutboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (err error) } type FederationInboundPeeks interface { - InsertInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) - RenewInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) - SelectInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (inboundPeek *types.InboundPeek, err error) + InsertInboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) (err error) + RenewInboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) (err error) + SelectInboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string) (inboundPeek *types.InboundPeek, err error) SelectInboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (inboundPeeks []types.InboundPeek, err error) - DeleteInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (err error) + DeleteInboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string) (err error) DeleteInboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (err error) } @@ -118,22 +119,22 @@ type FederationNotaryServerKeysJSON interface { // "Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid. // This is to avoid a situation where an attacker publishes a key which is valid for a significant amount of time // without a way for the homeserver owner to revoke it."" - InsertJSONResponse(ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp) (NotaryID, error) + InsertJSONResponse(ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName spec.ServerName, validUntil spec.Timestamp) (NotaryID, error) } // FederationNotaryServerKeysMetadata persists the metadata for FederationNotaryServerKeysJSON type FederationNotaryServerKeysMetadata interface { // UpsertKey updates or inserts a (server_name, key_id) tuple, pointing it via NotaryID at the the response which has the longest valid_until_ts // `newNotaryID` and `newValidUntil` should be the notary ID / valid_until which has this (server_name, key_id) tuple already, e.g one you just inserted. - UpsertKey(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID NotaryID, newValidUntil gomatrixserverlib.Timestamp) (NotaryID, error) + UpsertKey(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID NotaryID, newValidUntil spec.Timestamp) (NotaryID, error) // SelectKeys returns the signed JSON objects which contain the given key IDs. This will be at most the length of `keyIDs` and at least 1 (assuming // the keys exist in the first place). If `keyIDs` is empty, the signed JSON object with the longest valid_until_ts will be returned. - SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) + SelectKeys(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) // DeleteOldJSONResponses removes all responses which are not referenced in FederationNotaryServerKeysMetadata DeleteOldJSONResponses(ctx context.Context, txn *sql.Tx) error } type FederationServerSigningKeys interface { - BulkSelectServerKeys(ctx context.Context, txn *sql.Tx, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) + BulkSelectServerKeys(ctx context.Context, txn *sql.Tx, requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) UpsertServerKeys(ctx context.Context, txn *sql.Tx, request gomatrixserverlib.PublicKeyLookupRequest, key gomatrixserverlib.PublicKeyLookupResult) error } diff --git a/federationapi/storage/tables/relay_servers_table_test.go b/federationapi/storage/tables/relay_servers_table_test.go index b41211551..6a14e3f16 100644 --- a/federationapi/storage/tables/relay_servers_table_test.go +++ b/federationapi/storage/tables/relay_servers_table_test.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -57,7 +57,7 @@ func mustCreateRelayServersTable( return database, close } -func Equal(a, b []gomatrixserverlib.ServerName) bool { +func Equal(a, b []spec.ServerName) bool { if len(a) != len(b) { return false } @@ -74,7 +74,7 @@ func TestShouldInsertRelayServers(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRelayServersTable(t, dbType) defer close() - expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + expectedRelayServers := []spec.ServerName{server2, server3} err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) if err != nil { @@ -97,8 +97,8 @@ func TestShouldInsertRelayServersWithDuplicates(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRelayServersTable(t, dbType) defer close() - insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2} - expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + insertRelayServers := []spec.ServerName{server2, server2, server2, server3, server2} + expectedRelayServers := []spec.ServerName{server2, server3} err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) if err != nil { @@ -134,8 +134,8 @@ func TestShouldGetRelayServersUnknownDestination(t *testing.T) { t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) } - if !Equal(relayServers, []gomatrixserverlib.ServerName{}) { - t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers) + if !Equal(relayServers, []spec.ServerName{}) { + t.Fatalf("Expected: %v \nActual: %v", []spec.ServerName{}, relayServers) } }) } @@ -145,8 +145,8 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRelayServersTable(t, dbType) defer close() - relayServers1 := []gomatrixserverlib.ServerName{server2, server3} - relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4} + relayServers1 := []spec.ServerName{server2, server3} + relayServers2 := []spec.ServerName{server1, server3, server4} err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1) if err != nil { @@ -157,16 +157,16 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) { t.Fatalf("Failed inserting transaction: %s", err.Error()) } - err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2}) + err = db.Table.DeleteRelayServers(ctx, nil, server1, []spec.ServerName{server2}) if err != nil { t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) } - err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4}) + err = db.Table.DeleteRelayServers(ctx, nil, server2, []spec.ServerName{server1, server4}) if err != nil { t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error()) } - expectedRelayServers := []gomatrixserverlib.ServerName{server3} + expectedRelayServers := []spec.ServerName{server3} relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) if err != nil { t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) @@ -189,7 +189,7 @@ func TestShouldDeleteAllRelayServers(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRelayServersTable(t, dbType) defer close() - expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + expectedRelayServers := []spec.ServerName{server2, server3} err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) if err != nil { @@ -205,7 +205,7 @@ func TestShouldDeleteAllRelayServers(t *testing.T) { t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) } - expectedRelayServers1 := []gomatrixserverlib.ServerName{} + expectedRelayServers1 := []spec.ServerName{} relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) if err != nil { t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) diff --git a/federationapi/types/types.go b/federationapi/types/types.go index 5821000cc..20f92e804 100644 --- a/federationapi/types/types.go +++ b/federationapi/types/types.go @@ -14,9 +14,7 @@ package types -import ( - "github.com/matrix-org/gomatrixserverlib" -) +import "github.com/matrix-org/gomatrixserverlib/spec" const MSigningKeyUpdate = "m.signing_key_update" // TODO: move to gomatrixserverlib @@ -25,10 +23,10 @@ type JoinedHost struct { // The MemberEventID of a m.room.member join event. MemberEventID string // The domain part of the state key of the m.room.member join event - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } -type ServerNames []gomatrixserverlib.ServerName +type ServerNames []spec.ServerName func (s ServerNames) Len() int { return len(s) } func (s ServerNames) Swap(i, j int) { s[i], s[j] = s[j], s[i] } @@ -38,7 +36,7 @@ func (s ServerNames) Less(i, j int) bool { return s[i] < s[j] } type OutboundPeek struct { PeekID string RoomID string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName CreationTimestamp int64 RenewedTimestamp int64 RenewalInterval int64 @@ -48,7 +46,7 @@ type OutboundPeek struct { type InboundPeek struct { PeekID string RoomID string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName CreationTimestamp int64 RenewedTimestamp int64 RenewalInterval int64 @@ -64,7 +62,7 @@ type FederationReceiptData struct { } type ReceiptTS struct { - TS gomatrixserverlib.Timestamp `json:"ts"` + TS spec.Timestamp `json:"ts"` } type Presence struct { diff --git a/go.mod b/go.mod index 6857290f3..515f9a367 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blevesearch/bleve/v2 v2.3.6 github.com/codeclysm/extract v2.2.0+incompatible github.com/dgraph-io/ristretto v0.1.1 - github.com/docker/docker v20.10.19+incompatible + github.com/docker/docker v20.10.24+incompatible github.com/docker/go-connections v0.4.0 github.com/getsentry/sentry-go v0.14.0 github.com/gologme/log v1.3.0 @@ -22,15 +22,14 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230419135026-7f3c8ee774a0 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 - github.com/mattn/go-sqlite3 v1.14.15 - github.com/nats-io/nats-server/v2 v2.9.8 - github.com/nats-io/nats.go v1.20.0 + github.com/mattn/go-sqlite3 v1.14.16 + github.com/nats-io/nats-server/v2 v2.9.15 + github.com/nats-io/nats.go v1.24.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 - github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 github.com/opentracing/opentracing-go v1.2.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 @@ -43,10 +42,10 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.5.0 + golang.org/x/crypto v0.8.0 golang.org/x/image v0.5.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e - golang.org/x/term v0.5.0 + golang.org/x/term v0.7.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 @@ -92,7 +91,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/juju/errors v1.0.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/klauspost/compress v1.15.11 // indirect + github.com/klauspost/compress v1.16.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/mattn/go-isatty v0.0.16 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect @@ -122,12 +121,12 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect go.etcd.io/bbolt v1.3.6 // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect - golang.org/x/mod v0.6.0 // indirect - golang.org/x/net v0.7.0 // indirect - golang.org/x/sys v0.5.0 // indirect - golang.org/x/text v0.7.0 // indirect - golang.org/x/time v0.1.0 // indirect - golang.org/x/tools v0.2.0 // indirect + golang.org/x/mod v0.8.0 // indirect + golang.org/x/net v0.9.0 // indirect + golang.org/x/sys v0.7.0 // indirect + golang.org/x/text v0.9.0 // indirect + golang.org/x/time v0.3.0 // indirect + golang.org/x/tools v0.6.0 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 2272e5404..70a5a0568 100644 --- a/go.sum +++ b/go.sum @@ -134,8 +134,8 @@ github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczC github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68= github.com/docker/distribution v2.8.1+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v20.10.19+incompatible h1:lzEmjivyNHFHMNAFLXORMBXyGIhw/UP4DvJwvyKYq64= -github.com/docker/docker v20.10.19+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v20.10.24+incompatible h1:Ugvxm7a8+Gz6vqQYQQ2W7GYq5EUPaAiuPgIfVyI3dYE= +github.com/docker/docker v20.10.24+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -299,8 +299,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:C github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/klauspost/compress v1.15.11 h1:Lcadnb3RKGin4FYM/orgq0qde+nc15E5Cbqg4B9Sx9c= -github.com/klauspost/compress v1.15.11/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM= +github.com/klauspost/compress v1.16.0 h1:iULayQNOReoYUe+1qtKOqw9CwJv3aNQu8ivo7lw1HU4= +github.com/klauspost/compress v1.16.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -321,8 +321,16 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1 h1:JSw0nmjMrgBmoM2aQsa78LTpI5BnuD9+vOiEQ4Qo0qw= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230414140439-3cf4cd94d75f h1:sULN+zkwjt9bBy3dy5W98B6J/Pd4xOiF6yjWOBRRKmw= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230414140439-3cf4cd94d75f/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230418093913-f0ab3b996ed5 h1:WVjB9i6+0WvX/pc4jDQPYVWIf/SyYf0Aa9hhiQBdLq8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230418093913-f0ab3b996ed5/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230418110303-b59eb925da4a h1:84t3ixdKTJOzgmr8+KesjPY9G9g1LX0fz5zD6E4ybHA= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230418110303-b59eb925da4a/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230418132954-2113ec20fc8e h1:nLZjU+z/W2n48/9mlCjx+QWo3Q6z777RrLi6Id2CzmM= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230418132954-2113ec20fc8e/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230419135026-7f3c8ee774a0 h1:u+1vK+Sj2gL8IWgaaSF4WM38iL/Cx7jf6Q/aa1g0Icc= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230419135026-7f3c8ee774a0/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= @@ -330,8 +338,8 @@ github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foel github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= -github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= @@ -356,10 +364,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.8 h1:jgxZsv+A3Reb3MgwxaINcNq/za8xZInKhDg9Q0cGN1o= -github.com/nats-io/nats-server/v2 v2.9.8/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= -github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM= -github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= +github.com/nats-io/nats-server/v2 v2.9.15 h1:MuwEJheIwpvFgqvbs20W8Ish2azcygjf4Z0liVu2I4c= +github.com/nats-io/nats-server/v2 v2.9.15/go.mod h1:QlCTy115fqpx4KSOPFIxSV7DdI6OxtZsGOL1JLdeRlE= +github.com/nats-io/nats.go v1.24.0 h1:CRiD8L5GOQu/DcfkmgBcTTIQORMwizF+rPk6T0RaHVQ= +github.com/nats-io/nats.go v1.24.0/go.mod h1:dVQF+BK3SzUZpwyzHedXsvH3EO38aVKuOPkkHlv5hXA= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -368,8 +376,6 @@ github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQE github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= -github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 h1:Dmx8g2747UTVPzSkmohk84S3g/uWqd6+f4SSLPhLcfA= -github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79/go.mod h1:E26fwEtRNigBfFfHDWsklmo0T7Ixbg0XXgck+Hq4O9k= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/onsi/ginkgo/v2 v2.3.0 h1:kUMoxMoQG3ogk/QWyKh3zibV7BKZ+xBpWil1cTylVqc= github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= @@ -504,8 +510,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= -golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= +golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= +golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -548,8 +554,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= -golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= +golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -585,8 +591,8 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -606,6 +612,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -658,12 +665,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.7.0 h1:BEvjmm5fURWqcfbSKTdpkDXYBrUS1c0m8agp14W48vQ= +golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -671,13 +678,14 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA= -golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -726,8 +734,8 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= -golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= +golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/helm/dendrite/.helm-docs/monitoring.gotmpl b/helm/dendrite/.helm-docs/monitoring.gotmpl index 3618a1c1a..eaa336ebd 100644 --- a/helm/dendrite/.helm-docs/monitoring.gotmpl +++ b/helm/dendrite/.helm-docs/monitoring.gotmpl @@ -1,10 +1,15 @@ {{ define "chart.monitoringSection" }} ## Monitoring -[![Grafana Dashboard](https://grafana.com/api/dashboards/13916/images/9894/image)](https://grafana.com/grafana/dashboards/13916-dendrite/) +![Grafana Dashboard](grafana_dashboards/dendrite-rev2.png) * Works well with [Prometheus Operator](https://prometheus-operator.dev/) ([Helmchart](https://artifacthub.io/packages/helm/prometheus-community/kube-prometheus-stack)) and their setup of [Grafana](https://grafana.com/grafana/), by enabling the following values: ```yaml +dendrite_config: + global: + metrics: + enabled: true + prometheus: servicemonitor: enabled: true @@ -19,4 +24,4 @@ grafana: enabled: true # will deploy default dashboards ``` PS: The label `release=kube-prometheus-stack` is setup with the helmchart of the Prometheus Operator. For Grafana Dashboards it may be necessary to enable scanning in the correct namespaces (or ALL), enabled by `sidecar.dashboards.searchNamespace` in [Helmchart of grafana](https://artifacthub.io/packages/helm/grafana/grafana) (which is part of PrometheusOperator, so `grafana.sidecar.dashboards.searchNamespace`) -{{ end }} \ No newline at end of file +{{ end }} diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml index dc2764939..6a428e00f 100644 --- a/helm/dendrite/Chart.yaml +++ b/helm/dendrite/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v2 name: dendrite -version: "0.11.2" -appVersion: "0.11.1" +version: "0.12.2" +appVersion: "0.12.0" description: Dendrite Matrix Homeserver type: application keywords: diff --git a/helm/dendrite/README.md b/helm/dendrite/README.md index 51587b766..ca5705c03 100644 --- a/helm/dendrite/README.md +++ b/helm/dendrite/README.md @@ -1,6 +1,7 @@ + # dendrite -![Version: 0.11.2](https://img.shields.io/badge/Version-0.11.2-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.11.1](https://img.shields.io/badge/AppVersion-0.11.1-informational?style=flat-square) +![Version: 0.12.2](https://img.shields.io/badge/Version-0.12.2-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.12.0](https://img.shields.io/badge/AppVersion-0.12.0-informational?style=flat-square) Dendrite Matrix Homeserver Status: **NOT PRODUCTION READY** @@ -54,6 +55,11 @@ Create a folder `appservices` and place your configurations in there. The confi | persistence.media.capacity | string | `"1Gi"` | PVC Storage Request for the media volume | | persistence.search.existingClaim | string | `""` | Use an existing volume claim for the fulltext search index | | persistence.search.capacity | string | `"1Gi"` | PVC Storage Request for the search volume | +| extraVolumes | list | `[]` | Add additional volumes to the Dendrite Pod | +| extraVolumeMounts | list | `[]` | Configure additional mount points volumes in the Dendrite Pod | +| strategy.type | string | `"RollingUpdate"` | Strategy to use for rolling updates (e.g. Recreate, RollingUpdate) If you are using ReadWriteOnce volumes, you should probably use Recreate | +| strategy.rollingUpdate.maxUnavailable | string | `"25%"` | Maximum number of pods that can be unavailable during the update process | +| strategy.rollingUpdate.maxSurge | string | `"25%"` | Maximum number of pods that can be scheduled above the desired number of pods | | dendrite_config.version | int | `2` | | | dendrite_config.global.server_name | string | `""` | **REQUIRED** Servername for this Dendrite deployment. | | dendrite_config.global.private_key | string | `"/etc/dendrite/secrets/signing.key"` | The private key to use. (**NOTE**: This is overriden in Helm) | @@ -157,10 +163,15 @@ Create a folder `appservices` and place your configurations in there. The confi ## Monitoring -[![Grafana Dashboard](https://grafana.com/api/dashboards/13916/images/9894/image)](https://grafana.com/grafana/dashboards/13916-dendrite/) +![Grafana Dashboard](grafana_dashboards/dendrite-rev2.png) * Works well with [Prometheus Operator](https://prometheus-operator.dev/) ([Helmchart](https://artifacthub.io/packages/helm/prometheus-community/kube-prometheus-stack)) and their setup of [Grafana](https://grafana.com/grafana/), by enabling the following values: ```yaml +dendrite_config: + global: + metrics: + enabled: true + prometheus: servicemonitor: enabled: true @@ -176,5 +187,3 @@ grafana: ``` PS: The label `release=kube-prometheus-stack` is setup with the helmchart of the Prometheus Operator. For Grafana Dashboards it may be necessary to enable scanning in the correct namespaces (or ALL), enabled by `sidecar.dashboards.searchNamespace` in [Helmchart of grafana](https://artifacthub.io/packages/helm/grafana/grafana) (which is part of PrometheusOperator, so `grafana.sidecar.dashboards.searchNamespace`) ----------------------------------------------- -Autogenerated from chart metadata using [helm-docs vv1.11.0](https://github.com/norwoodj/helm-docs/releases/vv1.11.0) \ No newline at end of file diff --git a/helm/dendrite/grafana_dashboards/dendrite-rev1.json b/helm/dendrite/grafana_dashboards/dendrite-rev1.json deleted file mode 100644 index 206e8af87..000000000 --- a/helm/dendrite/grafana_dashboards/dendrite-rev1.json +++ /dev/null @@ -1,1119 +0,0 @@ -{ - "__inputs": [ - { - "name": "DS_INFLUXDB_DOMOTICA", - "label": "", - "description": "", - "type": "datasource", - "pluginId": "influxdb", - "pluginName": "InfluxDB" - }, - { - "name": "DS_PROMETHEUS", - "label": "Prometheus", - "description": "", - "type": "datasource", - "pluginId": "prometheus", - "pluginName": "Prometheus" - } - ], - "__requires": [ - { - "type": "grafana", - "id": "grafana", - "name": "Grafana", - "version": "7.4.2" - }, - { - "type": "panel", - "id": "graph", - "name": "Graph", - "version": "" - }, - { - "type": "panel", - "id": "heatmap", - "name": "Heatmap", - "version": "" - }, - { - "type": "datasource", - "id": "influxdb", - "name": "InfluxDB", - "version": "1.0.0" - }, - { - "type": "datasource", - "id": "prometheus", - "name": "Prometheus", - "version": "1.0.0" - }, - { - "type": "panel", - "id": "stat", - "name": "Stat", - "version": "" - } - ], - "annotations": { - "list": [ - { - "builtIn": 1, - "datasource": "-- Grafana --", - "enable": true, - "hide": true, - "iconColor": "rgba(0, 211, 255, 1)", - "name": "Annotations & Alerts", - "type": "dashboard" - } - ] - }, - "description": "Dendrite dashboard from https://github.com/matrix-org/dendrite/", - "editable": true, - "gnetId": 13916, - "graphTooltip": 0, - "id": null, - "iteration": 1613683251329, - "links": [], - "panels": [ - { - "collapsed": false, - "datasource": "${DS_INFLUXDB_DOMOTICA}", - "gridPos": { - "h": 1, - "w": 24, - "x": 0, - "y": 0 - }, - "id": 4, - "panels": [], - "title": "Overview", - "type": "row" - }, - { - "aliasColors": {}, - "bars": false, - "dashLength": 10, - "dashes": false, - "datasource": "$datasource", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, - "fill": 1, - "fillGradient": 0, - "gridPos": { - "h": 5, - "w": 10, - "x": 0, - "y": 1 - }, - "hiddenSeries": false, - "id": 2, - "legend": { - "avg": false, - "current": false, - "max": false, - "min": false, - "show": true, - "total": false, - "values": false - }, - "lines": true, - "linewidth": 1, - "links": [], - "nullPointMode": "null", - "options": { - "alertThreshold": true - }, - "percentage": false, - "pluginVersion": "7.4.2", - "pointradius": 5, - "points": false, - "renderer": "flot", - "seriesOverrides": [], - "spaceLength": 10, - "stack": false, - "steppedLine": false, - "targets": [ - { - "expr": "rate(process_cpu_seconds_total{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", - "format": "time_series", - "intervalFactor": 1, - "legendFormat": "{{job}}-{{index}} ", - "refId": "A" - } - ], - "thresholds": [], - "timeFrom": null, - "timeRegions": [], - "timeShift": null, - "title": "CPU usage", - "tooltip": { - "shared": true, - "sort": 0, - "value_type": "individual" - }, - "type": "graph", - "xaxis": { - "buckets": null, - "mode": "time", - "name": null, - "show": true, - "values": [] - }, - "yaxes": [ - { - "decimals": null, - "format": "percentunit", - "label": null, - "logBase": 1, - "max": "1", - "min": "0", - "show": true - }, - { - "format": "short", - "label": null, - "logBase": 1, - "max": null, - "min": null, - "show": true - } - ], - "yaxis": { - "align": false, - "alignLevel": null - } - }, - { - "datasource": "${DS_PROMETHEUS}", - "description": "Total number of registered users", - "fieldConfig": { - "defaults": { - "color": { - "mode": "thresholds" - }, - "custom": {}, - "mappings": [], - "thresholds": { - "mode": "absolute", - "steps": [ - { - "color": "green", - "value": null - }, - { - "color": "red", - "value": 80 - } - ] - } - }, - "overrides": [] - }, - "gridPos": { - "h": 5, - "w": 4, - "x": 10, - "y": 1 - }, - "id": 20, - "options": { - "colorMode": "value", - "graphMode": "area", - "justifyMode": "auto", - "orientation": "auto", - "reduceOptions": { - "calcs": [ - "lastNotNull" - ], - "fields": "", - "values": false - }, - "text": {}, - "textMode": "auto" - }, - "pluginVersion": "7.4.2", - "targets": [ - { - "exemplar": false, - "expr": "dendrite_clientapi_reg_users_total", - "instant": false, - "interval": "", - "legendFormat": "Users", - "refId": "A" - } - ], - "title": "Registerd Users", - "type": "stat" - }, - { - "aliasColors": {}, - "bars": false, - "dashLength": 10, - "dashes": false, - "datasource": "${DS_PROMETHEUS}", - "description": "The number of sync requests that are active right now and are waiting to be woken by a notifier", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, - "fill": 1, - "fillGradient": 0, - "gridPos": { - "h": 5, - "w": 10, - "x": 14, - "y": 1 - }, - "hiddenSeries": false, - "id": 6, - "legend": { - "avg": false, - "current": false, - "max": false, - "min": false, - "show": true, - "total": false, - "values": false - }, - "lines": true, - "linewidth": 2, - "nullPointMode": "null", - "options": { - "alertThreshold": true - }, - "percentage": false, - "pluginVersion": "7.4.2", - "pointradius": 2, - "points": false, - "renderer": "flot", - "seriesOverrides": [], - "spaceLength": 10, - "stack": false, - "steppedLine": false, - "targets": [ - { - "expr": "sum(rate(dendrite_syncapi_active_sync_requests{instance=\"$instance\"}[$bucket_size]))without (job,index)", - "hide": false, - "interval": "", - "legendFormat": "active", - "refId": "A" - }, - { - "expr": "sum(rate(dendrite_syncapi_waiting_sync_requests{instance=\"$instance\"}[$bucket_size]))without (job,index)", - "hide": false, - "interval": "", - "legendFormat": "waiting", - "refId": "B" - } - ], - "thresholds": [], - "timeFrom": null, - "timeRegions": [], - "timeShift": null, - "title": "Sync API", - "tooltip": { - "shared": true, - "sort": 0, - "value_type": "individual" - }, - "type": "graph", - "xaxis": { - "buckets": null, - "mode": "time", - "name": null, - "show": true, - "values": [] - }, - "yaxes": [ - { - "$$hashKey": "object:232", - "format": "hertz", - "label": null, - "logBase": 1, - "max": null, - "min": "0", - "show": true - }, - { - "$$hashKey": "object:233", - "format": "short", - "label": null, - "logBase": 1, - "max": null, - "min": null, - "show": true - } - ], - "yaxis": { - "align": false, - "alignLevel": null - } - }, - { - "cards": { - "cardPadding": null, - "cardRound": null - }, - "color": { - "cardColor": "#b4ff00", - "colorScale": "sqrt", - "colorScheme": "interpolateOranges", - "exponent": 0.5, - "mode": "spectrum" - }, - "dataFormat": "tsbuckets", - "datasource": "${DS_PROMETHEUS}", - "description": "How long it takes to build and submit a new event from the client API to the roomserver", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, - "gridPos": { - "h": 5, - "w": 24, - "x": 0, - "y": 6 - }, - "heatmap": {}, - "hideZeroBuckets": false, - "highlightCards": true, - "id": 24, - "legend": { - "show": false - }, - "pluginVersion": "7.4.2", - "reverseYBuckets": false, - "targets": [ - { - "expr": "dendrite_clientapi_sendevent_duration_millis_bucket{action=\"build\",instance=\"$instance\"}", - "interval": "", - "legendFormat": "{{le}}", - "refId": "A" - } - ], - "title": "Sendevent Duration", - "tooltip": { - "show": true, - "showHistogram": false - }, - "type": "heatmap", - "xAxis": { - "show": true - }, - "xBucketNumber": null, - "xBucketSize": null, - "yAxis": { - "decimals": null, - "format": "s", - "logBase": 1, - "max": null, - "min": "0", - "show": true, - "splitFactor": null - }, - "yBucketBound": "auto", - "yBucketNumber": null, - "yBucketSize": null - }, - { - "collapsed": false, - "datasource": "${DS_INFLUXDB_DOMOTICA}", - "gridPos": { - "h": 1, - "w": 24, - "x": 0, - "y": 11 - }, - "id": 8, - "panels": [], - "title": "Federation", - "type": "row" - }, - { - "aliasColors": {}, - "bars": false, - "dashLength": 10, - "dashes": false, - "datasource": "${DS_PROMETHEUS}", - "description": "Collection of queues for sending transactions to other matrix servers", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, - "fill": 1, - "fillGradient": 0, - "gridPos": { - "h": 6, - "w": 24, - "x": 0, - "y": 12 - }, - "hiddenSeries": false, - "id": 10, - "legend": { - "avg": false, - "current": false, - "max": false, - "min": false, - "show": true, - "total": false, - "values": false - }, - "lines": true, - "linewidth": 1, - "nullPointMode": "null", - "options": { - "alertThreshold": true - }, - "percentage": false, - "pluginVersion": "7.4.2", - "pointradius": 2, - "points": false, - "renderer": "flot", - "seriesOverrides": [], - "spaceLength": 10, - "stack": false, - "steppedLine": false, - "targets": [ - { - "expr": "dendrite_federationsender_destination_queues_running", - "interval": "", - "legendFormat": "Queue Running", - "refId": "A" - }, - { - "expr": "dendrite_federationsender_destination_queues_total", - "hide": false, - "interval": "", - "legendFormat": "Queue Total", - "refId": "B" - }, - { - "expr": "dendrite_federationsender_destination_queues_backing_off", - "hide": false, - "interval": "", - "legendFormat": "Backing Off", - "refId": "C" - } - ], - "thresholds": [], - "timeFrom": null, - "timeRegions": [], - "timeShift": null, - "title": "Federation Sender Destination", - "tooltip": { - "shared": true, - "sort": 0, - "value_type": "individual" - }, - "type": "graph", - "xaxis": { - "buckets": null, - "mode": "time", - "name": null, - "show": true, - "values": [] - }, - "yaxes": [ - { - "$$hashKey": "object:443", - "format": "short", - "label": null, - "logBase": 1, - "max": null, - "min": null, - "show": true - }, - { - "$$hashKey": "object:444", - "format": "short", - "label": null, - "logBase": 1, - "max": null, - "min": null, - "show": true - } - ], - "yaxis": { - "align": false, - "alignLevel": null - } - }, - { - "collapsed": false, - "datasource": "${DS_INFLUXDB_DOMOTICA}", - "gridPos": { - "h": 1, - "w": 24, - "x": 0, - "y": 18 - }, - "id": 26, - "panels": [], - "title": "Rooms", - "type": "row" - }, - { - "cards": { - "cardPadding": null, - "cardRound": null - }, - "color": { - "cardColor": "#b4ff00", - "colorScale": "sqrt", - "colorScheme": "interpolateOranges", - "exponent": 0.5, - "mode": "spectrum" - }, - "dataFormat": "timeseries", - "datasource": "${DS_PROMETHEUS}", - "description": "How long it takes the roomserver to process an event", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, - "gridPos": { - "h": 7, - "w": 24, - "x": 0, - "y": 19 - }, - "heatmap": {}, - "hideZeroBuckets": false, - "highlightCards": true, - "id": 28, - "legend": { - "show": false - }, - "pluginVersion": "7.4.2", - "reverseYBuckets": false, - "targets": [ - { - "expr": "sum(rate(dendrite_roomserver_processroomevent_duration_millis_bucket{instance=\"$instance\"}[$bucket_size])) by (le)", - "interval": "", - "legendFormat": "{{le}}", - "refId": "A" - } - ], - "title": "Room Event Processing", - "tooltip": { - "show": true, - "showHistogram": false - }, - "type": "heatmap", - "xAxis": { - "show": true - }, - "xBucketNumber": null, - "xBucketSize": null, - "yAxis": { - "decimals": null, - "format": "s", - "logBase": 1, - "max": null, - "min": null, - "show": true, - "splitFactor": null - }, - "yBucketBound": "auto", - "yBucketNumber": null, - "yBucketSize": null - }, - { - "collapsed": false, - "datasource": "${DS_INFLUXDB_DOMOTICA}", - "gridPos": { - "h": 1, - "w": 24, - "x": 0, - "y": 26 - }, - "id": 12, - "panels": [], - "title": "Caches", - "type": "row" - }, - { - "aliasColors": {}, - "bars": false, - "dashLength": 10, - "dashes": false, - "datasource": "${DS_PROMETHEUS}", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, - "fill": 1, - "fillGradient": 0, - "gridPos": { - "h": 7, - "w": 8, - "x": 0, - "y": 27 - }, - "hiddenSeries": false, - "id": 14, - "legend": { - "avg": false, - "current": false, - "max": false, - "min": false, - "show": false, - "total": false, - "values": false - }, - "lines": true, - "linewidth": 1, - "nullPointMode": "null", - "options": { - "alertThreshold": true - }, - "percentage": false, - "pluginVersion": "7.4.2", - "pointradius": 2, - "points": false, - "renderer": "flot", - "seriesOverrides": [], - "spaceLength": 10, - "stack": false, - "steppedLine": false, - "targets": [ - { - "expr": "dendrite_caching_in_memory_lru_server_key", - "interval": "", - "legendFormat": "Server keys", - "refId": "A" - } - ], - "thresholds": [], - "timeFrom": null, - "timeRegions": [], - "timeShift": null, - "title": "Server Keys", - "tooltip": { - "shared": true, - "sort": 0, - "value_type": "individual" - }, - "type": "graph", - "xaxis": { - "buckets": null, - "mode": "time", - "name": null, - "show": true, - "values": [] - }, - "yaxes": [ - { - "$$hashKey": "object:667", - "format": "short", - "label": null, - "logBase": 1, - "max": null, - "min": null, - "show": true - }, - { - "$$hashKey": "object:668", - "format": "short", - "label": null, - "logBase": 1, - "max": null, - "min": null, - "show": true - } - ], - "yaxis": { - "align": false, - "alignLevel": null - } - }, - { - "aliasColors": {}, - "bars": false, - "dashLength": 10, - "dashes": false, - "datasource": "${DS_PROMETHEUS}", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, - "fill": 1, - "fillGradient": 0, - "gridPos": { - "h": 7, - "w": 8, - "x": 8, - "y": 27 - }, - "hiddenSeries": false, - "id": 16, - "legend": { - "avg": false, - "current": false, - "max": false, - "min": false, - "show": false, - "total": false, - "values": false - }, - "lines": true, - "linewidth": 1, - "nullPointMode": "null", - "options": { - "alertThreshold": true - }, - "percentage": false, - "pluginVersion": "7.4.2", - "pointradius": 2, - "points": false, - "renderer": "flot", - "seriesOverrides": [], - "spaceLength": 10, - "stack": false, - "steppedLine": false, - "targets": [ - { - "expr": "dendrite_caching_in_memory_lru_federation_event", - "interval": "", - "legendFormat": "Federation Event", - "refId": "A" - } - ], - "thresholds": [], - "timeFrom": null, - "timeRegions": [], - "timeShift": null, - "title": "Federation Events", - "tooltip": { - "shared": true, - "sort": 0, - "value_type": "individual" - }, - "type": "graph", - "xaxis": { - "buckets": null, - "mode": "time", - "name": null, - "show": true, - "values": [] - }, - "yaxes": [ - { - "$$hashKey": "object:784", - "format": "short", - "label": null, - "logBase": 1, - "max": null, - "min": null, - "show": true - }, - { - "$$hashKey": "object:785", - "format": "short", - "label": null, - "logBase": 1, - "max": null, - "min": null, - "show": true - } - ], - "yaxis": { - "align": false, - "alignLevel": null - } - }, - { - "aliasColors": {}, - "bars": false, - "dashLength": 10, - "dashes": false, - "datasource": "${DS_PROMETHEUS}", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, - "fill": 1, - "fillGradient": 0, - "gridPos": { - "h": 7, - "w": 8, - "x": 16, - "y": 27 - }, - "hiddenSeries": false, - "id": 18, - "legend": { - "avg": false, - "current": false, - "max": false, - "min": false, - "show": false, - "total": false, - "values": false - }, - "lines": true, - "linewidth": 1, - "nullPointMode": "null", - "options": { - "alertThreshold": true - }, - "percentage": false, - "pluginVersion": "7.4.2", - "pointradius": 2, - "points": false, - "renderer": "flot", - "seriesOverrides": [], - "spaceLength": 10, - "stack": false, - "steppedLine": false, - "targets": [ - { - "expr": "dendrite_caching_in_memory_lru_roomserver_room_ids", - "interval": "", - "legendFormat": "Room IDs", - "refId": "A" - } - ], - "thresholds": [], - "timeFrom": null, - "timeRegions": [], - "timeShift": null, - "title": "Room IDs", - "tooltip": { - "shared": true, - "sort": 0, - "value_type": "individual" - }, - "type": "graph", - "xaxis": { - "buckets": null, - "mode": "time", - "name": null, - "show": true, - "values": [] - }, - "yaxes": [ - { - "$$hashKey": "object:898", - "format": "short", - "label": null, - "logBase": 1, - "max": null, - "min": null, - "show": true - }, - { - "$$hashKey": "object:899", - "format": "short", - "label": null, - "logBase": 1, - "max": null, - "min": null, - "show": true - } - ], - "yaxis": { - "align": false, - "alignLevel": null - } - } - ], - "refresh": "10s", - "schemaVersion": 27, - "style": "dark", - "tags": [ - "matrix", - "dendrite" - ], - "templating": { - "list": [ - { - "current": { - "selected": false, - "text": "Prometheus", - "value": "Prometheus" - }, - "description": null, - "error": null, - "hide": 0, - "includeAll": false, - "label": null, - "multi": false, - "name": "datasource", - "options": [], - "query": "prometheus", - "refresh": 1, - "regex": "", - "skipUrlSync": false, - "type": "datasource" - }, - { - "auto": true, - "auto_count": 100, - "auto_min": "30s", - "current": { - "selected": false, - "text": "auto", - "value": "$__auto_interval_bucket_size" - }, - "description": null, - "error": null, - "hide": 0, - "label": "Bucket Size", - "name": "bucket_size", - "options": [ - { - "selected": true, - "text": "auto", - "value": "$__auto_interval_bucket_size" - }, - { - "selected": false, - "text": "30s", - "value": "30s" - }, - { - "selected": false, - "text": "1m", - "value": "1m" - }, - { - "selected": false, - "text": "2m", - "value": "2m" - }, - { - "selected": false, - "text": "5m", - "value": "5m" - }, - { - "selected": false, - "text": "10m", - "value": "10m" - }, - { - "selected": false, - "text": "15m", - "value": "15m" - } - ], - "query": "30s,1m,2m,5m,10m,15m", - "queryValue": "", - "refresh": 2, - "skipUrlSync": false, - "type": "interval" - }, - { - "allValue": null, - "current": {}, - "datasource": "${DS_PROMETHEUS}", - "definition": "label_values(dendrite_caching_in_memory_lru_roominfo, instance)", - "description": null, - "error": null, - "hide": 0, - "includeAll": false, - "label": null, - "multi": false, - "name": "instance", - "options": [], - "query": { - "query": "label_values(dendrite_caching_in_memory_lru_roominfo, instance)", - "refId": "StandardVariableQuery" - }, - "refresh": 2, - "regex": "", - "skipUrlSync": false, - "sort": 0, - "tagValuesQuery": "", - "tags": [], - "tagsQuery": "", - "type": "query", - "useTags": false - }, - { - "allValue": null, - "current": {}, - "datasource": "${DS_PROMETHEUS}", - "definition": "label_values(dendrite_caching_in_memory_lru_roominfo, job)", - "description": null, - "error": null, - "hide": 0, - "includeAll": true, - "label": "Job", - "multi": true, - "name": "job", - "options": [], - "query": { - "query": "label_values(dendrite_caching_in_memory_lru_roominfo, job)", - "refId": "StandardVariableQuery" - }, - "refresh": 1, - "regex": "", - "skipUrlSync": false, - "sort": 1, - "tagValuesQuery": "", - "tags": [], - "tagsQuery": "", - "type": "query", - "useTags": false - }, - { - "allValue": ".*", - "current": {}, - "datasource": "${DS_PROMETHEUS}", - "definition": "label_values(dendrite_caching_in_memory_lru_roominfo, index)", - "description": null, - "error": null, - "hide": 0, - "includeAll": true, - "label": null, - "multi": true, - "name": "index", - "options": [], - "query": { - "query": "label_values(dendrite_caching_in_memory_lru_roominfo, index)", - "refId": "StandardVariableQuery" - }, - "refresh": 2, - "regex": "", - "skipUrlSync": false, - "sort": 3, - "tagValuesQuery": "", - "tags": [], - "tagsQuery": "", - "type": "query", - "useTags": false - } - ] - }, - "time": { - "from": "now-3h", - "to": "now" - }, - "timepicker": {}, - "timezone": "", - "title": "Dendrite", - "uid": "RoRt1jEGz", - "version": 8 -} \ No newline at end of file diff --git a/helm/dendrite/grafana_dashboards/dendrite-rev2.json b/helm/dendrite/grafana_dashboards/dendrite-rev2.json new file mode 100644 index 000000000..817f950b3 --- /dev/null +++ b/helm/dendrite/grafana_dashboards/dendrite-rev2.json @@ -0,0 +1,479 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "datasource", + "uid": "grafana" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, + "type": "dashboard" + } + ] + }, + "description": "Dendrite dashboard from https://github.com/matrix-org/dendrite/", + "editable": true, + "fiscalYearStartMonth": 0, + "gnetId": 13916, + "graphTooltip": 0, + "id": 60, + "links": [], + "liveNow": false, + "panels": [ + { + "collapsed": false, + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 4, + "panels": [], + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "refId": "A" + } + ], + "title": "Overview", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Total number of registered users", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 7, + "x": 0, + "y": 1 + }, + "id": 20, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "text": {}, + "textMode": "auto" + }, + "pluginVersion": "9.3.6", + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "exemplar": false, + "expr": "sum(dendrite_clientapi_reg_users_total{namespace=~\"$namespace\",service=~\"$service\"}) by (namespace,service)", + "instant": false, + "interval": "", + "legendFormat": "{{namespace}}: {{service}}", + "refId": "A" + } + ], + "title": "Registerd Users", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "The number of sync requests that are active right now and are waiting to be woken by a notifier", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "ops" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 17, + "x": 7, + "y": 1 + }, + "id": 6, + "options": { + "legend": { + "calcs": [ + "mean", + "lastNotNull" + ], + "displayMode": "table", + "placement": "right", + "showLegend": true + }, + "tooltip": { + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "9.3.6", + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "sum(rate(dendrite_syncapi_active_sync_requests{namespace=~\"$namespace\",service=~\"$service\"}[$__rate_interval]))by (namspace,service)", + "hide": false, + "interval": "", + "legendFormat": "active: {{namspace}} - {{service}}", + "range": true, + "refId": "A" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "sum(rate(dendrite_syncapi_waiting_sync_requests{namespace=~\"$namespace\",service=~\"$service\"}[$__rate_interval]))by (namespace,service)", + "hide": false, + "interval": "", + "legendFormat": "waiting: {{namspace}} - {{service}}", + "range": true, + "refId": "B" + } + ], + "title": "Sync API", + "type": "timeseries" + }, + { + "collapsed": false, + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 9 + }, + "id": 8, + "panels": [], + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "refId": "A" + } + ], + "title": "Federation", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Collection of queues for sending transactions to other matrix servers", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 10, + "w": 24, + "x": 0, + "y": 10 + }, + "id": 10, + "options": { + "legend": { + "calcs": [ + "mean", + "lastNotNull" + ], + "displayMode": "table", + "placement": "right", + "showLegend": true + }, + "tooltip": { + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "9.3.6", + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "dendrite_federationapi_destination_queues_running{namespace=~\"$namespace\",service=~\"$service\"}", + "interval": "", + "legendFormat": "Queue Running: {{namespace}}-{{service}}", + "range": true, + "refId": "A" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "dendrite_federationapi_destination_queues_total{namespace=~\"$namespace\",service=~\"$service\"}", + "hide": false, + "interval": "", + "legendFormat": "Queue Total: {{namespace}}-{{service}}", + "range": true, + "refId": "B" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "dendrite_federationapi_destination_queues_backing_off{namespace=~\"$namespace\",service=~\"$service\"}", + "hide": false, + "interval": "", + "legendFormat": "Backing Off: {{namespace}}-{{service}}", + "range": true, + "refId": "C" + } + ], + "title": "Federation Sender Destination", + "type": "timeseries" + } + ], + "refresh": "10s", + "schemaVersion": 37, + "style": "dark", + "tags": [ + "matrix", + "dendrite" + ], + "templating": { + "list": [ + { + "current": { + "selected": false, + "text": "Prometheus", + "value": "Prometheus" + }, + "hide": 0, + "includeAll": false, + "label": "datasource", + "multi": false, + "name": "DS_PROMETHEUS", + "options": [], + "query": "prometheus", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "type": "datasource" + }, + { + "current": { + "selected": true, + "text": [ + "All" + ], + "value": [ + "$__all" + ] + }, + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "definition": "label_values(dendrite_syncapi_active_sync_requests, namespace)", + "hide": 0, + "includeAll": true, + "multi": true, + "name": "namespace", + "options": [], + "query": { + "query": "label_values(dendrite_syncapi_active_sync_requests, namespace)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "type": "query" + }, + { + "current": { + "selected": false, + "text": "All", + "value": "$__all" + }, + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "definition": "label_values(dendrite_syncapi_active_sync_requests{namespace=~\"$namespace\"}, service)", + "hide": 0, + "includeAll": true, + "multi": true, + "name": "service", + "options": [], + "query": { + "query": "label_values(dendrite_syncapi_active_sync_requests{namespace=~\"$namespace\"}, service)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "type": "query" + } + ] + }, + "time": { + "from": "now-3h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Dendrite", + "uid": "RoRt1jEGz", + "version": 1, + "weekStart": "" +} diff --git a/helm/dendrite/grafana_dashboards/dendrite-rev2.png b/helm/dendrite/grafana_dashboards/dendrite-rev2.png new file mode 100644 index 000000000..179f4b5a1 Binary files /dev/null and b/helm/dendrite/grafana_dashboards/dendrite-rev2.png differ diff --git a/helm/dendrite/templates/_helpers.tpl b/helm/dendrite/templates/_helpers.tpl index 026706588..36bcefd8f 100644 --- a/helm/dendrite/templates/_helpers.tpl +++ b/helm/dendrite/templates/_helpers.tpl @@ -1,15 +1,9 @@ {{- define "validate.config" }} -{{- if not .Values.signing_key.create -}} -{{- fail "You must create a signing key for configuration.signing_key. (see https://github.com/matrix-org/dendrite/blob/master/docs/INSTALL.md#server-key-generation)" -}} +{{- if and (not .Values.signing_key.create) (eq .Values.signing_key.existingSecret "") -}} +{{- fail "You must create a signing key for configuration.signing_key OR specify an existing secret name in .Values.signing_key.existingSecret to mount it. (see https://github.com/matrix-org/dendrite/blob/master/docs/INSTALL.md#server-key-generation)" -}} {{- end -}} -{{- if not (or .Values.dendrite_config.global.database.host .Values.postgresql.enabled) -}} -{{- fail "Database server must be set." -}} -{{- end -}} -{{- if not (or .Values.dendrite_config.global.database.user .Values.postgresql.enabled) -}} -{{- fail "Database user must be set." -}} -{{- end -}} -{{- if not (or .Values.dendrite_config.global.database.password .Values.postgresql.enabled) -}} -{{- fail "Database password must be set." -}} +{{- if and (not .Values.postgresql.enabled) (eq .Values.dendrite_config.global.database.connection_string "") -}} +{{- fail "Database connection string must be set." -}} {{- end -}} {{- end -}} diff --git a/helm/dendrite/templates/deployment.yaml b/helm/dendrite/templates/deployment.yaml index b463c7d0b..df7dbbdc3 100644 --- a/helm/dendrite/templates/deployment.yaml +++ b/helm/dendrite/templates/deployment.yaml @@ -12,16 +12,19 @@ spec: matchLabels: {{- include "dendrite.selectorLabels" . | nindent 6 }} replicas: 1 + strategy: + type: {{ $.Values.strategy.type }} + {{- if eq $.Values.strategy.type "RollingUpdate" }} + rollingUpdate: + maxSurge: {{ $.Values.strategy.rollingUpdate.maxSurge }} + maxUnavailable: {{ $.Values.strategy.rollingUpdate.maxUnavailable }} + {{- end }} template: metadata: labels: {{- include "dendrite.selectorLabels" . | nindent 8 }} annotations: - confighash-global: secret-{{ .Values.global | toYaml | sha256sum | trunc 32 }} - confighash-clientapi: clientapi-{{ .Values.clientapi | toYaml | sha256sum | trunc 32 }} - confighash-federationapi: federationapi-{{ .Values.federationapi | toYaml | sha256sum | trunc 32 }} - confighash-mediaapi: mediaapi-{{ .Values.mediaapi | toYaml | sha256sum | trunc 32 }} - confighash-syncapi: syncapi-{{ .Values.syncapi | toYaml | sha256sum | trunc 32 }} + confighash: secret-{{ .Values.dendrite_config | toYaml | sha256sum | trunc 32 }} spec: volumes: - name: {{ include "dendrite.fullname" . }}-conf-vol @@ -44,6 +47,9 @@ spec: - name: {{ include "dendrite.fullname" . }}-search persistentVolumeClaim: claimName: {{ default (print ( include "dendrite.fullname" . ) "-search-pvc") $.Values.persistence.search.existingClaim | quote }} + {{- with .Values.extraVolumes }} + {{ . | toYaml | nindent 6 }} + {{- end }} containers: - name: {{ .Chart.Name }} {{- include "image.name" . | nindent 8 }} @@ -57,7 +63,7 @@ spec: {{- if $.Values.dendrite_config.global.profiling.enabled }} env: - name: PPROFLISTEN - value: "localhost:{{- $.Values.global.profiling.port -}}" + value: "localhost:{{- $.Values.dendrite_config.global.profiling.port -}}" {{- end }} resources: {{- toYaml $.Values.resources | nindent 10 }} @@ -77,6 +83,9 @@ spec: name: {{ include "dendrite.fullname" . }}-jetstream - mountPath: {{ .Values.dendrite_config.sync_api.search.index_path }} name: {{ include "dendrite.fullname" . }}-search + {{- with .Values.extraVolumeMounts }} + {{ . | toYaml | nindent 8 }} + {{- end }} livenessProbe: initialDelaySeconds: 10 periodSeconds: 10 diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml index c219d27f8..41ec1c390 100644 --- a/helm/dendrite/values.yaml +++ b/helm/dendrite/values.yaml @@ -43,6 +43,30 @@ persistence: # -- PVC Storage Request for the search volume capacity: "1Gi" +# -- Add additional volumes to the Dendrite Pod +extraVolumes: [] +# ex. +# - name: extra-config +# secret: +# secretName: extra-config + + +# -- Configure additional mount points volumes in the Dendrite Pod +extraVolumeMounts: [] +# ex. +# - mountPath: /etc/dendrite/extra-config +# name: extra-config + +strategy: + # -- Strategy to use for rolling updates (e.g. Recreate, RollingUpdate) + # If you are using ReadWriteOnce volumes, you should probably use Recreate + type: RollingUpdate + rollingUpdate: + # -- Maximum number of pods that can be unavailable during the update process + maxUnavailable: 25% + # -- Maximum number of pods that can be scheduled above the desired number of pods + maxSurge: 25% + dendrite_config: version: 2 global: diff --git a/internal/caching/cache_roomevents.go b/internal/caching/cache_roomevents.go index 9d5d3b912..14b6c3af8 100644 --- a/internal/caching/cache_roomevents.go +++ b/internal/caching/cache_roomevents.go @@ -10,6 +10,7 @@ import ( type RoomServerEventsCache interface { GetRoomServerEvent(eventNID types.EventNID) (*gomatrixserverlib.Event, bool) StoreRoomServerEvent(eventNID types.EventNID, event *gomatrixserverlib.Event) + InvalidateRoomServerEvent(eventNID types.EventNID) } func (c Caches) GetRoomServerEvent(eventNID types.EventNID) (*gomatrixserverlib.Event, bool) { @@ -19,3 +20,7 @@ func (c Caches) GetRoomServerEvent(eventNID types.EventNID) (*gomatrixserverlib. func (c Caches) StoreRoomServerEvent(eventNID types.EventNID, event *gomatrixserverlib.Event) { c.RoomServerEvents.Set(int64(eventNID), event) } + +func (c Caches) InvalidateRoomServerEvent(eventNID types.EventNID) { + c.RoomServerEvents.Unset(int64(eventNID)) +} diff --git a/internal/caching/cache_serverkeys.go b/internal/caching/cache_serverkeys.go index cffa101d5..37e331ab0 100644 --- a/internal/caching/cache_serverkeys.go +++ b/internal/caching/cache_serverkeys.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // ServerKeyCache contains the subset of functions needed for @@ -14,7 +15,7 @@ type ServerKeyCache interface { // The timestamp should be the timestamp of the event that is being // verified. We will not return keys from the cache that are not valid // at this timestamp. - GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest, timestamp gomatrixserverlib.Timestamp) (response gomatrixserverlib.PublicKeyLookupResult, ok bool) + GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest, timestamp spec.Timestamp) (response gomatrixserverlib.PublicKeyLookupResult, ok bool) // request -> result is emulating gomatrixserverlib.StoreKeys: // https://github.com/matrix-org/gomatrixserverlib/blob/f69539c86ea55d1e2cc76fd8e944e2d82d30397c/keyring.go#L112 @@ -23,7 +24,7 @@ type ServerKeyCache interface { func (c Caches) GetServerKey( request gomatrixserverlib.PublicKeyLookupRequest, - timestamp gomatrixserverlib.Timestamp, + timestamp spec.Timestamp, ) (gomatrixserverlib.PublicKeyLookupResult, bool) { key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) val, found := c.ServerKeys.Get(key) diff --git a/internal/caching/cache_space_rooms.go b/internal/caching/cache_space_rooms.go index 697f99269..100ab9023 100644 --- a/internal/caching/cache_space_rooms.go +++ b/internal/caching/cache_space_rooms.go @@ -1,18 +1,16 @@ package caching -import ( - "github.com/matrix-org/gomatrixserverlib" -) +import "github.com/matrix-org/gomatrixserverlib/fclient" type SpaceSummaryRoomsCache interface { - GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) - StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) + GetSpaceSummary(roomID string) (r fclient.MSC2946SpacesResponse, ok bool) + StoreSpaceSummary(roomID string, r fclient.MSC2946SpacesResponse) } -func (c Caches) GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) { +func (c Caches) GetSpaceSummary(roomID string) (r fclient.MSC2946SpacesResponse, ok bool) { return c.SpaceSummaryRooms.Get(roomID) } -func (c Caches) StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) { +func (c Caches) StoreSpaceSummary(roomID string, r fclient.MSC2946SpacesResponse) { c.SpaceSummaryRooms.Set(roomID, r) } diff --git a/internal/caching/caches.go b/internal/caching/caches.go index 479920466..a678632eb 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -17,6 +17,7 @@ package caching import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" ) // Caches contains a set of references to caches. They may be @@ -34,7 +35,7 @@ type Caches struct { RoomServerEventTypes Cache[types.EventTypeNID, string] // eventType NID -> eventType FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU - SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response + SpaceSummaryRooms Cache[string, fclient.MSC2946SpacesResponse] // room ID -> space response LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID } diff --git a/internal/caching/impl_ristretto.go b/internal/caching/impl_ristretto.go index fca93afd1..4656b6b7e 100644 --- a/internal/caching/impl_ristretto.go +++ b/internal/caching/impl_ristretto.go @@ -22,11 +22,13 @@ import ( "github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto/z" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" ) const ( @@ -45,6 +47,11 @@ const ( eventStateKeyNIDCache ) +const ( + DisableMetrics = false + EnableMetrics = true +) + func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enablePrometheus bool) *Caches { cache, err := ristretto.NewCache(&ristretto.Config{ NumCounters: int64((maxCost / 1024) * 10), // 10 counters per 1KB data, affects bloom filter size @@ -98,9 +105,10 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm }, RoomServerEvents: &RistrettoCostedCachePartition[int64, *gomatrixserverlib.Event]{ // event NID -> event &RistrettoCachePartition[int64, *gomatrixserverlib.Event]{ - cache: cache, - Prefix: roomEventsCache, - MaxAge: maxAge, + cache: cache, + Prefix: roomEventsCache, + MaxAge: maxAge, + Mutable: true, }, }, RoomServerStateKeys: &RistrettoCachePartition[types.EventStateKeyNID, string]{ // event NID -> event state key @@ -139,7 +147,7 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm MaxAge: lesserOf(time.Hour/2, maxAge), }, }, - SpaceSummaryRooms: &RistrettoCachePartition[string, gomatrixserverlib.MSC2946SpacesResponse]{ // room ID -> space response + SpaceSummaryRooms: &RistrettoCachePartition[string, fclient.MSC2946SpacesResponse]{ // room ID -> space response cache: cache, Prefix: spaceSummaryRoomsCache, Mutable: true, diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index c572d8830..c7dee3464 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -22,6 +22,8 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib" ) @@ -39,7 +41,7 @@ var ErrRoomNoExists = errors.New("room does not exist") func QueryAndBuildEvent( ctx context.Context, builder *gomatrixserverlib.EventBuilder, cfg *config.Global, - identity *gomatrixserverlib.SigningIdentity, evTime time.Time, + identity *fclient.SigningIdentity, evTime time.Time, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.HeaderedEvent, error) { if queryRes == nil { @@ -59,7 +61,7 @@ func QueryAndBuildEvent( func BuildEvent( ctx context.Context, builder *gomatrixserverlib.EventBuilder, cfg *config.Global, - identity *gomatrixserverlib.SigningIdentity, evTime time.Time, + identity *fclient.SigningIdentity, evTime time.Time, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.HeaderedEvent, error) { if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil { @@ -173,7 +175,7 @@ func truncateAuthAndPrevEvents(auth, prev []gomatrixserverlib.EventReference) ( // downstream components to the roomserver when an OutputTypeRedactedEvent occurs. func RedactEvent(redactionEvent, redactedEvent *gomatrixserverlib.Event) error { // sanity check - if redactionEvent.Type() != gomatrixserverlib.MRoomRedaction { + if redactionEvent.Type() != spec.MRoomRedaction { return fmt.Errorf("RedactEvent: redactionEvent isn't a redaction event, is '%s'", redactionEvent.Type()) } redactedEvent.Redact() diff --git a/internal/eventutil/types.go b/internal/eventutil/types.go index 18175d6a0..e43c0412e 100644 --- a/internal/eventutil/types.go +++ b/internal/eventutil/types.go @@ -15,16 +15,11 @@ package eventutil import ( - "errors" "strconv" "github.com/matrix-org/dendrite/syncapi/types" ) -// ErrProfileNoExists is returned when trying to lookup a user's profile that -// doesn't exist locally. -var ErrProfileNoExists = errors.New("no known profile for given user ID") - // AccountData represents account data sent from the client API server to the // sync API server type AccountData struct { @@ -56,20 +51,10 @@ type NotificationData struct { UnreadNotificationCount int `json:"unread_notification_count"` } -// ProfileResponse is a struct containing all known user profile data -type ProfileResponse struct { - AvatarURL string `json:"avatar_url"` - DisplayName string `json:"displayname"` -} - -// AvatarURL is a struct containing only the URL to a user's avatar -type AvatarURL struct { - AvatarURL string `json:"avatar_url"` -} - -// DisplayName is a struct containing only a user's display name -type DisplayName struct { - DisplayName string `json:"displayname"` +// UserProfile is a struct containing all known user profile data +type UserProfile struct { + AvatarURL string `json:"avatar_url,omitempty"` + DisplayName string `json:"displayname,omitempty"` } // WeakBoolean is a type that will Unmarshal to true or false even if the encoded diff --git a/internal/fulltext/bleve.go b/internal/fulltext/bleve.go index 7187861dd..d2807198a 100644 --- a/internal/fulltext/bleve.go +++ b/internal/fulltext/bleve.go @@ -18,9 +18,12 @@ package fulltext import ( + "regexp" "strings" "github.com/blevesearch/bleve/v2" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/gomatrixserverlib/spec" // side effect imports to allow all possible languages _ "github.com/blevesearch/bleve/v2/analysis/lang/ar" @@ -45,7 +48,6 @@ import ( _ "github.com/blevesearch/bleve/v2/analysis/lang/sv" _ "github.com/blevesearch/bleve/v2/analysis/lang/tr" "github.com/blevesearch/bleve/v2/mapping" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/setup/config" ) @@ -55,6 +57,14 @@ type Search struct { FulltextIndex bleve.Index } +type Indexer interface { + Index(elements ...IndexElement) error + Delete(eventID string) error + Search(term string, roomIDs, keys []string, limit, from int, orderByStreamPos bool) (*bleve.SearchResult, error) + GetHighlights(result *bleve.SearchResult) []string + Close() error +} + // IndexElement describes the layout of an element to index type IndexElement struct { EventID string @@ -69,20 +79,27 @@ func (i *IndexElement) SetContentType(v string) { switch v { case "m.room.message": i.ContentType = "content.body" - case gomatrixserverlib.MRoomName: + case spec.MRoomName: i.ContentType = "content.name" - case gomatrixserverlib.MRoomTopic: + case spec.MRoomTopic: i.ContentType = "content.topic" } } // New opens a new/existing fulltext index -func New(cfg config.Fulltext) (fts *Search, err error) { +func New(processCtx *process.ProcessContext, cfg config.Fulltext) (fts *Search, err error) { fts = &Search{} fts.FulltextIndex, err = openIndex(cfg) if err != nil { return nil, err } + go func() { + processCtx.ComponentStarted() + // Wait for the processContext to be done, indicating that Dendrite is shutting down. + <-processCtx.WaitForShutdown() + _ = fts.Close() + processCtx.ComponentFinished() + }() return fts, nil } @@ -109,6 +126,47 @@ func (f *Search) Delete(eventID string) error { return f.FulltextIndex.Delete(eventID) } +var highlightMatcher = regexp.MustCompile("(.*?)") + +// GetHighlights extracts the highlights from a SearchResult. +func (f *Search) GetHighlights(result *bleve.SearchResult) []string { + if result == nil { + return []string{} + } + + seenMatches := make(map[string]struct{}) + + for _, hit := range result.Hits { + if hit.Fragments == nil { + continue + } + fragments, ok := hit.Fragments["Content"] + if !ok { + continue + } + for _, x := range fragments { + substringMatches := highlightMatcher.FindAllStringSubmatch(x, -1) + for _, matches := range substringMatches { + for i := range matches { + if i == 0 { // skip first match, this is the complete substring match + continue + } + if _, ok := seenMatches[matches[i]]; ok { + continue + } + seenMatches[matches[i]] = struct{}{} + } + } + } + } + + res := make([]string, 0, len(seenMatches)) + for m := range seenMatches { + res = append(res, m) + } + return res +} + // Search searches the index given a search term, roomIDs and keys. func (f *Search) Search(term string, roomIDs, keys []string, limit, from int, orderByStreamPos bool) (*bleve.SearchResult, error) { qry := bleve.NewConjunctionQuery() @@ -148,6 +206,10 @@ func (f *Search) Search(term string, roomIDs, keys []string, limit, from int, or s.SortBy([]string{"-StreamPosition"}) } + // Highlight some words + s.Highlight = bleve.NewHighlight() + s.Highlight.Fields = []string{"Content"} + return f.FulltextIndex.Search(s) } diff --git a/internal/fulltext/bleve_test.go b/internal/fulltext/bleve_test.go index d16397a45..decb5eccb 100644 --- a/internal/fulltext/bleve_test.go +++ b/internal/fulltext/bleve_test.go @@ -18,14 +18,15 @@ import ( "reflect" "testing" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/setup/config" ) -func mustOpenIndex(t *testing.T, tempDir string) *fulltext.Search { +func mustOpenIndex(t *testing.T, tempDir string) (*fulltext.Search, *process.ProcessContext) { t.Helper() cfg := config.Fulltext{ Enabled: true, @@ -36,11 +37,12 @@ func mustOpenIndex(t *testing.T, tempDir string) *fulltext.Search { cfg.IndexPath = config.Path(tempDir) cfg.InMemory = false } - fts, err := fulltext.New(cfg) + ctx := process.NewProcessContext() + fts, err := fulltext.New(ctx, cfg) if err != nil { t.Fatal("failed to open fulltext index:", err) } - return fts + return fts, ctx } func mustAddTestData(t *testing.T, fts *fulltext.Search, firstStreamPos int64) (eventIDs, roomIDs []string) { @@ -75,7 +77,7 @@ func mustAddTestData(t *testing.T, fts *fulltext.Search, firstStreamPos int64) ( Content: "Roomname testing", StreamPosition: streamPos, } - e.SetContentType(gomatrixserverlib.MRoomName) + e.SetContentType(spec.MRoomName) batchItems = append(batchItems, e) e = fulltext.IndexElement{ EventID: util.RandomString(16), @@ -83,7 +85,7 @@ func mustAddTestData(t *testing.T, fts *fulltext.Search, firstStreamPos int64) ( Content: "Room topic fulltext", StreamPosition: streamPos, } - e.SetContentType(gomatrixserverlib.MRoomTopic) + e.SetContentType(spec.MRoomTopic) batchItems = append(batchItems, e) if err := fts.Index(batchItems...); err != nil { t.Fatalf("failed to batch insert elements: %v", err) @@ -93,19 +95,17 @@ func mustAddTestData(t *testing.T, fts *fulltext.Search, firstStreamPos int64) ( func TestOpen(t *testing.T) { dataDir := t.TempDir() - fts := mustOpenIndex(t, dataDir) - if err := fts.Close(); err != nil { - t.Fatal("unable to close fulltext index", err) - } + _, ctx := mustOpenIndex(t, dataDir) + ctx.ShutdownDendrite() // open existing index - fts = mustOpenIndex(t, dataDir) - defer fts.Close() + _, ctx = mustOpenIndex(t, dataDir) + ctx.ShutdownDendrite() } func TestIndex(t *testing.T) { - fts := mustOpenIndex(t, "") - defer fts.Close() + fts, ctx := mustOpenIndex(t, "") + defer ctx.ShutdownDendrite() // add some data var streamPos int64 = 1 @@ -128,8 +128,8 @@ func TestIndex(t *testing.T) { } func TestDelete(t *testing.T) { - fts := mustOpenIndex(t, "") - defer fts.Close() + fts, ctx := mustOpenIndex(t, "") + defer ctx.ShutdownDendrite() eventIDs, roomIDs := mustAddTestData(t, fts, 0) res1, err := fts.Search("lorem", roomIDs[:1], nil, 50, 0, false) if err != nil { @@ -160,14 +160,16 @@ func TestSearch(t *testing.T) { roomIndex []int } tests := []struct { - name string - args args - wantCount int - wantErr bool + name string + args args + wantCount int + wantErr bool + wantHighlights []string }{ { - name: "Can search for many results in one room", - wantCount: 16, + name: "Can search for many results in one room", + wantCount: 16, + wantHighlights: []string{"lorem"}, args: args{ term: "lorem", roomIndex: []int{0}, @@ -175,8 +177,9 @@ func TestSearch(t *testing.T) { }, }, { - name: "Can search for one result in one room", - wantCount: 1, + name: "Can search for one result in one room", + wantCount: 1, + wantHighlights: []string{"lorem"}, args: args{ term: "lorem", roomIndex: []int{16}, @@ -184,8 +187,9 @@ func TestSearch(t *testing.T) { }, }, { - name: "Can search for many results in multiple rooms", - wantCount: 17, + name: "Can search for many results in multiple rooms", + wantCount: 17, + wantHighlights: []string{"lorem"}, args: args{ term: "lorem", roomIndex: []int{0, 16}, @@ -193,8 +197,9 @@ func TestSearch(t *testing.T) { }, }, { - name: "Can search for many results in all rooms, reversed", - wantCount: 30, + name: "Can search for many results in all rooms, reversed", + wantCount: 30, + wantHighlights: []string{"lorem"}, args: args{ term: "lorem", limit: 30, @@ -202,8 +207,9 @@ func TestSearch(t *testing.T) { }, }, { - name: "Can search for specific search room name", - wantCount: 1, + name: "Can search for specific search room name", + wantCount: 1, + wantHighlights: []string{"testing"}, args: args{ term: "testing", roomIndex: []int{}, @@ -212,8 +218,9 @@ func TestSearch(t *testing.T) { }, }, { - name: "Can search for specific search room topic", - wantCount: 1, + name: "Can search for specific search room topic", + wantCount: 1, + wantHighlights: []string{"fulltext"}, args: args{ term: "fulltext", roomIndex: []int{}, @@ -222,9 +229,11 @@ func TestSearch(t *testing.T) { }, }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - f := mustOpenIndex(t, "") + f, ctx := mustOpenIndex(t, "") + defer ctx.ShutdownDendrite() eventIDs, roomIDs := mustAddTestData(t, f, 0) var searchRooms []string for _, x := range tt.args.roomIndex { @@ -237,6 +246,12 @@ func TestSearch(t *testing.T) { t.Errorf("Search() error = %v, wantErr %v", err, tt.wantErr) return } + + highlights := f.GetHighlights(got) + if !reflect.DeepEqual(highlights, tt.wantHighlights) { + t.Errorf("Search() got highligts = %v, want %v", highlights, tt.wantHighlights) + } + if !reflect.DeepEqual(len(got.Hits), tt.wantCount) { t.Errorf("Search() got = %v, want %v", len(got.Hits), tt.wantCount) } diff --git a/internal/fulltext/bleve_wasm.go b/internal/fulltext/bleve_wasm.go index a69a8926e..12709900b 100644 --- a/internal/fulltext/bleve_wasm.go +++ b/internal/fulltext/bleve_wasm.go @@ -15,8 +15,9 @@ package fulltext import ( - "github.com/matrix-org/dendrite/setup/config" "time" + + "github.com/matrix-org/dendrite/setup/config" ) type Search struct{} @@ -28,6 +29,14 @@ type IndexElement struct { StreamPosition int64 } +type Indexer interface { + Index(elements ...IndexElement) error + Delete(eventID string) error + Search(term string, roomIDs, keys []string, limit, from int, orderByStreamPos bool) (SearchResult, error) + GetHighlights(result SearchResult) []string + Close() error +} + type SearchResult struct { Status interface{} `json:"status"` Request *interface{} `json:"request"` @@ -48,7 +57,7 @@ func (f *Search) Close() error { return nil } -func (f *Search) Index(e IndexElement) error { +func (f *Search) Index(e ...IndexElement) error { return nil } @@ -63,3 +72,7 @@ func (f *Search) Delete(eventID string) error { func (f *Search) Search(term string, roomIDs, keys []string, limit, from int, orderByStreamPos bool) (SearchResult, error) { return SearchResult{}, nil } + +func (f *Search) GetHighlights(result SearchResult) []string { + return []string{} +} diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 37d144f4e..289d1d2ca 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -25,8 +25,6 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/ext" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -34,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -186,9 +185,9 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse } } - span := opentracing.StartSpan(metricsName) - defer span.Finish() - req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) + trace, ctx := internal.StartTask(req.Context(), metricsName) + defer trace.EndTask() + req = req.WithContext(ctx) h.ServeHTTP(nextWriter, req) } @@ -200,9 +199,9 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // This is used to serve HTML alongside JSON error messages func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request)) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { - span := opentracing.StartSpan(metricsName) - defer span.Finish() - req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) + trace, ctx := internal.StartTask(req.Context(), metricsName) + defer trace.EndTask() + req = req.WithContext(ctx) f(w, req) } @@ -223,57 +222,6 @@ func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWri ) } -// MakeInternalAPI turns a util.JSONRequestHandler function into an http.Handler. -// This is used for APIs that are internal to dendrite. -// If we are passed a tracing context in the request headers then we use that -// as the parent of any tracing spans we create. -func MakeInternalAPI(metricsName string, enableMetrics bool, f func(*http.Request) util.JSONResponse) http.Handler { - h := util.MakeJSONAPI(util.NewJSONRequestHandler(f)) - withSpan := func(w http.ResponseWriter, req *http.Request) { - carrier := opentracing.HTTPHeadersCarrier(req.Header) - tracer := opentracing.GlobalTracer() - clientContext, err := tracer.Extract(opentracing.HTTPHeaders, carrier) - var span opentracing.Span - if err == nil { - // Default to a span without RPC context. - span = tracer.StartSpan(metricsName) - } else { - // Set the RPC context. - span = tracer.StartSpan(metricsName, ext.RPCServerOption(clientContext)) - } - defer span.Finish() - req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) - h.ServeHTTP(w, req) - } - - if !enableMetrics { - return http.HandlerFunc(withSpan) - } - - return promhttp.InstrumentHandlerCounter( - promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: metricsName + "_requests_total", - Help: "Total number of internal API calls", - Namespace: "dendrite", - }, - []string{"code"}, - ), - promhttp.InstrumentHandlerResponseSize( - promauto.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "dendrite", - Name: metricsName + "_response_size_bytes", - Help: "A histogram of response sizes for requests.", - Buckets: []float64{200, 500, 900, 1500, 5000, 15000, 50000, 100000}, - }, - []string{}, - ), - http.HandlerFunc(withSpan), - ), - ) -} - // WrapHandlerInBasicAuth adds basic auth to a handler. Only used for /metrics func WrapHandlerInBasicAuth(h http.Handler, b BasicAuth) http.HandlerFunc { if b.Username == "" || b.Password == "" { diff --git a/internal/httputil/routing.go b/internal/httputil/routing.go index 0bd3655ec..c733c8ce7 100644 --- a/internal/httputil/routing.go +++ b/internal/httputil/routing.go @@ -15,7 +15,10 @@ package httputil import ( + "net/http" "net/url" + + "github.com/gorilla/mux" ) // URLDecodeMapValues is a function that iterates through each of the items in a @@ -33,3 +36,52 @@ func URLDecodeMapValues(vmap map[string]string) (map[string]string, error) { return decoded, nil } + +type Routers struct { + Client *mux.Router + Federation *mux.Router + Keys *mux.Router + Media *mux.Router + WellKnown *mux.Router + Static *mux.Router + DendriteAdmin *mux.Router + SynapseAdmin *mux.Router +} + +func NewRouters() Routers { + r := Routers{ + Client: mux.NewRouter().SkipClean(true).PathPrefix(PublicClientPathPrefix).Subrouter().UseEncodedPath(), + Federation: mux.NewRouter().SkipClean(true).PathPrefix(PublicFederationPathPrefix).Subrouter().UseEncodedPath(), + Keys: mux.NewRouter().SkipClean(true).PathPrefix(PublicKeyPathPrefix).Subrouter().UseEncodedPath(), + Media: mux.NewRouter().SkipClean(true).PathPrefix(PublicMediaPathPrefix).Subrouter().UseEncodedPath(), + WellKnown: mux.NewRouter().SkipClean(true).PathPrefix(PublicWellKnownPrefix).Subrouter().UseEncodedPath(), + Static: mux.NewRouter().SkipClean(true).PathPrefix(PublicStaticPath).Subrouter().UseEncodedPath(), + DendriteAdmin: mux.NewRouter().SkipClean(true).PathPrefix(DendriteAdminPathPrefix).Subrouter().UseEncodedPath(), + SynapseAdmin: mux.NewRouter().SkipClean(true).PathPrefix(SynapseAdminPathPrefix).Subrouter().UseEncodedPath(), + } + r.configureHTTPErrors() + return r +} + +var NotAllowedHandler = WrapHandlerInCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell +})) + +var NotFoundCORSHandler = WrapHandlerInCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell +})) + +func (r *Routers) configureHTTPErrors() { + for _, router := range []*mux.Router{ + r.Client, r.Federation, r.Keys, + r.Media, r.WellKnown, r.Static, + r.DendriteAdmin, r.SynapseAdmin, + } { + router.NotFoundHandler = NotFoundCORSHandler + router.MethodNotAllowedHandler = NotAllowedHandler + } +} diff --git a/internal/httputil/routing_test.go b/internal/httputil/routing_test.go new file mode 100644 index 000000000..21e2bf48a --- /dev/null +++ b/internal/httputil/routing_test.go @@ -0,0 +1,38 @@ +package httputil + +import ( + "net/http" + "net/http/httptest" + "path/filepath" + "testing" +) + +func TestRoutersError(t *testing.T) { + r := NewRouters() + + // not found test + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, filepath.Join(PublicFederationPathPrefix, "doesnotexist"), nil) + r.Federation.ServeHTTP(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("unexpected status code: %d - %s", rec.Code, rec.Body.String()) + } + if ct := rec.Header().Get("Content-Type"); ct != "application/json" { + t.Fatalf("unexpected content-type: %s", ct) + } + + // not allowed test + r.DendriteAdmin. + Handle("/test", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {})). + Methods(http.MethodPost) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, filepath.Join(DendriteAdminPathPrefix, "test"), nil) + r.DendriteAdmin.ServeHTTP(rec, req) + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("unexpected status code: %d - %s", rec.Code, rec.Body.String()) + } + if ct := rec.Header().Get("Content-Type"); ct != "application/json" { + t.Fatalf("unexpected content-type: %s", ct) + } +} diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go index 259239b87..d5671be3b 100644 --- a/internal/pushgateway/client.go +++ b/internal/pushgateway/client.go @@ -10,8 +10,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - - "github.com/opentracing/opentracing-go" ) type httpClient struct { @@ -34,8 +32,8 @@ func NewHTTPClient(disableTLSValidation bool) Client { } func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "Notify") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "Notify") + defer trace.EndRegion() body, err := json.Marshal(req) if err != nil { diff --git a/internal/pushrules/default.go b/internal/pushrules/default.go index 996985514..202a10d79 100644 --- a/internal/pushrules/default.go +++ b/internal/pushrules/default.go @@ -1,12 +1,10 @@ package pushrules -import ( - "github.com/matrix-org/gomatrixserverlib" -) +import "github.com/matrix-org/gomatrixserverlib/spec" // DefaultAccountRuleSets is the complete set of default push rules // for an account. -func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets { +func DefaultAccountRuleSets(localpart string, serverName spec.ServerName) *AccountRuleSets { return &AccountRuleSets{ Global: *DefaultGlobalRuleSet(localpart, serverName), } @@ -14,7 +12,7 @@ func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.Serve // DefaultGlobalRuleSet returns the default ruleset for a given (fully // qualified) MXID. -func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet { +func DefaultGlobalRuleSet(localpart string, serverName spec.ServerName) *RuleSet { return &RuleSet{ Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)), Content: defaultContentRules(localpart), diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go index f50c51bd7..b54ec3fb0 100644 --- a/internal/pushrules/validate.go +++ b/internal/pushrules/validate.go @@ -10,6 +10,10 @@ import ( func ValidateRule(kind Kind, rule *Rule) []error { var errs []error + if len(rule.RuleID) > 0 && rule.RuleID[:1] == "." { + errs = append(errs, fmt.Errorf("invalid rule ID: rule can not start with a dot")) + } + if !validRuleIDRE.MatchString(rule.RuleID) { errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID)) } diff --git a/internal/sqlutil/connection_manager.go b/internal/sqlutil/connection_manager.go new file mode 100644 index 000000000..934a2954a --- /dev/null +++ b/internal/sqlutil/connection_manager.go @@ -0,0 +1,75 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlutil + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" +) + +type Connections struct { + db *sql.DB + writer Writer + globalConfig config.DatabaseOptions + processContext *process.ProcessContext +} + +func NewConnectionManager(processCtx *process.ProcessContext, globalConfig config.DatabaseOptions) Connections { + return Connections{ + globalConfig: globalConfig, + processContext: processCtx, + } +} + +func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB, Writer, error) { + writer := NewDummyWriter() + if dbProperties.ConnectionString.IsSQLite() { + writer = NewExclusiveWriter() + } + var err error + if dbProperties.ConnectionString == "" { + // if no connectionString was provided, try the global one + dbProperties = &c.globalConfig + } + if dbProperties.ConnectionString != "" || c.db == nil { + // Open a new database connection using the supplied config. + c.db, err = Open(dbProperties, writer) + if err != nil { + return nil, nil, err + } + c.writer = writer + go func() { + if c.processContext == nil { + return + } + // If we have a ProcessContext, start a component and wait for + // Dendrite to shut down to cleanly close the database connection. + c.processContext.ComponentStarted() + <-c.processContext.WaitForShutdown() + _ = c.db.Close() + c.processContext.ComponentFinished() + }() + return c.db, c.writer, nil + } + if c.db != nil && c.writer != nil { + // Ignore the supplied config and return the global pool and + // writer. + return c.db, c.writer, nil + } + return nil, nil, fmt.Errorf("no database connections configured") +} diff --git a/internal/sqlutil/connection_manager_test.go b/internal/sqlutil/connection_manager_test.go new file mode 100644 index 000000000..a9ac8d57f --- /dev/null +++ b/internal/sqlutil/connection_manager_test.go @@ -0,0 +1,56 @@ +package sqlutil_test + +import ( + "reflect" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" +) + +func TestConnectionManager(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + t.Cleanup(close) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + + dbProps := &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)} + db, writer, err := cm.Connection(dbProps) + if err != nil { + t.Fatal(err) + } + + switch dbType { + case test.DBTypeSQLite: + _, ok := writer.(*sqlutil.ExclusiveWriter) + if !ok { + t.Fatalf("expected exclusive writer") + } + case test.DBTypePostgres: + _, ok := writer.(*sqlutil.DummyWriter) + if !ok { + t.Fatalf("expected dummy writer") + } + } + + // test global db pool + dbGlobal, writerGlobal, err := cm.Connection(&config.DatabaseOptions{}) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(db, dbGlobal) { + t.Fatalf("expected database connection to be reused") + } + if !reflect.DeepEqual(writer, writerGlobal) { + t.Fatalf("expected database writer to be reused") + } + + // test invalid connection string configured + cm2 := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + _, _, err = cm2.Connection(&config.DatabaseOptions{ConnectionString: "http://"}) + if err == nil { + t.Fatal("expected an error but got none") + } + }) +} diff --git a/internal/sqlutil/sqlite_cgo.go b/internal/sqlutil/sqlite_cgo.go index efb743fc7..2fe2396ff 100644 --- a/internal/sqlutil/sqlite_cgo.go +++ b/internal/sqlutil/sqlite_cgo.go @@ -4,7 +4,6 @@ package sqlutil import ( - "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3" ) @@ -13,7 +12,3 @@ const SQLITE_DRIVER_NAME = "sqlite3" func sqliteDSNExtension(dsn string) string { return dsn } - -func sqliteDriver() *sqlite3.SQLiteDriver { - return &sqlite3.SQLiteDriver{} -} diff --git a/internal/sqlutil/sqlite_native.go b/internal/sqlutil/sqlite_native.go index ed500afc6..3c545b659 100644 --- a/internal/sqlutil/sqlite_native.go +++ b/internal/sqlutil/sqlite_native.go @@ -4,7 +4,6 @@ package sqlutil import ( - "modernc.org/sqlite" "strings" ) @@ -23,7 +22,3 @@ func sqliteDSNExtension(dsn string) string { dsn += "_pragma=busy_timeout%3d10000" return dsn } - -func sqliteDriver() *sqlite.Driver { - return &sqlite.Driver{} -} diff --git a/internal/sqlutil/sqlutil.go b/internal/sqlutil/sqlutil.go index 39a067e52..b2d0d2186 100644 --- a/internal/sqlutil/sqlutil.go +++ b/internal/sqlutil/sqlutil.go @@ -13,8 +13,7 @@ import ( var skipSanityChecks = flag.Bool("skip-db-sanity", false, "Ignore sanity checks on the database connections (NOT RECOMMENDED!)") // Open opens a database specified by its database driver name and a driver-specific data source name, -// usually consisting of at least a database name and connection information. Includes tracing driver -// if DENDRITE_TRACE_SQL=1 +// usually consisting of at least a database name and connection information. func Open(dbProperties *config.DatabaseOptions, writer Writer) (*sql.DB, error) { var err error var driverName, dsn string @@ -32,10 +31,6 @@ func Open(dbProperties *config.DatabaseOptions, writer Writer) (*sql.DB, error) default: return nil, fmt.Errorf("invalid database connection string %q", dbProperties.ConnectionString) } - if tracingEnabled { - // install the wrapped driver - driverName += "-trace" - } db, err := sql.Open(driverName, dsn) if err != nil { return nil, err diff --git a/internal/sqlutil/trace.go b/internal/sqlutil/trace.go deleted file mode 100644 index 7b637106b..000000000 --- a/internal/sqlutil/trace.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlutil - -import ( - "context" - "database/sql/driver" - "fmt" - "io" - "os" - "runtime" - "strconv" - "strings" - "sync" - "time" - - "github.com/ngrok/sqlmw" - "github.com/sirupsen/logrus" -) - -var tracingEnabled = os.Getenv("DENDRITE_TRACE_SQL") == "1" -var goidToWriter sync.Map - -type traceInterceptor struct { - sqlmw.NullInterceptor -} - -func (in *traceInterceptor) StmtQueryContext(ctx context.Context, stmt driver.StmtQueryContext, query string, args []driver.NamedValue) (context.Context, driver.Rows, error) { - startedAt := time.Now() - rows, err := stmt.QueryContext(ctx, args) - - trackGoID(query) - - logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) - - return ctx, rows, err -} - -func (in *traceInterceptor) StmtExecContext(ctx context.Context, stmt driver.StmtExecContext, query string, args []driver.NamedValue) (driver.Result, error) { - startedAt := time.Now() - result, err := stmt.ExecContext(ctx, args) - - trackGoID(query) - - logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) - - return result, err -} - -func (in *traceInterceptor) RowsNext(c context.Context, rows driver.Rows, dest []driver.Value) error { - err := rows.Next(dest) - if err == io.EOF { - // For all cases, we call Next() n+1 times, the first to populate the initial dest, then eventually - // it will io.EOF. If we log on each Next() call we log the last element twice, so don't. - return err - } - cols := rows.Columns() - logrus.Debug(strings.Join(cols, " | ")) - - b := strings.Builder{} - for i, val := range dest { - b.WriteString(fmt.Sprintf("%q", val)) - if i+1 <= len(dest)-1 { - b.WriteString(" | ") - } - } - logrus.Debug(b.String()) - return err -} - -func trackGoID(query string) { - thisGoID := goid() - if _, ok := goidToWriter.Load(thisGoID); ok { - return // we're on a writer goroutine - } - - q := strings.TrimSpace(query) - if strings.HasPrefix(q, "SELECT") { - return // SELECTs can go on other goroutines - } - logrus.Warnf("unsafe goid %d: SQL executed not on an ExclusiveWriter: %s", thisGoID, q) -} - -func init() { - registerDrivers() -} - -func goid() int { - var buf [64]byte - n := runtime.Stack(buf[:], false) - idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] - id, err := strconv.Atoi(idField) - if err != nil { - panic(fmt.Sprintf("cannot get goroutine id: %v", err)) - } - return id -} diff --git a/internal/sqlutil/trace_driver.go b/internal/sqlutil/trace_driver.go deleted file mode 100644 index a2e0d12e2..000000000 --- a/internal/sqlutil/trace_driver.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package sqlutil - -import ( - "database/sql" - - "github.com/lib/pq" - "github.com/ngrok/sqlmw" -) - -func registerDrivers() { - if !tracingEnabled { - return - } - // install the wrapped drivers - sql.Register("postgres-trace", sqlmw.Driver(&pq.Driver{}, new(traceInterceptor))) - sql.Register("sqlite3-trace", sqlmw.Driver(sqliteDriver(), new(traceInterceptor))) - -} diff --git a/internal/sqlutil/trace_driver_wasm.go b/internal/sqlutil/trace_driver_wasm.go deleted file mode 100644 index 51b60c3c8..000000000 --- a/internal/sqlutil/trace_driver_wasm.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build wasm -// +build wasm - -package sqlutil - -import ( - "database/sql" - - sqlitejs "github.com/matrix-org/go-sqlite3-js" - "github.com/ngrok/sqlmw" -) - -func registerDrivers() { - if !tracingEnabled { - return - } - // install the wrapped drivers - sql.Register("sqlite3_js-trace", sqlmw.Driver(&sqlitejs.SqliteJsDriver{}, new(traceInterceptor))) - -} diff --git a/internal/sqlutil/writer_exclusive.go b/internal/sqlutil/writer_exclusive.go index 8eff3ce55..c6a271c1c 100644 --- a/internal/sqlutil/writer_exclusive.go +++ b/internal/sqlutil/writer_exclusive.go @@ -60,11 +60,6 @@ func (w *ExclusiveWriter) run() { if !w.running.CompareAndSwap(false, true) { return } - if tracingEnabled { - gid := goid() - goidToWriter.Store(gid, w) - defer goidToWriter.Delete(gid) - } defer w.running.Store(false) for task := range w.todo { diff --git a/internal/tracing.go b/internal/tracing.go new file mode 100644 index 000000000..4e062aed3 --- /dev/null +++ b/internal/tracing.go @@ -0,0 +1,64 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "runtime/trace" + + "github.com/opentracing/opentracing-go" +) + +type Trace struct { + span opentracing.Span + region *trace.Region + task *trace.Task +} + +func StartTask(inCtx context.Context, name string) (Trace, context.Context) { + ctx, task := trace.NewTask(inCtx, name) + span, ctx := opentracing.StartSpanFromContext(ctx, name) + return Trace{ + span: span, + task: task, + }, ctx +} + +func StartRegion(inCtx context.Context, name string) (Trace, context.Context) { + region := trace.StartRegion(inCtx, name) + span, ctx := opentracing.StartSpanFromContext(inCtx, name) + return Trace{ + span: span, + region: region, + }, ctx +} + +func (t Trace) EndRegion() { + t.span.Finish() + if t.region != nil { + t.region.End() + } +} + +func (t Trace) EndTask() { + t.span.Finish() + if t.task != nil { + t.task.End() + } +} + +func (t Trace) SetTag(key string, value any) { + t.span.SetTag(key, value) +} diff --git a/internal/tracing_test.go b/internal/tracing_test.go new file mode 100644 index 000000000..582f50c3a --- /dev/null +++ b/internal/tracing_test.go @@ -0,0 +1,25 @@ +package internal + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTracing(t *testing.T) { + inCtx := context.Background() + + task, ctx := StartTask(inCtx, "testing") + assert.NotNil(t, ctx) + assert.NotNil(t, task) + assert.NotEqual(t, inCtx, ctx) + task.SetTag("key", "value") + + region, ctx2 := StartRegion(ctx, "testing") + assert.NotNil(t, ctx) + assert.NotNil(t, region) + assert.NotEqual(t, ctx, ctx2) + defer task.EndTask() + defer region.EndRegion() +} diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index 13b00af50..235a3123f 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -28,6 +28,8 @@ import ( syncTypes "github.com/matrix-org/dendrite/syncapi/types" userAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" @@ -57,7 +59,7 @@ type TxnReq struct { gomatrixserverlib.Transaction rsAPI api.FederationRoomserverAPI userAPI userAPI.FederationUserAPI - ourServerName gomatrixserverlib.ServerName + ourServerName spec.ServerName keys gomatrixserverlib.JSONVerifier roomsMu *MutexByRoom producer *producers.SyncAPIProducer @@ -67,16 +69,16 @@ type TxnReq struct { func NewTxnReq( rsAPI api.FederationRoomserverAPI, userAPI userAPI.FederationUserAPI, - ourServerName gomatrixserverlib.ServerName, + ourServerName spec.ServerName, keys gomatrixserverlib.JSONVerifier, roomsMu *MutexByRoom, producer *producers.SyncAPIProducer, inboundPresenceEnabled bool, pdus []json.RawMessage, edus []gomatrixserverlib.EDU, - origin gomatrixserverlib.ServerName, + origin spec.ServerName, transactionID gomatrixserverlib.TransactionID, - destination gomatrixserverlib.ServerName, + destination spec.ServerName, ) TxnReq { t := TxnReq{ rsAPI: rsAPI, @@ -97,7 +99,7 @@ func NewTxnReq( return t } -func (t *TxnReq) ProcessTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { +func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *util.JSONResponse) { var wg sync.WaitGroup wg.Add(1) go func() { @@ -107,7 +109,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*gomatrixserverlib.Res } }() - results := make(map[string]gomatrixserverlib.PDUResult) + results := make(map[string]fclient.PDUResult) roomVersions := make(map[string]gomatrixserverlib.RoomVersion) getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { if v, ok := roomVersions[roomID]; ok { @@ -153,18 +155,18 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*gomatrixserverlib.Res util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) continue } - if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + if event.Type() == spec.MRoomCreate && event.StateKeyEquals("") { continue } if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { - results[event.EventID()] = gomatrixserverlib.PDUResult{ + results[event.EventID()] = fclient.PDUResult{ Error: "Forbidden by server ACLs", } continue } if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) - results[event.EventID()] = gomatrixserverlib.PDUResult{ + results[event.EventID()] = fclient.PDUResult{ Error: err.Error(), } continue @@ -187,18 +189,18 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*gomatrixserverlib.Res true, ); err != nil { util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) - results[event.EventID()] = gomatrixserverlib.PDUResult{ + results[event.EventID()] = fclient.PDUResult{ Error: err.Error(), } continue } - results[event.EventID()] = gomatrixserverlib.PDUResult{} + results[event.EventID()] = fclient.PDUResult{} PDUCountTotal.WithLabelValues("success").Inc() } wg.Wait() - return &gomatrixserverlib.RespSend{PDUs: results}, nil + return &fclient.RespSend{PDUs: results}, nil } // nolint:gocyclo @@ -206,7 +208,7 @@ func (t *TxnReq) processEDUs(ctx context.Context) { for _, e := range t.EDUs { EDUCountTotal.Inc() switch e.Type { - case gomatrixserverlib.MTyping: + case spec.MTyping: // https://matrix.org/docs/spec/server_server/latest#typing-notifications var typingPayload struct { RoomID string `json:"room_id"` @@ -227,7 +229,7 @@ func (t *TxnReq) processEDUs(ctx context.Context) { if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") } - case gomatrixserverlib.MDirectToDevice: + case spec.MDirectToDevice: // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema var directPayload gomatrixserverlib.ToDeviceMessage if err := json.Unmarshal(e.Content, &directPayload); err != nil { @@ -254,12 +256,12 @@ func (t *TxnReq) processEDUs(ctx context.Context) { } } } - case gomatrixserverlib.MDeviceListUpdate: + case spec.MDeviceListUpdate: if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { sentry.CaptureException(err) util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") } - case gomatrixserverlib.MReceipt: + case spec.MReceipt: // https://matrix.org/docs/spec/server_server/r0.1.4#receipts payload := map[string]types.FederationReceiptMRead{} @@ -295,7 +297,7 @@ func (t *TxnReq) processEDUs(ctx context.Context) { sentry.CaptureException(err) logrus.WithError(err).Errorf("Failed to process signing key update") } - case gomatrixserverlib.MPresence: + case spec.MPresence: if t.inboundPresenceEnabled { if err := t.processPresence(ctx, e); err != nil { logrus.WithError(err).Errorf("Failed to process presence update") @@ -335,7 +337,7 @@ func (t *TxnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) e // processReceiptEvent sends receipt events to JetStream func (t *TxnReq) processReceiptEvent(ctx context.Context, userID, roomID, receiptType string, - timestamp gomatrixserverlib.Timestamp, + timestamp spec.Timestamp, eventIDs []string, ) error { if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index 8597ae24b..b539e1a57 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -22,6 +22,13 @@ import ( "testing" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + "gotest.tools/v3/poll" + "github.com/matrix-org/dendrite/federationapi/producers" rsAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -30,16 +37,11 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" keyAPI "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/stretchr/testify/assert" - "go.uber.org/atomic" - "gotest.tools/v3/poll" ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") - testDestination = gomatrixserverlib.ServerName("white.orchard") + testOrigin = spec.ServerName("kaer.morhen") + testDestination = spec.ServerName("white.orchard") ) var ( @@ -234,7 +236,7 @@ func TestProcessTransactionRequestEDUTyping(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.typing"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") edus := []gomatrixserverlib.EDU{badEDU, edu} ctx := process.NewProcessContext() @@ -300,7 +302,7 @@ func TestProcessTransactionRequestEDUToDevice(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.direct_to_device"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") edus := []gomatrixserverlib.EDU{badEDU, edu} ctx := process.NewProcessContext() @@ -377,7 +379,7 @@ func TestProcessTransactionRequestEDUDeviceListUpdate(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.device_list_update"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") edus := []gomatrixserverlib.EDU{badEDU, edu} ctx := process.NewProcessContext() @@ -427,7 +429,7 @@ func TestProcessTransactionRequestEDUReceipt(t *testing.T) { roomID: map[string]interface{}{ "m.read": map[string]interface{}{ "@john:kaer.morhen": map[string]interface{}{ - "data": map[string]interface{}{ + "data": map[string]int64{ "ts": 1533358089009, }, "event_ids": []string{ @@ -440,13 +442,13 @@ func TestProcessTransactionRequestEDUReceipt(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.receipt"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") badUser := gomatrixserverlib.EDU{Type: "m.receipt"} if badUser.Content, err = json.Marshal(map[string]interface{}{ roomID: map[string]interface{}{ "m.read": map[string]interface{}{ "johnkaer.morhen": map[string]interface{}{ - "data": map[string]interface{}{ + "data": map[string]int64{ "ts": 1533358089009, }, "event_ids": []string{ @@ -463,7 +465,7 @@ func TestProcessTransactionRequestEDUReceipt(t *testing.T) { roomID: map[string]interface{}{ "m.read": map[string]interface{}{ "@john:bad.domain": map[string]interface{}{ - "data": map[string]interface{}{ + "data": map[string]int64{ "ts": 1533358089009, }, "event_ids": []string{ @@ -518,7 +520,7 @@ func TestProcessTransactionRequestEDUSigningKeyUpdate(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.signing_key_update"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") edus := []gomatrixserverlib.EDU{badEDU, edu} ctx := process.NewProcessContext() @@ -575,7 +577,7 @@ func TestProcessTransactionRequestEDUPresence(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.presence"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") edus := []gomatrixserverlib.EDU{badEDU, edu} ctx := process.NewProcessContext() diff --git a/internal/validate.go b/internal/validate.go index 0461b897e..f794d7a5b 100644 --- a/internal/validate.go +++ b/internal/validate.go @@ -21,7 +21,7 @@ import ( "regexp" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -70,7 +70,7 @@ func PasswordResponse(err error) *util.JSONResponse { } // ValidateUsername returns an error if the username is invalid -func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error { +func ValidateUsername(localpart string, domain spec.ServerName) error { // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { return ErrUsernameTooLong @@ -100,7 +100,7 @@ func UsernameResponse(err error) *util.JSONResponse { } // ValidateApplicationServiceUsername returns an error if the username is invalid for an application service -func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error { +func ValidateApplicationServiceUsername(localpart string, domain spec.ServerName) error { if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { return ErrUsernameTooLong } else if !validUsernameRegex.MatchString(localpart) { diff --git a/internal/validate_test.go b/internal/validate_test.go index d0ad04707..2244b7a96 100644 --- a/internal/validate_test.go +++ b/internal/validate_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -54,7 +54,7 @@ func Test_validateUsername(t *testing.T) { tests := []struct { name string localpart string - domain gomatrixserverlib.ServerName + domain spec.ServerName wantErr error wantJSON *util.JSONResponse }{ diff --git a/internal/version.go b/internal/version.go index e2655cb63..907547589 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 11 - VersionPatch = 1 + VersionMinor = 12 + VersionPatch = 0 VersionTag = "" // example: "rc1" ) diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index 4792c996d..284071a53 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -15,29 +15,30 @@ package mediaapi import ( + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/routing" "github.com/matrix-org/dendrite/mediaapi/storage" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/sirupsen/logrus" ) // AddPublicRoutes sets up and registers HTTP handlers for the MediaAPI component. func AddPublicRoutes( - base *base.BaseDendrite, + mediaRouter *mux.Router, + cm sqlutil.Connections, + cfg *config.Dendrite, userAPI userapi.MediaUserAPI, - client *gomatrixserverlib.Client, + client *fclient.Client, ) { - cfg := &base.Cfg.MediaAPI - rateCfg := &base.Cfg.ClientAPI.RateLimiting - - mediaDB, err := storage.NewMediaAPIDatasource(base, &cfg.Database) + mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to media db") } routing.Setup( - base.PublicMediaAPIMux, cfg, rateCfg, mediaDB, userAPI, client, + mediaRouter, cfg, mediaDB, userAPI, client, ) } diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index c9299b1fc..bba24327b 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -36,7 +36,8 @@ import ( "github.com/matrix-org/dendrite/mediaapi/thumbnailer" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -71,11 +72,11 @@ type downloadRequest struct { func Download( w http.ResponseWriter, req *http.Request, - origin gomatrixserverlib.ServerName, + origin spec.ServerName, mediaID types.MediaID, cfg *config.MediaAPI, db storage.Database, - client *gomatrixserverlib.Client, + client *fclient.Client, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, isThumbnailRequest bool, @@ -205,7 +206,7 @@ func (r *downloadRequest) doDownload( w http.ResponseWriter, cfg *config.MediaAPI, db storage.Database, - client *gomatrixserverlib.Client, + client *fclient.Client, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, ) (*types.MediaMetadata, error) { @@ -513,7 +514,7 @@ func (r *downloadRequest) generateThumbnail( // Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines. func (r *downloadRequest) getRemoteFile( ctx context.Context, - client *gomatrixserverlib.Client, + client *fclient.Client, cfg *config.MediaAPI, db storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, @@ -615,7 +616,7 @@ func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.Act // fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( ctx context.Context, - client *gomatrixserverlib.Client, + client *fclient.Client, absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes, db storage.Database, @@ -713,7 +714,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, func (r *downloadRequest) fetchRemoteFile( ctx context.Context, - client *gomatrixserverlib.Client, + client *fclient.Client, absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes, ) (types.Path, bool, error) { diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 50af2f884..e0af4a911 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -25,7 +25,8 @@ import ( "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -35,7 +36,7 @@ import ( // configResponse is the response to GET /_matrix/media/r0/config // https://matrix.org/docs/spec/client_server/latest#get-matrix-media-r0-config type configResponse struct { - UploadSize *config.FileSizeBytes `json:"m.upload.size"` + UploadSize *config.FileSizeBytes `json:"m.upload.size,omitempty"` } // Setup registers the media API HTTP handlers @@ -45,13 +46,12 @@ type configResponse struct { // nolint: gocyclo func Setup( publicAPIMux *mux.Router, - cfg *config.MediaAPI, - rateLimit *config.RateLimiting, + cfg *config.Dendrite, db storage.Database, userAPI userapi.MediaUserAPI, - client *gomatrixserverlib.Client, + client *fclient.Client, ) { - rateLimits := httputil.NewRateLimits(rateLimit) + rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting) v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter() @@ -65,7 +65,7 @@ func Setup( if r := rateLimits.Limit(req, dev); r != nil { return *r } - return Upload(req, cfg, dev, db, activeThumbnailGeneration) + return Upload(req, &cfg.MediaAPI, dev, db, activeThumbnailGeneration) }, ) @@ -73,8 +73,8 @@ func Setup( if r := rateLimits.Limit(req, device); r != nil { return *r } - respondSize := &cfg.MaxFileSizeBytes - if cfg.MaxFileSizeBytes == 0 { + respondSize := &cfg.MediaAPI.MaxFileSizeBytes + if cfg.MediaAPI.MaxFileSizeBytes == 0 { respondSize = nil } return util.JSONResponse{ @@ -90,12 +90,12 @@ func Setup( MXCToResult: map[string]*types.RemoteRequestResult{}, } - downloadHandler := makeDownloadAPI("download", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration) + downloadHandler := makeDownloadAPI("download", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration) v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thumbnail/{serverName}/{mediaId}", - makeDownloadAPI("thumbnail", cfg, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration), + makeDownloadAPI("thumbnail", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration), ).Methods(http.MethodGet, http.MethodOptions) } @@ -104,7 +104,7 @@ func makeDownloadAPI( cfg *config.MediaAPI, rateLimits *httputil.RateLimits, db storage.Database, - client *gomatrixserverlib.Client, + client *fclient.Client, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, ) http.HandlerFunc { @@ -140,7 +140,7 @@ func makeDownloadAPI( } vars, _ := httputil.URLDecodeMapValues(mux.Vars(req)) - serverName := gomatrixserverlib.ServerName(vars["serverName"]) + serverName := spec.ServerName(vars["serverName"]) // For the purposes of loop avoidance, we will return a 404 if allow_remote is set to // false in the query string and the target server name isn't our own. diff --git a/mediaapi/routing/upload_test.go b/mediaapi/routing/upload_test.go index 420d0eba9..d088950ca 100644 --- a/mediaapi/routing/upload_test.go +++ b/mediaapi/routing/upload_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" @@ -49,8 +50,8 @@ func Test_uploadRequest_doUpload(t *testing.T) { // create testdata folder and remove when done _ = os.Mkdir(testdataPath, os.ModePerm) defer fileutils.RemoveDir(types.Path(testdataPath), nil) - - db, err := storage.NewMediaAPIDatasource(nil, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.NewMediaAPIDatasource(cm, &config.DatabaseOptions{ ConnectionString: "file::memory:?cache=shared", MaxOpenConnections: 100, MaxIdleConnections: 2, diff --git a/mediaapi/storage/interface.go b/mediaapi/storage/interface.go index d083be1eb..cf3e7df57 100644 --- a/mediaapi/storage/interface.go +++ b/mediaapi/storage/interface.go @@ -18,7 +18,7 @@ import ( "context" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Database interface { @@ -28,12 +28,12 @@ type Database interface { type MediaRepository interface { StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error - GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) - GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) + GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName) (*types.MediaMetadata, error) + GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin spec.ServerName) (*types.MediaMetadata, error) } type Thumbnails interface { StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error - GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) - GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) + GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) + GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName) ([]*types.ThumbnailMetadata, error) } diff --git a/mediaapi/storage/postgres/media_repository_table.go b/mediaapi/storage/postgres/media_repository_table.go index 41cee4878..0583dd017 100644 --- a/mediaapi/storage/postgres/media_repository_table.go +++ b/mediaapi/storage/postgres/media_repository_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const mediaSchema = ` @@ -88,7 +88,7 @@ func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) func (s *mediaStatements) InsertMedia( ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata, ) error { - mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + mediaMetadata.CreationTimestamp = spec.AsTimestamp(time.Now()) _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext( ctx, mediaMetadata.MediaID, @@ -104,7 +104,7 @@ func (s *mediaStatements) InsertMedia( } func (s *mediaStatements) SelectMedia( - ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin spec.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ MediaID: mediaID, @@ -124,7 +124,7 @@ func (s *mediaStatements) SelectMedia( } func (s *mediaStatements) SelectMediaByHash( - ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin spec.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ Base64Hash: mediaHash, diff --git a/mediaapi/storage/postgres/mediaapi.go b/mediaapi/storage/postgres/mediaapi.go index 30ec64f84..5b6687743 100644 --- a/mediaapi/storage/postgres/mediaapi.go +++ b/mediaapi/storage/postgres/mediaapi.go @@ -20,13 +20,12 @@ import ( _ "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/shared" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewDatabase opens a postgres database. -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) +func NewDatabase(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } diff --git a/mediaapi/storage/postgres/thumbnail_table.go b/mediaapi/storage/postgres/thumbnail_table.go index 7e07b476e..854485528 100644 --- a/mediaapi/storage/postgres/thumbnail_table.go +++ b/mediaapi/storage/postgres/thumbnail_table.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const thumbnailSchema = ` @@ -91,7 +91,7 @@ func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) { func (s *thumbnailStatements) InsertThumbnail( ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata, ) error { - thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + thumbnailMetadata.MediaMetadata.CreationTimestamp = spec.AsTimestamp(time.Now()) _, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext( ctx, thumbnailMetadata.MediaMetadata.MediaID, @@ -110,7 +110,7 @@ func (s *thumbnailStatements) SelectThumbnail( ctx context.Context, txn *sql.Tx, mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, + mediaOrigin spec.ServerName, width, height int, resizeMethod string, ) (*types.ThumbnailMetadata, error) { @@ -141,7 +141,7 @@ func (s *thumbnailStatements) SelectThumbnail( } func (s *thumbnailStatements) SelectThumbnails( - ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin spec.ServerName, ) ([]*types.ThumbnailMetadata, error) { rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext( ctx, mediaID, mediaOrigin, diff --git a/mediaapi/storage/shared/mediaapi.go b/mediaapi/storage/shared/mediaapi.go index c8d9ad6ab..867405fb3 100644 --- a/mediaapi/storage/shared/mediaapi.go +++ b/mediaapi/storage/shared/mediaapi.go @@ -21,7 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Database struct { @@ -42,7 +42,7 @@ func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.M // GetMediaMetadata returns metadata about media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this media. -func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) { +func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName) (*types.MediaMetadata, error) { mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin) if err != nil && err == sql.ErrNoRows { return nil, nil @@ -53,7 +53,7 @@ func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, m // GetMediaMetadataByHash returns metadata about media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this media. -func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) { +func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin spec.ServerName) (*types.MediaMetadata, error) { mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin) if err != nil && err == sql.ErrNoRows { return nil, nil @@ -72,7 +72,7 @@ func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.T // GetThumbnail returns metadata about a specific thumbnail. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this thumbnail. -func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) { +func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) { metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod) if err != nil { if err == sql.ErrNoRows { @@ -86,7 +86,7 @@ func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, media // GetThumbnails returns metadata about all thumbnails for a specific media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there are no thumbnails associated with this media. -func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) { +func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName) ([]*types.ThumbnailMetadata, error) { metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin) if err != nil { if err == sql.ErrNoRows { diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go index 78431967f..625688cd7 100644 --- a/mediaapi/storage/sqlite3/media_repository_table.go +++ b/mediaapi/storage/sqlite3/media_repository_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const mediaSchema = ` @@ -91,7 +91,7 @@ func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) { func (s *mediaStatements) InsertMedia( ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata, ) error { - mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + mediaMetadata.CreationTimestamp = spec.AsTimestamp(time.Now()) _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext( ctx, mediaMetadata.MediaID, @@ -107,7 +107,7 @@ func (s *mediaStatements) InsertMedia( } func (s *mediaStatements) SelectMedia( - ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin spec.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ MediaID: mediaID, @@ -127,7 +127,7 @@ func (s *mediaStatements) SelectMedia( } func (s *mediaStatements) SelectMediaByHash( - ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin spec.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ Base64Hash: mediaHash, diff --git a/mediaapi/storage/sqlite3/mediaapi.go b/mediaapi/storage/sqlite3/mediaapi.go index c0ab10e9f..4d484f326 100644 --- a/mediaapi/storage/sqlite3/mediaapi.go +++ b/mediaapi/storage/sqlite3/mediaapi.go @@ -19,13 +19,12 @@ import ( // Import the postgres database driver. "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/shared" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewDatabase opens a SQLIte database. -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) +func NewDatabase(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } diff --git a/mediaapi/storage/sqlite3/thumbnail_table.go b/mediaapi/storage/sqlite3/thumbnail_table.go index 5ff2fece0..259d55b73 100644 --- a/mediaapi/storage/sqlite3/thumbnail_table.go +++ b/mediaapi/storage/sqlite3/thumbnail_table.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const thumbnailSchema = ` @@ -79,7 +79,7 @@ func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) { } func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error { - thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + thumbnailMetadata.MediaMetadata.CreationTimestamp = spec.AsTimestamp(time.Now()) _, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext( ctx, thumbnailMetadata.MediaMetadata.MediaID, @@ -98,7 +98,7 @@ func (s *thumbnailStatements) SelectThumbnail( ctx context.Context, txn *sql.Tx, mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, + mediaOrigin spec.ServerName, width, height int, resizeMethod string, ) (*types.ThumbnailMetadata, error) { @@ -130,7 +130,7 @@ func (s *thumbnailStatements) SelectThumbnail( func (s *thumbnailStatements) SelectThumbnails( ctx context.Context, txn *sql.Tx, mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, + mediaOrigin spec.ServerName, ) ([]*types.ThumbnailMetadata, error) { rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext( ctx, mediaID, mediaOrigin, diff --git a/mediaapi/storage/storage.go b/mediaapi/storage/storage.go index f673ae7e6..8e67af9f9 100644 --- a/mediaapi/storage/storage.go +++ b/mediaapi/storage/storage.go @@ -20,19 +20,19 @@ package storage import ( "fmt" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/postgres" "github.com/matrix-org/dendrite/mediaapi/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewMediaAPIDatasource opens a database connection. -func NewMediaAPIDatasource(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { +func NewMediaAPIDatasource(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) + return sqlite3.NewDatabase(conMan, dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties) + return postgres.NewDatabase(conMan, dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/mediaapi/storage/storage_test.go b/mediaapi/storage/storage_test.go index 81f0a5d24..8cd29a54d 100644 --- a/mediaapi/storage/storage_test.go +++ b/mediaapi/storage/storage_test.go @@ -5,6 +5,7 @@ import ( "reflect" "testing" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/setup/config" @@ -13,7 +14,8 @@ import ( func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewMediaAPIDatasource(nil, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.NewMediaAPIDatasource(cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { diff --git a/mediaapi/storage/storage_wasm.go b/mediaapi/storage/storage_wasm.go index 41e4a28c0..47ee3792c 100644 --- a/mediaapi/storage/storage_wasm.go +++ b/mediaapi/storage/storage_wasm.go @@ -17,16 +17,16 @@ package storage import ( "fmt" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // Open opens a postgres database. -func NewMediaAPIDatasource(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { +func NewMediaAPIDatasource(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) + return sqlite3.NewDatabase(conMan, dbProperties) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/mediaapi/storage/tables/interface.go b/mediaapi/storage/tables/interface.go index bf63bc6ab..2ff8039b4 100644 --- a/mediaapi/storage/tables/interface.go +++ b/mediaapi/storage/tables/interface.go @@ -19,28 +19,28 @@ import ( "database/sql" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Thumbnails interface { InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error SelectThumbnail( ctx context.Context, txn *sql.Tx, - mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, + mediaID types.MediaID, mediaOrigin spec.ServerName, width, height int, resizeMethod string, ) (*types.ThumbnailMetadata, error) SelectThumbnails( ctx context.Context, txn *sql.Tx, mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, + mediaOrigin spec.ServerName, ) ([]*types.ThumbnailMetadata, error) } type MediaRepository interface { InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error - SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) + SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin spec.ServerName) (*types.MediaMetadata, error) SelectMediaByHash( ctx context.Context, txn *sql.Tx, - mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, + mediaHash types.Base64Hash, mediaOrigin spec.ServerName, ) (*types.MediaMetadata, error) } diff --git a/mediaapi/types/types.go b/mediaapi/types/types.go index ab28b3410..e1c29e0f6 100644 --- a/mediaapi/types/types.go +++ b/mediaapi/types/types.go @@ -18,7 +18,7 @@ import ( "sync" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // FileSizeBytes is a file size in bytes @@ -48,10 +48,10 @@ type MatrixUserID string // MediaMetadata is metadata associated with a media file type MediaMetadata struct { MediaID MediaID - Origin gomatrixserverlib.ServerName + Origin spec.ServerName ContentType ContentType FileSizeBytes FileSizeBytes - CreationTimestamp gomatrixserverlib.Timestamp + CreationTimestamp spec.Timestamp UploadName Filename Base64Hash Base64Hash UserID MatrixUserID diff --git a/relayapi/api/api.go b/relayapi/api/api.go index 9b4b62e58..83ff2890b 100644 --- a/relayapi/api/api.go +++ b/relayapi/api/api.go @@ -18,6 +18,8 @@ import ( "context" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) // RelayInternalAPI is used to query information from the relay server. @@ -27,8 +29,8 @@ type RelayInternalAPI interface { // Retrieve from external relay server all transactions stored for us and process them. PerformRelayServerSync( ctx context.Context, - userID gomatrixserverlib.UserID, - relayServer gomatrixserverlib.ServerName, + userID spec.UserID, + relayServer spec.ServerName, ) error // Tells the relayapi whether or not it should act as a relay server for external servers. @@ -44,14 +46,14 @@ type RelayServerAPI interface { PerformStoreTransaction( ctx context.Context, transaction gomatrixserverlib.Transaction, - userID gomatrixserverlib.UserID, + userID spec.UserID, ) error // Obtain the oldest stored transaction for the specified userID. QueryTransactions( ctx context.Context, - userID gomatrixserverlib.UserID, - previousEntry gomatrixserverlib.RelayEntry, + userID spec.UserID, + previousEntry fclient.RelayEntry, ) (QueryRelayTransactionsResponse, error) } diff --git a/relayapi/internal/api.go b/relayapi/internal/api.go index 3a5762fb8..55e86aef9 100644 --- a/relayapi/internal/api.go +++ b/relayapi/internal/api.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/relayapi/storage" rsAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type RelayInternalAPI struct { @@ -31,7 +32,7 @@ type RelayInternalAPI struct { keyRing *gomatrixserverlib.KeyRing producer *producers.SyncAPIProducer presenceEnabledInbound bool - serverName gomatrixserverlib.ServerName + serverName spec.ServerName relayingEnabledMutex sync.Mutex relayingEnabled bool } @@ -43,7 +44,7 @@ func NewRelayInternalAPI( keyRing *gomatrixserverlib.KeyRing, producer *producers.SyncAPIProducer, presenceEnabledInbound bool, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, relayingEnabled bool, ) *RelayInternalAPI { return &RelayInternalAPI{ diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go index 62c7d446e..457652112 100644 --- a/relayapi/internal/perform.go +++ b/relayapi/internal/perform.go @@ -21,6 +21,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -41,12 +43,12 @@ func (r *RelayInternalAPI) RelayingEnabled() bool { // PerformRelayServerSync implements api.RelayInternalAPI func (r *RelayInternalAPI) PerformRelayServerSync( ctx context.Context, - userID gomatrixserverlib.UserID, - relayServer gomatrixserverlib.ServerName, + userID spec.UserID, + relayServer spec.ServerName, ) error { // Providing a default RelayEntry (EntryID = 0) is done to ask the relay if there are any // transactions available for this node. - prevEntry := gomatrixserverlib.RelayEntry{} + prevEntry := fclient.RelayEntry{} asyncResponse, err := r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) if err != nil { logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) @@ -54,12 +56,12 @@ func (r *RelayInternalAPI) PerformRelayServerSync( } r.processTransaction(&asyncResponse.Transaction) - prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + prevEntry = fclient.RelayEntry{EntryID: asyncResponse.EntryID} for asyncResponse.EntriesQueued { // There are still more entries available for this node from the relay. logrus.Infof("Retrieving next entry from relay, previous: %v", prevEntry) asyncResponse, err = r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) - prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + prevEntry = fclient.RelayEntry{EntryID: asyncResponse.EntryID} if err != nil { logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) return err @@ -74,7 +76,7 @@ func (r *RelayInternalAPI) PerformRelayServerSync( func (r *RelayInternalAPI) PerformStoreTransaction( ctx context.Context, transaction gomatrixserverlib.Transaction, - userID gomatrixserverlib.UserID, + userID spec.UserID, ) error { logrus.Warnf("Storing transaction for %v", userID) receipt, err := r.db.StoreTransaction(ctx, transaction) @@ -84,7 +86,7 @@ func (r *RelayInternalAPI) PerformStoreTransaction( } err = r.db.AssociateTransactionWithDestinations( ctx, - map[gomatrixserverlib.UserID]struct{}{ + map[spec.UserID]struct{}{ userID: {}, }, transaction.TransactionID, @@ -96,8 +98,8 @@ func (r *RelayInternalAPI) PerformStoreTransaction( // QueryTransactions implements api.RelayInternalAPI func (r *RelayInternalAPI) QueryTransactions( ctx context.Context, - userID gomatrixserverlib.UserID, - previousEntry gomatrixserverlib.RelayEntry, + userID spec.UserID, + previousEntry fclient.RelayEntry, ) (api.QueryRelayTransactionsResponse, error) { logrus.Infof("QueryTransactions for %s", userID.Raw()) if previousEntry.EntryID > 0 { diff --git a/relayapi/internal/perform_test.go b/relayapi/internal/perform_test.go index 278706a3e..111fb46b0 100644 --- a/relayapi/internal/perform_test.go +++ b/relayapi/internal/perform_test.go @@ -24,6 +24,8 @@ import ( "github.com/matrix-org/dendrite/relayapi/storage/shared" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -36,16 +38,16 @@ type testFedClient struct { func (f *testFedClient) P2PGetTransactionFromRelay( ctx context.Context, - u gomatrixserverlib.UserID, - prev gomatrixserverlib.RelayEntry, - relayServer gomatrixserverlib.ServerName, -) (res gomatrixserverlib.RespGetRelayTransaction, err error) { + u spec.UserID, + prev fclient.RelayEntry, + relayServer spec.ServerName, +) (res fclient.RespGetRelayTransaction, err error) { f.queryCount++ if f.shouldFail { return res, fmt.Errorf("Error") } - res = gomatrixserverlib.RespGetRelayTransaction{ + res = fclient.RespGetRelayTransaction{ Transaction: gomatrixserverlib.Transaction{}, EntryID: 0, } @@ -67,7 +69,7 @@ func TestPerformRelayServerSync(t *testing.T) { RelayQueueJSON: testDB, } - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.Nil(t, err, "Invalid userID") fedClient := &testFedClient{} @@ -75,7 +77,7 @@ func TestPerformRelayServerSync(t *testing.T) { &db, fedClient, nil, nil, nil, false, "", true, ) - err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, spec.ServerName("relay")) assert.NoError(t, err) } @@ -87,7 +89,7 @@ func TestPerformRelayServerSyncFedError(t *testing.T) { RelayQueueJSON: testDB, } - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.Nil(t, err, "Invalid userID") fedClient := &testFedClient{shouldFail: true} @@ -95,7 +97,7 @@ func TestPerformRelayServerSyncFedError(t *testing.T) { &db, fedClient, nil, nil, nil, false, "", true, ) - err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, spec.ServerName("relay")) assert.Error(t, err) } @@ -107,7 +109,7 @@ func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) { RelayQueueJSON: testDB, } - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.Nil(t, err, "Invalid userID") fedClient := &testFedClient{queueDepth: 2} @@ -115,7 +117,7 @@ func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) { &db, fedClient, nil, nil, nil, false, "", true, ) - err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, spec.ServerName("relay")) assert.NoError(t, err) assert.Equal(t, uint(3), fedClient.queryCount) } diff --git a/relayapi/relayapi.go b/relayapi/relayapi.go index 200a1814a..bae6b0cf5 100644 --- a/relayapi/relayapi.go +++ b/relayapi/relayapi.go @@ -16,24 +16,27 @@ package relayapi import ( "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/relayapi/internal" "github.com/matrix-org/dendrite/relayapi/routing" "github.com/matrix-org/dendrite/relayapi/storage" rsAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/sirupsen/logrus" ) // AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. func AddPublicRoutes( - base *base.BaseDendrite, + routers httputil.Routers, + dendriteCfg *config.Dendrite, keyRing gomatrixserverlib.JSONVerifier, relayAPI api.RelayInternalAPI, ) { - fedCfg := &base.Cfg.FederationAPI - relay, ok := relayAPI.(*internal.RelayInternalAPI) if !ok { panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " + @@ -41,24 +44,24 @@ func AddPublicRoutes( } routing.Setup( - base.PublicFederationAPIMux, - fedCfg, + routers.Federation, + &dendriteCfg.FederationAPI, relay, keyRing, ) } func NewRelayInternalAPI( - base *base.BaseDendrite, - fedClient *gomatrixserverlib.FederationClient, + dendriteCfg *config.Dendrite, + cm sqlutil.Connections, + fedClient *fclient.FederationClient, rsAPI rsAPI.RoomserverInternalAPI, keyRing *gomatrixserverlib.KeyRing, producer *producers.SyncAPIProducer, relayingEnabled bool, + caches caching.FederationCache, ) api.RelayInternalAPI { - cfg := &base.Cfg.RelayAPI - - relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) + relayDB, err := storage.NewDatabase(cm, &dendriteCfg.RelayAPI.Database, caches, dendriteCfg.Global.IsLocalServerName) if err != nil { logrus.WithError(err).Panic("failed to connect to relay db") } @@ -69,8 +72,8 @@ func NewRelayInternalAPI( rsAPI, keyRing, producer, - base.Cfg.Global.Presence.EnableInbound, - base.Cfg.Global.ServerName, + dendriteCfg.Global.Presence.EnableInbound, + dendriteCfg.Global.ServerName, relayingEnabled, ) } diff --git a/relayapi/relayapi_test.go b/relayapi/relayapi_test.go index f1b3262aa..9d67cdf95 100644 --- a/relayapi/relayapi_test.go +++ b/relayapi/relayapi_test.go @@ -21,60 +21,67 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) func TestCreateNewRelayInternalAPI(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) defer close() - - relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, caches) assert.NotNil(t, relayAPI) }) } func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) if dbType == test.DBTypeSQLite { - base.Cfg.RelayAPI.Database.ConnectionString = "file:" + cfg.RelayAPI.Database.ConnectionString = "file:" } else { - base.Cfg.RelayAPI.Database.ConnectionString = "test" + cfg.RelayAPI.Database.ConnectionString = "test" } defer close() - + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) assert.Panics(t, func() { - relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true) + relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, nil) }) }) } func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, _, close := testrig.CreateConfig(t, dbType) defer close() - + routers := httputil.NewRouters() assert.Panics(t, func() { - relayapi.AddPublicRoutes(base, nil, nil) + relayapi.AddPublicRoutes(routers, cfg, nil, nil) }) }) } -func createGetRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, userID string) *http.Request { +func createGetRelayTxnHTTPRequest(serverName spec.ServerName, userID string) *http.Request { _, sk, _ := ed25519.GenerateKey(nil) keyID := signing.KeyID pk := sk.Public().(ed25519.PublicKey) - origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - req := gomatrixserverlib.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID) - content := gomatrixserverlib.RelayEntry{EntryID: 0} + origin := spec.ServerName(hex.EncodeToString(pk)) + req := fclient.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID) + content := fclient.RelayEntry{EntryID: 0} req.SetContent(content) req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) httpreq, _ := req.HTTPRequest() @@ -88,12 +95,12 @@ type sendRelayContent struct { EDUs []gomatrixserverlib.EDU `json:"edus"` } -func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnID string, userID string) *http.Request { +func createSendRelayTxnHTTPRequest(serverName spec.ServerName, txnID string, userID string) *http.Request { _, sk, _ := ed25519.GenerateKey(nil) keyID := signing.KeyID pk := sk.Public().(ed25519.PublicKey) - origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - req := gomatrixserverlib.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID) + origin := spec.ServerName(hex.EncodeToString(pk)) + req := fclient.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID) content := sendRelayContent{} req.SetContent(content) req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) @@ -105,15 +112,19 @@ func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnI func TestCreateRelayPublicRoutes(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) defer close() + routers := httputil.NewRouters() + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) - relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + + relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, caches) assert.NotNil(t, relayAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - relayapi.AddPublicRoutes(base, keyRing, relayAPI) + relayapi.AddPublicRoutes(routers, cfg, keyRing, relayAPI) testCases := []struct { name string @@ -122,29 +133,29 @@ func TestCreateRelayPublicRoutes(t *testing.T) { }{ { name: "relay_txn invalid user id", - req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"), + req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "user:local"), wantCode: 400, }, { name: "relay_txn valid user id", - req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), + req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "@user:local"), wantCode: 200, }, { name: "send_relay invalid user id", - req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"), + req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "user:local"), wantCode: 400, }, { name: "send_relay valid user id", - req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), + req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "@user:local"), wantCode: 200, }, } for _, tc := range testCases { w := httptest.NewRecorder() - base.PublicFederationAPIMux.ServeHTTP(w, tc.req) + routers.Federation.ServeHTTP(w, tc.req) if w.Code != tc.wantCode { t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) } @@ -154,15 +165,19 @@ func TestCreateRelayPublicRoutes(t *testing.T) { func TestDisableRelayPublicRoutes(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) defer close() + routers := httputil.NewRouters() + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) - relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, false) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + + relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, false, caches) assert.NotNil(t, relayAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - relayapi.AddPublicRoutes(base, keyRing, relayAPI) + relayapi.AddPublicRoutes(routers, cfg, keyRing, relayAPI) testCases := []struct { name string @@ -171,19 +186,19 @@ func TestDisableRelayPublicRoutes(t *testing.T) { }{ { name: "relay_txn valid user id", - req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), + req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "@user:local"), wantCode: 404, }, { name: "send_relay valid user id", - req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), + req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "@user:local"), wantCode: 404, }, } for _, tc := range testCases { w := httptest.NewRecorder() - base.PublicFederationAPIMux.ServeHTTP(w, tc.req) + routers.Federation.ServeHTTP(w, tc.req) if w.Code != tc.wantCode { t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) } diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go index 63b42ec7d..2fc61373a 100644 --- a/relayapi/routing/relaytxn.go +++ b/relayapi/routing/relaytxn.go @@ -20,7 +20,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/relayapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -29,13 +30,13 @@ import ( // This endpoint can be extracted into a separate relay server service. func GetTransactionFromRelay( httpReq *http.Request, - fedReq *gomatrixserverlib.FederationRequest, + fedReq *fclient.FederationRequest, relayAPI api.RelayInternalAPI, - userID gomatrixserverlib.UserID, + userID spec.UserID, ) util.JSONResponse { logrus.Infof("Processing relay_txn for %s", userID.Raw()) - var previousEntry gomatrixserverlib.RelayEntry + var previousEntry fclient.RelayEntry if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -59,7 +60,7 @@ func GetTransactionFromRelay( return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespGetRelayTransaction{ + JSON: fclient.RespGetRelayTransaction{ Transaction: response.Transaction, EntryID: response.EntryID, EntriesQueued: response.EntriesQueued, diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go index 4c099a642..e6d2d9e5d 100644 --- a/relayapi/routing/relaytxn_test.go +++ b/relayapi/routing/relaytxn_test.go @@ -25,16 +25,18 @@ import ( "github.com/matrix-org/dendrite/relayapi/storage/shared" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) func createQuery( - userID gomatrixserverlib.UserID, - prevEntry gomatrixserverlib.RelayEntry, -) gomatrixserverlib.FederationRequest { + userID spec.UserID, + prevEntry fclient.RelayEntry, +) fclient.FederationRequest { var federationPathPrefixV1 = "/_matrix/federation/v1" path := federationPathPrefixV1 + "/relay_txn/" + userID.Raw() - request := gomatrixserverlib.NewFederationRequest("GET", userID.Domain(), "relay", path) + request := fclient.NewFederationRequest("GET", userID.Domain(), "relay", path) request.SetContent(prevEntry) return request @@ -48,7 +50,7 @@ func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { RelayQueueJSON: testDB, } httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.NoError(t, err, "Invalid userID") transaction := createTransaction() @@ -60,11 +62,11 @@ func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { &db, nil, nil, nil, nil, false, "", true, ) - request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + request := createQuery(*userID, fclient.RelayEntry{}) response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse := response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + jsonResponse := response.JSON.(fclient.RespGetRelayTransaction) assert.Equal(t, false, jsonResponse.EntriesQueued) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) @@ -81,7 +83,7 @@ func TestGetInvalidPrevEntryFails(t *testing.T) { RelayQueueJSON: testDB, } httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.NoError(t, err, "Invalid userID") transaction := createTransaction() @@ -93,7 +95,7 @@ func TestGetInvalidPrevEntryFails(t *testing.T) { &db, nil, nil, nil, nil, false, "", true, ) - request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) + request := createQuery(*userID, fclient.RelayEntry{EntryID: -1}) response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusInternalServerError, response.Code) } @@ -106,7 +108,7 @@ func TestGetReturnsSavedTransaction(t *testing.T) { RelayQueueJSON: testDB, } httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.NoError(t, err, "Invalid userID") transaction := createTransaction() @@ -115,7 +117,7 @@ func TestGetReturnsSavedTransaction(t *testing.T) { err = db.AssociateTransactionWithDestinations( context.Background(), - map[gomatrixserverlib.UserID]struct{}{ + map[spec.UserID]struct{}{ *userID: {}, }, transaction.TransactionID, @@ -126,20 +128,20 @@ func TestGetReturnsSavedTransaction(t *testing.T) { &db, nil, nil, nil, nil, false, "", true, ) - request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + request := createQuery(*userID, fclient.RelayEntry{}) response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse := response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + jsonResponse := response.JSON.(fclient.RespGetRelayTransaction) assert.True(t, jsonResponse.EntriesQueued) assert.Equal(t, transaction, jsonResponse.Transaction) // And once more to clear the queue - request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + request = createQuery(*userID, fclient.RelayEntry{EntryID: jsonResponse.EntryID}) response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse = response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + jsonResponse = response.JSON.(fclient.RespGetRelayTransaction) assert.False(t, jsonResponse.EntriesQueued) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) @@ -156,7 +158,7 @@ func TestGetReturnsMultipleSavedTransactions(t *testing.T) { RelayQueueJSON: testDB, } httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.NoError(t, err, "Invalid userID") transaction := createTransaction() @@ -165,7 +167,7 @@ func TestGetReturnsMultipleSavedTransactions(t *testing.T) { err = db.AssociateTransactionWithDestinations( context.Background(), - map[gomatrixserverlib.UserID]struct{}{ + map[spec.UserID]struct{}{ *userID: {}, }, transaction.TransactionID, @@ -178,7 +180,7 @@ func TestGetReturnsMultipleSavedTransactions(t *testing.T) { err = db.AssociateTransactionWithDestinations( context.Background(), - map[gomatrixserverlib.UserID]struct{}{ + map[spec.UserID]struct{}{ *userID: {}, }, transaction2.TransactionID, @@ -189,28 +191,28 @@ func TestGetReturnsMultipleSavedTransactions(t *testing.T) { &db, nil, nil, nil, nil, false, "", true, ) - request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + request := createQuery(*userID, fclient.RelayEntry{}) response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse := response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + jsonResponse := response.JSON.(fclient.RespGetRelayTransaction) assert.True(t, jsonResponse.EntriesQueued) assert.Equal(t, transaction, jsonResponse.Transaction) - request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + request = createQuery(*userID, fclient.RelayEntry{EntryID: jsonResponse.EntryID}) response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse = response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + jsonResponse = response.JSON.(fclient.RespGetRelayTransaction) assert.True(t, jsonResponse.EntriesQueued) assert.Equal(t, transaction2, jsonResponse.Transaction) // And once more to clear the queue - request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + request = createQuery(*userID, fclient.RelayEntry{EntryID: jsonResponse.EntryID}) response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse = response.JSON.(gomatrixserverlib.RespGetRelayTransaction) + jsonResponse = response.JSON.(fclient.RespGetRelayTransaction) assert.False(t, jsonResponse.EntriesQueued) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go index 8ee0743ef..6140d0326 100644 --- a/relayapi/routing/routing.go +++ b/relayapi/routing/routing.go @@ -26,6 +26,8 @@ import ( relayInternal "github.com/matrix-org/dendrite/relayapi/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -44,7 +46,7 @@ func Setup( v1fedmux.Handle("/send_relay/{txnID}/{userID}", MakeRelayAPI( "send_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { logrus.Infof("Handling send_relay from: %s", request.Origin()) if !relayAPI.RelayingEnabled() { logrus.Warnf("Dropping send_relay from: %s", request.Origin()) @@ -53,7 +55,7 @@ func Setup( } } - userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + userID, err := spec.NewUserID(vars["userID"], false) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -69,7 +71,7 @@ func Setup( v1fedmux.Handle("/relay_txn/{userID}", MakeRelayAPI( "get_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { logrus.Infof("Handling relay_txn from: %s", request.Origin()) if !relayAPI.RelayingEnabled() { logrus.Warnf("Dropping relay_txn from: %s", request.Origin()) @@ -78,7 +80,7 @@ func Setup( } } - userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + userID, err := spec.NewUserID(vars["userID"], false) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -92,13 +94,13 @@ func Setup( // MakeRelayAPI makes an http.Handler that checks matrix relay authentication. func MakeRelayAPI( - metricsName string, serverName gomatrixserverlib.ServerName, - isLocalServerName func(gomatrixserverlib.ServerName) bool, + metricsName string, serverName spec.ServerName, + isLocalServerName func(spec.ServerName) bool, keyRing gomatrixserverlib.JSONVerifier, - f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, + f func(*http.Request, *fclient.FederationRequest, map[string]string) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { - fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + fedReq, errResp := fclient.VerifyHTTPRequest( req, time.Now(), serverName, isLocalServerName, keyRing, ) if fedReq == nil { diff --git a/relayapi/routing/sendrelay.go b/relayapi/routing/sendrelay.go index ce744cb49..e4794dc4d 100644 --- a/relayapi/routing/sendrelay.go +++ b/relayapi/routing/sendrelay.go @@ -21,6 +21,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -29,14 +31,14 @@ import ( // This endpoint can be extracted into a separate relay server service. func SendTransactionToRelay( httpReq *http.Request, - fedReq *gomatrixserverlib.FederationRequest, + fedReq *fclient.FederationRequest, relayAPI api.RelayInternalAPI, txnID gomatrixserverlib.TransactionID, - userID gomatrixserverlib.UserID, + userID spec.UserID, ) util.JSONResponse { logrus.Infof("Processing send_relay for %s", userID.Raw()) - var txnEvents gomatrixserverlib.RelayEvents + var txnEvents fclient.RelayEvents if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { logrus.Info("The request body could not be decoded into valid JSON." + err.Error()) return util.JSONResponse{ diff --git a/relayapi/routing/sendrelay_test.go b/relayapi/routing/sendrelay_test.go index 66594c47c..05dfbe6d1 100644 --- a/relayapi/routing/sendrelay_test.go +++ b/relayapi/routing/sendrelay_test.go @@ -26,11 +26,13 @@ import ( "github.com/matrix-org/dendrite/relayapi/storage/shared" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testOrigin = spec.ServerName("kaer.morhen") ) func createTransaction() gomatrixserverlib.Transaction { @@ -43,15 +45,15 @@ func createTransaction() gomatrixserverlib.Transaction { } func createFederationRequest( - userID gomatrixserverlib.UserID, + userID spec.UserID, txnID gomatrixserverlib.TransactionID, - origin gomatrixserverlib.ServerName, - destination gomatrixserverlib.ServerName, + origin spec.ServerName, + destination spec.ServerName, content interface{}, -) gomatrixserverlib.FederationRequest { +) fclient.FederationRequest { var federationPathPrefixV1 = "/_matrix/federation/v1" path := federationPathPrefixV1 + "/send_relay/" + string(txnID) + "/" + userID.Raw() - request := gomatrixserverlib.NewFederationRequest("PUT", origin, destination, path) + request := fclient.NewFederationRequest("PUT", origin, destination, path) request.SetContent(content) return request @@ -65,7 +67,7 @@ func TestForwardEmptyReturnsOk(t *testing.T) { RelayQueueJSON: testDB, } httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.NoError(t, err, "Invalid userID") txn := createTransaction() @@ -88,7 +90,7 @@ func TestForwardBadJSONReturnsError(t *testing.T) { RelayQueueJSON: testDB, } httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.NoError(t, err, "Invalid userID") type BadData struct { @@ -117,7 +119,7 @@ func TestForwardTooManyPDUsReturnsError(t *testing.T) { RelayQueueJSON: testDB, } httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.NoError(t, err, "Invalid userID") type BadData struct { @@ -151,7 +153,7 @@ func TestForwardTooManyEDUsReturnsError(t *testing.T) { RelayQueueJSON: testDB, } httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.NoError(t, err, "Invalid userID") type BadData struct { @@ -161,7 +163,7 @@ func TestForwardTooManyEDUsReturnsError(t *testing.T) { Field: []gomatrixserverlib.EDU{}, } for i := 0; i < 101; i++ { - content.Field = append(content.Field, gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}) + content.Field = append(content.Field, gomatrixserverlib.EDU{Type: spec.MTyping}) } assert.Greater(t, len(content.Field), 100) @@ -185,7 +187,7 @@ func TestUniqueTransactionStoredInDatabase(t *testing.T) { RelayQueueJSON: testDB, } httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + userID, err := spec.NewUserID("@local:domain", false) assert.NoError(t, err, "Invalid userID") txn := createTransaction() diff --git a/relayapi/storage/interface.go b/relayapi/storage/interface.go index f5f9a06e5..bc1722cd9 100644 --- a/relayapi/storage/interface.go +++ b/relayapi/storage/interface.go @@ -19,6 +19,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Database interface { @@ -29,19 +30,19 @@ type Database interface { // Adds a new transaction_id: server_name mapping with associated json table nid to the queue // entry table for each provided destination. - AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error + AssociateTransactionWithDestinations(ctx context.Context, destinations map[spec.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error // Removes every server_name: receipt pair provided from the queue entries table. // Will then remove every entry for each receipt provided from the queue json table. // If any of the entries don't exist in either table, nothing will happen for that entry and // an error will not be generated. - CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error + CleanTransactions(ctx context.Context, userID spec.UserID, receipts []*receipt.Receipt) error // Gets the oldest transaction for the provided server_name. // If no transactions exist, returns nil and no error. - GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) + GetTransaction(ctx context.Context, userID spec.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) // Gets the number of transactions being stored for the provided server_name. // If the server doesn't exist in the database then 0 is returned with no error. - GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) + GetTransactionCount(ctx context.Context, userID spec.UserID) (int64, error) } diff --git a/relayapi/storage/postgres/relay_queue_table.go b/relayapi/storage/postgres/relay_queue_table.go index e97cf8cc0..5873af760 100644 --- a/relayapi/storage/postgres/relay_queue_table.go +++ b/relayapi/storage/postgres/relay_queue_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const relayQueueSchema = ` @@ -90,7 +91,7 @@ func (s *relayQueueStatements) InsertQueueEntry( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) @@ -106,7 +107,7 @@ func (s *relayQueueStatements) InsertQueueEntry( func (s *relayQueueStatements) DeleteQueueEntries( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt) @@ -117,7 +118,7 @@ func (s *relayQueueStatements) DeleteQueueEntries( func (s *relayQueueStatements) SelectQueueEntries( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) @@ -141,7 +142,7 @@ func (s *relayQueueStatements) SelectQueueEntries( func (s *relayQueueStatements) SelectQueueEntryCount( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (int64, error) { var count int64 stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) diff --git a/relayapi/storage/postgres/storage.go b/relayapi/storage/postgres/storage.go index 1042beba7..35c08c283 100644 --- a/relayapi/storage/postgres/storage.go +++ b/relayapi/storage/postgres/storage.go @@ -20,9 +20,8 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/storage/shared" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Database stores information needed by the relayapi @@ -34,14 +33,14 @@ type Database struct { // NewDatabase opens a new database func NewDatabase( - base *base.BaseDendrite, + conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, - isLocalServerName func(gomatrixserverlib.ServerName) bool, + isLocalServerName func(spec.ServerName) bool, ) (*Database, error) { var d Database var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { return nil, err } queue, err := NewPostgresRelayQueueTable(d.db) diff --git a/relayapi/storage/shared/storage.go b/relayapi/storage/shared/storage.go index 0993707bf..fc1f12e74 100644 --- a/relayapi/storage/shared/storage.go +++ b/relayapi/storage/shared/storage.go @@ -25,11 +25,12 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Database struct { DB *sql.DB - IsLocalServerName func(gomatrixserverlib.ServerName) bool + IsLocalServerName func(spec.ServerName) bool Cache caching.FederationCache Writer sqlutil.Writer RelayQueue tables.RelayQueue @@ -61,7 +62,7 @@ func (d *Database) StoreTransaction( func (d *Database) AssociateTransactionWithDestinations( ctx context.Context, - destinations map[gomatrixserverlib.UserID]struct{}, + destinations map[spec.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt, ) error { @@ -88,7 +89,7 @@ func (d *Database) AssociateTransactionWithDestinations( func (d *Database) CleanTransactions( ctx context.Context, - userID gomatrixserverlib.UserID, + userID spec.UserID, receipts []*receipt.Receipt, ) error { nids := make([]int64, len(receipts)) @@ -123,7 +124,7 @@ func (d *Database) CleanTransactions( func (d *Database) GetTransaction( ctx context.Context, - userID gomatrixserverlib.UserID, + userID spec.UserID, ) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) { entriesRequested := 1 nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), entriesRequested) @@ -160,7 +161,7 @@ func (d *Database) GetTransaction( func (d *Database) GetTransactionCount( ctx context.Context, - userID gomatrixserverlib.UserID, + userID spec.UserID, ) (int64, error) { count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain()) if err != nil { diff --git a/relayapi/storage/sqlite3/relay_queue_table.go b/relayapi/storage/sqlite3/relay_queue_table.go index 49c6b4de5..30482ae97 100644 --- a/relayapi/storage/sqlite3/relay_queue_table.go +++ b/relayapi/storage/sqlite3/relay_queue_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const relayQueueSchema = ` @@ -90,7 +91,7 @@ func (s *relayQueueStatements) InsertQueueEntry( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) @@ -106,7 +107,7 @@ func (s *relayQueueStatements) InsertQueueEntry( func (s *relayQueueStatements) DeleteQueueEntries( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) @@ -129,7 +130,7 @@ func (s *relayQueueStatements) DeleteQueueEntries( func (s *relayQueueStatements) SelectQueueEntries( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) @@ -153,7 +154,7 @@ func (s *relayQueueStatements) SelectQueueEntries( func (s *relayQueueStatements) SelectQueueEntryCount( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (int64, error) { var count int64 stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) diff --git a/relayapi/storage/sqlite3/storage.go b/relayapi/storage/sqlite3/storage.go index 3ed4ab046..7b46396fd 100644 --- a/relayapi/storage/sqlite3/storage.go +++ b/relayapi/storage/sqlite3/storage.go @@ -20,9 +20,8 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/storage/shared" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Database stores information needed by the federation sender @@ -34,14 +33,14 @@ type Database struct { // NewDatabase opens a new database func NewDatabase( - base *base.BaseDendrite, + conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, - isLocalServerName func(gomatrixserverlib.ServerName) bool, + isLocalServerName func(spec.ServerName) bool, ) (*Database, error) { var d Database var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { return nil, err } queue, err := NewSQLiteRelayQueueTable(d.db) diff --git a/relayapi/storage/storage.go b/relayapi/storage/storage.go index 16ecbcfb7..6fce1efe3 100644 --- a/relayapi/storage/storage.go +++ b/relayapi/storage/storage.go @@ -21,25 +21,25 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/storage/postgres" "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // NewDatabase opens a new database func NewDatabase( - base *base.BaseDendrite, + conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, - isLocalServerName func(gomatrixserverlib.ServerName) bool, + isLocalServerName func(spec.ServerName) bool, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) + return sqlite3.NewDatabase(conMan, dbProperties, cache, isLocalServerName) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) + return postgres.NewDatabase(conMan, dbProperties, cache, isLocalServerName) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/relayapi/storage/storage_wasm.go b/relayapi/storage/storage_wasm.go new file mode 100644 index 000000000..7e7323347 --- /dev/null +++ b/relayapi/storage/storage_wasm.go @@ -0,0 +1,42 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "fmt" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// NewDatabase opens a new database +func NewDatabase( + conMan sqlutil.Connections, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(spec.ServerName) bool, +) (Database, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewDatabase(conMan, dbProperties, cache, isLocalServerName) + case dbProperties.ConnectionString.IsPostgres(): + return nil, fmt.Errorf("can't use Postgres implementation") + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/relayapi/storage/tables/interface.go b/relayapi/storage/tables/interface.go index 9056a5678..27f43a8d5 100644 --- a/relayapi/storage/tables/interface.go +++ b/relayapi/storage/tables/interface.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // RelayQueue table contains a mapping of server name to transaction id and the corresponding nid. @@ -28,21 +29,21 @@ type RelayQueue interface { // Adds a new transaction_id: server_name mapping with associated json table nid to the table. // Will ensure only one transaction id is present for each server_name: nid mapping. // Adding duplicates will silently do nothing. - InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName spec.ServerName, nid int64) error // Removes multiple entries from the table corresponding the the list of nids provided. // If any of the provided nids don't match a row in the table, that deletion is considered // successful. - DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, jsonNIDs []int64) error // Get a list of nids associated with the provided server name. // Returns up to `limit` nids. The entries are returned oldest first. // Will return an empty list if no matches were found. - SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, limit int) ([]int64, error) // Get the number of entries in the table associated with the provided server name. // If there are no matching rows, a count of 0 is returned with err set to nil. - SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) + SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (int64, error) } // RelayQueueJSON table contains a map of nid to the raw transaction json. diff --git a/relayapi/storage/tables/relay_queue_json_table_test.go b/relayapi/storage/tables/relay_queue_json_table_test.go index efa3363e5..97af7eaeb 100644 --- a/relayapi/storage/tables/relay_queue_json_table_test.go +++ b/relayapi/storage/tables/relay_queue_json_table_test.go @@ -27,11 +27,12 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testOrigin = spec.ServerName("kaer.morhen") ) func mustCreateTransaction() gomatrixserverlib.Transaction { diff --git a/relayapi/storage/tables/relay_queue_table_test.go b/relayapi/storage/tables/relay_queue_table_test.go index 99f9922c0..d196eaf57 100644 --- a/relayapi/storage/tables/relay_queue_table_test.go +++ b/relayapi/storage/tables/relay_queue_table_test.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -73,7 +74,7 @@ func TestShoudInsertQueueTransaction(t *testing.T) { defer close() transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - serverName := gomatrixserverlib.ServerName("domain") + serverName := spec.ServerName("domain") nid := int64(1) err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) if err != nil { @@ -89,7 +90,7 @@ func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { defer close() transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - serverName := gomatrixserverlib.ServerName("domain") + serverName := spec.ServerName("domain") nid := int64(1) err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) @@ -114,7 +115,7 @@ func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) { defer close() transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - serverName := gomatrixserverlib.ServerName("domain") + serverName := spec.ServerName("domain") nid := int64(2) err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) if err != nil { @@ -122,7 +123,7 @@ func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) { } transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - serverName = gomatrixserverlib.ServerName("domain") + serverName = spec.ServerName("domain") oldestNID := int64(1) err = db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, oldestNID) if err != nil { @@ -155,7 +156,7 @@ func TestShouldDeleteQueueTransaction(t *testing.T) { defer close() transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - serverName := gomatrixserverlib.ServerName("domain") + serverName := spec.ServerName("domain") nid := int64(1) err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) @@ -186,10 +187,10 @@ func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { defer close() transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - serverName := gomatrixserverlib.ServerName("domain") + serverName := spec.ServerName("domain") nid := int64(1) transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano())) - serverName2 := gomatrixserverlib.ServerName("domain2") + serverName2 := spec.ServerName("domain2") nid2 := int64(2) transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano())) diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index b18daa3de..80d45e8c3 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -24,6 +24,7 @@ import ( "sync" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -120,7 +121,7 @@ func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) { s.acls[state.RoomID()] = acls } -func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool { +func (s *ServerACLs) IsServerBannedFromRoom(serverName spec.ServerName, roomID string) bool { s.aclsMutex.RLock() // First of all check if we have an ACL for this room. If we don't then // no servers are banned from the room. @@ -133,7 +134,7 @@ func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerN // Split the host and port apart. This is because the spec calls on us to // validate the hostname only in cases where the port is also present. if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil { - serverName = gomatrixserverlib.ServerName(serverNameOnly) + serverName = spec.ServerName(serverNameOnly) } // Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding // a /0 prefix length just to trick ParseCIDR into working. If we find that diff --git a/roomserver/api/api.go b/roomserver/api/api.go index f6d003a44..4ce40e3ea 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -4,6 +4,7 @@ import ( "context" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" @@ -144,7 +145,6 @@ type ClientRoomserverAPI interface { QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error - QueryRoomVersionCapabilities(ctx context.Context, req *QueryRoomVersionCapabilitiesRequest, res *QueryRoomVersionCapabilitiesResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error @@ -198,7 +198,7 @@ type FederationRoomserverAPI interface { // Query missing events for a room from roomserver QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error // Query whether a server is allowed to see an event - QueryServerAllowedToSeeEvent(ctx context.Context, serverName gomatrixserverlib.ServerName, eventID string) (allowed bool, err error) + QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 88d523270..b52317cd5 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Kind int @@ -68,7 +69,7 @@ type InputRoomEvent struct { // The event JSON for the event to add. Event *gomatrixserverlib.HeaderedEvent `json:"event"` // Which server told us about this event. - Origin gomatrixserverlib.ServerName `json:"origin"` + Origin spec.ServerName `json:"origin"` // Whether the state is supplied as a list of event IDs or whether it // should be derived from the state at the previous events. HasState bool `json:"has_state"` @@ -94,9 +95,9 @@ type TransactionID struct { // InputRoomEventsRequest is a request to InputRoomEvents type InputRoomEventsRequest struct { - InputRoomEvents []InputRoomEvent `json:"input_room_events"` - Asynchronous bool `json:"async"` - VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` + InputRoomEvents []InputRoomEvent `json:"input_room_events"` + Asynchronous bool `json:"async"` + VirtualHost spec.ServerName `json:"virtual_host"` } // InputRoomEventsResponse is a response to InputRoomEvents diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 0c0f52c45..737887473 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -16,6 +16,7 @@ package api import ( "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // An OutputType is a type of roomserver output. @@ -250,7 +251,7 @@ type OutputNewInboundPeek struct { // a race between tracking the state returned by /peek and emitting subsequent // peeked events) LatestEventID string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName // how often we told the peeking server to renew the peek RenewalInterval int64 } diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 83cb0460a..e125a3008 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -6,6 +6,8 @@ import ( "net/http" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -76,18 +78,18 @@ const ( ) type PerformJoinRequest struct { - RoomIDOrAlias string `json:"room_id_or_alias"` - UserID string `json:"user_id"` - IsGuest bool `json:"is_guest"` - Content map[string]interface{} `json:"content"` - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` - Unsigned map[string]interface{} `json:"unsigned"` + RoomIDOrAlias string `json:"room_id_or_alias"` + UserID string `json:"user_id"` + IsGuest bool `json:"is_guest"` + Content map[string]interface{} `json:"content"` + ServerNames []spec.ServerName `json:"server_names"` + Unsigned map[string]interface{} `json:"unsigned"` } type PerformJoinResponse struct { // The room ID, populated on success. RoomID string `json:"room_id"` - JoinedVia gomatrixserverlib.ServerName + JoinedVia spec.ServerName // If non-nil, the join request failed. Contains more information why it failed. Error *PerformError } @@ -103,11 +105,11 @@ type PerformLeaveResponse struct { } type PerformInviteRequest struct { - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - Event *gomatrixserverlib.HeaderedEvent `json:"event"` - InviteRoomState []gomatrixserverlib.InviteV2StrippedState `json:"invite_room_state"` - SendAsServer string `json:"send_as_server"` - TransactionID *TransactionID `json:"transaction_id"` + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + Event *gomatrixserverlib.HeaderedEvent `json:"event"` + InviteRoomState []fclient.InviteV2StrippedState `json:"invite_room_state"` + SendAsServer string `json:"send_as_server"` + TransactionID *TransactionID `json:"transaction_id"` } type PerformInviteResponse struct { @@ -115,10 +117,10 @@ type PerformInviteResponse struct { } type PerformPeekRequest struct { - RoomIDOrAlias string `json:"room_id_or_alias"` - UserID string `json:"user_id"` - DeviceID string `json:"device_id"` - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` + RoomIDOrAlias string `json:"room_id_or_alias"` + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + ServerNames []spec.ServerName `json:"server_names"` } type PerformPeekResponse struct { @@ -148,9 +150,9 @@ type PerformBackfillRequest struct { // The maximum number of events to retrieve. Limit int `json:"limit"` // The server interested in the events. - ServerName gomatrixserverlib.ServerName `json:"server_name"` + ServerName spec.ServerName `json:"server_name"` // Which virtual host are we doing this for? - VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` + VirtualHost spec.ServerName `json:"virtual_host"` } // PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. @@ -183,11 +185,11 @@ type PerformPublishResponse struct { } type PerformInboundPeekRequest struct { - UserID string `json:"user_id"` - RoomID string `json:"room_id"` - PeekID string `json:"peek_id"` - ServerName gomatrixserverlib.ServerName `json:"server_name"` - RenewalInterval int64 `json:"renewal_interval"` + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + PeekID string `json:"peek_id"` + ServerName spec.ServerName `json:"server_name"` + RenewalInterval int64 `json:"renewal_interval"` } type PerformInboundPeekResponse struct { @@ -250,9 +252,9 @@ type PerformAdminPurgeRoomResponse struct { } type PerformAdminDownloadStateRequest struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` - ServerName gomatrixserverlib.ServerName `json:"server_name"` + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + ServerName spec.ServerName `json:"server_name"` } type PerformAdminDownloadStateResponse struct { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 24722db0b..56915b1e4 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -22,8 +22,10 @@ import ( "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/syncapi/synctypes" ) // QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState @@ -146,7 +148,7 @@ type QueryMembershipsForRoomRequest struct { // QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom type QueryMembershipsForRoomResponse struct { // The "m.room.member" events (of "join" membership) in the client format - JoinEvents []gomatrixserverlib.ClientEvent `json:"join_events"` + JoinEvents []synctypes.ClientEvent `json:"join_events"` // True if the user has been in room before and has either stayed in it or // left it. HasBeenInRoom bool `json:"has_been_in_room"` @@ -158,7 +160,7 @@ type QueryMembershipsForRoomResponse struct { type QueryServerJoinedToRoomRequest struct { // Server name of the server to find. If not specified, we will // default to checking if the local server is joined. - ServerName gomatrixserverlib.ServerName `json:"server_name"` + ServerName spec.ServerName `json:"server_name"` // ID of the room to see if we are still joined to RoomID string `json:"room_id"` } @@ -176,7 +178,7 @@ type QueryServerAllowedToSeeEventRequest struct { // The event ID to look up invites in. EventID string `json:"event_id"` // The server interested in the event - ServerName gomatrixserverlib.ServerName `json:"server_name"` + ServerName spec.ServerName `json:"server_name"` } // QueryServerAllowedToSeeEventResponse is a response to QueryServerAllowedToSeeEvent @@ -194,7 +196,7 @@ type QueryMissingEventsRequest struct { // Limit the number of events this query returns. Limit int `json:"limit"` // The server interested in the event - ServerName gomatrixserverlib.ServerName `json:"server_name"` + ServerName spec.ServerName `json:"server_name"` } // QueryMissingEventsResponse is a response to QueryMissingEvents @@ -240,15 +242,6 @@ type QueryStateAndAuthChainResponse struct { IsRejected bool `json:"is_rejected"` } -// QueryRoomVersionCapabilitiesRequest asks for the default room version -type QueryRoomVersionCapabilitiesRequest struct{} - -// QueryRoomVersionCapabilitiesResponse is a response to QueryRoomVersionCapabilitiesRequest -type QueryRoomVersionCapabilitiesResponse struct { - DefaultRoomVersion gomatrixserverlib.RoomVersion `json:"default"` - AvailableRoomVersions map[gomatrixserverlib.RoomVersion]string `json:"available"` -} - // QueryRoomVersionForRoomRequest asks for the room version for a given room. type QueryRoomVersionForRoomRequest struct { RoomID string `json:"room_id"` @@ -348,8 +341,8 @@ type QueryKnownUsersResponse struct { } type QueryServerBannedFromRoomRequest struct { - ServerName gomatrixserverlib.ServerName `json:"server_name"` - RoomID string `json:"room_id"` + ServerName spec.ServerName `json:"server_name"` + RoomID string `json:"room_id"` } type QueryServerBannedFromRoomResponse struct { diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index f220560ed..ff4445bfa 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -18,6 +18,8 @@ import ( "context" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -26,8 +28,8 @@ import ( func SendEvents( ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, events []*gomatrixserverlib.HeaderedEvent, - virtualHost, origin gomatrixserverlib.ServerName, - sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, + virtualHost, origin spec.ServerName, + sendAsServer spec.ServerName, txnID *TransactionID, async bool, ) error { ires := make([]InputRoomEvent, len(events)) @@ -48,11 +50,11 @@ func SendEvents( // marked as `true` in haveEventIDs. func SendEventWithState( ctx context.Context, rsAPI InputRoomEventsAPI, - virtualHost gomatrixserverlib.ServerName, kind Kind, - state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, - origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, + virtualHost spec.ServerName, kind Kind, + state gomatrixserverlib.StateResponse, event *gomatrixserverlib.HeaderedEvent, + origin spec.ServerName, haveEventIDs map[string]bool, async bool, ) error { - outliers := state.Events(event.RoomVersion) + outliers := gomatrixserverlib.LineariseStateResponse(event.RoomVersion, state) ires := make([]InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if haveEventIDs[outlier.EventID()] { @@ -65,7 +67,7 @@ func SendEventWithState( }) } - stateEvents := state.StateEvents.UntrustedEvents(event.RoomVersion) + stateEvents := state.GetStateEvents().UntrustedEvents(event.RoomVersion) stateEventIDs := make([]string, len(stateEvents)) for i := range stateEvents { stateEventIDs[i] = stateEvents[i].EventID() @@ -92,7 +94,7 @@ func SendEventWithState( // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI InputRoomEventsAPI, - virtualHost gomatrixserverlib.ServerName, + virtualHost spec.ServerName, ires []InputRoomEvent, async bool, ) error { request := InputRoomEventsRequest{ @@ -143,7 +145,7 @@ func GetStateEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID string, tup } // IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs. -func IsServerBannedFromRoom(ctx context.Context, rsAPI FederationRoomserverAPI, roomID string, serverName gomatrixserverlib.ServerName) bool { +func IsServerBannedFromRoom(ctx context.Context, rsAPI FederationRoomserverAPI, roomID string, serverName spec.ServerName) bool { req := &QueryServerBannedFromRoomRequest{ ServerName: serverName, RoomID: roomID, @@ -159,14 +161,14 @@ func IsServerBannedFromRoom(ctx context.Context, rsAPI FederationRoomserverAPI, // PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the // published room directory. // due to lots of switches -func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI QueryBulkStateContentAPI) ([]gomatrixserverlib.PublicRoom, error) { +func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI QueryBulkStateContentAPI) ([]fclient.PublicRoom, error) { avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} - canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} + canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCanonicalAlias, StateKey: ""} topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""} guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""} - visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} - joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} + visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomHistoryVisibility, StateKey: ""} + joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""} var stateRes QueryBulkStateContentResponse err := rsAPI.QueryBulkStateContent(ctx, &QueryBulkStateContentRequest{ @@ -174,23 +176,23 @@ func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI QueryBulkS AllowWildcards: true, StateTuples: []gomatrixserverlib.StateKeyTuple{ nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple, - {EventType: gomatrixserverlib.MRoomMember, StateKey: "*"}, + {EventType: spec.MRoomMember, StateKey: "*"}, }, }, &stateRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed") return nil, err } - chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs)) + chunk := make([]fclient.PublicRoom, len(roomIDs)) i := 0 for roomID, data := range stateRes.Rooms { - pub := gomatrixserverlib.PublicRoom{ + pub := fclient.PublicRoom{ RoomID: roomID, } joinCount := 0 var joinRule, guestAccess string for tuple, contentVal := range data { - if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" { + if tuple.EventType == spec.MRoomMember && contentVal == "join" { joinCount++ continue } @@ -214,7 +216,7 @@ func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI QueryBulkS guestAccess = contentVal } } - if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" { + if joinRule == spec.Public && guestAccess == "can_join" { pub.GuestCanJoin = true } pub.JoinedMembersCount = joinCount diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index 31a856e8e..5f72454ae 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -14,6 +14,7 @@ package auth import ( "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // TODO: This logic should live in gomatrixserverlib @@ -21,7 +22,7 @@ import ( // IsServerAllowed returns true if the server is allowed to see events in the room // at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87 func IsServerAllowed( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, serverCurrentlyInRoom bool, authEvents []*gomatrixserverlib.Event, ) bool { @@ -32,7 +33,7 @@ func IsServerAllowed( return true } // 2. If the user's membership was join, allow. - joinedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, gomatrixserverlib.Join) + joinedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Join) if joinedUserExists { return true } @@ -41,7 +42,7 @@ func IsServerAllowed( return true } // 4. If the user's membership was invite, and the history_visibility was set to invited, allow. - invitedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, gomatrixserverlib.Invite) + invitedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Invite) if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited { return true } @@ -55,7 +56,7 @@ func HistoryVisibilityForRoom(authEvents []*gomatrixserverlib.Event) gomatrixser // By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared. visibility := gomatrixserverlib.HistoryVisibilityShared for _, ev := range authEvents { - if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility { + if ev.Type() != spec.MRoomHistoryVisibility { continue } if vis, err := ev.HistoryVisibility(); err == nil { @@ -65,9 +66,9 @@ func HistoryVisibilityForRoom(authEvents []*gomatrixserverlib.Event) gomatrixser return visibility } -func IsAnyUserOnServerWithMembership(serverName gomatrixserverlib.ServerName, authEvents []*gomatrixserverlib.Event, wantMembership string) bool { +func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []*gomatrixserverlib.Event, wantMembership string) bool { for _, ev := range authEvents { - if ev.Type() != gomatrixserverlib.MRoomMember { + if ev.Type() != spec.MRoomMember { continue } membership, err := ev.Membership() diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index fc61b7f4a..46598746c 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -117,7 +118,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { - _, virtualHost, err := r.Cfg.Matrix.SplitLocalID('@', request.UserID) + _, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.UserID) if err != nil { return err } @@ -142,7 +143,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( var plEvent *gomatrixserverlib.HeaderedEvent var pls *gomatrixserverlib.PowerLevelContent - plEvent, err = r.DB.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomPowerLevels, "") + plEvent, err = r.DB.GetStateEvent(ctx, roomID, spec.MRoomPowerLevels, "") if err != nil { return fmt.Errorf("r.DB.GetStateEvent: %w", err) } @@ -152,13 +153,13 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return fmt.Errorf("plEvent.PowerLevels: %w", err) } - if pls.UserLevel(request.UserID) < pls.EventLevel(gomatrixserverlib.MRoomCanonicalAlias, true) { + if pls.UserLevel(request.UserID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { response.Removed = false return nil } } - ev, err := r.DB.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomCanonicalAlias, "") + ev, err := r.DB.GetStateEvent(ctx, roomID, spec.MRoomCanonicalAlias, "") if err != nil && err != sql.ErrNoRows { return err } else if ev != nil { @@ -175,12 +176,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( sender = ev.Sender() } - _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', sender) + _, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender) if err != nil { return err } - identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + identity, err := r.Cfg.Global.SigningIdentityFor(senderDomain) if err != nil { return err } @@ -206,7 +207,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, stateRes) + newEvent, err := eventutil.BuildEvent(ctx, builder, &r.Cfg.Global, identity, time.Now(), &eventsNeeded, stateRes) if err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index c43b9d049..975e77966 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -5,6 +5,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -18,7 +19,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -41,11 +41,10 @@ type RoomserverInternalAPI struct { *perform.Upgrader *perform.Admin ProcessContext *process.ProcessContext - Base *base.BaseDendrite DB storage.Database - Cfg *config.RoomServer + Cfg *config.Dendrite Cache caching.RoomServerCaches - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName KeyRing gomatrixserverlib.JSONVerifier ServerACLs *acls.ServerACLs fsAPI fsAPI.RoomserverFederationAPI @@ -55,44 +54,45 @@ type RoomserverInternalAPI struct { Durable string InputRoomEventTopic string // JetStream topic for new input room events OutputProducer *producers.RoomEventProducer - PerspectiveServerNames []gomatrixserverlib.ServerName + PerspectiveServerNames []spec.ServerName + enableMetrics bool } func NewRoomserverAPI( - base *base.BaseDendrite, roomserverDB storage.Database, - js nats.JetStreamContext, nc *nats.Conn, + processContext *process.ProcessContext, dendriteCfg *config.Dendrite, roomserverDB storage.Database, + js nats.JetStreamContext, nc *nats.Conn, caches caching.RoomServerCaches, enableMetrics bool, ) *RoomserverInternalAPI { - var perspectiveServerNames []gomatrixserverlib.ServerName - for _, kp := range base.Cfg.FederationAPI.KeyPerspectives { + var perspectiveServerNames []spec.ServerName + for _, kp := range dendriteCfg.FederationAPI.KeyPerspectives { perspectiveServerNames = append(perspectiveServerNames, kp.ServerName) } serverACLs := acls.NewServerACLs(roomserverDB) producer := &producers.RoomEventProducer{ - Topic: string(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)), + Topic: string(dendriteCfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)), JetStream: js, ACLs: serverACLs, } a := &RoomserverInternalAPI{ - ProcessContext: base.ProcessContext, + ProcessContext: processContext, DB: roomserverDB, - Base: base, - Cfg: &base.Cfg.RoomServer, - Cache: base.Caches, - ServerName: base.Cfg.Global.ServerName, + Cfg: dendriteCfg, + Cache: caches, + ServerName: dendriteCfg.Global.ServerName, PerspectiveServerNames: perspectiveServerNames, - InputRoomEventTopic: base.Cfg.Global.JetStream.Prefixed(jetstream.InputRoomEvent), + InputRoomEventTopic: dendriteCfg.Global.JetStream.Prefixed(jetstream.InputRoomEvent), OutputProducer: producer, JetStream: js, NATSClient: nc, - Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"), + Durable: dendriteCfg.Global.JetStream.Durable("RoomserverInputConsumer"), ServerACLs: serverACLs, Queryer: &query.Queryer{ DB: roomserverDB, - Cache: base.Caches, - IsLocalServerName: base.Cfg.Global.IsLocalServerName, + Cache: caches, + IsLocalServerName: dendriteCfg.Global.IsLocalServerName, ServerACLs: serverACLs, }, + enableMetrics: enableMetrics, // perform-er structs get initialised when we have a federation sender to use } return a @@ -105,15 +105,14 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.fsAPI = fsAPI r.KeyRing = keyRing - identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName) + identity, err := r.Cfg.Global.SigningIdentityFor(r.ServerName) if err != nil { logrus.Panic(err) } r.Inputer = &input.Inputer{ - Cfg: &r.Base.Cfg.RoomServer, - Base: r.Base, - ProcessContext: r.Base.ProcessContext, + Cfg: &r.Cfg.RoomServer, + ProcessContext: r.ProcessContext, DB: r.DB, InputRoomEventTopic: r.InputRoomEventTopic, OutputProducer: r.OutputProducer, @@ -129,12 +128,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio } r.Inviter = &perform.Inviter{ DB: r.DB, - Cfg: r.Cfg, + Cfg: &r.Cfg.RoomServer, FSAPI: r.fsAPI, Inputer: r.Inputer, } r.Joiner = &perform.Joiner{ - Cfg: r.Cfg, + Cfg: &r.Cfg.RoomServer, DB: r.DB, FSAPI: r.fsAPI, RSAPI: r, @@ -143,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio } r.Peeker = &perform.Peeker{ ServerName: r.ServerName, - Cfg: r.Cfg, + Cfg: &r.Cfg.RoomServer, DB: r.DB, FSAPI: r.fsAPI, Inputer: r.Inputer, @@ -154,21 +153,22 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio } r.Unpeeker = &perform.Unpeeker{ ServerName: r.ServerName, - Cfg: r.Cfg, + Cfg: &r.Cfg.RoomServer, FSAPI: r.fsAPI, Inputer: r.Inputer, } r.Leaver = &perform.Leaver{ - Cfg: r.Cfg, + Cfg: &r.Cfg.RoomServer, DB: r.DB, FSAPI: r.fsAPI, + RSAPI: r, Inputer: r.Inputer, } r.Publisher = &perform.Publisher{ DB: r.DB, } r.Backfiller = &perform.Backfiller{ - IsLocalServerName: r.Cfg.Matrix.IsLocalServerName, + IsLocalServerName: r.Cfg.Global.IsLocalServerName, DB: r.DB, FSAPI: r.fsAPI, KeyRing: r.KeyRing, @@ -181,12 +181,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio DB: r.DB, } r.Upgrader = &perform.Upgrader{ - Cfg: r.Cfg, + Cfg: &r.Cfg.RoomServer, URSAPI: r, } r.Admin = &perform.Admin{ DB: r.DB, - Cfg: r.Cfg, + Cfg: &r.Cfg.RoomServer, Inputer: r.Inputer, Queryer: r.Queryer, Leaver: r.Leaver, diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 9defe7945..4ef6e2480 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -20,6 +20,7 @@ import ( "sort" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" @@ -59,7 +60,7 @@ func CheckForSoftFail( // state because we haven't received a m.room.create event yet. // If we're now processing the first create event then never // soft-fail it. - if len(authStateEntries) == 0 && event.Type() == gomatrixserverlib.MRoomCreate { + if len(authStateEntries) == 0 && event.Type() == spec.MRoomCreate { return false, nil } diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 9a70bcc9c..959f5cf7c 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/roomserver/api" @@ -54,7 +55,7 @@ func UpdateToInviteMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, - Membership: gomatrixserverlib.Join, + Membership: spec.Join, RetiredByEventID: add.EventID(), TargetUserID: *add.StateKey(), }, @@ -67,7 +68,7 @@ func UpdateToInviteMembership( // memberships. If the servername is not supplied then the local server will be // checked instead using a faster code path. // TODO: This should probably be replaced by an API call. -func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { +func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName spec.ServerName, roomID string) (bool, error) { info, err := db.RoomInfo(ctx, roomID) if err != nil { return false, err @@ -93,7 +94,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam for i := range events { gmslEvents[i] = events[i].Event } - return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil + return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, spec.Join), nil } func IsInvitePending( @@ -194,7 +195,7 @@ func GetMembershipsAtState( return nil, err } - if membership == gomatrixserverlib.Join { + if membership == spec.Join { events = append(events, event) } } @@ -252,7 +253,7 @@ func LoadStateEvents( } func CheckServerAllowedToSeeEvent( - ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, + ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, isServerInRoom bool, ) (bool, error) { stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) switch err { @@ -280,7 +281,7 @@ func CheckServerAllowedToSeeEvent( } func slowGetHistoryVisibilityState( - ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, + ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, ) ([]*gomatrixserverlib.Event, error) { roomState := state.NewStateResolution(db, info) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) @@ -332,7 +333,7 @@ func slowGetHistoryVisibilityState( // TODO: Remove this when we have tests to assert correctness of this function func ScanEventTree( ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) ([]types.EventNID, map[string]struct{}, error) { var resultNIDs []types.EventNID var err error diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index c056e704c..dd74b844a 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -3,26 +3,28 @@ package helpers import ( "context" "testing" + "time" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/setup/base" - - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/test" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { - base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches) +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.Open(context.Background(), cm, &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}, caches) if err != nil { t.Fatalf("failed to create Database: %v", err) } - return base, db, close + return db, close } func TestIsInvitePendingWithoutNID(t *testing.T) { @@ -32,7 +34,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) _ = bob test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - _, db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateDatabase(t, dbType) defer close() // store all events diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 2ec19f010..3e7ff7f7c 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -24,6 +24,8 @@ import ( "time" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" @@ -39,7 +41,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -74,14 +75,13 @@ import ( // or C. type Inputer struct { Cfg *config.RoomServer - Base *base.BaseDendrite ProcessContext *process.ProcessContext DB storage.RoomDatabase NATSClient *nats.Conn JetStream nats.JetStreamContext Durable nats.SubOpt - ServerName gomatrixserverlib.ServerName - SigningIdentity *gomatrixserverlib.SigningIdentity + ServerName spec.ServerName + SigningIdentity *fclient.SigningIdentity FSAPI fedapi.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs @@ -89,8 +89,9 @@ type Inputer struct { OutputProducer *producers.RoomEventProducer workers sync.Map // room ID -> *worker - Queryer *query.Queryer - UserAPI userapi.RoomserverUserAPI + Queryer *query.Queryer + UserAPI userapi.RoomserverUserAPI + enableMetrics bool } // If a room consumer is inactive for a while then we will allow NATS @@ -177,7 +178,7 @@ func (r *Inputer) startWorkerForRoom(roomID string) { // will look to see if we have a worker for that room which has its // own consumer. If we don't, we'll start one. func (r *Inputer) Start() error { - if r.Base.EnableMetrics { + if r.enableMetrics { prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration) } _, err := r.JetStream.Subscribe( @@ -284,7 +285,7 @@ func (w *worker) _next() { var errString string if err = w.r.processRoomEvent( w.r.ProcessContext.Context(), - gomatrixserverlib.ServerName(msg.Header.Get("virtual_host")), + spec.ServerName(msg.Header.Get("virtual_host")), &inputRoomEvent, ); err != nil { switch err.(type) { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index d709541be..334e68b9a 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -27,8 +27,9 @@ import ( "github.com/tidwall/gjson" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" @@ -73,7 +74,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, - virtualHost gomatrixserverlib.ServerName, + virtualHost spec.ServerName, input *api.InputRoomEvent, ) error { select { @@ -85,10 +86,10 @@ func (r *Inputer) processRoomEvent( default: } - span, ctx := opentracing.StartSpanFromContext(ctx, "processRoomEvent") - span.SetTag("room_id", input.Event.RoomID()) - span.SetTag("event_id", input.Event.EventID()) - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "processRoomEvent") + trace.SetTag("room_id", input.Event.RoomID()) + trace.SetTag("event_id", input.Event.EventID()) + defer trace.EndRegion() // Measure how long it takes to process this event. started := time.Now() @@ -123,7 +124,7 @@ func (r *Inputer) processRoomEvent( if rerr != nil { return fmt.Errorf("r.DB.RoomInfo: %w", rerr) } - isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") + isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("") if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } @@ -180,7 +181,7 @@ func (r *Inputer) processRoomEvent( // Sort all of the servers into a map so that we can randomise // their order. Then make sure that the input origin and the // event origin are first on the list. - servers := map[gomatrixserverlib.ServerName]struct{}{} + servers := map[spec.ServerName]struct{}{} for _, server := range serverRes.ServerNames { servers[server] = struct{}{} } @@ -476,7 +477,7 @@ func (r *Inputer) processRoomEvent( } // If guest_access changed and is not can_join, kick all guest users. - if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { + if event.Type() == spec.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { if err = r.kickGuests(ctx, event, roomInfo); err != nil { logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation") } @@ -509,7 +510,7 @@ func (r *Inputer) processStateBefore( ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared. event := input.Event.Unwrap() - isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") + isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("") var stateBeforeEvent []*gomatrixserverlib.Event switch { case isCreateEvent: @@ -546,7 +547,7 @@ func (r *Inputer) processStateBefore( // output events. tuplesNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event}).Tuples() tuplesNeeded = append(tuplesNeeded, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomHistoryVisibility, + EventType: spec.MRoomHistoryVisibility, StateKey: "", }) stateBeforeReq := &api.QueryStateAfterEventsRequest{ @@ -579,7 +580,7 @@ func (r *Inputer) processStateBefore( // Work out what the history visibility was at the time of the // event. for _, event := range stateBeforeEvent { - if event.Type() != gomatrixserverlib.MRoomHistoryVisibility || !event.StateKeyEquals("") { + if event.Type() != spec.MRoomHistoryVisibility || !event.StateKeyEquals("") { continue } if hisVis, err := event.HistoryVisibility(); err == nil { @@ -602,14 +603,14 @@ func (r *Inputer) fetchAuthEvents( ctx context.Context, logger *logrus.Entry, roomInfo *types.RoomInfo, - virtualHost gomatrixserverlib.ServerName, + virtualHost spec.ServerName, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, known map[string]*types.Event, - servers []gomatrixserverlib.ServerName, + servers []spec.ServerName, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "fetchAuthEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "fetchAuthEvents") + defer trace.EndRegion() unknown := map[string]struct{}{} authEventIDs := event.AuthEventIDs() @@ -647,7 +648,7 @@ func (r *Inputer) fetchAuthEvents( } var err error - var res gomatrixserverlib.RespEventAuth + var res fclient.RespEventAuth var found bool for _, serverName := range servers { // Request the entire auth chain for the event in question. This should @@ -753,8 +754,8 @@ func (r *Inputer) calculateAndSetState( event *gomatrixserverlib.Event, isRejected bool, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "calculateAndSetState") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "calculateAndSetState") + defer trace.EndRegion() var succeeded bool updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) @@ -842,12 +843,12 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { return err } - memberContent.Membership = gomatrixserverlib.Leave + memberContent.Membership = spec.Leave stateKey := *memberEvent.StateKey() fledglingEvent := &gomatrixserverlib.EventBuilder{ RoomID: event.RoomID(), - Type: gomatrixserverlib.MRoomMember, + Type: spec.MRoomMember, StateKey: &stateKey, Sender: stateKey, PrevEvents: prevEvents, diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 818e7715c..9acc435b6 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/test" ) @@ -21,15 +22,15 @@ func Test_EventAuth(t *testing.T) { // Add the legal auth events from room2 for _, x := range room2.Events() { - if x.Type() == gomatrixserverlib.MRoomCreate { + if x.Type() == spec.MRoomCreate { authEventIDs = append(authEventIDs, x.EventID()) authEvents = append(authEvents, x.Event) } - if x.Type() == gomatrixserverlib.MRoomPowerLevels { + if x.Type() == spec.MRoomPowerLevels { authEventIDs = append(authEventIDs, x.EventID()) authEvents = append(authEvents, x.Event) } - if x.Type() == gomatrixserverlib.MRoomJoinRules { + if x.Type() == spec.MRoomJoinRules { authEventIDs = append(authEventIDs, x.EventID()) authEvents = append(authEvents, x.Event) } @@ -37,7 +38,7 @@ func Test_EventAuth(t *testing.T) { // Add the illegal auth event from room1 (rooms are different) for _, x := range room1.Events() { - if x.Type() == gomatrixserverlib.MRoomMember { + if x.Type() == spec.MRoomMember { authEventIDs = append(authEventIDs, x.EventID()) authEvents = append(authEvents, x.Event) } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index a223820ef..09db18431 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -23,9 +23,9 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" @@ -59,8 +59,8 @@ func (r *Inputer) updateLatestEvents( rewritesState bool, historyVisibility gomatrixserverlib.HistoryVisibility, ) (err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "updateLatestEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "updateLatestEvents") + defer trace.EndRegion() var succeeded bool updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) @@ -209,8 +209,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { } func (u *latestEventsUpdater) latestState() error { - span, ctx := opentracing.StartSpanFromContext(u.ctx, "processEventWithMissingState") - defer span.Finish() + trace, ctx := internal.StartRegion(u.ctx, "processEventWithMissingState") + defer trace.EndRegion() var err error roomState := state.NewStateResolution(u.updater, u.roomInfo) @@ -329,8 +329,8 @@ func (u *latestEventsUpdater) calculateLatest( newEvent *gomatrixserverlib.Event, newStateAndRef types.StateAtEventAndReference, ) (bool, error) { - span, _ := opentracing.StartSpanFromContext(u.ctx, "calculateLatest") - defer span.Finish() + trace, _ := internal.StartRegion(u.ctx, "calculateLatest") + defer trace.EndRegion() // First of all, get a list of all of the events in our current // set of forward extremities. diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index e1dfa6cfa..947f6c150 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -18,13 +18,15 @@ import ( "context" "fmt" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/opentracing/opentracing-go" ) // updateMembership updates the current membership and the invites for each @@ -36,8 +38,8 @@ func (r *Inputer) updateMemberships( updater *shared.RoomUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "updateMemberships") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "updateMemberships") + defer trace.EndRegion() changes := membershipChanges(removed, added) var eventNIDs []types.EventNID @@ -85,7 +87,7 @@ func (r *Inputer) updateMembership( ) ([]api.OutputEvent, error) { var err error // Default the membership to Leave if no event was added or removed. - newMembership := gomatrixserverlib.Leave + newMembership := spec.Leave if add != nil { newMembership, err = add.Membership() if err != nil { @@ -119,13 +121,13 @@ func (r *Inputer) updateMembership( } switch newMembership { - case gomatrixserverlib.Invite: + case spec.Invite: return helpers.UpdateToInviteMembership(mu, add, updates, updater.RoomVersion()) - case gomatrixserverlib.Join: + case spec.Join: return updateToJoinMembership(mu, add, updates) - case gomatrixserverlib.Leave, gomatrixserverlib.Ban: + case spec.Leave, spec.Ban: return updateToLeaveMembership(mu, add, newMembership, updates) - case gomatrixserverlib.Knock: + case spec.Knock: return updateToKnockMembership(mu, add, updates) default: panic(fmt.Errorf( @@ -159,7 +161,7 @@ func updateToJoinMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, - Membership: gomatrixserverlib.Join, + Membership: spec.Join, RetiredByEventID: add.EventID(), TargetUserID: *add.StateKey(), }, diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 9627f15ac..c7b647978 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -7,16 +7,18 @@ import ( "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" - "github.com/sirupsen/logrus" ) type parsedRespState struct { @@ -41,15 +43,15 @@ func (p *parsedRespState) Events() []*gomatrixserverlib.Event { type missingStateReq struct { log *logrus.Entry - virtualHost gomatrixserverlib.ServerName - origin gomatrixserverlib.ServerName + virtualHost spec.ServerName + origin spec.ServerName db storage.RoomDatabase roomInfo *types.RoomInfo inputer *Inputer keys gomatrixserverlib.JSONVerifier federation fedapi.RoomserverFederationAPI roomsMu *internal.MutexByRoom - servers []gomatrixserverlib.ServerName + servers []spec.ServerName hadEvents map[string]bool hadEventsMutex sync.Mutex haveEvents map[string]*gomatrixserverlib.Event @@ -62,8 +64,8 @@ type missingStateReq struct { func (t *missingStateReq) processEventWithMissingState( ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, ) (*parsedRespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "processEventWithMissingState") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "processEventWithMissingState") + defer trace.EndRegion() // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the @@ -241,8 +243,8 @@ func (t *missingStateReq) processEventWithMissingState( } func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupResolvedStateBeforeEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupResolvedStateBeforeEvent") + defer trace.EndRegion() type respState struct { // A snapshot is considered trustworthy if it came from our own roomserver. @@ -278,7 +280,7 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e resolvedState := &parsedRespState{} switch len(states) { case 0: - extremityIsCreate := e.Type() == gomatrixserverlib.MRoomCreate && e.StateKeyEquals("") + extremityIsCreate := e.Type() == spec.MRoomCreate && e.StateKeyEquals("") if !extremityIsCreate { // There are no previous states and this isn't the beginning of the // room - this is an error condition! @@ -290,7 +292,7 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e // use it as-is. There's no point in resolving it again. Only trust a // trustworthy state snapshot if it actually contains some state for all // non-create events, otherwise we need to resolve what came from federation. - isCreate := e.Type() == gomatrixserverlib.MRoomCreate && e.StateKeyEquals("") + isCreate := e.Type() == spec.MRoomCreate && e.StateKeyEquals("") if states[0].trustworthy && (isCreate || len(states[0].StateEvents) > 0) { resolvedState = states[0].parsedRespState break @@ -319,8 +321,8 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // added into the mix. func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*parsedRespState, bool, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateAfterEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupStateAfterEvent") + defer trace.EndRegion() // try doing all this locally before we resort to querying federation respState := t.lookupStateAfterEventLocally(ctx, eventID) @@ -376,8 +378,8 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixs } func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, eventID string) *parsedRespState { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateAfterEventLocally") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupStateAfterEventLocally") + defer trace.EndRegion() var res parsedRespState roomState := state.NewStateResolution(t.db, t.roomInfo) @@ -449,16 +451,16 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even // the server supports. func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( *parsedRespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateBeforeEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupStateBeforeEvent") + defer trace.EndRegion() // Attempt to fetch the missing state using /state_ids and /events return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) } func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "resolveStatesAndCheck") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "resolveStatesAndCheck") + defer trace.EndRegion() var authEventList []*gomatrixserverlib.Event var stateEventList []*gomatrixserverlib.Event @@ -503,8 +505,8 @@ retryAllowedState: // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "getMissingEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "getMissingEvents") + defer trace.EndRegion() logger := t.log.WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) @@ -517,10 +519,10 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve t.hadEvent(ev.EventID) } - var missingResp *gomatrixserverlib.RespMissingEvents + var missingResp *fclient.RespMissingEvents for _, server := range t.servers { - var m gomatrixserverlib.RespMissingEvents - if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), gomatrixserverlib.MissingEvents{ + var m fclient.RespMissingEvents + if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), fclient.MissingEvents{ Limit: 20, // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. EarliestEvents: latestEvents, @@ -596,7 +598,7 @@ Event: // If we retrieved back to the beginning of the room then there's nothing else // to do - we closed the gap. - if len(earliestNewEvent.PrevEventIDs()) == 0 && earliestNewEvent.Type() == gomatrixserverlib.MRoomCreate && earliestNewEvent.StateKeyEquals("") { + if len(earliestNewEvent.PrevEventIDs()) == 0 && earliestNewEvent.Type() == spec.MRoomCreate && earliestNewEvent.StateKeyEquals("") { return newEvents, true, t.isPrevStateKnown(ctx, e), nil } @@ -633,44 +635,36 @@ func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e *gomatrixserve func (t *missingStateReq) lookupMissingStateViaState( ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (respState *parsedRespState, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupMissingStateViaState") + defer trace.EndRegion() state, err := t.federation.LookupState(ctx, t.virtualHost, t.origin, roomID, eventID, roomVersion) if err != nil { return nil, err } + // Check that the returned state is valid. - authEvents, stateEvents, err := state.Check(ctx, roomVersion, t.keys, nil) + authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ + StateEvents: state.GetStateEvents(), + AuthEvents: state.GetAuthEvents(), + }, roomVersion, t.keys, nil) if err != nil { return nil, err } - parsedState := &parsedRespState{ + return &parsedRespState{ AuthEvents: authEvents, StateEvents: stateEvents, - } - // Cache the results of this state lookup and deduplicate anything we already - // have in the cache, freeing up memory. - // We load these as trusted as we called state.Check before which loaded them as untrusted. - for i, evJSON := range state.AuthEvents { - ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) - parsedState.AuthEvents[i] = t.cacheAndReturn(ev) - } - for i, evJSON := range state.StateEvents { - ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) - parsedState.StateEvents[i] = t.cacheAndReturn(ev) - } - return parsedState, nil + }, nil } func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( *parsedRespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaStateIDs") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupMissingStateViaStateIDs") + defer trace.EndRegion() t.log.Infof("lookupMissingStateViaStateIDs %s", eventID) // fetch the state event IDs at the time of the event - var stateIDs gomatrixserverlib.RespStateIDs + var stateIDs gomatrixserverlib.StateIDResponse var err error count := 0 totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5) @@ -688,7 +682,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo return nil, fmt.Errorf("t.federation.LookupStateIDs tried %d server(s), last error: %w", count, err) } // work out which auth/state IDs are missing - wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) + wantIDs := append(stateIDs.GetStateEventIDs(), stateIDs.GetAuthEventIDs()...) missing := make(map[string]bool) var missingEventList []string t.haveEventsMutex.Lock() @@ -730,8 +724,8 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo t.log.WithFields(logrus.Fields{ "missing": missingCount, "event_id": eventID, - "total_state": len(stateIDs.StateEventIDs), - "total_auth_events": len(stateIDs.AuthEventIDs), + "total_state": len(stateIDs.GetStateEventIDs()), + "total_auth_events": len(stateIDs.GetAuthEventIDs()), }).Debug("Fetching all state at event") return t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) } @@ -740,8 +734,8 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo t.log.WithFields(logrus.Fields{ "missing": missingCount, "event_id": eventID, - "total_state": len(stateIDs.StateEventIDs), - "total_auth_events": len(stateIDs.AuthEventIDs), + "total_state": len(stateIDs.GetStateEventIDs()), + "total_auth_events": len(stateIDs.GetAuthEventIDs()), "concurrent_requests": concurrentRequests, }).Debug("Fetching missing state at event") @@ -808,7 +802,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo } func (t *missingStateReq) createRespStateFromStateIDs( - stateIDs gomatrixserverlib.RespStateIDs, + stateIDs gomatrixserverlib.StateIDResponse, ) (*parsedRespState, error) { // nolint:unparam t.haveEventsMutex.Lock() defer t.haveEventsMutex.Unlock() @@ -816,18 +810,20 @@ func (t *missingStateReq) createRespStateFromStateIDs( // create a RespState response using the response to /state_ids as a guide respState := parsedRespState{} - for i := range stateIDs.StateEventIDs { - ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] + stateEventIDs := stateIDs.GetStateEventIDs() + authEventIDs := stateIDs.GetAuthEventIDs() + for i := range stateEventIDs { + ev, ok := t.haveEvents[stateEventIDs[i]] if !ok { - logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) + logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateEventIDs[i]) continue } respState.StateEvents = append(respState.StateEvents, ev) } - for i := range stateIDs.AuthEventIDs { - ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] + for i := range authEventIDs { + ev, ok := t.haveEvents[authEventIDs[i]] if !ok { - logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) + logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", authEventIDs[i]) continue } respState.AuthEvents = append(respState.AuthEvents, ev) @@ -839,8 +835,8 @@ func (t *missingStateReq) createRespStateFromStateIDs( } func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "lookupEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "lookupEvent") + defer trace.EndRegion() if localFirst { // fetch from the roomserver diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go index 4708560ac..51c50c37a 100644 --- a/roomserver/internal/input/input_test.go +++ b/roomserver/internal/input/input_test.go @@ -2,82 +2,69 @@ package input_test import ( "context" - "os" "testing" "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" - "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" ) -var js nats.JetStreamContext -var jc *nats.Conn - -func TestMain(m *testing.M) { - var b *base.BaseDendrite - b, js, jc = testrig.Base(nil) - code := m.Run() - b.ShutdownDendrite() - b.WaitForComponentsToFinish() - os.Exit(code) -} - func TestSingleTransactionOnInput(t *testing.T) { - deadline, _ := t.Deadline() - if max := time.Now().Add(time.Second * 3); deadline.After(max) { - deadline = max - } - ctx, cancel := context.WithDeadline(context.Background(), deadline) - defer cancel() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) - event, err := gomatrixserverlib.NewEventFromTrustedJSON( - []byte(`{"auth_events":[],"content":{"creator":"@neilalexander:dendrite.matrix.org","room_version":"6"},"depth":1,"hashes":{"sha256":"jqOqdNEH5r0NiN3xJtj0u5XUVmRqq9YvGbki1wxxuuM"},"origin":"dendrite.matrix.org","origin_server_ts":1644595362726,"prev_events":[],"prev_state":[],"room_id":"!jSZZRknA6GkTBXNP:dendrite.matrix.org","sender":"@neilalexander:dendrite.matrix.org","signatures":{"dendrite.matrix.org":{"ed25519:6jB2aB":"bsQXO1wketf1OSe9xlndDIWe71W9KIundc6rBw4KEZdGPW7x4Tv4zDWWvbxDsG64sS2IPWfIm+J0OOozbrWIDw"}},"state_key":"","type":"m.room.create"}`), - false, gomatrixserverlib.RoomVersionV6, - ) - if err != nil { - t.Fatal(err) - } - in := api.InputRoomEvent{ - Kind: api.KindOutlier, // don't panic if we generate an output event - Event: event.Headered(gomatrixserverlib.RoomVersionV6), - } - db, err := storage.Open( - nil, - &config.DatabaseOptions{ - ConnectionString: "", - MaxOpenConnections: 1, - MaxIdleConnections: 1, - }, - caching.NewRistrettoCache(8*1024*1024, time.Hour, false), - ) - if err != nil { - t.Logf("PostgreSQL not available (%s), skipping", err) - t.SkipNow() - } - inputter := &input.Inputer{ - DB: db, - JetStream: js, - NATSClient: jc, - } - res := &api.InputRoomEventsResponse{} - inputter.InputRoomEvents( - ctx, - &api.InputRoomEventsRequest{ - InputRoomEvents: []api.InputRoomEvent{in}, - Asynchronous: false, - }, - res, - ) - // If we fail here then it's because we've hit the test deadline, - // so we probably deadlocked - if err := res.Err(); err != nil { - t.Fatal(err) - } + natsInstance := &jetstream.NATSInstance{} + js, jc := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + + deadline, _ := t.Deadline() + if max := time.Now().Add(time.Second * 3); deadline.Before(max) { + deadline = max + } + ctx, cancel := context.WithDeadline(processCtx.Context(), deadline) + defer cancel() + + event, err := gomatrixserverlib.NewEventFromTrustedJSON( + []byte(`{"auth_events":[],"content":{"creator":"@neilalexander:dendrite.matrix.org","room_version":"6"},"depth":1,"hashes":{"sha256":"jqOqdNEH5r0NiN3xJtj0u5XUVmRqq9YvGbki1wxxuuM"},"origin":"dendrite.matrix.org","origin_server_ts":1644595362726,"prev_events":[],"prev_state":[],"room_id":"!jSZZRknA6GkTBXNP:dendrite.matrix.org","sender":"@neilalexander:dendrite.matrix.org","signatures":{"dendrite.matrix.org":{"ed25519:6jB2aB":"bsQXO1wketf1OSe9xlndDIWe71W9KIundc6rBw4KEZdGPW7x4Tv4zDWWvbxDsG64sS2IPWfIm+J0OOozbrWIDw"}},"state_key":"","type":"m.room.create"}`), + false, gomatrixserverlib.RoomVersionV6, + ) + if err != nil { + t.Fatal(err) + } + in := api.InputRoomEvent{ + Kind: api.KindOutlier, // don't panic if we generate an output event + Event: event.Headered(gomatrixserverlib.RoomVersionV6), + } + + inputter := &input.Inputer{ + JetStream: js, + NATSClient: jc, + Cfg: &cfg.RoomServer, + } + res := &api.InputRoomEventsResponse{} + inputter.InputRoomEvents( + ctx, + &api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{in}, + Asynchronous: false, + }, + res, + ) + // If we fail here then it's because we've hit the test deadline, + // so we probably deadlocked + if err := res.Err(); err != nil { + t.Fatal(err) + } + }) } diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 45089bdd1..e7b9db1ff 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -107,12 +108,12 @@ func (r *Admin) PerformAdminEvacuateRoom( } return nil } - memberContent.Membership = gomatrixserverlib.Leave + memberContent.Membership = spec.Leave stateKey := *memberEvent.StateKey() fledglingEvent := &gomatrixserverlib.EventBuilder{ RoomID: req.RoomID, - Type: gomatrixserverlib.MRoomMember, + Type: spec.MRoomMember, StateKey: &stateKey, Sender: stateKey, PrevEvents: prevEvents, @@ -195,7 +196,7 @@ func (r *Admin) PerformAdminEvacuateUser( return nil } - roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Join) + roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, spec.Join) if err != nil && err != sql.ErrNoRows { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -204,7 +205,7 @@ func (r *Admin) PerformAdminEvacuateUser( return nil } - inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Invite) + inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, spec.Invite) if err != nil && err != sql.ErrNoRows { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -227,6 +228,7 @@ func (r *Admin) PerformAdminEvacuateUser( } return nil } + res.Affected = append(res.Affected, roomID) if len(outputEvents) == 0 { continue } @@ -237,8 +239,6 @@ func (r *Admin) PerformAdminEvacuateUser( } return nil } - - res.Affected = append(res.Affected, roomID) } return nil } @@ -323,7 +323,7 @@ func (r *Admin) PerformAdminDownloadState( stateEventMap := map[string]*gomatrixserverlib.Event{} for _, fwdExtremity := range fwdExtremities { - var state gomatrixserverlib.RespState + var state gomatrixserverlib.StateResponse state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) if err != nil { res.Error = &api.PerformError{ @@ -332,13 +332,13 @@ func (r *Admin) PerformAdminDownloadState( } return nil } - for _, authEvent := range state.AuthEvents.UntrustedEvents(roomInfo.RoomVersion) { + for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { if err = authEvent.VerifyEventSignatures(ctx, r.Inputer.KeyRing); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } - for _, stateEvent := range state.StateEvents.UntrustedEvents(roomInfo.RoomVersion) { + for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { if err = stateEvent.VerifyEventSignatures(ctx, r.Inputer.KeyRing); err != nil { continue } @@ -362,7 +362,7 @@ func (r *Admin) PerformAdminDownloadState( Type: "org.matrix.dendrite.state_download", Sender: req.UserID, RoomID: req.RoomID, - Content: gomatrixserverlib.RawJSON("{}"), + Content: spec.RawJSON("{}"), } eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 07ad28b82..d9a2394a8 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -37,13 +38,13 @@ import ( const maxBackfillServers = 5 type Backfiller struct { - IsLocalServerName func(gomatrixserverlib.ServerName) bool + IsLocalServerName func(spec.ServerName) bool DB storage.Database FSAPI federationAPI.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier // The servers which should be preferred above other servers when backfilling - PreferServers []gomatrixserverlib.ServerName + PreferServers []spec.ServerName } // PerformBackfill implements api.RoomServerQueryAPI @@ -133,9 +134,6 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // persist these new events - auth checks have already been done roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) - if err != nil { - return err - } for _, ev := range backfilledEventMap { // now add state for these events @@ -178,7 +176,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just // best effort. func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - backfillRequester *backfillRequester, stateIDs []string, virtualHost gomatrixserverlib.ServerName) { + backfillRequester *backfillRequester, stateIDs []string, virtualHost spec.ServerName) { servers := backfillRequester.servers @@ -248,13 +246,13 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom type backfillRequester struct { db storage.Database fsAPI federationAPI.RoomserverFederationAPI - virtualHost gomatrixserverlib.ServerName - isLocalServerName func(gomatrixserverlib.ServerName) bool - preferServer map[gomatrixserverlib.ServerName]bool + virtualHost spec.ServerName + isLocalServerName func(spec.ServerName) bool + preferServer map[spec.ServerName]bool bwExtrems map[string][]string // per-request state - servers []gomatrixserverlib.ServerName + servers []spec.ServerName eventIDToBeforeStateIDs map[string][]string eventIDMap map[string]*gomatrixserverlib.Event historyVisiblity gomatrixserverlib.HistoryVisibility @@ -263,11 +261,11 @@ type backfillRequester struct { func newBackfillRequester( db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, - virtualHost gomatrixserverlib.ServerName, - isLocalServerName func(gomatrixserverlib.ServerName) bool, - bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName, + virtualHost spec.ServerName, + isLocalServerName func(spec.ServerName) bool, + bwExtrems map[string][]string, preferServers []spec.ServerName, ) *backfillRequester { - preferServer := make(map[gomatrixserverlib.ServerName]bool) + preferServer := make(map[spec.ServerName]bool) for _, p := range preferServers { preferServer[p] = true } @@ -418,7 +416,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr // It returns a list of servers which can be queried for backfill requests. These servers // will be servers that are in the room already. The entries at the beginning are preferred servers // and will be tried first. An empty list will fail the request. -func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName { +func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []spec.ServerName { // eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use // its successor, so look it up. successor := "" @@ -481,19 +479,19 @@ FindSuccessor: memberEvents = append(memberEvents, memberEventsFromVis...) // Store the server names in a temporary map to avoid duplicates. - serverSet := make(map[gomatrixserverlib.ServerName]bool) + serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil { serverSet[senderDomain] = true } } - var servers []gomatrixserverlib.ServerName + var servers []spec.ServerName for server := range serverSet { if b.isLocalServerName(server) { continue } if b.preferServer[server] { // insert at the front - servers = append([]gomatrixserverlib.ServerName{server}, servers...) + servers = append([]spec.ServerName{server}, servers...) } else { // insert at the back servers = append(servers, server) } @@ -508,7 +506,7 @@ FindSuccessor: // Backfill performs a backfill request to the given server. // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid -func (b *backfillRequester) Backfill(ctx context.Context, origin, server gomatrixserverlib.ServerName, roomID string, +func (b *backfillRequester) Backfill(ctx context.Context, origin, server spec.ServerName, roomID string, limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { tx, err := b.fsAPI.Backfill(ctx, origin, server, roomID, limit, fromEventIDs) @@ -550,7 +548,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, // pull all events and then filter by that table. func joinEventsFromHistoryVisibility( ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, - thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { + thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { var eventNIDs []types.EventNID for _, entry := range stateEntries { @@ -611,6 +609,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs logrus.WithError(err).Error("failed to get or create roomNID") continue } + roomNID = roomInfo.RoomNID eventTypeNID, err := db.GetOrCreateEventTypeNID(ctx, ev.Type()) if err != nil { diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 13d13f7b5..a23cdea10 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -28,6 +28,8 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -95,7 +97,7 @@ func (r *Inviter) PerformInvite( inviteState := req.InviteRoomState if len(inviteState) == 0 && info != nil { - var is []gomatrixserverlib.InviteV2StrippedState + var is []fclient.InviteV2StrippedState if is, err = buildInviteStrippedState(ctx, r.DB, info, req); err == nil { inviteState = is } @@ -266,14 +268,14 @@ func buildInviteStrippedState( db storage.Database, info *types.RoomInfo, input *api.PerformInviteRequest, -) ([]gomatrixserverlib.InviteV2StrippedState, error) { +) ([]fclient.InviteV2StrippedState, error) { stateWanted := []gomatrixserverlib.StateKeyTuple{} // "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included." // https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member for _, t := range []string{ - gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias, - gomatrixserverlib.MRoomJoinRules, gomatrixserverlib.MRoomAvatar, - gomatrixserverlib.MRoomEncryption, gomatrixserverlib.MRoomCreate, + spec.MRoomName, spec.MRoomCanonicalAlias, + spec.MRoomJoinRules, spec.MRoomAvatar, + spec.MRoomEncryption, spec.MRoomCreate, } { stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{ EventType: t, @@ -295,12 +297,12 @@ func buildInviteStrippedState( if err != nil { return nil, err } - inviteState := []gomatrixserverlib.InviteV2StrippedState{ - gomatrixserverlib.NewInviteV2StrippedState(input.Event.Event), + inviteState := []fclient.InviteV2StrippedState{ + fclient.NewInviteV2StrippedState(input.Event.Event), } stateEvents = append(stateEvents, types.Event{Event: input.Event.Unwrap()}) for _, event := range stateEvents { - inviteState = append(inviteState, gomatrixserverlib.NewInviteV2StrippedState(event.Event)) + inviteState = append(inviteState, fclient.NewInviteV2StrippedState(event.Event)) } return inviteState, nil } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index fc7ba940c..9b0895e9f 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -24,6 +24,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -35,7 +36,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" ) @@ -84,7 +84,7 @@ func (r *Joiner) PerformJoin( func (r *Joiner) performJoin( ctx context.Context, req *rsAPI.PerformJoinRequest, -) (string, gomatrixserverlib.ServerName, error) { +) (string, spec.ServerName, error) { _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return "", "", &rsAPI.PerformError{ @@ -113,7 +113,7 @@ func (r *Joiner) performJoin( func (r *Joiner) performJoinRoomByAlias( ctx context.Context, req *rsAPI.PerformJoinRequest, -) (string, gomatrixserverlib.ServerName, error) { +) (string, spec.ServerName, error) { // Get the domain part of the room alias. _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) if err != nil { @@ -168,7 +168,7 @@ func (r *Joiner) performJoinRoomByAlias( func (r *Joiner) performJoinRoomByID( ctx context.Context, req *rsAPI.PerformJoinRequest, -) (string, gomatrixserverlib.ServerName, error) { +) (string, spec.ServerName, error) { // The original client request ?server_name=... may include this HS so filter that out so we // don't attempt to make_join with ourselves for i := 0; i < len(req.ServerNames); i++ { @@ -205,7 +205,7 @@ func (r *Joiner) performJoinRoomByID( } } eb := gomatrixserverlib.EventBuilder{ - Type: gomatrixserverlib.MRoomMember, + Type: spec.MRoomMember, Sender: userID, StateKey: &userID, RoomID: req.RoomIDOrAlias, @@ -221,7 +221,7 @@ func (r *Joiner) performJoinRoomByID( if req.Content == nil { req.Content = map[string]interface{}{} } - req.Content["membership"] = gomatrixserverlib.Join + req.Content["membership"] = spec.Join if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req); aerr != nil { return "", "", aerr } else if authorisedVia != "" { @@ -275,7 +275,7 @@ func (r *Joiner) performJoinRoomByID( if req.IsGuest { var guestAccessEvent *gomatrixserverlib.HeaderedEvent guestAccess := "forbidden" - guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "") + guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, spec.MRoomGuestAccess, "") if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil { logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'") } @@ -294,7 +294,7 @@ func (r *Joiner) performJoinRoomByID( } // If we should do a forced federated join then do that. - var joinedVia gomatrixserverlib.ServerName + var joinedVia spec.ServerName if forceFederatedJoin { joinedVia, err = r.performFederatedJoinRoomByID(ctx, req) return req.RoomIDOrAlias, joinedVia, err @@ -305,7 +305,13 @@ func (r *Joiner) performJoinRoomByID( // locally on the homeserver. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, userDomain, &eb) + + var buildRes rsAPI.QueryLatestEventsAndStateResponse + identity, err := r.Cfg.Matrix.SigningIdentityFor(userDomain) + if err != nil { + return "", "", fmt.Errorf("error joining local room: %q", err) + } + event, err := eventutil.QueryAndBuildEvent(ctx, &eb, r.Cfg.Matrix, identity, time.Now(), r.RSAPI, &buildRes) switch err { case nil: @@ -383,7 +389,7 @@ func (r *Joiner) performJoinRoomByID( func (r *Joiner) performFederatedJoinRoomByID( ctx context.Context, req *rsAPI.PerformJoinRequest, -) (gomatrixserverlib.ServerName, error) { +) (spec.ServerName, error) { // Try joining by all of the supplied server names. fedReq := fsAPI.PerformJoinRequest{ RoomID: req.RoomIDOrAlias, // the room ID to try and join @@ -430,46 +436,3 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( } return res.AuthorisedVia, nil } - -func buildEvent( - ctx context.Context, db storage.Database, cfg *config.Global, - senderDomain gomatrixserverlib.ServerName, - builder *gomatrixserverlib.EventBuilder, -) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) { - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) - if err != nil { - return nil, nil, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) - } - - if len(eventsNeeded.Tuples()) == 0 { - return nil, nil, errors.New("expecting state tuples for event builder, got none") - } - - var queryRes rsAPI.QueryLatestEventsAndStateResponse - err = helpers.QueryLatestEventsAndState(ctx, db, &rsAPI.QueryLatestEventsAndStateRequest{ - RoomID: builder.RoomID, - StateToFetch: eventsNeeded.Tuples(), - }, &queryRes) - if err != nil { - switch err.(type) { - case types.MissingStateError: - // We know something about the room but the state seems to be - // insufficient to actually build a new event, so in effect we - // had might as well treat the room as if it doesn't exist. - return nil, nil, eventutil.ErrRoomNoExists - default: - return nil, nil, fmt.Errorf("QueryLatestEventsAndState: %w", err) - } - } - - identity, err := cfg.SigningIdentityFor(senderDomain) - if err != nil { - return nil, nil, err - } - - ev, err := eventutil.BuildEvent(ctx, builder, cfg, identity, time.Now(), &eventsNeeded, &queryRes) - if err != nil { - return nil, nil, err - } - return ev, &queryRes, nil -} diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 86f1dfaee..4dfb522a2 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -19,15 +19,19 @@ import ( "encoding/json" "fmt" "strings" + "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" @@ -39,6 +43,7 @@ type Leaver struct { Cfg *config.RoomServer DB storage.Database FSAPI fsAPI.RoomserverFederationAPI + RSAPI rsAPI.RoomserverInternalAPI UserAPI userapi.RoomserverUserAPI Inputer *input.Inputer } @@ -122,7 +127,7 @@ func (r *Leaver) performLeaveRoomByID( RoomID: req.RoomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{ { - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: req.UserID, }, }, @@ -143,14 +148,14 @@ func (r *Leaver) performLeaveRoomByID( if err != nil { return nil, fmt.Errorf("error getting membership: %w", err) } - if membership != gomatrixserverlib.Join && membership != gomatrixserverlib.Invite { + if membership != spec.Join && membership != spec.Invite { return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.UserID, membership) } // Prepare the template for the leave event. userID := req.UserID eb := gomatrixserverlib.EventBuilder{ - Type: gomatrixserverlib.MRoomMember, + Type: spec.MRoomMember, Sender: userID, StateKey: &userID, RoomID: req.RoomID, @@ -173,9 +178,15 @@ func (r *Leaver) performLeaveRoomByID( // a leave event. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, senderDomain, &eb) + + var buildRes rsAPI.QueryLatestEventsAndStateResponse + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) if err != nil { - return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) + return nil, fmt.Errorf("SigningIdentityFor: %w", err) + } + event, err := eventutil.QueryAndBuildEvent(ctx, &eb, r.Cfg.Matrix, identity, time.Now(), r.RSAPI, &buildRes) + if err != nil { + return nil, fmt.Errorf("eventutil.QueryAndBuildEvent: %w", err) } // Give our leave event to the roomserver input stream. The @@ -217,7 +228,7 @@ func (r *Leaver) performFederatedRejectInvite( leaveReq := fsAPI.PerformLeaveRequest{ RoomID: req.RoomID, UserID: req.UserID, - ServerNames: []gomatrixserverlib.ServerName{domain}, + ServerNames: []spec.ServerName{domain}, } leaveRes := fsAPI.PerformLeaveResponse{} if err = r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go index 436d137ff..2f39050dd 100644 --- a/roomserver/internal/perform/perform_peek.go +++ b/roomserver/internal/perform/perform_peek.go @@ -26,12 +26,13 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) type Peeker struct { - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName Cfg *config.RoomServer FSAPI fsAPI.RoomserverFederationAPI DB storage.Database diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 4d714be66..28486fa14 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -24,10 +24,11 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Unpeeker struct { - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName Cfg *config.RoomServer FSAPI fsAPI.RoomserverFederationAPI Inputer *input.Inputer diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 02a19911c..ed57abf26 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -148,7 +149,7 @@ func (r *Upgrader) performRoomUpgrade( func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*gomatrixserverlib.PowerLevelContent, *api.PerformError) { oldPowerLevelsEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, + EventType: spec.MRoomPowerLevels, StateKey: "", }) powerLevelContent, err := oldPowerLevelsEvent.PowerLevels() @@ -161,7 +162,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma return powerLevelContent, nil } -func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError { +func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) *api.PerformError { restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) if pErr != nil { return pErr @@ -179,7 +180,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, fledglingEvent{ - Type: gomatrixserverlib.MRoomPowerLevels, + Type: spec.MRoomPowerLevels, StateKey: "", Content: restrictedPowerLevelContent, }) @@ -230,9 +231,9 @@ func moveLocalAliases(ctx context.Context, return nil } -func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError { +func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) *api.PerformError { for _, event := range oldRoom.StateEvents { - if event.Type() != gomatrixserverlib.MRoomCanonicalAlias || !event.StateKeyEquals("") { + if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") { continue } var aliasContent struct { @@ -251,7 +252,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api } emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, fledglingEvent{ - Type: gomatrixserverlib.MRoomCanonicalAlias, + Type: spec.MRoomCanonicalAlias, Content: map[string]interface{}{}, }) if resErr != nil { @@ -332,7 +333,7 @@ func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, ) bool { plEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, + EventType: spec.MRoomPowerLevels, StateKey: "", }) if plEvent == nil { @@ -355,7 +356,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // This shouldn't ever happen, but better to be safe than sorry. continue } - if event.Type() == gomatrixserverlib.MRoomMember && !event.StateKeyEquals(userID) { + if event.Type() == spec.MRoomMember && !event.StateKeyEquals(userID) { // With the exception of bans and invites which we do want to copy, we // should ignore membership events that aren't our own, as event auth will // prevent us from being able to create membership events on behalf of other @@ -365,8 +366,8 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query continue } switch membership { - case gomatrixserverlib.Ban: - case gomatrixserverlib.Invite: + case spec.Ban: + case spec.Invite: default: continue } @@ -377,10 +378,10 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // The following events are ones that we are going to override manually // in the following section. override := map[gomatrixserverlib.StateKeyTuple]struct{}{ - {EventType: gomatrixserverlib.MRoomCreate, StateKey: ""}: {}, - {EventType: gomatrixserverlib.MRoomMember, StateKey: userID}: {}, - {EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}: {}, - {EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""}: {}, + {EventType: spec.MRoomCreate, StateKey: ""}: {}, + {EventType: spec.MRoomMember, StateKey: userID}: {}, + {EventType: spec.MRoomPowerLevels, StateKey: ""}: {}, + {EventType: spec.MRoomJoinRules, StateKey: ""}: {}, } // The overridden events are essential events that must be present in the @@ -393,10 +394,10 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query } } - oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCreate, StateKey: ""}] - oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomMember, StateKey: userID}] - oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}] - oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""}] + oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate, StateKey: ""}] + oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: userID}] + oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomPowerLevels, StateKey: ""}] + oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""}] // Create the new room create event. Using a map here instead of CreateContent // means that we preserve any other interesting fields that might be present @@ -410,7 +411,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query RoomID: roomID, } newCreateEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomCreate, + Type: spec.MRoomCreate, StateKey: "", Content: newCreateContent, } @@ -421,9 +422,9 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // the events after it. newMembershipContent := map[string]interface{}{} _ = json.Unmarshal(oldMembershipEvent.Content(), &newMembershipContent) - newMembershipContent["membership"] = gomatrixserverlib.Join + newMembershipContent["membership"] = spec.Join newMembershipEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomMember, + Type: spec.MRoomMember, StateKey: userID, Content: newMembershipContent, } @@ -447,11 +448,11 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // existing join rules contains garbage, the room can still be // upgraded. newJoinRulesContent := map[string]interface{}{ - "join_rule": gomatrixserverlib.Invite, // sane default + "join_rule": spec.Invite, // sane default } _ = json.Unmarshal(oldJoinRulesEvent.Content(), &newJoinRulesContent) newJoinRulesEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomJoinRules, + Type: spec.MRoomJoinRules, StateKey: "", Content: newJoinRulesContent, } @@ -464,9 +465,9 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // For some reason Sytest expects there to be a guest access event. // Create one if it doesn't exist. - if _, ok := state[gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomGuestAccess, StateKey: ""}]; !ok { + if _, ok := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomGuestAccess, StateKey: ""}]; !ok { eventsToMake = append(eventsToMake, fledglingEvent{ - Type: gomatrixserverlib.MRoomGuestAccess, + Type: spec.MRoomGuestAccess, Content: map[string]string{ "guest_access": "forbidden", }, @@ -495,14 +496,14 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // override that now by restoring the original power levels. if powerLevelsOverridden { eventsToMake = append(eventsToMake, fledglingEvent{ - Type: gomatrixserverlib.MRoomPowerLevels, + Type: spec.MRoomPowerLevels, Content: powerLevelContent, }) } return eventsToMake, nil } -func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { +func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { var err error var builtEvents []*gomatrixserverlib.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) @@ -526,7 +527,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} } var event *gomatrixserverlib.Event - event, err = r.buildEvent(&builder, userDomain, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion)) + event, err = builder.AddAuthEventsAndBuild(userDomain, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion), r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey) if err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err), @@ -681,14 +682,14 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC // Then return the temporary power levels event. return fledglingEvent{ - Type: gomatrixserverlib.MRoomPowerLevels, + Type: spec.MRoomPowerLevels, Content: tempPowerLevelContent, }, powerLevelsOverridden } func (r *Upgrader) sendHeaderedEvent( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, headeredEvent *gomatrixserverlib.HeaderedEvent, sendAsServer string, ) *api.PerformError { @@ -707,29 +708,3 @@ func (r *Upgrader) sendHeaderedEvent( return nil } - -func (r *Upgrader) buildEvent( - builder *gomatrixserverlib.EventBuilder, - serverName gomatrixserverlib.ServerName, - provider gomatrixserverlib.AuthEventProvider, - evTime time.Time, - roomVersion gomatrixserverlib.RoomVersion, -) (*gomatrixserverlib.Event, error) { - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) - if err != nil { - return nil, err - } - refs, err := eventsNeeded.AuthEventReferences(provider) - if err != nil { - return nil, err - } - builder.AuthEvents = refs - event, err := builder.Build( - evTime, serverName, r.Cfg.Matrix.KeyID, - r.Cfg.Matrix.PrivateKey, roomVersion, - ) - if err != nil { - return nil, err - } - return event, nil -} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index c5b74422f..2c85d3ebb 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -22,10 +22,12 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/caching" @@ -35,13 +37,12 @@ import ( "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/roomserver/version" ) type Queryer struct { DB storage.Database Cache caching.RoomServerCaches - IsLocalServerName func(gomatrixserverlib.ServerName) bool + IsLocalServerName func(spec.ServerName) bool ServerACLs *acls.ServerACLs } @@ -305,7 +306,7 @@ func (r *Queryer) QueryMembershipAtEvent( // a given event, overwrite any other existing membership events. for i := range memberships { ev := memberships[i] - if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { + if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) { response.Membership[eventID] = ev.Event.Headered(info.RoomVersion) } } @@ -346,7 +347,7 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll) + clientEvent := synctypes.ToClientEvent(event.Event, synctypes.FormatAll) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil @@ -366,7 +367,7 @@ func (r *Queryer) QueryMembershipsForRoom( } response.HasBeenInRoom = true - response.JoinEvents = []gomatrixserverlib.ClientEvent{} + response.JoinEvents = []synctypes.ClientEvent{} var events []types.Event var stateEntries []types.StateEntry @@ -395,7 +396,7 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll) + clientEvent := synctypes.ToClientEvent(event.Event, synctypes.FormatAll) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -435,7 +436,7 @@ func (r *Queryer) QueryServerJoinedToRoom( // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI func (r *Queryer) QueryServerAllowedToSeeEvent( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, eventID string, ) (allowed bool, err error) { events, err := r.DB.EventNIDs(ctx, []string{eventID}) @@ -694,25 +695,7 @@ func GetAuthChain( return authEvents, nil } -// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI -func (r *Queryer) QueryRoomVersionCapabilities( - ctx context.Context, - request *api.QueryRoomVersionCapabilitiesRequest, - response *api.QueryRoomVersionCapabilitiesResponse, -) error { - response.DefaultRoomVersion = version.DefaultRoomVersion() - response.AvailableRoomVersions = make(map[gomatrixserverlib.RoomVersion]string) - for v, desc := range version.SupportedRoomVersions() { - if desc.Stable { - response.AvailableRoomVersions[v] = "stable" - } else { - response.AvailableRoomVersions[v] = "unstable" - } - } - return nil -} - -// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI +// QueryRoomVersionForRoom implements api.RoomserverInternalAPI func (r *Queryer) QueryRoomVersionForRoom( ctx context.Context, request *api.QueryRoomVersionForRoomRequest, @@ -914,7 +897,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query // the flag. res.Resident = true // Get the join rules to work out if the join rule is "restricted". - joinRulesEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, gomatrixserverlib.MRoomJoinRules, "") + joinRulesEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, spec.MRoomJoinRules, "") if err != nil { return fmt.Errorf("r.DB.GetStateEvent: %w", err) } @@ -926,7 +909,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query return fmt.Errorf("json.Unmarshal: %w", err) } // If the join rule isn't "restricted" then there's nothing more to do. - res.Restricted = joinRules.JoinRule == gomatrixserverlib.Restricted + res.Restricted = joinRules.JoinRule == spec.Restricted if !res.Restricted { return nil } @@ -943,7 +926,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query // We need to get the power levels content so that we can determine which // users in the room are entitled to issue invites. We need to use one of // these users as the authorising user. - powerLevelsEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, gomatrixserverlib.MRoomPowerLevels, "") + powerLevelsEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, spec.MRoomPowerLevels, "") if err != nil { return fmt.Errorf("r.DB.GetStateEvent: %w", err) } @@ -955,7 +938,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query for _, rule := range joinRules.Allow { // We only understand "m.room_membership" rules at this point in // time, so skip any rule that doesn't match those. - if rule.Type != gomatrixserverlib.MRoomMembership { + if rule.Type != spec.MRoomMembership { continue } // See if the room exists. If it doesn't exist or if it's a stub @@ -1002,7 +985,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query continue } event := events[0] - if event.Type() != gomatrixserverlib.MRoomMember || event.StateKey() == nil { + if event.Type() != spec.MRoomMember || event.StateKey() == nil { continue // shouldn't happen } // Only users that have the power to invite should be chosen. diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 5a8d8b570..4685f474f 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -15,28 +15,35 @@ package roomserver import ( + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal" "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/setup/base" ) // NewInternalAPI returns a concrete implementation of the internal API. func NewInternalAPI( - base *base.BaseDendrite, + processContext *process.ProcessContext, + cfg *config.Dendrite, + cm sqlutil.Connections, + natsInstance *jetstream.NATSInstance, + caches caching.RoomServerCaches, + enableMetrics bool, ) api.RoomserverInternalAPI { - cfg := &base.Cfg.RoomServer - - roomserverDB, err := storage.Open(base, &cfg.Database, base.Caches) + roomserverDB, err := storage.Open(processContext.Context(), cm, &cfg.RoomServer.Database, caches) if err != nil { logrus.WithError(err).Panicf("failed to connect to room server db") } - js, nc := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, nc := natsInstance.Prepare(processContext, &cfg.Global.JetStream) return internal.NewRoomserverAPI( - base, roomserverDB, js, nc, + processContext, cfg, roomserverDB, js, nc, caches, enableMetrics, ) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index a3ca5909e..67d6db46f 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -7,11 +7,14 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/userapi" userAPI "github.com/matrix-org/dendrite/userapi/api" @@ -29,21 +32,14 @@ import ( "github.com/matrix-org/dendrite/test/testrig" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { - t.Helper() - base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches) - if err != nil { - t.Fatalf("failed to create Database: %v", err) - } - return base, db, close -} - func TestUsers(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) defer close() - rsAPI := roomserver.NewInternalAPI(base) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) // SetFederationAPI starts the room event input consumer rsAPI.SetFederationAPI(nil, nil) @@ -52,7 +48,7 @@ func TestUsers(t *testing.T) { }) t.Run("kick users", func(t *testing.T) { - usrAPI := userapi.NewInternalAPI(base, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) rsAPI.SetUserAPI(usrAPI) testKickUsers(t, rsAPI, usrAPI) }) @@ -66,10 +62,10 @@ func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) { room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) // Invite and join Bob - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) @@ -107,7 +103,7 @@ func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true)) // Join with the guest user - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) @@ -139,7 +135,7 @@ func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI } // revoke guest access - revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) + revokeEvent := room.CreateAndInsert(t, alice, spec.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil { t.Errorf("failed to send events: %v", err) } @@ -169,19 +165,22 @@ func Test_QueryLeftUsers(t *testing.T) { room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) // Invite and join Bob - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, _, close := mustCreateDatabase(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) defer close() - rsAPI := roomserver.NewInternalAPI(base) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) // SetFederationAPI starts the room event input consumer rsAPI.SetFederationAPI(nil, nil) // Create the room @@ -218,36 +217,42 @@ func TestPurgeRoom(t *testing.T) { room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) // Invite Bob - inviteEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + inviteEvent := room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, db, close := mustCreateDatabase(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + natsInstance := jetstream.NATSInstance{} defer close() + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + db, err := storage.Open(processCtx.Context(), cm, &cfg.RoomServer.Database, caches) + if err != nil { + t.Fatal(err) + } + jsCtx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsCtx, &cfg.Global.JetStream) - jsCtx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - defer jetstream.DeleteAllStreams(jsCtx, &base.Cfg.Global.JetStream) - - fedClient := base.CreateFederationClient() - rsAPI := roomserver.NewInternalAPI(base) - userAPI := userapi.NewInternalAPI(base, rsAPI, nil) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) // this starts the JetStream consumers - syncapi.AddPublicRoutes(base, userAPI, rsAPI) - federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) + federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) rsAPI.SetFederationAPI(nil, nil) // Create the room - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + if err = api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } // some dummy entries to validate after purging publishResp := &api.PerformPublishResponse{} - if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil { + if err = rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil { t.Fatal(err) } if publishResp.Error != nil { @@ -335,7 +340,7 @@ func TestPurgeRoom(t *testing.T) { t.Fatalf("test timed out after %s", timeout) } sum = 0 - consumerCh := jsCtx.Consumers(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)) + consumerCh := jsCtx.Consumers(cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)) for x := range consumerCh { sum += x.NumAckPending } @@ -439,7 +444,7 @@ func TestRedaction(t *testing.T) { redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hello world"}) builderEv := mustCreateEvent(t, fledglingEvent{ - Type: gomatrixserverlib.MRoomRedaction, + Type: spec.MRoomRedaction, Sender: alice.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), @@ -456,7 +461,7 @@ func TestRedaction(t *testing.T) { redactedEvent := room.CreateAndInsert(t, bob, "m.room.message", map[string]interface{}{"body": "hello world"}) builderEv := mustCreateEvent(t, fledglingEvent{ - Type: gomatrixserverlib.MRoomRedaction, + Type: spec.MRoomRedaction, Sender: alice.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), @@ -473,7 +478,7 @@ func TestRedaction(t *testing.T) { redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hello world"}) builderEv := mustCreateEvent(t, fledglingEvent{ - Type: gomatrixserverlib.MRoomRedaction, + Type: spec.MRoomRedaction, Sender: bob.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), @@ -489,7 +494,7 @@ func TestRedaction(t *testing.T) { redactedEvent := room.CreateAndInsert(t, bob, "m.room.message", map[string]interface{}{"body": "hello world"}) builderEv := mustCreateEvent(t, fledglingEvent{ - Type: gomatrixserverlib.MRoomRedaction, + Type: spec.MRoomRedaction, Sender: charlie.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), @@ -503,8 +508,14 @@ func TestRedaction(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - _, db, close := mustCreateDatabase(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) defer close() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + db, err := storage.Open(processCtx.Context(), cm, &cfg.RoomServer.Database, caches) + if err != nil { + t.Fatal(err) + } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -513,10 +524,10 @@ func TestRedaction(t *testing.T) { var err error room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(charlie.ID)) @@ -558,7 +569,7 @@ func TestRedaction(t *testing.T) { if redactedEvent != nil { assert.Equal(t, ev.Redacts(), redactedEvent.EventID()) } - if ev.Type() == gomatrixserverlib.MRoomRedaction { + if ev.Type() == spec.MRoomRedaction { nids, err := db.EventNIDs(ctx, []string{ev.Redacts()}) assert.NoError(t, err) evs, err := db.Events(ctx, roomInfo, []types.EventNID{nids[ev.Redacts()].EventNID}) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 47e1488df..c3842784e 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -25,9 +25,9 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -106,8 +106,8 @@ func (p *StateResolution) Resolve(ctx context.Context, eventID string) (*gomatri func (v *StateResolution) LoadStateAtSnapshot( ctx context.Context, stateNID types.StateSnapshotNID, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshot") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadStateAtSnapshot") + defer trace.EndRegion() stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) if err != nil { @@ -147,8 +147,8 @@ func (v *StateResolution) LoadStateAtSnapshot( func (v *StateResolution) LoadStateAtEvent( ctx context.Context, eventID string, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadStateAtEvent") + defer trace.EndRegion() snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) if err != nil { @@ -169,8 +169,8 @@ func (v *StateResolution) LoadStateAtEvent( func (v *StateResolution) LoadMembershipAtEvent( ctx context.Context, eventIDs []string, stateKeyNID types.EventStateKeyNID, ) (map[string][]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadMembershipAtEvent") + defer trace.EndRegion() // Get a mapping from snapshotNID -> eventIDs snapshotNIDMap, err := v.db.BulkSelectSnapshotsFromEventIDs(ctx, eventIDs) @@ -238,8 +238,8 @@ func (v *StateResolution) LoadMembershipAtEvent( func (v *StateResolution) LoadStateAtEventForHistoryVisibility( ctx context.Context, eventID string, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadStateAtEventForHistoryVisibility") + defer trace.EndRegion() snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) if err != nil { @@ -263,8 +263,8 @@ func (v *StateResolution) LoadStateAtEventForHistoryVisibility( func (v *StateResolution) LoadCombinedStateAfterEvents( ctx context.Context, prevStates []types.StateAtEvent, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadCombinedStateAfterEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadCombinedStateAfterEvents") + defer trace.EndRegion() stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) for i, state := range prevStates { @@ -338,8 +338,8 @@ func (v *StateResolution) LoadCombinedStateAfterEvents( func (v *StateResolution) DifferenceBetweeenStateSnapshots( ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, ) (removed, added []types.StateEntry, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.DifferenceBetweeenStateSnapshots") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.DifferenceBetweeenStateSnapshots") + defer trace.EndRegion() if oldStateNID == newStateNID { // If the snapshot NIDs are the same then nothing has changed @@ -402,8 +402,8 @@ func (v *StateResolution) LoadStateAtSnapshotForStringTuples( stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshotForStringTuples") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadStateAtSnapshotForStringTuples") + defer trace.EndRegion() numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) if err != nil { @@ -419,8 +419,8 @@ func (v *StateResolution) stringTuplesToNumericTuples( ctx context.Context, stringTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateKeyTuple, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.stringTuplesToNumericTuples") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.stringTuplesToNumericTuples") + defer trace.EndRegion() eventTypes := make([]string, len(stringTuples)) stateKeys := make([]string, len(stringTuples)) @@ -464,8 +464,8 @@ func (v *StateResolution) loadStateAtSnapshotForNumericTuples( stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAtSnapshotForNumericTuples") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.loadStateAtSnapshotForNumericTuples") + defer trace.EndRegion() stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) if err != nil { @@ -515,8 +515,8 @@ func (v *StateResolution) LoadStateAfterEventsForStringTuples( prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAfterEventsForStringTuples") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.LoadStateAfterEventsForStringTuples") + defer trace.EndRegion() numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) if err != nil { @@ -530,8 +530,8 @@ func (v *StateResolution) loadStateAfterEventsForNumericTuples( prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAfterEventsForNumericTuples") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.loadStateAfterEventsForNumericTuples") + defer trace.EndRegion() if len(prevStates) == 1 { // Fast path for a single event. @@ -705,8 +705,8 @@ func (v *StateResolution) CalculateAndStoreStateBeforeEvent( event *gomatrixserverlib.Event, isRejected bool, ) (types.StateSnapshotNID, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateBeforeEvent") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.CalculateAndStoreStateBeforeEvent") + defer trace.EndRegion() // Load the state at the prev events. prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs()) @@ -724,8 +724,8 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents( ctx context.Context, prevStates []types.StateAtEvent, ) (types.StateSnapshotNID, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateAfterEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.CalculateAndStoreStateAfterEvents") + defer trace.EndRegion() metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} @@ -799,8 +799,8 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents( prevStates []types.StateAtEvent, metrics calculateStateMetrics, ) (types.StateSnapshotNID, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateAndStoreStateAfterManyEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.calculateAndStoreStateAfterManyEvents") + defer trace.EndRegion() state, algorithm, conflictLength, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) @@ -820,8 +820,8 @@ func (v *StateResolution) calculateStateAfterManyEvents( ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, prevStates []types.StateAtEvent, ) (state []types.StateEntry, algorithm string, conflictLength int, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateStateAfterManyEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.calculateStateAfterManyEvents") + defer trace.EndRegion() var combined []types.StateEntry // Conflict resolution. @@ -875,8 +875,8 @@ func (v *StateResolution) resolveConflicts( ctx context.Context, version gomatrixserverlib.RoomVersion, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflicts") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.resolveConflicts") + defer trace.EndRegion() stateResAlgo, err := version.StateResAlgorithm() if err != nil { @@ -902,8 +902,8 @@ func (v *StateResolution) resolveConflictsV1( ctx context.Context, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV1") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.resolveConflictsV1") + defer trace.EndRegion() // Load the conflicted events conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted) @@ -967,8 +967,8 @@ func (v *StateResolution) resolveConflictsV2( ctx context.Context, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV2") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.resolveConflictsV2") + defer trace.EndRegion() estimate := len(conflicted) + len(notConflicted) eventIDMap := make(map[string]types.StateEntry, estimate) @@ -1000,8 +1000,8 @@ func (v *StateResolution) resolveConflictsV2( // For each conflicted event, let's try and get the needed auth events. if err = func() error { - span, sctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadAuthEvents") - defer span.Finish() + loadAuthEventsTrace, sctx := internal.StartRegion(ctx, "StateResolution.loadAuthEvents") + defer loadAuthEventsTrace.EndRegion() loader := authEventLoader{ v: v, @@ -1045,8 +1045,8 @@ func (v *StateResolution) resolveConflictsV2( // Resolve the conflicts. resolvedEvents := func() []*gomatrixserverlib.Event { - span, _ := opentracing.StartSpanFromContext(ctx, "gomatrixserverlib.ResolveStateConflictsV2") - defer span.Finish() + resolvedTrace, _ := internal.StartRegion(ctx, "StateResolution.ResolveStateConflictsV2") + defer resolvedTrace.EndRegion() return gomatrixserverlib.ResolveStateConflictsV2( conflictedEvents, @@ -1118,8 +1118,8 @@ func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.E func (v *StateResolution) loadStateEvents( ctx context.Context, entries []types.StateEntry, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateEvents") - defer span.Finish() + trace, ctx := internal.StartRegion(ctx, "StateResolution.loadStateEvents") + defer trace.EndRegion() result := make([]*gomatrixserverlib.Event, 0, len(entries)) eventEntries := make([]types.StateEntry, 0, len(entries)) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 5b90f8b33..b80184f9a 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -18,6 +18,7 @@ import ( "context" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -162,7 +163,7 @@ type Database interface { // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) // GetServerInRoom returns true if we think a server is in a given room or false otherwise. - GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) + GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) // GetKnownRooms returns a list of all rooms we know about. diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index d774b7892..835a43b2d 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -21,13 +21,13 @@ import ( "fmt" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) const membershipSchema = ` @@ -450,7 +450,7 @@ func (s *membershipStatements) SelectLocalServerInRoom( func (s *membershipStatements) SelectServerInRoom( ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, serverName gomatrixserverlib.ServerName, + roomNID types.RoomNID, serverName spec.ServerName, ) (bool, error) { var nid types.RoomNID stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index d98a5cf97..19cde5410 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -28,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) @@ -38,10 +37,10 @@ type Database struct { } // Open a postgres database. -func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { +func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { var d Database var err error - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, fmt.Errorf("sqlutil.Open: %w", err) } @@ -53,7 +52,7 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c // Special case, since this migration uses several tables, so it needs to // be sure that all tables are created first. - if err = executeMigration(base.Context(), db); err != nil { + if err = executeMigration(ctx, db); err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 106a8244e..170957096 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -8,6 +8,7 @@ import ( "sort" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/tidwall/gjson" @@ -567,6 +568,7 @@ func (d *EventDatabase) events( if !redactionsArePermanent { d.applyRedactions(results) } + return results, nil } eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { @@ -578,8 +580,9 @@ func (d *EventDatabase) events( } for _, eventJSON := range eventJSONs { + redacted := gjson.GetBytes(eventJSON.EventJSON, "unsigned.redacted_because").Exists() events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( - eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomInfo.RoomVersion, + eventIDs[eventJSON.EventNID], eventJSON.EventJSON, redacted, roomInfo.RoomVersion, ) if err != nil { return nil, err @@ -903,7 +906,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( var err error var roomVersion gomatrixserverlib.RoomVersion // Look for m.room.create events. - if event.Type() != gomatrixserverlib.MRoomCreate { + if event.Type() != spec.MRoomCreate { return gomatrixserverlib.RoomVersion(""), nil } roomVersion = gomatrixserverlib.RoomVersionV1 @@ -947,7 +950,7 @@ func (d *EventDatabase) MaybeRedactEvent( ) wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil + isRedactionEvent := event.Type() == spec.MRoomRedaction && event.StateKey() == nil if isRedactionEvent { // an event which redacts itself should be ignored if event.EventID() == event.Redacts() { @@ -1020,7 +1023,9 @@ func (d *EventDatabase) MaybeRedactEvent( return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) } - d.Cache.StoreRoomServerEvent(redactedEvent.EventNID, redactedEvent.Event) + // We remove the entry from the cache, as if we just "StoreRoomServerEvent", we can't be + // certain that the cached entry actually is updated, since ristretto is eventual-persistent. + d.Cache.InvalidateRoomServerEvent(redactedEvent.EventNID) return nil }) @@ -1040,7 +1045,7 @@ func (d *EventDatabase) loadRedactionPair( var redactionEvent, redactedEvent *types.Event var info *tables.RedactionInfo var err error - isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil + isRedactionEvent := event.Type() == spec.MRoomRedaction && event.StateKey() == nil var eventBeingRedacted string if isRedactionEvent { @@ -1465,7 +1470,7 @@ func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomN } // GetServerInRoom returns true if we think a server is in a given room or false otherwise. -func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) { return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName) } diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 684e80b8f..941e84802 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -15,14 +15,12 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" ) func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Database, func()) { t.Helper() connStr, clearDB := test.PrepareDBConnectionString(t, dbType) - base, _, _ := testrig.Base(nil) dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)} db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter()) @@ -61,8 +59,6 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat Writer: sqlutil.NewExclusiveWriter(), Cache: cache, }, func() { - err := base.Close() - assert.NoError(t, err) clearDB() err = db.Close() assert.NoError(t, err) diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 8a60b359f..977788d50 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -21,13 +21,12 @@ import ( "fmt" "strings" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) const membershipSchema = ` @@ -398,7 +397,7 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn return found, nil } -func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) { var nid types.RoomNID stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 2adedd2d8..89e16fc14 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -28,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) @@ -38,10 +37,10 @@ type Database struct { } // Open a sqlite database. -func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { +func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { var d Database var err error - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, fmt.Errorf("sqlutil.Open: %w", err) } @@ -62,7 +61,7 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c // Special case, since this migration uses several tables, so it needs to // be sure that all tables are created first. - if err = executeMigration(base.Context(), db); err != nil { + if err = executeMigration(ctx, db); err != nil { return nil, err } diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 8a87b7d7c..2b3b3bd85 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -18,22 +18,23 @@ package storage import ( + "context" "fmt" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // Open opens a database connection. -func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { +func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.Open(base, dbProperties, cache) + return sqlite3.Open(ctx, conMan, dbProperties, cache) case dbProperties.ConnectionString.IsPostgres(): - return postgres.Open(base, dbProperties, cache) + return postgres.Open(ctx, conMan, dbProperties, cache) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/roomserver/storage/storage_wasm.go b/roomserver/storage/storage_wasm.go index df5a56ac3..817f9304c 100644 --- a/roomserver/storage/storage_wasm.go +++ b/roomserver/storage/storage_wasm.go @@ -15,19 +15,20 @@ package storage import ( + "context" "fmt" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewPublicRoomsServerDatabase opens a database connection. -func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { +func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.Open(base, dbProperties, cache) + return sqlite3.Open(ctx, conMan, dbProperties, cache) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 4ce2a9c4e..45dc0fc29 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/roomserver/types" @@ -147,7 +148,7 @@ type Membership interface { SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) - SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) + SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error) } @@ -199,17 +200,17 @@ func ExtractContentValue(ev *gomatrixserverlib.HeaderedEvent) string { content := ev.Content() key := "" switch ev.Type() { - case gomatrixserverlib.MRoomCreate: + case spec.MRoomCreate: key = "creator" - case gomatrixserverlib.MRoomCanonicalAlias: + case spec.MRoomCanonicalAlias: key = "alias" - case gomatrixserverlib.MRoomHistoryVisibility: + case spec.MRoomHistoryVisibility: key = "history_visibility" - case gomatrixserverlib.MRoomJoinRules: + case spec.MRoomJoinRules: key = "join_rule" - case gomatrixserverlib.MRoomMember: + case spec.MRoomMember: key = "membership" - case gomatrixserverlib.MRoomName: + case spec.MRoomName: key = "name" case "m.room.avatar": key = "url" diff --git a/setup/base/base.go b/setup/base/base.go index dfe48ff3c..d6c350109 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -17,319 +17,91 @@ package base import ( "bytes" "context" - "database/sql" "embed" "encoding/json" "errors" "fmt" "html/template" - "io" "io/fs" "net" "net/http" _ "net/http/pprof" "os" "os/signal" - "sync" "syscall" "time" - "github.com/getsentry/sentry-go" sentryhttp "github.com/getsentry/sentry-go/http" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/prometheus/client_golang/prometheus/promhttp" "go.uber.org/atomic" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/fulltext" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/internal/pushgateway" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/gorilla/mux" "github.com/kardianos/minwinsvc" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" ) //go:embed static/*.gotmpl var staticContent embed.FS -// BaseDendrite is a base for creating new instances of dendrite. It parses -// command line flags and config, and exposes methods for creating various -// resources. All errors are handled by logging then exiting, so all methods -// should only be used during start up. -// Must be closed when shutting down. -type BaseDendrite struct { - *process.ProcessContext - tracerCloser io.Closer - PublicClientAPIMux *mux.Router - PublicFederationAPIMux *mux.Router - PublicKeyAPIMux *mux.Router - PublicMediaAPIMux *mux.Router - PublicWellKnownAPIMux *mux.Router - PublicStaticMux *mux.Router - DendriteAdminMux *mux.Router - SynapseAdminMux *mux.Router - NATS *jetstream.NATSInstance - Cfg *config.Dendrite - Caches *caching.Caches - DNSCache *gomatrixserverlib.DNSCache - Database *sql.DB - DatabaseWriter sqlutil.Writer - EnableMetrics bool - Fulltext *fulltext.Search - startupLock sync.Mutex -} - const HTTPServerTimeout = time.Minute * 5 -type BaseDendriteOptions int - -const ( - DisableMetrics BaseDendriteOptions = iota -) - -// NewBaseDendrite creates a new instance to be used by a component. -func NewBaseDendrite(cfg *config.Dendrite, options ...BaseDendriteOptions) *BaseDendrite { - platformSanityChecks() - enableMetrics := true - for _, opt := range options { - switch opt { - case DisableMetrics: - enableMetrics = false - } - } - - configErrors := &config.ConfigErrors{} - cfg.Verify(configErrors) - if len(*configErrors) > 0 { - for _, err := range *configErrors { - logrus.Errorf("Configuration error: %s", err) - } - logrus.Fatalf("Failed to start due to configuration errors") - } - - internal.SetupStdLogging() - internal.SetupHookLogging(cfg.Logging) - internal.SetupPprof() - - logrus.Infof("Dendrite version %s", internal.VersionString()) - - if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled { - logrus.Warn("Open registration is enabled") - } - - closer, err := cfg.SetupTracing() - if err != nil { - logrus.WithError(err).Panicf("failed to start opentracing") - } - - var fts *fulltext.Search - if cfg.SyncAPI.Fulltext.Enabled { - fts, err = fulltext.New(cfg.SyncAPI.Fulltext) - if err != nil { - logrus.WithError(err).Panicf("failed to create full text") - } - } - - if cfg.Global.Sentry.Enabled { - logrus.Info("Setting up Sentry for debugging...") - err = sentry.Init(sentry.ClientOptions{ - Dsn: cfg.Global.Sentry.DSN, - Environment: cfg.Global.Sentry.Environment, - Debug: true, - ServerName: string(cfg.Global.ServerName), - Release: "dendrite@" + internal.VersionString(), - AttachStacktrace: true, - }) - if err != nil { - logrus.WithError(err).Panic("failed to start Sentry") - } - } - - var dnsCache *gomatrixserverlib.DNSCache - if cfg.Global.DNSCache.Enabled { - dnsCache = gomatrixserverlib.NewDNSCache( - cfg.Global.DNSCache.CacheSize, - cfg.Global.DNSCache.CacheLifetime, - ) - logrus.Infof( - "DNS cache enabled (size %d, lifetime %s)", - cfg.Global.DNSCache.CacheSize, - cfg.Global.DNSCache.CacheLifetime, - ) - } - - // If we're in monolith mode, we'll set up a global pool of database - // connections. A component is welcome to use this pool if they don't - // have a separate database config of their own. - var db *sql.DB - var writer sqlutil.Writer - if cfg.Global.DatabaseOptions.ConnectionString != "" { - if cfg.Global.DatabaseOptions.ConnectionString.IsSQLite() { - logrus.Panic("Using a global database connection pool is not supported with SQLite databases") - } - writer = sqlutil.NewDummyWriter() - if db, err = sqlutil.Open(&cfg.Global.DatabaseOptions, writer); err != nil { - logrus.WithError(err).Panic("Failed to set up global database connections") - } - logrus.Debug("Using global database connection pool") - } - - // Ideally we would only use SkipClean on routes which we know can allow '/' but due to - // https://github.com/gorilla/mux/issues/460 we have to attach this at the top router. - // When used in conjunction with UseEncodedPath() we get the behaviour we want when parsing - // path parameters: - // /foo/bar%2Fbaz == [foo, bar%2Fbaz] (from UseEncodedPath) - // /foo/bar%2F%2Fbaz == [foo, bar%2F%2Fbaz] (from SkipClean) - // In particular, rooms v3 event IDs are not urlsafe and can include '/' and because they - // are randomly generated it results in flakey tests. - // We need to be careful with media APIs if they read from a filesystem to make sure they - // are not inadvertently reading paths without cleaning, else this could introduce a - // directory traversal attack e.g /../../../etc/passwd - - return &BaseDendrite{ - ProcessContext: process.NewProcessContext(), - tracerCloser: closer, - Cfg: cfg, - Caches: caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, enableMetrics), - DNSCache: dnsCache, - PublicClientAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicClientPathPrefix).Subrouter().UseEncodedPath(), - PublicFederationAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath(), - PublicKeyAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicKeyPathPrefix).Subrouter().UseEncodedPath(), - PublicMediaAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicMediaPathPrefix).Subrouter().UseEncodedPath(), - PublicWellKnownAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicWellKnownPrefix).Subrouter().UseEncodedPath(), - PublicStaticMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicStaticPath).Subrouter().UseEncodedPath(), - DendriteAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.DendriteAdminPathPrefix).Subrouter().UseEncodedPath(), - SynapseAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.SynapseAdminPathPrefix).Subrouter().UseEncodedPath(), - NATS: &jetstream.NATSInstance{}, - Database: db, // set if monolith with global connection pool only - DatabaseWriter: writer, // set if monolith with global connection pool only - EnableMetrics: enableMetrics, - Fulltext: fts, - } -} - -// Close implements io.Closer -func (b *BaseDendrite) Close() error { - b.ProcessContext.ShutdownDendrite() - b.ProcessContext.WaitForShutdown() - return b.tracerCloser.Close() -} - -// DatabaseConnection assists in setting up a database connection. It accepts -// the database properties and a new writer for the given component. If we're -// running in monolith mode with a global connection pool configured then we -// will return that connection, along with the global writer, effectively -// ignoring the options provided. Otherwise we'll open a new database connection -// using the supplied options and writer. Note that it's possible for the pointer -// receiver to be nil here – that's deliberate as some of the unit tests don't -// have a BaseDendrite and just want a connection with the supplied config -// without any pooling stuff. -func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, writer sqlutil.Writer) (*sql.DB, sqlutil.Writer, error) { - if dbProperties.ConnectionString != "" || b == nil { - // Open a new database connection using the supplied config. - db, err := sqlutil.Open(dbProperties, writer) - return db, writer, err - } - if b.Database != nil && b.DatabaseWriter != nil { - // Ignore the supplied config and return the global pool and - // writer. - return b.Database, b.DatabaseWriter, nil - } - return nil, nil, fmt.Errorf("no database connections configured") -} - -// PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways. -func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client { - return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation) -} - // CreateClient creates a new client (normally used for media fetch requests). // Should only be called once per component. -func (b *BaseDendrite) CreateClient() *gomatrixserverlib.Client { - if b.Cfg.Global.DisableFederation { - return gomatrixserverlib.NewClient( - gomatrixserverlib.WithTransport(noOpHTTPTransport), +func CreateClient(cfg *config.Dendrite, dnsCache *fclient.DNSCache) *fclient.Client { + if cfg.Global.DisableFederation { + return fclient.NewClient( + fclient.WithTransport(noOpHTTPTransport), ) } - opts := []gomatrixserverlib.ClientOption{ - gomatrixserverlib.WithSkipVerify(b.Cfg.FederationAPI.DisableTLSValidation), - gomatrixserverlib.WithWellKnownSRVLookups(true), + opts := []fclient.ClientOption{ + fclient.WithSkipVerify(cfg.FederationAPI.DisableTLSValidation), + fclient.WithWellKnownSRVLookups(true), } - if b.Cfg.Global.DNSCache.Enabled { - opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache)) + if cfg.Global.DNSCache.Enabled && dnsCache != nil { + opts = append(opts, fclient.WithDNSCache(dnsCache)) } - client := gomatrixserverlib.NewClient(opts...) + client := fclient.NewClient(opts...) client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())) return client } // CreateFederationClient creates a new federation client. Should only be called // once per component. -func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { - identities := b.Cfg.Global.SigningIdentities() - if b.Cfg.Global.DisableFederation { - return gomatrixserverlib.NewFederationClient( - identities, gomatrixserverlib.WithTransport(noOpHTTPTransport), +func CreateFederationClient(cfg *config.Dendrite, dnsCache *fclient.DNSCache) *fclient.FederationClient { + identities := cfg.Global.SigningIdentities() + if cfg.Global.DisableFederation { + return fclient.NewFederationClient( + identities, fclient.WithTransport(noOpHTTPTransport), ) } - opts := []gomatrixserverlib.ClientOption{ - gomatrixserverlib.WithTimeout(time.Minute * 5), - gomatrixserverlib.WithSkipVerify(b.Cfg.FederationAPI.DisableTLSValidation), - gomatrixserverlib.WithKeepAlives(!b.Cfg.FederationAPI.DisableHTTPKeepalives), + opts := []fclient.ClientOption{ + fclient.WithTimeout(time.Minute * 5), + fclient.WithSkipVerify(cfg.FederationAPI.DisableTLSValidation), + fclient.WithKeepAlives(!cfg.FederationAPI.DisableHTTPKeepalives), } - if b.Cfg.Global.DNSCache.Enabled { - opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache)) + if cfg.Global.DNSCache.Enabled { + opts = append(opts, fclient.WithDNSCache(dnsCache)) } - client := gomatrixserverlib.NewFederationClient( + client := fclient.NewFederationClient( identities, opts..., ) client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())) return client } -func (b *BaseDendrite) configureHTTPErrors() { - notAllowedHandler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusMethodNotAllowed) - _, _ = w.Write([]byte(fmt.Sprintf("405 %s not allowed on this endpoint", r.Method))) - } - - clientNotFoundHandler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell - } - - notFoundCORSHandler := httputil.WrapHandlerInCORS(http.NotFoundHandler()) - notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler)) - - for _, router := range []*mux.Router{ - b.PublicMediaAPIMux, b.DendriteAdminMux, - b.SynapseAdminMux, b.PublicWellKnownAPIMux, - b.PublicStaticMux, - } { - router.NotFoundHandler = notFoundCORSHandler - router.MethodNotAllowedHandler = notAllowedCORSHandler - } - - // Special case so that we don't upset clients on the CS API. - b.PublicClientAPIMux.NotFoundHandler = http.HandlerFunc(clientNotFoundHandler) - b.PublicClientAPIMux.MethodNotAllowedHandler = http.HandlerFunc(clientNotFoundHandler) -} - -func (b *BaseDendrite) ConfigureAdminEndpoints() { - b.DendriteAdminMux.HandleFunc("/monitor/up", func(w http.ResponseWriter, r *http.Request) { +func ConfigureAdminEndpoints(processContext *process.ProcessContext, routers httputil.Routers) { + routers.DendriteAdmin.HandleFunc("/monitor/up", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) }) - b.DendriteAdminMux.HandleFunc("/monitor/health", func(w http.ResponseWriter, r *http.Request) { - if isDegraded, reasons := b.ProcessContext.IsDegraded(); isDegraded { + routers.DendriteAdmin.HandleFunc("/monitor/health", func(w http.ResponseWriter, r *http.Request) { + if isDegraded, reasons := processContext.IsDegraded(); isDegraded { w.WriteHeader(503) _ = json.NewEncoder(w).Encode(struct { Warnings []string `json:"warnings"` @@ -344,14 +116,13 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() { // SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs // and adds a prometheus handler under /_dendrite/metrics. -func (b *BaseDendrite) SetupAndServeHTTP( +func SetupAndServeHTTP( + processContext *process.ProcessContext, + cfg *config.Dendrite, + routers httputil.Routers, externalHTTPAddr config.ServerAddress, certFile, keyFile *string, ) { - // Manually unlocked right before actually serving requests, - // as we don't return from this method (defer doesn't work). - b.startupLock.Lock() - externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() externalServ := &http.Server{ @@ -359,22 +130,20 @@ func (b *BaseDendrite) SetupAndServeHTTP( WriteTimeout: HTTPServerTimeout, Handler: externalRouter, BaseContext: func(_ net.Listener) context.Context { - return b.ProcessContext.Context() + return processContext.Context() }, } - b.configureHTTPErrors() - //Redirect for Landing Page externalRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, httputil.PublicStaticPath, http.StatusFound) }) - if b.Cfg.Global.Metrics.Enabled { - externalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), b.Cfg.Global.Metrics.BasicAuth)) + if cfg.Global.Metrics.Enabled { + externalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), cfg.Global.Metrics.BasicAuth)) } - b.ConfigureAdminEndpoints() + ConfigureAdminEndpoints(processContext, routers) // Parse and execute the landing page template tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) @@ -385,47 +154,48 @@ func (b *BaseDendrite) SetupAndServeHTTP( logrus.WithError(err).Fatal("failed to execute landing page template") } - b.PublicStaticMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + routers.Static.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(landingPage.Bytes()) }) var clientHandler http.Handler - clientHandler = b.PublicClientAPIMux - if b.Cfg.Global.Sentry.Enabled { + clientHandler = routers.Client + if cfg.Global.Sentry.Enabled { sentryHandler := sentryhttp.New(sentryhttp.Options{ Repanic: true, }) - clientHandler = sentryHandler.Handle(b.PublicClientAPIMux) + clientHandler = sentryHandler.Handle(routers.Client) } var federationHandler http.Handler - federationHandler = b.PublicFederationAPIMux - if b.Cfg.Global.Sentry.Enabled { + federationHandler = routers.Federation + if cfg.Global.Sentry.Enabled { sentryHandler := sentryhttp.New(sentryhttp.Options{ Repanic: true, }) - federationHandler = sentryHandler.Handle(b.PublicFederationAPIMux) + federationHandler = sentryHandler.Handle(routers.Federation) } - externalRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(b.DendriteAdminMux) + externalRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(routers.DendriteAdmin) externalRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(clientHandler) - if !b.Cfg.Global.DisableFederation { - externalRouter.PathPrefix(httputil.PublicKeyPathPrefix).Handler(b.PublicKeyAPIMux) + if !cfg.Global.DisableFederation { + externalRouter.PathPrefix(httputil.PublicKeyPathPrefix).Handler(routers.Keys) externalRouter.PathPrefix(httputil.PublicFederationPathPrefix).Handler(federationHandler) } - externalRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(b.SynapseAdminMux) - externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(b.PublicMediaAPIMux) - externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(b.PublicWellKnownAPIMux) - externalRouter.PathPrefix(httputil.PublicStaticPath).Handler(b.PublicStaticMux) + externalRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(routers.SynapseAdmin) + externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(routers.Media) + externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(routers.WellKnown) + externalRouter.PathPrefix(httputil.PublicStaticPath).Handler(routers.Static) - b.startupLock.Unlock() + externalRouter.NotFoundHandler = httputil.NotFoundCORSHandler + externalRouter.MethodNotAllowedHandler = httputil.NotAllowedHandler if externalHTTPAddr.Enabled() { go func() { var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once logrus.Infof("Starting external listener on %s", externalServ.Addr) - b.ProcessContext.ComponentStarted() + processContext.ComponentStarted() externalServ.RegisterOnShutdown(func() { if externalShutdown.CompareAndSwap(false, true) { - b.ProcessContext.ComponentFinished() + processContext.ComponentFinished() logrus.Infof("Stopped external HTTP listener") } }) @@ -467,38 +237,27 @@ func (b *BaseDendrite) SetupAndServeHTTP( }() } - minwinsvc.SetOnExit(b.ProcessContext.ShutdownDendrite) - <-b.ProcessContext.WaitForShutdown() + minwinsvc.SetOnExit(processContext.ShutdownDendrite) + <-processContext.WaitForShutdown() logrus.Infof("Stopping HTTP listeners") _ = externalServ.Shutdown(context.Background()) logrus.Infof("Stopped HTTP listeners") } -func (b *BaseDendrite) WaitForShutdown() { +func WaitForShutdown(processCtx *process.ProcessContext) { sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) select { case <-sigs: - case <-b.ProcessContext.WaitForShutdown(): + case <-processCtx.WaitForShutdown(): } signal.Reset(syscall.SIGINT, syscall.SIGTERM) logrus.Warnf("Shutdown signal received") - b.ProcessContext.ShutdownDendrite() - b.ProcessContext.WaitForComponentsToFinish() - if b.Cfg.Global.Sentry.Enabled { - if !sentry.Flush(time.Second * 5) { - logrus.Warnf("failed to flush all Sentry events!") - } - } - if b.Fulltext != nil { - err := b.Fulltext.Close() - if err != nil { - logrus.Warnf("failed to close full text search!") - } - } + processCtx.ShutdownDendrite() + processCtx.WaitForComponentsToFinish() logrus.Warnf("Dendrite is exiting now") } diff --git a/setup/base/base_test.go b/setup/base/base_test.go index 658dc5b03..bba967b94 100644 --- a/setup/base/base_test.go +++ b/setup/base/base_test.go @@ -13,8 +13,10 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/httputil" + basepkg "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/setup/process" "github.com/stretchr/testify/assert" ) @@ -30,8 +32,10 @@ func TestLandingPage_Tcp(t *testing.T) { }) assert.NoError(t, err) - b, _, _ := testrig.Base(nil) - defer b.Close() + processCtx := process.NewProcessContext() + routers := httputil.NewRouters() + cfg := config.Dendrite{} + cfg.Defaults(config.DefaultOpts{Generate: true, SingleDatabase: true}) // hack: create a server and close it immediately, just to get a random port assigned s := httptest.NewServer(nil) @@ -40,7 +44,7 @@ func TestLandingPage_Tcp(t *testing.T) { // start base with the listener and wait for it to be started address, err := config.HTTPAddress(s.URL) assert.NoError(t, err) - go b.SetupAndServeHTTP(address, nil, nil) + go basepkg.SetupAndServeHTTP(processCtx, &cfg, routers, address, nil, nil) time.Sleep(time.Millisecond * 10) // When hitting /, we should be redirected to /_matrix/static, which should contain the landing page @@ -70,15 +74,17 @@ func TestLandingPage_UnixSocket(t *testing.T) { }) assert.NoError(t, err) - b, _, _ := testrig.Base(nil) - defer b.Close() + processCtx := process.NewProcessContext() + routers := httputil.NewRouters() + cfg := config.Dendrite{} + cfg.Defaults(config.DefaultOpts{Generate: true, SingleDatabase: true}) tempDir := t.TempDir() socket := path.Join(tempDir, "socket") // start base with the listener and wait for it to be started - address := config.UnixSocketAddress(socket, 0755) + address, err := config.UnixSocketAddress(socket, "755") assert.NoError(t, err) - go b.SetupAndServeHTTP(address, nil, nil) + go basepkg.SetupAndServeHTTP(processCtx, &cfg, routers, address, nil, nil) time.Sleep(time.Millisecond * 100) client := &http.Client{ diff --git a/setup/base/sanity_other.go b/setup/base/sanity_other.go index 48fe6e1f8..d35c2e872 100644 --- a/setup/base/sanity_other.go +++ b/setup/base/sanity_other.go @@ -3,6 +3,6 @@ package base -func platformSanityChecks() { +func PlatformSanityChecks() { // Nothing to do yet. } diff --git a/setup/base/sanity_unix.go b/setup/base/sanity_unix.go index c630d3f19..0403df1a8 100644 --- a/setup/base/sanity_unix.go +++ b/setup/base/sanity_unix.go @@ -9,7 +9,7 @@ import ( "github.com/sirupsen/logrus" ) -func platformSanityChecks() { +func PlatformSanityChecks() { // Dendrite needs a relatively high number of file descriptors in order // to function properly, particularly when federating with lots of servers. // If we run out of file descriptors, we might run into problems accessing diff --git a/setup/config/config.go b/setup/config/config.go index 1a25f71eb..41396ae36 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "golang.org/x/crypto/ed25519" "gopkg.in/yaml.v2" @@ -239,7 +240,7 @@ func loadConfig( key.KeyID = keyID key.PrivateKey = privateKey - key.PublicKey = gomatrixserverlib.Base64Bytes(privateKey.Public().(ed25519.PublicKey)) + key.PublicKey = spec.Base64Bytes(privateKey.Public().(ed25519.PublicKey)) case key.KeyID == "": return nil, fmt.Errorf("'key_id' must be specified if 'public_key' is specified") diff --git a/setup/config/config_address.go b/setup/config/config_address.go index 0e4f0296f..a35cc3f96 100644 --- a/setup/config/config_address.go +++ b/setup/config/config_address.go @@ -3,6 +3,7 @@ package config import ( "io/fs" "net/url" + "strconv" ) const ( @@ -32,8 +33,12 @@ func (s ServerAddress) Network() string { } } -func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress { - return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm} +func UnixSocketAddress(path string, perm string) (ServerAddress, error) { + permission, err := strconv.ParseInt(perm, 8, 32) + if err != nil { + return ServerAddress{}, err + } + return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: fs.FileMode(permission)}, nil } func HTTPAddress(urlAddress string) (ServerAddress, error) { diff --git a/setup/config/config_address_test.go b/setup/config/config_address_test.go index 1be484fd5..38c96ab7f 100644 --- a/setup/config/config_address_test.go +++ b/setup/config/config_address_test.go @@ -20,6 +20,24 @@ func TestHttpAddress_ParseBad(t *testing.T) { } func TestUnixSocketAddress_Network(t *testing.T) { - address := UnixSocketAddress("/tmp", fs.FileMode(0755)) + address, err := UnixSocketAddress("/tmp", "0755") + assert.NoError(t, err) assert.Equal(t, "unix", address.Network()) } + +func TestUnixSocketAddress_Permission_LeadingZero_Ok(t *testing.T) { + address, err := UnixSocketAddress("/tmp", "0755") + assert.NoError(t, err) + assert.Equal(t, fs.FileMode(0755), address.UnixSocketPermission) +} + +func TestUnixSocketAddress_Permission_NoLeadingZero_Ok(t *testing.T) { + address, err := UnixSocketAddress("/tmp", "755") + assert.NoError(t, err) + assert.Equal(t, fs.FileMode(0755), address.UnixSocketPermission) +} + +func TestUnixSocketAddress_Permission_NonOctal_Bad(t *testing.T) { + _, err := UnixSocketAddress("/tmp", "855") + assert.Error(t, err) +} diff --git a/setup/config/config_appservice.go b/setup/config/config_appservice.go index 37e20a978..ef10649d2 100644 --- a/setup/config/config_appservice.go +++ b/setup/config/config_appservice.go @@ -15,16 +15,23 @@ package config import ( + "context" + "crypto/tls" "fmt" + "net" + "net/http" "os" "path/filepath" "regexp" "strings" + "time" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) +const UnixSocketPrefix = "unix://" + type AppServiceAPI struct { Matrix *Global `yaml:"-"` Derived *Derived `yaml:"-"` // TODO: Nuke Derived from orbit @@ -80,7 +87,41 @@ type ApplicationService struct { // Whether rate limiting is applied to each application service user RateLimited bool `yaml:"rate_limited"` // Any custom protocols that this application service provides (e.g. IRC) - Protocols []string `yaml:"protocols"` + Protocols []string `yaml:"protocols"` + HTTPClient *http.Client + isUnixSocket bool + unixSocket string +} + +func (a *ApplicationService) CreateHTTPClient(insecureSkipVerify bool) { + client := &http.Client{ + Timeout: time.Second * 30, + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: insecureSkipVerify, + }, + Proxy: http.ProxyFromEnvironment, + }, + } + if strings.HasPrefix(a.URL, UnixSocketPrefix) { + a.isUnixSocket = true + a.unixSocket = "http://unix" + client.Transport = &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", strings.TrimPrefix(a.URL, UnixSocketPrefix)) + }, + } + } + a.HTTPClient = client +} + +func (a *ApplicationService) RequestUrl() string { + if a.isUnixSocket { + return a.unixSocket + } else { + return a.URL + } } // IsInterestedInRoomID returns a bool on whether an application service's @@ -152,7 +193,7 @@ func (a *ApplicationService) IsInterestedInRoomAlias( func loadAppServices(config *AppServiceAPI, derived *Derived) error { for _, configPath := range config.ConfigFiles { // Create a new application service with default options - appservice := ApplicationService{ + appservice := &ApplicationService{ RateLimited: true, } @@ -169,13 +210,13 @@ func loadAppServices(config *AppServiceAPI, derived *Derived) error { } // Load the config data into our struct - if err = yaml.Unmarshal(configData, &appservice); err != nil { + if err = yaml.Unmarshal(configData, appservice); err != nil { return err } - + appservice.CreateHTTPClient(config.DisableTLSValidation) // Append the parsed application service to the global config derived.ApplicationServices = append( - derived.ApplicationServices, appservice, + derived.ApplicationServices, *appservice, ) } diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 8c1540b57..a72eee369 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -2,6 +2,7 @@ package config import ( "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type FederationAPI struct { @@ -101,7 +102,7 @@ type KeyPerspectives []KeyPerspective type KeyPerspective struct { // The server name of the perspective key server - ServerName gomatrixserverlib.ServerName `yaml:"server_name"` + ServerName spec.ServerName `yaml:"server_name"` // Server keys for the perspective user, used to verify the // keys have been signed by the perspective server Keys []KeyPerspectiveTrustKey `yaml:"keys"` diff --git a/setup/config/config_global.go b/setup/config/config_global.go index ed980afaa..1622bf357 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -8,13 +8,15 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "golang.org/x/crypto/ed25519" ) type Global struct { // Signing identity contains the server name, private key and key ID of // the deployment. - gomatrixserverlib.SigningIdentity `yaml:",inline"` + fclient.SigningIdentity `yaml:",inline"` // The secondary server names, used for virtual hosting. VirtualHosts []*VirtualHost `yaml:"-"` @@ -121,7 +123,7 @@ func (c *Global) Verify(configErrs *ConfigErrors) { c.Cache.Verify(configErrs) } -func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool { +func (c *Global) IsLocalServerName(serverName spec.ServerName) bool { if c.ServerName == serverName { return true } @@ -133,7 +135,7 @@ func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool return false } -func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib.ServerName, error) { +func (c *Global) SplitLocalID(sigil byte, id string) (string, spec.ServerName, error) { u, s, err := gomatrixserverlib.SplitID(sigil, id) if err != nil { return u, s, err @@ -144,7 +146,7 @@ func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib. return u, s, nil } -func (c *Global) VirtualHost(serverName gomatrixserverlib.ServerName) *VirtualHost { +func (c *Global) VirtualHost(serverName spec.ServerName) *VirtualHost { for _, v := range c.VirtualHosts { if v.ServerName == serverName { return v @@ -153,7 +155,7 @@ func (c *Global) VirtualHost(serverName gomatrixserverlib.ServerName) *VirtualHo return nil } -func (c *Global) VirtualHostForHTTPHost(serverName gomatrixserverlib.ServerName) *VirtualHost { +func (c *Global) VirtualHostForHTTPHost(serverName spec.ServerName) *VirtualHost { for _, v := range c.VirtualHosts { if v.ServerName == serverName { return v @@ -167,7 +169,7 @@ func (c *Global) VirtualHostForHTTPHost(serverName gomatrixserverlib.ServerName) return nil } -func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.SigningIdentity, error) { +func (c *Global) SigningIdentityFor(serverName spec.ServerName) (*fclient.SigningIdentity, error) { for _, id := range c.SigningIdentities() { if id.ServerName == serverName { return id, nil @@ -176,8 +178,8 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g return nil, fmt.Errorf("no signing identity for %q", serverName) } -func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { - identities := make([]*gomatrixserverlib.SigningIdentity, 0, len(c.VirtualHosts)+1) +func (c *Global) SigningIdentities() []*fclient.SigningIdentity { + identities := make([]*fclient.SigningIdentity, 0, len(c.VirtualHosts)+1) identities = append(identities, &c.SigningIdentity) for _, v := range c.VirtualHosts { identities = append(identities, &v.SigningIdentity) @@ -188,7 +190,7 @@ func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { type VirtualHost struct { // Signing identity contains the server name, private key and key ID of // the virtual host. - gomatrixserverlib.SigningIdentity `yaml:",inline"` + fclient.SigningIdentity `yaml:",inline"` // Path to the private key. If not specified, the default global private key // will be used instead. @@ -204,7 +206,7 @@ type VirtualHost struct { // Match these HTTP Host headers on the `/key/v2/server` endpoint, this needs // to match all delegated names, likely including the port number too if // the well-known delegation includes that also. - MatchHTTPHosts []gomatrixserverlib.ServerName `yaml:"match_http_hosts"` + MatchHTTPHosts []spec.ServerName `yaml:"match_http_hosts"` // Is registration enabled on this virtual host? AllowRegistration bool `yaml:"allow_registration"` @@ -235,14 +237,14 @@ type OldVerifyKeys struct { PrivateKey ed25519.PrivateKey `yaml:"-"` // The public key, in case only that part is known. - PublicKey gomatrixserverlib.Base64Bytes `yaml:"public_key"` + PublicKey spec.Base64Bytes `yaml:"public_key"` // The key ID of the private key. KeyID gomatrixserverlib.KeyID `yaml:"key_id"` // When the private key was designed as "expired", as a UNIX timestamp // in millisecond precision. - ExpiredAt gomatrixserverlib.Timestamp `yaml:"expired_at"` + ExpiredAt spec.Timestamp `yaml:"expired_at"` } // The configuration to use for Prometheus metrics @@ -319,10 +321,15 @@ type ReportStats struct { func (c *ReportStats) Defaults() { c.Enabled = false - c.Endpoint = "https://matrix.org/report-usage-stats/push" + c.Endpoint = "https://panopticon.matrix.org/push" } func (c *ReportStats) Verify(configErrs *ConfigErrors) { + // We prefer to hit panopticon (https://github.com/matrix-org/panopticon) directly over + // the "old" matrix.org endpoint. + if c.Endpoint == "https://matrix.org/report-usage-stats/push" { + c.Endpoint = "https://panopticon.matrix.org/push" + } if c.Enabled { checkNotEmpty(configErrs, "global.report_stats.endpoint", c.Endpoint) } diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 79407f30d..8a65c990f 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -19,7 +19,8 @@ import ( "reflect" "testing" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) @@ -274,8 +275,8 @@ func Test_SigningIdentityFor(t *testing.T) { tests := []struct { name string virtualHosts []*VirtualHost - serverName gomatrixserverlib.ServerName - want *gomatrixserverlib.SigningIdentity + serverName spec.ServerName + want *fclient.SigningIdentity wantErr bool }{ { @@ -284,29 +285,29 @@ func Test_SigningIdentityFor(t *testing.T) { }, { name: "no identity found", - serverName: gomatrixserverlib.ServerName("doesnotexist"), + serverName: spec.ServerName("doesnotexist"), wantErr: true, }, { name: "found identity", - serverName: gomatrixserverlib.ServerName("main"), - want: &gomatrixserverlib.SigningIdentity{ServerName: "main"}, + serverName: spec.ServerName("main"), + want: &fclient.SigningIdentity{ServerName: "main"}, }, { name: "identity found on virtual hosts", - serverName: gomatrixserverlib.ServerName("vh2"), + serverName: spec.ServerName("vh2"), virtualHosts: []*VirtualHost{ - {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}}, - {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}}, + {SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}}, + {SigningIdentity: fclient.SigningIdentity{ServerName: "vh2"}}, }, - want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"}, + want: &fclient.SigningIdentity{ServerName: "vh2"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Global{ VirtualHosts: tt.virtualHosts, - SigningIdentity: gomatrixserverlib.SigningIdentity{ + SigningIdentity: fclient.SigningIdentity{ ServerName: "main", }, } diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 01fec9ad6..06a58d542 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -20,6 +20,8 @@ import ( type NATSInstance struct { *natsserver.Server + nc *natsclient.Conn + js natsclient.JetStreamContext } var natsLock sync.Mutex @@ -54,7 +56,9 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS if err != nil { panic(err) } - s.SetLogger(NewLogAdapter(), opts.Debug, opts.Trace) + if !cfg.NoLog { + s.SetLogger(NewLogAdapter(), opts.Debug, opts.Trace) + } go func() { process.ComponentStarted() s.Start() @@ -69,11 +73,18 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS if !s.ReadyForConnections(time.Second * 10) { logrus.Fatalln("NATS did not start in time") } + // reuse existing connections + if s.nc != nil { + return s.js, s.nc + } nc, err := natsclient.Connect("", natsclient.InProcessServer(s)) if err != nil { logrus.Fatalln("Failed to create NATS client") } - return setupNATS(process, cfg, nc) + js, _ := setupNATS(process, cfg, nc) + s.js = js + s.nc = nc + return js, nc } func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsclient.Conn) (natsclient.JetStreamContext, *natsclient.Conn) { diff --git a/setup/monolith.go b/setup/monolith.go index d8c652234..e5af69853 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -20,16 +20,21 @@ import ( "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/federationapi" federationAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/mediaapi" "github.com/matrix-org/dendrite/relayapi" relayAPI "github.com/matrix-org/dendrite/relayapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" ) // Monolith represents an instantiation of all dependencies required to build @@ -37,8 +42,8 @@ import ( type Monolith struct { Config *config.Dendrite KeyRing *gomatrixserverlib.KeyRing - Client *gomatrixserverlib.Client - FedClient *gomatrixserverlib.FederationClient + Client *fclient.Client + FedClient *fclient.FederationClient AppserviceAPI appserviceAPI.AppServiceInternalAPI FederationAPI federationAPI.FederationInternalAPI @@ -52,23 +57,31 @@ type Monolith struct { } // AddAllPublicRoutes attaches all public paths to the given router -func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) { +func (m *Monolith) AddAllPublicRoutes( + processCtx *process.ProcessContext, + cfg *config.Dendrite, + routers httputil.Routers, + cm sqlutil.Connections, + natsInstance *jetstream.NATSInstance, + caches *caching.Caches, + enableMetrics bool, +) { userDirectoryProvider := m.ExtUserDirectoryProvider if userDirectoryProvider == nil { userDirectoryProvider = m.UserAPI } clientapi.AddPublicRoutes( - base, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), + processCtx, routers, cfg, natsInstance, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), m.FederationAPI, m.UserAPI, userDirectoryProvider, - m.ExtPublicRoomsProvider, + m.ExtPublicRoomsProvider, enableMetrics, ) federationapi.AddPublicRoutes( - base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, nil, + processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, nil, enableMetrics, ) - mediaapi.AddPublicRoutes(base, m.UserAPI, m.Client) - syncapi.AddPublicRoutes(base, m.UserAPI, m.RoomserverAPI) + mediaapi.AddPublicRoutes(routers.Media, cm, cfg, m.UserAPI, m.Client) + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics) if m.RelayAPI != nil { - relayapi.AddPublicRoutes(base, m.KeyRing, m.RelayAPI) + relayapi.AddPublicRoutes(routers, cfg, m.KeyRing, m.RelayAPI) } } diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 4bb6a5eee..8982e6e37 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -31,10 +31,14 @@ import ( fs "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/synctypes" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -77,20 +81,20 @@ func (r *EventRelationshipRequest) Defaults() { } type EventRelationshipResponse struct { - Events []gomatrixserverlib.ClientEvent `json:"events"` - NextBatch string `json:"next_batch"` - Limited bool `json:"limited"` + Events []synctypes.ClientEvent `json:"events"` + NextBatch string `json:"next_batch"` + Limited bool `json:"limited"` } type MSC2836EventRelationshipsResponse struct { - gomatrixserverlib.MSC2836EventRelationshipsResponse + fclient.MSC2836EventRelationshipsResponse ParsedEvents []*gomatrixserverlib.Event ParsedAuthChain []*gomatrixserverlib.Event } func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: gomatrixserverlib.ToClientEvents(res.ParsedEvents, gomatrixserverlib.FormatAll), + Events: synctypes.ToClientEvents(res.ParsedEvents, synctypes.FormatAll), Limited: res.Limited, NextBatch: res.NextBatch, } @@ -99,10 +103,10 @@ func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationship // Enable this MSC func Enable( - base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, + cfg *config.Dendrite, cm sqlutil.Connections, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, ) error { - db, err := NewDatabase(base, &base.Cfg.MSCs.Database) + db, err := NewDatabase(cm, &cfg.MSCs.Database) if err != nil { return fmt.Errorf("cannot enable MSC2836: %w", err) } @@ -125,14 +129,14 @@ func Enable( } }) - base.PublicClientAPIMux.Handle("/unstable/event_relationships", + routers.Client.Handle("/unstable/event_relationships", httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI, fsAPI)), ).Methods(http.MethodPost, http.MethodOptions) - base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( + routers.Federation.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( "msc2836_event_relationships", func(req *http.Request) util.JSONResponse { - fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), base.Cfg.Global.ServerName, base.Cfg.Global.IsLocalServerName, keyRing, + fedReq, errResp := fclient.VerifyHTTPRequest( + req, time.Now(), cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing, ) if fedReq == nil { return errResp @@ -153,7 +157,7 @@ type reqCtx struct { // federated request args isFederatedRequest bool - serverName gomatrixserverlib.ServerName + serverName spec.ServerName fsAPI fs.FederationInternalAPI } @@ -189,7 +193,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP } func federatedEventRelationship( - ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, + ctx context.Context, fedReq *fclient.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, ) util.JSONResponse { relation, err := NewEventRelationshipRequest(bytes.NewBuffer(fedReq.Content())) if err != nil { @@ -397,7 +401,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen serversToQuery := rc.getServersForEventID(parentID) var result *MSC2836EventRelationshipsResponse for _, srv := range serversToQuery { - res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, fclient.MSC2836EventRelationshipsRequest{ EventID: parentID, Direction: "down", Limit: 100, @@ -483,8 +487,8 @@ func walkThread( } // MSC2836EventRelationships performs an /event_relationships request to a remote server -func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { - res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ +func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv spec.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, fclient.MSC2836EventRelationshipsRequest{ EventID: eventID, DepthFirst: rc.req.DepthFirst, Direction: rc.req.Direction, @@ -542,7 +546,7 @@ func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) b return queryMembershipRes.IsInRoom } -func (rc *reqCtx) getServersForEventID(eventID string) []gomatrixserverlib.ServerName { +func (rc *reqCtx) getServersForEventID(eventID string) []spec.ServerName { if rc.req.RoomID == "" { util.GetLogger(rc.ctx).WithField("event_id", eventID).Error( "getServersForEventID: event exists in unknown room", @@ -651,11 +655,11 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo messageEvents = append(messageEvents, ev) } } - respState := gomatrixserverlib.RespState{ + respState := &fclient.RespState{ AuthEvents: res.AuthChain, StateEvents: stateEvents, } - eventsInOrder := respState.Events(rc.roomVersion) + eventsInOrder := gomatrixserverlib.LineariseStateResponse(rc.roomVersion, respState) // everything gets sent as an outlier because auth chain events may be disjoint from the DAG // as may the threaded events. var ires []roomserver.InputRoomEvent @@ -686,7 +690,7 @@ func (rc *reqCtx) addChildMetadata(ev *gomatrixserverlib.HeaderedEvent) { if count == 0 { return } - err := ev.SetUnsignedField("children_hash", gomatrixserverlib.Base64Bytes(hash)) + err := ev.SetUnsignedField("children_hash", spec.Base64Bytes(hash)) if err != nil { util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children_hash") } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index f12fbbfcb..a06b7211f 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -15,12 +15,15 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/mscs/msc2836" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -461,7 +464,7 @@ func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wan } } -func assertUnsignedChildren(t *testing.T, ev gomatrixserverlib.ClientEvent, relType string, wantCount int, childrenEventIDs []string) { +func assertUnsignedChildren(t *testing.T, ev synctypes.ClientEvent, relType string, wantCount int, childrenEventIDs []string) { t.Helper() unsigned := struct { Children map[string]int `json:"children"` @@ -554,20 +557,18 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve cfg.Global.ServerName = "localhost" cfg.MSCs.Database.ConnectionString = "file:msc2836_test.db" cfg.MSCs.MSCs = []string{"msc2836"} - base := &base.BaseDendrite{ - Cfg: cfg, - PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(), - PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(), - } - err := msc2836.Enable(base, rsAPI, nil, userAPI, nil) + processCtx := process.NewProcessContext() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() + err := msc2836.Enable(cfg, cm, routers, rsAPI, nil, userAPI, nil) if err != nil { t.Fatalf("failed to enable MSC2836: %s", err) } for _, ev := range events { hooks.Run(hooks.KindNewEventPersisted, ev) } - return base.PublicClientAPIMux + return routers.Client } type fledglingEvent struct { @@ -596,7 +597,7 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib } // make sure the origin_server_ts changes so we can test recency time.Sleep(1 * time.Millisecond) - signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer) + signedEvent, err := eb.Build(time.Now(), spec.ServerName("localhost"), "ed25519:test", key, roomVer) if err != nil { t.Fatalf("mustCreateEvent: failed to sign event: %s", err) } diff --git a/setup/mscs/msc2836/storage.go b/setup/mscs/msc2836/storage.go index 827e82f70..4f348192c 100644 --- a/setup/mscs/msc2836/storage.go +++ b/setup/mscs/msc2836/storage.go @@ -8,15 +8,15 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) type eventInfo struct { EventID string - OriginServerTS gomatrixserverlib.Timestamp + OriginServerTS spec.Timestamp RoomID string } @@ -59,17 +59,17 @@ type DB struct { } // NewDatabase loads the database for msc2836 -func NewDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) { +func NewDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) { if dbOpts.ConnectionString.IsPostgres() { - return newPostgresDatabase(base, dbOpts) + return newPostgresDatabase(conMan, dbOpts) } - return newSQLiteDatabase(base, dbOpts) + return newSQLiteDatabase(conMan, dbOpts) } -func newPostgresDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) { +func newPostgresDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) { d := DB{} var err error - if d.db, d.writer, err = base.DatabaseConnection(dbOpts, sqlutil.NewDummyWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil { return nil, err } _, err = d.db.Exec(` @@ -144,10 +144,10 @@ func newPostgresDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions return &d, err } -func newSQLiteDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) { +func newSQLiteDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) { d := DB{} var err error - if d.db, d.writer, err = base.DatabaseConnection(dbOpts, sqlutil.NewExclusiveWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil { return nil, err } _, err = d.db.Exec(` @@ -351,8 +351,8 @@ func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, serve func extractChildMetadata(ev *gomatrixserverlib.HeaderedEvent) (count int, hash []byte) { unsigned := struct { - Counts map[string]int `json:"children"` - Hash gomatrixserverlib.Base64Bytes `json:"children_hash"` + Counts map[string]int `json:"children"` + Hash spec.Base64Bytes `json:"children_hash"` }{} if err := json.Unmarshal(ev.Unsigned(), &unsigned); err != nil { // expected if there is no unsigned field at all diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 56c063598..80d6c5980 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -33,9 +33,11 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -48,23 +50,23 @@ const ( ) type MSC2946ClientResponse struct { - Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"` - NextBatch string `json:"next_batch,omitempty"` + Rooms []fclient.MSC2946Room `json:"rooms"` + NextBatch string `json:"next_batch,omitempty"` } // Enable this MSC func Enable( - base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, + cfg *config.Dendrite, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache, ) error { - clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName), httputil.WithAllowGuests()) - base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) - base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, cfg.Global.ServerName), httputil.WithAllowGuests()) + routers.Client.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) + routers.Client.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) fedAPI := httputil.MakeExternalAPI( "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { - fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), base.Cfg.Global.ServerName, base.Cfg.Global.IsLocalServerName, keyRing, + fedReq, errResp := fclient.VerifyHTTPRequest( + req, time.Now(), cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing, ) if fedReq == nil { return errResp @@ -75,19 +77,19 @@ func Enable( return util.ErrorResponse(err) } roomID := params["roomID"] - return federatedSpacesHandler(req.Context(), fedReq, roomID, cache, rsAPI, fsAPI, base.Cfg.Global.ServerName) + return federatedSpacesHandler(req.Context(), fedReq, roomID, cache, rsAPI, fsAPI, cfg.Global.ServerName) }, ) - base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) - base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) + routers.Federation.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) + routers.Federation.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) return nil } func federatedSpacesHandler( - ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, + ctx context.Context, fedReq *fclient.FederationRequest, roomID string, cache caching.SpaceSummaryRoomsCache, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, - thisServer gomatrixserverlib.ServerName, + thisServer spec.ServerName, ) util.JSONResponse { u, err := url.Parse(fedReq.RequestURI()) if err != nil { @@ -121,7 +123,7 @@ func spacesHandler( rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, cache caching.SpaceSummaryRoomsCache, - thisServer gomatrixserverlib.ServerName, + thisServer spec.ServerName, ) func(*http.Request, *userapi.Device) util.JSONResponse { // declared outside the returned handler so it persists between calls // TODO: clear based on... time? @@ -161,8 +163,8 @@ type paginationInfo struct { type walker struct { rootRoomID string caller *userapi.Device - serverName gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName + serverName spec.ServerName + thisServer spec.ServerName rsAPI roomserver.RoomserverInternalAPI fsAPI fs.FederationInternalAPI ctx context.Context @@ -222,7 +224,7 @@ func (w *walker) walk() util.JSONResponse { } } - var discoveredRooms []gomatrixserverlib.MSC2946Room + var discoveredRooms []fclient.MSC2946Room var cache *paginationInfo if w.paginationToken != "" { @@ -268,14 +270,14 @@ func (w *walker) walk() util.JSONResponse { // if this room is not a space room, skip. var roomType string - create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "") + create := w.stateEvent(rv.roomID, spec.MRoomCreate, "") if create != nil { // escape the `.`s so gjson doesn't think it's nested roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str } // Collect rooms/events to send back (either locally or fetched via federation) - var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent + var discoveredChildEvents []fclient.MSC2946StrippedEvent // If we know about this room and the caller is authorised (joined/world_readable) then pull // events locally @@ -306,7 +308,7 @@ func (w *walker) walk() util.JSONResponse { pubRoom := w.publicRoomsChunk(rv.roomID) - discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ + discoveredRooms = append(discoveredRooms, fclient.MSC2946Room{ PublicRoom: *pubRoom, RoomType: roomType, ChildrenState: events, @@ -379,7 +381,7 @@ func (w *walker) walk() util.JSONResponse { } return util.JSONResponse{ Code: 200, - JSON: gomatrixserverlib.MSC2946SpacesResponse{ + JSON: fclient.MSC2946SpacesResponse{ Room: discoveredRooms[0], Children: discoveredRooms[1:], }, @@ -402,7 +404,7 @@ func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib. return queryRes.StateEvents[tuple] } -func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom { +func (w *walker) publicRoomsChunk(roomID string) *fclient.PublicRoom { pubRooms, err := roomserver.PopulatePublicRooms(w.ctx, []string{roomID}, w.rsAPI) if err != nil { util.GetLogger(w.ctx).WithError(err).Error("failed to PopulatePublicRooms") @@ -416,7 +418,7 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom { // federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was // unsuccessful. -func (w *walker) federatedRoomInfo(roomID string, vias []string) *gomatrixserverlib.MSC2946SpacesResponse { +func (w *walker) federatedRoomInfo(roomID string, vias []string) *fclient.MSC2946SpacesResponse { // only do federated requests for client requests if w.caller == nil { return nil @@ -433,19 +435,19 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) *gomatrixserver if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, w.thisServer, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) + res, err := w.fsAPI.MSC2946Spaces(ctx, w.thisServer, spec.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue } // ensure nil slices are empty as we send this to the client sometimes if res.Room.ChildrenState == nil { - res.Room.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{} + res.Room.ChildrenState = []fclient.MSC2946StrippedEvent{} } for i := 0; i < len(res.Children); i++ { child := res.Children[i] if child.ChildrenState == nil { - child.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{} + child.ChildrenState = []fclient.MSC2946StrippedEvent{} } res.Children[i] = child } @@ -483,11 +485,11 @@ func (w *walker) authorised(roomID, parentRoomID string) (authed, isJoinedOrInvi func (w *walker) authorisedServer(roomID string) bool { // Check history visibility / join rules first hisVisTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomHistoryVisibility, + EventType: spec.MRoomHistoryVisibility, StateKey: "", } joinRuleTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomJoinRules, + EventType: spec.MRoomJoinRules, StateKey: "", } var queryRoomRes roomserver.QueryCurrentStateResponse @@ -521,11 +523,11 @@ func (w *walker) authorisedServer(roomID string) bool { return false } - if rule == gomatrixserverlib.Public || rule == gomatrixserverlib.Knock { + if rule == spec.Public || rule == spec.Knock { return true } - if rule == gomatrixserverlib.Restricted { + if rule == spec.Restricted { allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")...) } } @@ -555,15 +557,15 @@ func (w *walker) authorisedServer(roomID string) bool { // Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true. func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoinedOrInvited bool) { hisVisTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomHistoryVisibility, + EventType: spec.MRoomHistoryVisibility, StateKey: "", } joinRuleTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomJoinRules, + EventType: spec.MRoomJoinRules, StateKey: "", } roomMemberTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: w.caller.UserID, } var queryRes roomserver.QueryCurrentStateResponse @@ -580,7 +582,7 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi memberEv := queryRes.StateEvents[roomMemberTuple] if memberEv != nil { membership, _ := memberEv.Membership() - if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { + if membership == spec.Join || membership == spec.Invite { return true, true } } @@ -597,9 +599,9 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi rule, ruleErr := joinRuleEv.JoinRule() if ruleErr != nil { util.GetLogger(w.ctx).WithError(ruleErr).WithField("parent_room_id", parentRoomID).Warn("failed to get join rule") - } else if rule == gomatrixserverlib.Public || rule == gomatrixserverlib.Knock { + } else if rule == spec.Public || rule == spec.Knock { allowed = true - } else if rule == gomatrixserverlib.Restricted { + } else if rule == spec.Restricted { allowedRoomIDs := w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership") // check parent is in the allowed set for _, a := range allowedRoomIDs { @@ -624,7 +626,7 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi memberEv = queryRes2.StateEvents[roomMemberTuple] if memberEv != nil { membership, _ := memberEv.Membership() - if membership == gomatrixserverlib.Join { + if membership == spec.Join { return true, false } } @@ -636,7 +638,7 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.HeaderedEvent, allowType string) (allows []string) { rule, _ := joinRuleEv.JoinRule() - if rule != gomatrixserverlib.Restricted { + if rule != spec.Restricted { return nil } var jrContent gomatrixserverlib.JoinRuleContent @@ -653,9 +655,9 @@ func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.He } // references returns all child references pointing to or from this room. -func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { +func (w *walker) childReferences(roomID string) ([]fclient.MSC2946StrippedEvent, error) { createTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomCreate, + EventType: spec.MRoomCreate, StateKey: "", } var res roomserver.QueryCurrentStateResponse @@ -678,12 +680,12 @@ func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946Stri // escape the `.`s so gjson doesn't think it's nested roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str if roomType != ConstCreateEventContentValueSpace { - return []gomatrixserverlib.MSC2946StrippedEvent{}, nil + return []fclient.MSC2946StrippedEvent{}, nil } } delete(res.StateEvents, createTuple) - el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents)) + el := make([]fclient.MSC2946StrippedEvent, 0, len(res.StateEvents)) for _, ev := range res.StateEvents { content := gjson.ParseBytes(ev.Content()) // only return events that have a `via` key as per MSC1772 @@ -720,11 +722,11 @@ func (s set) isSet(val string) bool { return ok } -func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent { +func stripped(ev *gomatrixserverlib.Event) *fclient.MSC2946StrippedEvent { if ev.StateKey() == nil { return nil } - return &gomatrixserverlib.MSC2946StrippedEvent{ + return &fclient.MSC2946StrippedEvent{ Type: ev.Type(), StateKey: *ev.StateKey(), Content: ev.Content(), diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go index 35b7bba3b..9cd5eed1c 100644 --- a/setup/mscs/mscs.go +++ b/setup/mscs/mscs.go @@ -19,30 +19,33 @@ import ( "context" "fmt" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/mscs/msc2836" "github.com/matrix-org/dendrite/setup/mscs/msc2946" "github.com/matrix-org/util" ) // Enable MSCs - returns an error on unknown MSCs -func Enable(base *base.BaseDendrite, monolith *setup.Monolith) error { - for _, msc := range base.Cfg.MSCs.MSCs { +func Enable(cfg *config.Dendrite, cm sqlutil.Connections, routers httputil.Routers, monolith *setup.Monolith, caches *caching.Caches) error { + for _, msc := range cfg.MSCs.MSCs { util.GetLogger(context.Background()).WithField("msc", msc).Info("Enabling MSC") - if err := EnableMSC(base, monolith, msc); err != nil { + if err := EnableMSC(cfg, cm, routers, monolith, msc, caches); err != nil { return err } } return nil } -func EnableMSC(base *base.BaseDendrite, monolith *setup.Monolith, msc string) error { +func EnableMSC(cfg *config.Dendrite, cm sqlutil.Connections, routers httputil.Routers, monolith *setup.Monolith, msc string, caches *caching.Caches) error { switch msc { case "msc2836": - return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) + return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) case "msc2946": - return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing, base.Caches) + return msc2946.Enable(cfg, routers, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing, caches) case "msc2444": // enabled inside federationapi case "msc2753": // enabled inside clientapi default: diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 735f6718c..3ed455e9f 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -21,7 +21,7 @@ import ( "time" "github.com/getsentry/sentry-go" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" @@ -49,8 +49,8 @@ type OutputClientDataConsumer struct { db storage.Database stream streams.StreamProvider notifier *notifier.Notifier - serverName gomatrixserverlib.ServerName - fts *fulltext.Search + serverName spec.ServerName + fts fulltext.Indexer cfg *config.SyncAPI } @@ -121,9 +121,9 @@ func (s *OutputClientDataConsumer) Start() error { switch ev.Type() { case "m.room.message": e.Content = gjson.GetBytes(ev.Content(), "body").String() - case gomatrixserverlib.MRoomName: + case spec.MRoomName: e.Content = gjson.GetBytes(ev.Content(), "name").String() - case gomatrixserverlib.MRoomTopic: + case spec.MRoomTopic: e.Content = gjson.GetBytes(ev.Content(), "topic").String() default: continue diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 6e3150c29..c7c0866ba 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -26,7 +26,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) @@ -108,7 +108,7 @@ func (s *PresenceConsumer) Start() error { for i := range deviceRes.Devices { if int64(presence.LastActiveTS) < deviceRes.Devices[i].LastSeenTS { - presence.LastActiveTS = gomatrixserverlib.Timestamp(deviceRes.Devices[i].LastSeenTS) + presence.LastActiveTS = spec.Timestamp(deviceRes.Devices[i].LastSeenTS) } } @@ -161,11 +161,11 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool // already checked, so no need to check error p, _ := types.PresenceFromString(presence) - s.EmitPresence(ctx, userID, p, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync) + s.EmitPresence(ctx, userID, p, statusMsg, spec.Timestamp(ts), fromSync) return true } -func (s *PresenceConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts gomatrixserverlib.Timestamp, fromSync bool) { +func (s *PresenceConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts spec.Timestamp, fromSync bool) { pos, err := s.db.UpdatePresence(ctx, userID, presence, statusMsg, ts, fromSync) if err != nil { logrus.WithError(err).WithField("user", userID).WithField("presence", presence).Warn("failed to updated presence for user") diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index e39d43f94..69571f90c 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -19,7 +19,6 @@ import ( "strconv" "github.com/getsentry/sentry-go" - "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -30,6 +29,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) // OutputReceiptEventConsumer consumes events that originated in the EDU server. @@ -89,7 +89,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return true } - output.Timestamp = gomatrixserverlib.Timestamp(timestamp) + output.Timestamp = spec.Timestamp(timestamp) streamPos, err := s.db.StoreReceipt( s.ctx, diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index a8d4d2b2c..c842c0c38 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -22,6 +22,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" @@ -51,7 +52,7 @@ type OutputRoomEventConsumer struct { pduStream streams.StreamProvider inviteStream streams.StreamProvider notifier *notifier.Notifier - fts *fulltext.Search + fts fulltext.Indexer } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -108,7 +109,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms // Ignore redaction events. We will add them to the database when they are // validated (when we receive OutputTypeRedactedEvent) event := output.NewRoomEvent.Event - if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil { + if event.Type() == spec.MRoomRedaction && event.StateKey() == nil { // in the special case where the event redacts itself, just pass the message through because // we will never see the other part of the pair if event.Redacts() != event.EventID() { @@ -362,7 +363,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( } func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, sp types.StreamPosition) (types.StreamPosition, error) { - if ev.Type() != gomatrixserverlib.MRoomMember { + if ev.Type() != spec.MRoomMember { return sp, nil } membership, err := ev.Membership() @@ -370,7 +371,7 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom return sp, fmt.Errorf("ev.Membership: %w", err) } // TODO: check that it's a join and not a profile change (means unmarshalling prev_content) - if membership == gomatrixserverlib.Join { + if membership == spec.Join { // check it's a local join if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil { return sp, nil @@ -433,7 +434,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( // Only notify clients about retired invite events, if the user didn't accept the invite. // The PDU stream will also receive an event about accepting the invitation, so there should // be a "smooth" transition from invite -> join, and not invite -> leave -> join - if msg.Membership == gomatrixserverlib.Join { + if msg.Membership == spec.Join { return } @@ -544,11 +545,11 @@ func (s *OutputRoomEventConsumer) writeFTS(ev *gomatrixserverlib.HeaderedEvent, switch ev.Type() { case "m.room.message": e.Content = gjson.GetBytes(ev.Content(), "body").String() - case gomatrixserverlib.MRoomName: + case spec.MRoomName: e.Content = gjson.GetBytes(ev.Content(), "name").String() - case gomatrixserverlib.MRoomTopic: + case spec.MRoomTopic: e.Content = gjson.GetBytes(ev.Content(), "topic").String() - case gomatrixserverlib.MRoomRedaction: + case spec.MRoomRedaction: log.Tracef("Redacting event: %s", ev.Redacts()) if err := s.fts.Delete(ev.Redacts()); err != nil { return fmt.Errorf("failed to delete entry from fulltext index: %w", err) diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 32208c585..7f387dc09 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -20,6 +20,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -43,7 +44,7 @@ type OutputSendToDeviceEventConsumer struct { topic string db storage.Database userAPI api.SyncKeyAPI - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool stream streams.StreamProvider notifier *notifier.Notifier } diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index ee695f0f5..10d7383b2 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -50,7 +51,7 @@ var calculateHistoryVisibilityDuration = prometheus.NewHistogramVec( ) var historyVisibilityPriority = map[gomatrixserverlib.HistoryVisibility]uint8{ - gomatrixserverlib.WorldReadable: 0, + spec.WorldReadable: 0, gomatrixserverlib.HistoryVisibilityShared: 1, gomatrixserverlib.HistoryVisibilityInvited: 2, gomatrixserverlib.HistoryVisibilityJoined: 3, @@ -72,23 +73,23 @@ func (ev eventVisibility) allowed() (allowed bool) { return true case gomatrixserverlib.HistoryVisibilityJoined: // If the user’s membership was join, allow. - if ev.membershipAtEvent == gomatrixserverlib.Join { + if ev.membershipAtEvent == spec.Join { return true } return false case gomatrixserverlib.HistoryVisibilityShared: // If the user’s membership was join, allow. // If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow. - if ev.membershipAtEvent == gomatrixserverlib.Join || ev.membershipCurrent == gomatrixserverlib.Join { + if ev.membershipAtEvent == spec.Join || ev.membershipCurrent == spec.Join { return true } return false case gomatrixserverlib.HistoryVisibilityInvited: // If the user’s membership was join, allow. - if ev.membershipAtEvent == gomatrixserverlib.Join { + if ev.membershipAtEvent == spec.Join { return true } - if ev.membershipAtEvent == gomatrixserverlib.Invite { + if ev.membershipAtEvent == spec.Invite { return true } return false @@ -132,7 +133,7 @@ func ApplyHistoryVisibilityFilter( } } // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules") - if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(userID) { + if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(userID) { eventsFiltered = append(eventsFiltered, ev) continue } @@ -195,7 +196,7 @@ func visibilityForEvents( for _, event := range events { eventID := event.EventID() vis := eventVisibility{ - membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident + membershipAtEvent: spec.Leave, // default to leave, to not expose events by accident visibility: event.Visibility, } ev, ok := membershipResp.Membership[eventID] diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index e7f677c85..ad5935cdc 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -20,12 +20,14 @@ import ( keytypes "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/userapi/api" ) @@ -158,7 +160,7 @@ func TrackChangedUsers( RoomIDs: newlyLeftRooms, StateTuples: []gomatrixserverlib.StateKeyTuple{ { - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: "*", }, }, @@ -169,7 +171,7 @@ func TrackChangedUsers( } for _, state := range stateRes.Rooms { for tuple, membership := range state { - if membership != gomatrixserverlib.Join { + if membership != spec.Join { continue } queryRes.UserIDsToCount[tuple.StateKey]-- @@ -200,7 +202,7 @@ func TrackChangedUsers( RoomIDs: newlyJoinedRooms, StateTuples: []gomatrixserverlib.StateKeyTuple{ { - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: "*", }, }, @@ -211,7 +213,7 @@ func TrackChangedUsers( } for _, state := range stateRes.Rooms { for tuple, membership := range state { - if membership != gomatrixserverlib.Join { + if membership != spec.Join { continue } // new user who we weren't previously sharing rooms with @@ -278,11 +280,11 @@ func leftRooms(res *types.Response) []string { return roomIDs } -func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID string) bool { +func membershipEventPresent(events []synctypes.ClientEvent, userID string) bool { for _, ev := range events { // it's enough to know that we have our member event here, don't need to check membership content // as it's implied by being in the respective section of the sync response. - if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { + if ev.Type == spec.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { // ignore e.g. join -> join changes if gjson.GetBytes(ev.Unsigned, "prev_content.membership").Str == gjson.GetBytes(ev.Content, "membership").Str { continue @@ -301,7 +303,7 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin func membershipEvents(res *types.Response) (joinUserIDs, leaveUserIDs []string) { for _, room := range res.Rooms.Join { for _, ev := range room.Timeline.Events { - if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil { + if ev.Type == spec.MRoomMember && ev.StateKey != nil { if strings.Contains(string(ev.Content), `"join"`) { joinUserIDs = append(joinUserIDs, *ev.StateKey) } else if strings.Contains(string(ev.Content), `"invite"`) { diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 4bb851668..119549045 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -7,9 +7,11 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -75,12 +77,12 @@ func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.Quer // QueryBulkStateContent does a bulk query for state event content in the given rooms. func (s *mockRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) - if req.AllowWildcards && len(req.StateTuples) == 1 && req.StateTuples[0].EventType == gomatrixserverlib.MRoomMember && req.StateTuples[0].StateKey == "*" { + if req.AllowWildcards && len(req.StateTuples) == 1 && req.StateTuples[0].EventType == spec.MRoomMember && req.StateTuples[0].StateKey == "*" { for _, roomID := range req.RoomIDs { res.Rooms[roomID] = make(map[gomatrixserverlib.StateKeyTuple]string) for _, userID := range s.roomIDToJoinedMembers[roomID] { res.Rooms[roomID][gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: userID, }] = "join" } @@ -159,7 +161,7 @@ func assertCatchup(t *testing.T, hasNew bool, syncResponse *types.Response, want func joinResponseWithRooms(syncResponse *types.Response, userID string, roomIDs []string) *types.Response { for _, roomID := range roomIDs { - roomEvents := []gomatrixserverlib.ClientEvent{ + roomEvents := []synctypes.ClientEvent{ { Type: "m.room.member", StateKey: &userID, @@ -182,7 +184,7 @@ func joinResponseWithRooms(syncResponse *types.Response, userID string, roomIDs func leaveResponseWithRooms(syncResponse *types.Response, userID string, roomIDs []string) *types.Response { for _, roomID := range roomIDs { - roomEvents := []gomatrixserverlib.ClientEvent{ + roomEvents := []synctypes.ClientEvent{ { Type: "m.room.member", StateKey: &userID, @@ -299,7 +301,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { roomID := "!TestKeyChangeCatchupNoNewJoinsButMessages:bar" syncResponse := types.NewResponse() empty := "" - roomStateEvents := []gomatrixserverlib.ClientEvent{ + roomStateEvents := []synctypes.ClientEvent{ { Type: "m.room.name", StateKey: &empty, @@ -309,7 +311,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { Content: []byte(`{"name":"The Room Name"}`), }, } - roomTimelineEvents := []gomatrixserverlib.ClientEvent{ + roomTimelineEvents := []synctypes.ClientEvent{ { Type: "m.room.message", EventID: "$something1:here", @@ -402,7 +404,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) { newShareUser2 := "@bobby:localhost" roomID := "!join:bar" syncResponse := types.NewResponse() - roomEvents := []gomatrixserverlib.ClientEvent{ + roomEvents := []synctypes.ClientEvent{ { Type: "m.room.member", StateKey: &syncingUser, diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 27f7c37ba..81b0e7f9e 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -112,16 +113,16 @@ func (n *Notifier) OnNewEvent( } else { // Keep the joined user map up-to-date switch membership { - case gomatrixserverlib.Invite: + case spec.Invite: usersToNotify = append(usersToNotify, targetUserID) - case gomatrixserverlib.Join: + case spec.Join: // Manually append the new user's ID so they get notified // along all members in the room usersToNotify = append(usersToNotify, targetUserID) n._addJoinedUser(ev.RoomID(), targetUserID) - case gomatrixserverlib.Leave: + case spec.Leave: fallthrough - case gomatrixserverlib.Ban: + case spec.Ban: n._removeJoinedUser(ev.RoomID(), targetUserID) } } diff --git a/syncapi/producers/federationapi_presence.go b/syncapi/producers/federationapi_presence.go index dc03457e3..eab1b0b25 100644 --- a/syncapi/producers/federationapi_presence.go +++ b/syncapi/producers/federationapi_presence.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" ) @@ -37,7 +37,7 @@ func (f *FederationAPIPresenceProducer) SendPresence( msg.Header.Set(jetstream.UserID, userID) msg.Header.Set("presence", presence.String()) msg.Header.Set("from_sync", "true") // only update last_active_ts and presence - msg.Header.Set("last_active_ts", strconv.Itoa(int(gomatrixserverlib.AsTimestamp(time.Now())))) + msg.Header.Set("last_active_ts", strconv.Itoa(int(spec.AsTimestamp(time.Now())))) if statusMsg != nil { msg.Header.Set("status_msg", *statusMsg) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 76f003671..afd61eae8 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -29,20 +29,22 @@ import ( roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) type ContextRespsonse struct { - End string `json:"end"` - Event *gomatrixserverlib.ClientEvent `json:"event,omitempty"` - EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after,omitempty"` - EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before,omitempty"` - Start string `json:"start"` - State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` + End string `json:"end"` + Event *synctypes.ClientEvent `json:"event,omitempty"` + EventsAfter []synctypes.ClientEvent `json:"events_after,omitempty"` + EventsBefore []synctypes.ClientEvent `json:"events_before,omitempty"` + Start string `json:"start"` + State []synctypes.ClientEvent `json:"state,omitempty"` } func Context( @@ -94,7 +96,7 @@ func Context( } } - stateFilter := gomatrixserverlib.StateFilter{ + stateFilter := synctypes.StateFilter{ NotSenders: filter.NotSenders, NotTypes: filter.NotTypes, Senders: filter.Senders, @@ -167,14 +169,14 @@ func Context( return jsonerror.InternalServerError() } - eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll) - eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) + eventsBeforeClient := synctypes.HeaderedToClientEvents(eventsBeforeFiltered, synctypes.FormatAll) + eventsAfterClient := synctypes.HeaderedToClientEvents(eventsAfterFiltered, synctypes.FormatAll) newState := state if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := gomatrixserverlib.HeaderedToClientEvents(allEvents, gomatrixserverlib.FormatAll) + evs := synctypes.HeaderedToClientEvents(allEvents, synctypes.FormatAll) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) if err != nil { logrus.WithError(err).Error("unable to load membership events") @@ -182,12 +184,12 @@ func Context( } } - ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll) + ev := synctypes.HeaderedToClientEvent(&requestedEvent, synctypes.FormatAll) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: gomatrixserverlib.HeaderedToClientEvents(newState, gomatrixserverlib.FormatAll), + State: synctypes.HeaderedToClientEvents(newState, synctypes.FormatAll), } if len(response.State) > filter.Limit { @@ -261,7 +263,7 @@ func applyLazyLoadMembers( device *userapi.Device, snapshot storage.DatabaseTransaction, roomID string, - events []gomatrixserverlib.ClientEvent, + events []synctypes.ClientEvent, lazyLoadCache caching.LazyLoadCache, ) ([]*gomatrixserverlib.HeaderedEvent, error) { eventSenders := make(map[string]struct{}) @@ -280,9 +282,9 @@ func applyLazyLoadMembers( } // Query missing membership events - filter := gomatrixserverlib.DefaultStateFilter() + filter := synctypes.DefaultStateFilter() filter.Senders = &wantUsers - filter.Types = &[]string{gomatrixserverlib.MRoomMember} + filter.Types = &[]string{spec.MRoomMember} memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter) if err != nil { return nil, err @@ -296,9 +298,9 @@ func applyLazyLoadMembers( return memberships, nil } -func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { +func parseRoomEventFilter(req *http.Request) (*synctypes.RoomEventFilter, error) { // Default room filter - filter := &gomatrixserverlib.RoomEventFilter{Limit: 10} + filter := &synctypes.RoomEventFilter{Limit: 10} l := req.URL.Query().Get("limit") f := req.URL.Query().Get("filter") diff --git a/syncapi/routing/context_test.go b/syncapi/routing/context_test.go index e79a5d5f1..5c6682bd1 100644 --- a/syncapi/routing/context_test.go +++ b/syncapi/routing/context_test.go @@ -5,7 +5,7 @@ import ( "reflect" "testing" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/syncapi/synctypes" ) func Test_parseContextParams(t *testing.T) { @@ -19,28 +19,28 @@ func Test_parseContextParams(t *testing.T) { tests := []struct { name string req *http.Request - wantFilter *gomatrixserverlib.RoomEventFilter + wantFilter *synctypes.RoomEventFilter wantErr bool }{ { name: "no params set", req: noParamsReq, - wantFilter: &gomatrixserverlib.RoomEventFilter{Limit: 10}, + wantFilter: &synctypes.RoomEventFilter{Limit: 10}, }, { name: "limit 2 param set", req: limit2Req, - wantFilter: &gomatrixserverlib.RoomEventFilter{Limit: 2}, + wantFilter: &synctypes.RoomEventFilter{Limit: 2}, }, { name: "limit 10000 param set", req: limit10000Req, - wantFilter: &gomatrixserverlib.RoomEventFilter{Limit: 100}, + wantFilter: &synctypes.RoomEventFilter{Limit: 100}, }, { name: "filter lazy_load_members param set", req: lazyLoadReq, - wantFilter: &gomatrixserverlib.RoomEventFilter{Limit: 2, LazyLoadMembers: true}, + wantFilter: &synctypes.RoomEventFilter{Limit: 2, LazyLoadMembers: true}, }, { name: "invalid limit req", diff --git a/syncapi/routing/filter.go b/syncapi/routing/filter.go index f5acdbde3..266ad4adc 100644 --- a/syncapi/routing/filter.go +++ b/syncapi/routing/filter.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/userapi/api" ) @@ -45,7 +46,7 @@ func GetFilter( return jsonerror.InternalServerError() } - filter := gomatrixserverlib.DefaultFilter() + filter := synctypes.DefaultFilter() if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterID); err != nil { //TODO better error handling. This error message is *probably* right, // but if there are obscure db errors, this will also be returned, @@ -85,7 +86,7 @@ func PutFilter( return jsonerror.InternalServerError() } - var filter gomatrixserverlib.Filter + var filter synctypes.Filter defer req.Body.Close() // nolint:errcheck body, err := io.ReadAll(req.Body) diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index d2cdc1b5f..84986d3b3 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -17,7 +17,6 @@ package routing import ( "net/http" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -26,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -97,6 +97,6 @@ func GetEvent( return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.HeaderedToClientEvent(events[0], gomatrixserverlib.FormatAll), + JSON: synctypes.HeaderedToClientEvent(events[0], synctypes.FormatAll), } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 8efd77cef..9ea660f59 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -22,14 +22,14 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) type getMembershipResponse struct { - Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` + Chunk []synctypes.ClientEvent `json:"chunk"` } // https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-rooms-roomid-joined-members @@ -134,6 +134,6 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{gomatrixserverlib.HeaderedToClientEvents(result, gomatrixserverlib.FormatAll)}, + JSON: getMembershipResponse{synctypes.HeaderedToClientEvents(result, synctypes.FormatAll)}, } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 02d8fcc7e..6e508a92c 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -23,6 +23,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -34,6 +35,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -50,15 +52,15 @@ type messagesReq struct { device *userapi.Device wasToProvided bool backwardOrdering bool - filter *gomatrixserverlib.RoomEventFilter + filter *synctypes.RoomEventFilter } type messagesResp struct { - Start string `json:"start"` - StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token - End string `json:"end,omitempty"` - Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` - State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` + Start string `json:"start"` + StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token + End string `json:"end,omitempty"` + Chunk []synctypes.ClientEvent `json:"chunk"` + State []synctypes.ClientEvent `json:"state,omitempty"` } // OnIncomingMessagesRequest implements the /messages endpoint from the @@ -199,7 +201,7 @@ func OnIncomingMessagesRequest( } // If the user already left the room, grep events from before that - if membershipResp.Membership == gomatrixserverlib.Leave { + if membershipResp.Membership == spec.Leave { var token types.TopologyToken token, err = snapshot.EventPositionInTopology(req.Context(), membershipResp.EventID) if err != nil { @@ -253,7 +255,7 @@ func OnIncomingMessagesRequest( util.GetLogger(req.Context()).WithError(err).Error("failed to apply lazy loading") return jsonerror.InternalServerError() } - res.State = append(res.State, gomatrixserverlib.HeaderedToClientEvents(membershipEvents, gomatrixserverlib.FormatAll)...) + res.State = append(res.State, synctypes.HeaderedToClientEvents(membershipEvents, synctypes.FormatAll)...) } // If we didn't return any events, set the end to an empty string, so it will be omitted @@ -291,7 +293,7 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api. // Returns an error if there was an issue talking to the database or with the // remote homeserver. func (r *messagesReq) retrieveEvents() ( - clientEvents []gomatrixserverlib.ClientEvent, start, + clientEvents []synctypes.ClientEvent, start, end types.TopologyToken, err error, ) { // Retrieve the events from the local database. @@ -323,7 +325,7 @@ func (r *messagesReq) retrieveEvents() ( // If we didn't get any event, we don't need to proceed any further. if len(events) == 0 { - return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil + return []synctypes.ClientEvent{}, *r.from, *r.to, nil } // Get the position of the first and the last event in the room's topology. @@ -334,7 +336,7 @@ func (r *messagesReq) retrieveEvents() ( // only have to change it in one place, i.e. the database. start, end, err = r.getStartEnd(events) if err != nil { - return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, err + return []synctypes.ClientEvent{}, *r.from, *r.to, err } // Sort the events to ensure we send them in the right order. @@ -350,7 +352,7 @@ func (r *messagesReq) retrieveEvents() ( events = reversed(events) } if len(events) == 0 { - return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil + return []synctypes.ClientEvent{}, *r.from, *r.to, nil } // Apply room history visibility filter @@ -362,13 +364,13 @@ func (r *messagesReq) retrieveEvents() ( "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err + return synctypes.HeaderedToClientEvents(filteredEvents, synctypes.FormatAll), start, end, err } func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { if r.backwardOrdering { start = *r.from - if events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate { + if events[len(events)-1].Type() == spec.MRoomCreate { // NOTSPEC: We've hit the beginning of the room so there's really nowhere // else to go. This seems to fix Element iOS from looping on /messages endlessly. end = types.TopologyToken{} diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index fee61b0df..79533883f 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -27,14 +27,15 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) type RelationsResponse struct { - Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` - NextBatch string `json:"next_batch,omitempty"` - PrevBatch string `json:"prev_batch,omitempty"` + Chunk []synctypes.ClientEvent `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` + PrevBatch string `json:"prev_batch,omitempty"` } // nolint:gocyclo @@ -85,7 +86,7 @@ func Relations( defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) res := &RelationsResponse{ - Chunk: []gomatrixserverlib.ClientEvent{}, + Chunk: []synctypes.ClientEvent{}, } var events []types.StreamEvent events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor( @@ -108,11 +109,11 @@ func Relations( // Convert the events into client events, and optionally filter based on the event // type if it was specified. - res.Chunk = make([]gomatrixserverlib.ClientEvent, 0, len(filteredEvents)) + res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { res.Chunk = append( res.Chunk, - gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll), + synctypes.ToClientEvent(event.Event, synctypes.FormatAll), ) } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 4cc1a6a85..9607aa325 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -18,7 +18,7 @@ import ( "net/http" "github.com/gorilla/mux" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -43,7 +43,7 @@ func Setup( rsAPI api.SyncRoomserverAPI, cfg *config.SyncAPI, lazyLoadCache caching.LazyLoadCache, - fts *fulltext.Search, + fts fulltext.Indexer, ) { v1unstablemux := csMux.PathPrefix("/{apiversion:(?:v1|unstable)}/").Subrouter() v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() @@ -96,7 +96,7 @@ func Setup( }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/context/{eventId}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("context", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -112,7 +112,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relations", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -126,7 +126,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relation_type", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -140,7 +140,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relation_type_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -201,7 +201,7 @@ func Setup( return util.ErrorResponse(err) } at := req.URL.Query().Get("at") - membership := gomatrixserverlib.Join + membership := spec.Join return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, true, &membership, nil, at) }), ).Methods(http.MethodGet, http.MethodOptions) diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 081ec6cb1..224c68072 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -19,11 +19,11 @@ import ( "net/http" "sort" "strconv" - "strings" "time" "github.com/blevesearch/bleve/v2/search" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -33,11 +33,12 @@ import ( "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/userapi/api" ) // nolint:gocyclo -func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts *fulltext.Search, from *string) util.JSONResponse { +func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string) util.JSONResponse { start := time.Now() var ( searchReq SearchRequest @@ -123,8 +124,8 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts return util.JSONResponse{ Code: http.StatusOK, JSON: SearchResponse{ - SearchCategories: SearchCategories{ - RoomEvents: RoomEvents{ + SearchCategories: SearchCategoriesResponse{ + RoomEvents: RoomEventsResponse{ Count: int(result.Total), NextBatch: nil, }, @@ -146,7 +147,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts // Filter on m.room.message, as otherwise we also get events like m.reaction // which "breaks" displaying results in Element Web. types := []string{"m.room.message"} - roomFilter := &gomatrixserverlib.RoomEventFilter{ + roomFilter := &synctypes.RoomEventFilter{ Rooms: &rooms, Types: &types, } @@ -158,7 +159,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts } groups := make(map[string]RoomResult) - knownUsersProfiles := make(map[string]ProfileInfo) + knownUsersProfiles := make(map[string]ProfileInfoResponse) // Sort the events by depth, as the returned values aren't ordered if orderByTime { @@ -167,7 +168,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts }) } - stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent) + stateForRooms := make(map[string][]synctypes.ClientEvent) for _, event := range evs { eventsBefore, eventsAfter, err := contextEvents(ctx, snapshot, event, roomFilter, searchReq) if err != nil { @@ -180,11 +181,11 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts return jsonerror.InternalServerError() } - profileInfos := make(map[string]ProfileInfo) + profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { profile, ok := knownUsersProfiles[event.Sender()] if !ok { - stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender()) + stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.Sender()) if err != nil { logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") continue @@ -192,7 +193,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts if stateEvent == nil { continue } - profile = ProfileInfo{ + profile = ProfileInfoResponse{ AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str, DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str, } @@ -205,24 +206,24 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts Context: SearchContextResponse{ Start: startToken.String(), End: endToken.String(), - EventsAfter: gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatSync), - EventsBefore: gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatSync), + EventsAfter: synctypes.HeaderedToClientEvents(eventsAfter, synctypes.FormatSync), + EventsBefore: synctypes.HeaderedToClientEvents(eventsBefore, synctypes.FormatSync), ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll), + Result: synctypes.HeaderedToClientEvent(event, synctypes.FormatAll), }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) groups[event.RoomID()] = roomGroup if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok { - stateFilter := gomatrixserverlib.DefaultStateFilter() + stateFilter := synctypes.DefaultStateFilter() state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to get current state") return jsonerror.InternalServerError() } - stateForRooms[event.RoomID()] = gomatrixserverlib.HeaderedToClientEvents(state, gomatrixserverlib.FormatSync) + stateForRooms[event.RoomID()] = synctypes.HeaderedToClientEvents(state, synctypes.FormatSync) } } @@ -237,13 +238,13 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts } res := SearchResponse{ - SearchCategories: SearchCategories{ - RoomEvents: RoomEvents{ + SearchCategories: SearchCategoriesResponse{ + RoomEvents: RoomEventsResponse{ Count: int(result.Total), Groups: Groups{RoomID: groups}, Results: results, NextBatch: nextBatchResult, - Highlights: strings.Split(searchReq.SearchCategories.RoomEvents.SearchTerm, " "), + Highlights: fts.GetHighlights(result), State: stateForRooms, }, }, @@ -263,7 +264,7 @@ func contextEvents( ctx context.Context, snapshot storage.DatabaseTransaction, event *gomatrixserverlib.HeaderedEvent, - roomFilter *gomatrixserverlib.RoomEventFilter, + roomFilter *synctypes.RoomEventFilter, searchReq SearchRequest, ) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) { id, _, err := snapshot.SelectContextEvent(ctx, event.RoomID(), event.EventID()) @@ -286,30 +287,40 @@ func contextEvents( return eventsBefore, eventsAfter, err } +type EventContext struct { + AfterLimit int `json:"after_limit,omitempty"` + BeforeLimit int `json:"before_limit,omitempty"` + IncludeProfile bool `json:"include_profile,omitempty"` +} + +type GroupBy struct { + Key string `json:"key"` +} + +type Groupings struct { + GroupBy []GroupBy `json:"group_by"` +} + +type RoomEvents struct { + EventContext EventContext `json:"event_context"` + Filter synctypes.RoomEventFilter `json:"filter"` + Groupings Groupings `json:"groupings"` + IncludeState bool `json:"include_state"` + Keys []string `json:"keys"` + OrderBy string `json:"order_by"` + SearchTerm string `json:"search_term"` +} + +type SearchCategories struct { + RoomEvents RoomEvents `json:"room_events"` +} + type SearchRequest struct { - SearchCategories struct { - RoomEvents struct { - EventContext struct { - AfterLimit int `json:"after_limit,omitempty"` - BeforeLimit int `json:"before_limit,omitempty"` - IncludeProfile bool `json:"include_profile,omitempty"` - } `json:"event_context"` - Filter gomatrixserverlib.RoomEventFilter `json:"filter"` - Groupings struct { - GroupBy []struct { - Key string `json:"key"` - } `json:"group_by"` - } `json:"groupings"` - IncludeState bool `json:"include_state"` - Keys []string `json:"keys"` - OrderBy string `json:"order_by"` - SearchTerm string `json:"search_term"` - } `json:"room_events"` - } `json:"search_categories"` + SearchCategories SearchCategories `json:"search_categories"` } type SearchResponse struct { - SearchCategories SearchCategories `json:"search_categories"` + SearchCategories SearchCategoriesResponse `json:"search_categories"` } type RoomResult struct { NextBatch *string `json:"next_batch,omitempty"` @@ -322,32 +333,32 @@ type Groups struct { } type Result struct { - Context SearchContextResponse `json:"context"` - Rank float64 `json:"rank"` - Result gomatrixserverlib.ClientEvent `json:"result"` + Context SearchContextResponse `json:"context"` + Rank float64 `json:"rank"` + Result synctypes.ClientEvent `json:"result"` } type SearchContextResponse struct { - End string `json:"end"` - EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after"` - EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before"` - Start string `json:"start"` - ProfileInfo map[string]ProfileInfo `json:"profile_info"` + End string `json:"end"` + EventsAfter []synctypes.ClientEvent `json:"events_after"` + EventsBefore []synctypes.ClientEvent `json:"events_before"` + Start string `json:"start"` + ProfileInfo map[string]ProfileInfoResponse `json:"profile_info"` } -type ProfileInfo struct { +type ProfileInfoResponse struct { AvatarURL string `json:"avatar_url"` DisplayName string `json:"display_name"` } -type RoomEvents struct { - Count int `json:"count"` - Groups Groups `json:"groups"` - Highlights []string `json:"highlights"` - NextBatch *string `json:"next_batch,omitempty"` - Results []Result `json:"results"` - State map[string][]gomatrixserverlib.ClientEvent `json:"state,omitempty"` +type RoomEventsResponse struct { + Count int `json:"count"` + Groups Groups `json:"groups"` + Highlights []string `json:"highlights"` + NextBatch *string `json:"next_batch,omitempty"` + Results []Result `json:"results"` + State map[string][]synctypes.ClientEvent `json:"state,omitempty"` } -type SearchCategories struct { - RoomEvents RoomEvents `json:"room_events"` +type SearchCategoriesResponse struct { + RoomEvents RoomEventsResponse `json:"room_events"` } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go new file mode 100644 index 000000000..c8b9c604c --- /dev/null +++ b/syncapi/routing/search_test.go @@ -0,0 +1,266 @@ +package routing + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/matrix-org/dendrite/internal/fulltext" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/stretchr/testify/assert" +) + +func TestSearch(t *testing.T) { + alice := test.NewUser(t) + aliceDevice := userapi.Device{UserID: alice.ID} + room := test.NewRoom(t, alice) + room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "context before"}) + room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hello world3!"}) + room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "context after"}) + + roomsFilter := []string{room.ID} + roomsFilterUnknown := []string{"!unknown"} + + emptyFromString := "" + fromStringValid := "1" + fromStringInvalid := "iCantBeParsed" + + testCases := []struct { + name string + wantOK bool + searchReq SearchRequest + device *userapi.Device + wantResponseCount int + from *string + }{ + { + name: "no user ID", + searchReq: SearchRequest{}, + device: &userapi.Device{}, + }, + { + name: "with alice ID", + wantOK: true, + searchReq: SearchRequest{}, + device: &aliceDevice, + }, + { + name: "searchTerm specified, found at the beginning", + wantOK: true, + searchReq: SearchRequest{ + SearchCategories: SearchCategories{RoomEvents: RoomEvents{SearchTerm: "hello"}}, + }, + device: &aliceDevice, + wantResponseCount: 1, + }, + { + name: "searchTerm specified, found at the end", + wantOK: true, + searchReq: SearchRequest{ + SearchCategories: SearchCategories{RoomEvents: RoomEvents{SearchTerm: "world3"}}, + }, + device: &aliceDevice, + wantResponseCount: 1, + }, + /* the following would need matchQuery.SetFuzziness(1) in bleve.go + { + name: "searchTerm fuzzy search", + wantOK: true, + searchReq: SearchRequest{ + SearchCategories: SearchCategories{RoomEvents: RoomEvents{SearchTerm: "hell"}}, // this still should find hello world + }, + device: &aliceDevice, + wantResponseCount: 1, + }, + */ + { + name: "searchTerm specified but no result", + wantOK: true, + searchReq: SearchRequest{ + SearchCategories: SearchCategories{RoomEvents: RoomEvents{SearchTerm: "i don't match"}}, + }, + device: &aliceDevice, + }, + { + name: "filter on room", + wantOK: true, + searchReq: SearchRequest{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + SearchTerm: "hello", + Filter: synctypes.RoomEventFilter{ + Rooms: &roomsFilter, + }, + }, + }, + }, + device: &aliceDevice, + wantResponseCount: 1, + }, + { + name: "filter on unknown room", + searchReq: SearchRequest{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + SearchTerm: "hello", + Filter: synctypes.RoomEventFilter{ + Rooms: &roomsFilterUnknown, + }, + }, + }, + }, + device: &aliceDevice, + }, + { + name: "include state", + wantOK: true, + searchReq: SearchRequest{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + SearchTerm: "hello", + Filter: synctypes.RoomEventFilter{ + Rooms: &roomsFilter, + }, + IncludeState: true, + }, + }, + }, + device: &aliceDevice, + wantResponseCount: 1, + }, + { + name: "empty from does not error", + wantOK: true, + searchReq: SearchRequest{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + SearchTerm: "hello", + Filter: synctypes.RoomEventFilter{ + Rooms: &roomsFilter, + }, + }, + }, + }, + wantResponseCount: 1, + device: &aliceDevice, + from: &emptyFromString, + }, + { + name: "valid from does not error", + wantOK: true, + searchReq: SearchRequest{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + SearchTerm: "hello", + Filter: synctypes.RoomEventFilter{ + Rooms: &roomsFilter, + }, + }, + }, + }, + wantResponseCount: 1, + device: &aliceDevice, + from: &fromStringValid, + }, + { + name: "invalid from does error", + searchReq: SearchRequest{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + SearchTerm: "hello", + Filter: synctypes.RoomEventFilter{ + Rooms: &roomsFilter, + }, + }, + }, + }, + device: &aliceDevice, + from: &fromStringInvalid, + }, + { + name: "order by stream position", + wantOK: true, + searchReq: SearchRequest{ + SearchCategories: SearchCategories{RoomEvents: RoomEvents{SearchTerm: "hello", OrderBy: "recent"}}, + }, + device: &aliceDevice, + wantResponseCount: 1, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + + // create requisites + fts, err := fulltext.New(processCtx, cfg.SyncAPI.Fulltext) + assert.NoError(t, err) + assert.NotNil(t, fts) + + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + db, err := storage.NewSyncServerDatasource(processCtx.Context(), cm, &cfg.SyncAPI.Database) + assert.NoError(t, err) + + elements := []fulltext.IndexElement{} + // store the events in the database + var sp types.StreamPosition + for _, x := range room.Events() { + var stateEvents []*gomatrixserverlib.HeaderedEvent + var stateEventIDs []string + if x.Type() == spec.MRoomMember { + stateEvents = append(stateEvents, x) + stateEventIDs = append(stateEventIDs, x.EventID()) + } + sp, err = db.WriteEvent(processCtx.Context(), x, stateEvents, stateEventIDs, nil, nil, false, gomatrixserverlib.HistoryVisibilityShared) + assert.NoError(t, err) + if x.Type() != "m.room.message" { + continue + } + elements = append(elements, fulltext.IndexElement{ + EventID: x.EventID(), + RoomID: x.RoomID(), + Content: string(x.Content()), + ContentType: x.Type(), + StreamPosition: int64(sp), + }) + } + // Index the events + err = fts.Index(elements...) + assert.NoError(t, err) + + // run the tests + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reqBody := &bytes.Buffer{} + err = json.NewEncoder(reqBody).Encode(tc.searchReq) + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/", reqBody) + + res := Search(req, tc.device, db, fts, tc.from) + if !tc.wantOK && !res.Is2xx() { + return + } + resp, ok := res.JSON.(SearchResponse) + if !ok && !tc.wantOK { + t.Fatalf("not a SearchResponse: %T: %s", res.JSON, res.JSON) + } + assert.Equal(t, tc.wantResponseCount, resp.SearchCategories.RoomEvents.Count) + + // if we requested state, it should not be empty + if tc.searchReq.SearchCategories.RoomEvents.IncludeState { + assert.NotEmpty(t, resp.SearchCategories.RoomEvents.State) + } + }) + } + }) +} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 04c2020a0..f5c1223a7 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -18,11 +18,13 @@ import ( "context" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/shared" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -40,13 +42,13 @@ type DatabaseTransaction interface { MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) - CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) - GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) - GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) + CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) + GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) + GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) - RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) + RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *synctypes.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) @@ -71,15 +73,15 @@ type DatabaseTransaction interface { // GetStateEventsForRoom fetches the state events for a given room. // Returns an empty slice if no state events could be found for this room. // Returns an error if there was an issue with the retrieval. - GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) + GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) // GetAccountDataInRange returns all account data for a given user inserted or // updated between two given positions // Returns a map following the format data[roomID] = []dataTypes // If no data is retrieved, returns an empty map // If there was an issue with the retrieval, returns an error - GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error) + GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *synctypes.EventFilter) (map[string][]string, types.StreamPosition, error) // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. - GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) + GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *synctypes.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) // EventPositionInTopology returns the depth and stream position of the given event. EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. @@ -94,8 +96,8 @@ type DatabaseTransaction interface { // GetRoomReceipts gets all receipts for a given roomID GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) - SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) - SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) + SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found @@ -105,7 +107,7 @@ type DatabaseTransaction interface { // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) - PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) + PresenceAfter(ctx context.Context, after types.StreamPosition, filter synctypes.EventFilter) (map[string]*types.PresenceInternal, error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) } @@ -165,15 +167,15 @@ type Database interface { // GetFilter looks up the filter associated with a given local user and filter ID // and populates the target filter. Otherwise returns an error if no such filter exists // or if there was an error talking to the database. - GetFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error + GetFilter(ctx context.Context, target *synctypes.Filter, localpart string, filterID string) error // PutFilter puts the passed filter into the database. // Returns the filterID as a string. Otherwise returns an error if something // goes wrong. - PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) + PutFilter(ctx context.Context, localpart string, filter *synctypes.Filter) (string, error) // RedactEvent wipes an event in the database and sets the unsigned.redacted_because key to the redaction event RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error // StoreReceipt stores new receipt events - StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) + StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) UpdateRelations(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error @@ -187,7 +189,7 @@ type Database interface { type Presence interface { GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) - UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) + UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS spec.Timestamp, fromSync bool) (types.StreamPosition, error) } type SharedUsers interface { diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index aa54cb08f..44d735bb4 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -23,8 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -78,16 +78,11 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountData, error) { if err != nil { return nil, err } - if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return nil, err - } - if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { - return nil, err - } - if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertAccountDataStmt, insertAccountDataSQL}, + {&s.selectAccountDataInRangeStmt, selectAccountDataInRangeSQL}, + {&s.selectMaxAccountDataIDStmt, selectMaxAccountDataIDSQL}, + }.Prepare(db) } func (s *accountDataStatements) InsertAccountData( @@ -102,7 +97,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( ctx context.Context, txn *sql.Tx, userID string, r types.Range, - accountDataEventFilter *gomatrixserverlib.EventFilter, + accountDataEventFilter *synctypes.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 0d607b7c0..c2a870c07 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -26,8 +26,10 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const currentRoomStateSchema = ` @@ -270,16 +272,16 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithAnyMembership( // SelectCurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, - stateFilter *gomatrixserverlib.StateFilter, + stateFilter *synctypes.StateFilter, excludeEventIDs []string, ) ([]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) senders, notSenders := getSendersStateFilterFilter(stateFilter) // We're going to query members later, so remove them from this request if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers { - notTypes := &[]string{gomatrixserverlib.MRoomMember} + notTypes := &[]string{spec.MRoomMember} if stateFilter.NotTypes != nil { - *stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember) + *stateFilter.NotTypes = append(*stateFilter.NotTypes, spec.MRoomMember) } else { stateFilter.NotTypes = notTypes } diff --git a/syncapi/storage/postgres/filter_table.go b/syncapi/storage/postgres/filter_table.go index 86cec3625..089382b89 100644 --- a/syncapi/storage/postgres/filter_table.go +++ b/syncapi/storage/postgres/filter_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -61,20 +62,15 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) { return nil, err } s := &filterStatements{} - if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { - return nil, err - } - if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { - return nil, err - } - if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.selectFilterStmt, selectFilterSQL}, + {&s.selectFilterIDByContentStmt, selectFilterIDByContentSQL}, + {&s.insertFilterStmt, insertFilterSQL}, + }.Prepare(db) } func (s *filterStatements) SelectFilter( - ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string, + ctx context.Context, txn *sql.Tx, target *synctypes.Filter, localpart string, filterID string, ) error { // Retrieve filter from database (stored as canonical JSON) var filterData []byte @@ -91,7 +87,7 @@ func (s *filterStatements) SelectFilter( } func (s *filterStatements) InsertFilter( - ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string, + ctx context.Context, txn *sql.Tx, filter *synctypes.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string diff --git a/syncapi/storage/postgres/filtering.go b/syncapi/storage/postgres/filtering.go index a2ca42156..39ef11086 100644 --- a/syncapi/storage/postgres/filtering.go +++ b/syncapi/storage/postgres/filtering.go @@ -17,7 +17,7 @@ package postgres import ( "strings" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/syncapi/synctypes" ) // filterConvertWildcardToSQL converts wildcards as defined in @@ -39,7 +39,7 @@ func filterConvertTypeWildcardToSQL(values *[]string) []string { } // TODO: Replace when Dendrite uses Go 1.18 -func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) { +func getSendersRoomEventFilter(filter *synctypes.RoomEventFilter) (senders []string, notSenders []string) { if filter.Senders != nil { senders = *filter.Senders } @@ -49,7 +49,7 @@ func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (sende return senders, notSenders } -func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) { +func getSendersStateFilterFilter(filter *synctypes.StateFilter) (senders []string, notSenders []string) { if filter.Senders != nil { senders = *filter.Senders } diff --git a/syncapi/storage/postgres/ignores_table.go b/syncapi/storage/postgres/ignores_table.go index 97660725c..a511c747c 100644 --- a/syncapi/storage/postgres/ignores_table.go +++ b/syncapi/storage/postgres/ignores_table.go @@ -52,13 +52,11 @@ func NewPostgresIgnoresTable(db *sql.DB) (tables.Ignores, error) { return nil, err } s := &ignoresStatements{} - if s.selectIgnoresStmt, err = db.Prepare(selectIgnoresSQL); err != nil { - return nil, err - } - if s.upsertIgnoresStmt, err = db.Prepare(upsertIgnoresSQL); err != nil { - return nil, err - } - return s, nil + + return s, sqlutil.StatementList{ + {&s.selectIgnoresStmt, selectIgnoresSQL}, + {&s.upsertIgnoresStmt, upsertIgnoresSQL}, + }.Prepare(db) } func (s *ignoresStatements) SelectIgnores( diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 59fb99aa3..8ee5098c7 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -135,15 +136,6 @@ FROM room_ids, ) AS x ` -const selectEarlyEventsSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3" + - " AND ( $4::text[] IS NULL OR sender = ANY($4) )" + - " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + - " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + - " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + - " ORDER BY id ASC LIMIT $8" - const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" @@ -205,7 +197,6 @@ type outputRoomEventsStatements struct { selectMaxEventIDStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt - selectEarlyEventsStmt *sql.Stmt selectStateInRangeFilteredStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt updateEventJSONStmt *sql.Stmt @@ -261,7 +252,6 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, - {&s.selectEarlyEventsStmt, selectEarlyEventsSQL}, {&s.selectStateInRangeFilteredStmt, selectStateInRangeFilteredSQL}, {&s.selectStateInRangeStmt, selectStateInRangeSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL}, @@ -288,7 +278,7 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *s // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, - stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, + stateFilter *synctypes.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { var rows *sql.Rows var err error @@ -433,7 +423,7 @@ func (s *outputRoomEventsStatements) InsertEvent( // from sync. func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, - roomIDs []string, ra types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + roomIDs []string, ra types.Range, eventFilter *synctypes.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool, ) (map[string]types.RecentEvents, error) { var stmt *sql.Stmt @@ -529,43 +519,10 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( return result, rows.Err() } -// selectEarlyEvents returns the earliest events in the given room, starting -// from a given position, up to a maximum of 'limit'. -func (s *outputRoomEventsStatements) SelectEarlyEvents( - ctx context.Context, txn *sql.Tx, - roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, -) ([]types.StreamEvent, error) { - senders, notSenders := getSendersRoomEventFilter(eventFilter) - stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) - rows, err := stmt.QueryContext( - ctx, roomID, r.Low(), r.High(), - pq.StringArray(senders), - pq.StringArray(notSenders), - pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), - pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), - eventFilter.Limit, - ) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, err - } - // The events need to be returned from oldest to latest, which isn't - // necessarily the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - return events, nil -} - // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool, + ctx context.Context, txn *sql.Tx, eventIDs []string, filter *synctypes.RoomEventFilter, preserveOrder bool, ) ([]types.StreamEvent, error) { var ( stmt *sql.Stmt @@ -637,7 +594,7 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn } func (s *outputRoomEventsStatements) SelectContextBeforeEvent( - ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, + ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter, ) (evts []*gomatrixserverlib.HeaderedEvent, err error) { senders, notSenders := getSendersRoomEventFilter(filter) rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext( @@ -672,7 +629,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( } func (s *outputRoomEventsStatements) SelectContextAfterEvent( - ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, + ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter, ) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) { senders, notSenders := getSendersRoomEventFilter(filter) rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext( diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index a3f7c5213..f37b5331e 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -20,10 +20,11 @@ import ( "time" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -108,7 +109,7 @@ func (p *presenceStatements) UpsertPresence( userID string, statusMsg *string, presence types.Presence, - lastActiveTS gomatrixserverlib.Timestamp, + lastActiveTS spec.Timestamp, fromSync bool, ) (pos types.StreamPosition, err error) { if fromSync { @@ -156,11 +157,11 @@ func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) func (p *presenceStatements) GetPresenceAfter( ctx context.Context, txn *sql.Tx, after types.StreamPosition, - filter gomatrixserverlib.EventFilter, + filter synctypes.EventFilter, ) (presences map[string]*types.PresenceInternal, err error) { presences = make(map[string]*types.PresenceInternal) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) - afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5)) + afterTS := spec.AsTimestamp(time.Now().Add(time.Minute * -5)) rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit) if err != nil { return nil, err diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index 0fcbebfcb..9ab8eece0 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -26,7 +26,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const receiptsSchema = ` @@ -98,7 +98,7 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { }.Prepare(db) } -func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { +func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) { stmt := sqlutil.TxStmt(txn, r.upsertReceipt) err = stmt.QueryRowContext(ctx, roomId, receiptType, userId, eventId, timestamp).Scan(&pos) return diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go index 6ab1f0f48..88b86ef7b 100644 --- a/syncapi/storage/postgres/send_to_device_table.go +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -88,19 +88,12 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { if err != nil { return nil, err } - if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { - return nil, err - } - if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { - return nil, err - } - if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { - return nil, err - } - if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertSendToDeviceMessageStmt, insertSendToDeviceMessageSQL}, + {&s.selectSendToDeviceMessagesStmt, selectSendToDeviceMessagesSQL}, + {&s.deleteSendToDeviceMessagesStmt, deleteSendToDeviceMessagesSQL}, + {&s.selectMaxSendToDeviceIDStmt, selectMaxSendToDeviceIDSQL}, + }.Prepare(db) } func (s *sendToDeviceStatements) InsertSendToDeviceMessage( diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 850d24a07..9f9de28d9 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -16,12 +16,12 @@ package postgres import ( + "context" "database/sql" // Import the postgres database driver. _ "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/shared" @@ -36,10 +36,10 @@ type SyncServerDatasource struct { } // NewDatabase creates a new sync server database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { +func NewDatabase(ctx context.Context, cm sqlutil.Connections, dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + if d.db, d.writer, err = cm.Connection(dbProperties); err != nil { return nil, err } accountData, err := NewPostgresAccountDataTable(d.db) @@ -111,7 +111,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created. }, ) - err = m.Up(base.Context()) + err = m.Up(ctx) if err != nil { return nil, err } diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index aeeebb1d2..47490fb03 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -23,6 +23,7 @@ import ( "github.com/tidwall/gjson" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" @@ -31,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -323,13 +325,13 @@ func (d *Database) updateRoomState( } func (d *Database) GetFilter( - ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, + ctx context.Context, target *synctypes.Filter, localpart string, filterID string, ) error { return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID) } func (d *Database) PutFilter( - ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, + ctx context.Context, localpart string, filter *synctypes.Filter, ) (string, error) { var filterID string var err error @@ -503,7 +505,7 @@ func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, } // StoreReceipt stores user receipts -func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { +func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp) return err @@ -523,10 +525,10 @@ func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID strin return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) } -func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { +func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) } -func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { +func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) } @@ -540,7 +542,7 @@ func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, igno }) } -func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { +func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS spec.Timestamp, fromSync bool) (types.StreamPosition, error) { var pos types.StreamPosition var err error _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -560,16 +562,21 @@ func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID s func (d *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) { return d.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{ - gomatrixserverlib.MRoomName, - gomatrixserverlib.MRoomTopic, + spec.MRoomName, + spec.MRoomTopic, "m.room.message", }) } func (d *Database) UpdateRelations(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { + // No need to unmarshal if the event is a redaction + if event.Type() == spec.MRoomRedaction { + return nil + } var content gomatrixserverlib.RelationContent if err := json.Unmarshal(event.Content(), &content); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) + logrus.WithError(err).Error("unable to unmarshal relation content") + return nil } switch { case content.Relations == nil: @@ -578,8 +585,6 @@ func (d *Database) UpdateRelations(ctx context.Context, event *gomatrixserverlib return nil case content.Relations.RelationType == "": return nil - case event.Type() == gomatrixserverlib.MRoomRedaction: - return nil default: return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Relations.InsertRelation( diff --git a/syncapi/storage/shared/storage_consumer_test.go b/syncapi/storage/shared/storage_consumer_test.go new file mode 100644 index 000000000..e5f734c96 --- /dev/null +++ b/syncapi/storage/shared/storage_consumer_test.go @@ -0,0 +1,103 @@ +package shared_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +func newSyncDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { + t.Helper() + + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + syncDB, err := storage.NewSyncServerDatasource(processCtx.Context(), cm, &cfg.SyncAPI.Database) + if err != nil { + t.Fatalf("failed to create sync DB: %s", err) + } + + return syncDB, closeDB +} + +func TestFilterTable(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := newSyncDB(t, dbType) + defer closeDB() + + // initially create a filter + filter := &synctypes.Filter{} + filterID, err := tab.PutFilter(context.Background(), "alice", filter) + if err != nil { + t.Fatal(err) + } + + // create the same filter again, we should receive the existing filter + secondFilterID, err := tab.PutFilter(context.Background(), "alice", filter) + if err != nil { + t.Fatal(err) + } + + if secondFilterID != filterID { + t.Fatalf("expected second filter to be the same as the first: %s vs %s", filterID, secondFilterID) + } + + // query the filter again + targetFilter := &synctypes.Filter{} + if err = tab.GetFilter(context.Background(), targetFilter, "alice", filterID); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(filter, targetFilter) { + t.Fatalf("%#v vs %#v", filter, targetFilter) + } + + // query non-existent filter + if err = tab.GetFilter(context.Background(), targetFilter, "bob", filterID); err == nil { + t.Fatalf("expected filter to not exist, but it does exist: %v", targetFilter) + } + }) +} + +func TestIgnores(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + syncDB, closeDB := newSyncDB(t, dbType) + defer closeDB() + + tab, err := syncDB.NewDatabaseTransaction(context.Background()) + if err != nil { + t.Fatal(err) + } + defer tab.Rollback() // nolint: errcheck + + ignoredUsers := &types.IgnoredUsers{List: map[string]interface{}{ + bob.ID: "", + }} + if err = tab.UpdateIgnoresForUser(context.Background(), alice.ID, ignoredUsers); err != nil { + t.Fatal(err) + } + + gotIgnoredUsers, err := tab.IgnoresForUser(context.Background(), alice.ID) + if err != nil { + t.Fatal(err) + } + + // verify the ignored users matches those we stored + if !reflect.DeepEqual(gotIgnoredUsers, ignoredUsers) { + t.Fatalf("%#v vs %#v", gotIgnoredUsers, ignoredUsers) + } + + // Bob doesn't have any ignored users, so should receive sql.ErrNoRows + if _, err = tab.IgnoresForUser(context.Background(), bob.ID); err == nil { + t.Fatalf("expected an error but got none") + } + }) +} diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 931bc9e23..f11cbb574 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -7,9 +7,11 @@ import ( "math" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -82,7 +84,7 @@ func (d *DatabaseTransaction) MaxStreamPositionForNotificationData(ctx context.C return types.StreamPosition(id), nil } -func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { +func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs) } @@ -97,11 +99,11 @@ func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membe func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID string) (*types.Summary, error) { summary := &types.Summary{Heroes: []string{}} - joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Join) + joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, spec.Join) if err != nil { return summary, err } - inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Invite) + inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, spec.Invite) if err != nil { return summary, err } @@ -109,8 +111,8 @@ func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID summary.JoinedMemberCount = &joinCount // Get the room name and canonical alias, if any - filter := gomatrixserverlib.DefaultStateFilter() - filterTypes := []string{gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias} + filter := synctypes.DefaultStateFilter() + filterTypes := []string{spec.MRoomName, spec.MRoomCanonicalAlias} filterRooms := []string{roomID} filter.Types = &filterTypes @@ -122,11 +124,11 @@ func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID for _, ev := range evs { switch ev.Type() { - case gomatrixserverlib.MRoomName: + case spec.MRoomName: if gjson.GetBytes(ev.Content(), "name").Str != "" { return summary, nil } - case gomatrixserverlib.MRoomCanonicalAlias: + case spec.MRoomCanonicalAlias: if gjson.GetBytes(ev.Content(), "alias").Str != "" { return summary, nil } @@ -134,14 +136,14 @@ func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID } // If there's no room name or canonical alias, get the room heroes, excluding the user - heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Join, gomatrixserverlib.Invite}) + heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{spec.Join, spec.Invite}) if err != nil { return summary, err } // "When no joined or invited members are available, this should consist of the banned and left users" if len(heroes) == 0 { - heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Leave, gomatrixserverlib.Ban}) + heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{spec.Leave, spec.Ban}) if err != nil { return summary, err } @@ -151,7 +153,7 @@ func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID return summary, nil } -func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) { +func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *synctypes.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) { return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomIDs, r, eventFilter, chronologicalOrder, onlySyncEvents) } @@ -210,7 +212,7 @@ func (d *DatabaseTransaction) GetStateEvent( } func (d *DatabaseTransaction) GetStateEventsForRoom( - ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, + ctx context.Context, roomID string, stateFilter *synctypes.StateFilter, ) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil) return @@ -223,7 +225,7 @@ func (d *DatabaseTransaction) GetStateEventsForRoom( // If there was an issue with the retrieval, returns an error func (d *DatabaseTransaction) GetAccountDataInRange( ctx context.Context, userID string, r types.Range, - accountDataFilterPart *gomatrixserverlib.EventFilter, + accountDataFilterPart *synctypes.EventFilter, ) (map[string][]string, types.StreamPosition, error) { return d.AccountData.SelectAccountDataInRange(ctx, d.txn, userID, r, accountDataFilterPart) } @@ -232,7 +234,7 @@ func (d *DatabaseTransaction) GetEventsInTopologicalRange( ctx context.Context, from, to *types.TopologyToken, roomID string, - filter *gomatrixserverlib.RoomEventFilter, + filter *synctypes.RoomEventFilter, backwardOrdering bool, ) (events []types.StreamEvent, err error) { var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition @@ -323,7 +325,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos( func (d *DatabaseTransaction) GetStateDeltas( ctx context.Context, device *userapi.Device, r types.Range, userID string, - stateFilter *gomatrixserverlib.StateFilter, + stateFilter *synctypes.StateFilter, ) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response @@ -348,7 +350,7 @@ func (d *DatabaseTransaction) GetStateDeltas( joinedRoomIDs := make([]string, 0, len(memberships)) for roomID, membership := range memberships { allRoomIDs = append(allRoomIDs, roomID) - if membership == gomatrixserverlib.Join { + if membership == spec.Join { joinedRoomIDs = append(joinedRoomIDs, roomID) } } @@ -414,7 +416,7 @@ func (d *DatabaseTransaction) GetStateDeltas( } if !peek.Deleted { deltas = append(deltas, types.StateDelta{ - Membership: gomatrixserverlib.Peek, + Membership: spec.Peek, StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), RoomID: peek.RoomID, }) @@ -432,7 +434,7 @@ func (d *DatabaseTransaction) GetStateDeltas( continue } - if membership == gomatrixserverlib.Join { + if membership == spec.Join { // If our membership is now join but the previous membership wasn't // then this is a "join transition", so we'll insert this room. if prevMembership != membership { @@ -471,7 +473,7 @@ func (d *DatabaseTransaction) GetStateDeltas( // join transitions above. for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, types.StateDelta{ - Membership: gomatrixserverlib.Join, + Membership: spec.Join, StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]), RoomID: joinedRoomID, NewlyJoined: newlyJoinedRooms[joinedRoomID], @@ -488,7 +490,7 @@ func (d *DatabaseTransaction) GetStateDeltas( func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( ctx context.Context, device *userapi.Device, r types.Range, userID string, - stateFilter *gomatrixserverlib.StateFilter, + stateFilter *synctypes.StateFilter, ) ([]types.StateDelta, []string, error) { // Look up all memberships for the user. We only care about rooms that a // user has ever interacted with — joined to, kicked/banned from, left. @@ -504,7 +506,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( joinedRoomIDs := make([]string, 0, len(memberships)) for roomID, membership := range memberships { allRoomIDs = append(allRoomIDs, roomID) - if membership == gomatrixserverlib.Join { + if membership == spec.Join { joinedRoomIDs = append(joinedRoomIDs, roomID) } } @@ -528,7 +530,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( return nil, nil, stateErr } deltas[peek.RoomID] = types.StateDelta{ - Membership: gomatrixserverlib.Peek, + Membership: spec.Peek, StateEvents: d.StreamEventsToEvents(device, s), RoomID: peek.RoomID, } @@ -554,7 +556,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( for roomID, stateStreamEvents := range state { for _, ev := range stateStreamEvents { if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" { - if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. + if membership != spec.Join { // We've already added full state for all joined rooms above. deltas[roomID] = types.StateDelta{ Membership: membership, MembershipPos: ev.StreamPosition, @@ -578,7 +580,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( return nil, nil, stateErr } deltas[joinedRoomID] = types.StateDelta{ - Membership: gomatrixserverlib.Join, + Membership: spec.Join, StateEvents: d.StreamEventsToEvents(device, s), RoomID: joinedRoomID, } @@ -597,7 +599,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( func (d *DatabaseTransaction) currentStateStreamEventsForRoom( ctx context.Context, roomID string, - stateFilter *gomatrixserverlib.StateFilter, + stateFilter *synctypes.StateFilter, ) ([]types.StreamEvent, error) { allState, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil) if err != nil { @@ -635,7 +637,7 @@ func (d *DatabaseTransaction) GetRoomReceipts(ctx context.Context, roomIDs []str func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) { roomIDs := make([]string, 0, len(rooms)) for roomID, membership := range rooms { - if membership != gomatrixserverlib.Join { + if membership != spec.Join { continue } roomIDs = append(roomIDs, roomID) @@ -647,7 +649,7 @@ func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs) } -func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { +func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter synctypes.EventFilter) (map[string]*types.PresenceInternal, error) { return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter) } @@ -707,7 +709,7 @@ func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) return types.StreamPosition(id), err } -func isStatefilterEmpty(filter *gomatrixserverlib.StateFilter) bool { +func isStatefilterEmpty(filter *synctypes.StateFilter) bool { if filter == nil { return true } diff --git a/syncapi/storage/shared/storage_sync_test.go b/syncapi/storage/shared/storage_sync_test.go index c56720db7..4468a7728 100644 --- a/syncapi/storage/shared/storage_sync_test.go +++ b/syncapi/storage/shared/storage_sync_test.go @@ -3,7 +3,7 @@ package shared import ( "testing" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/syncapi/synctypes" ) func Test_isStatefilterEmpty(t *testing.T) { @@ -12,7 +12,7 @@ func Test_isStatefilterEmpty(t *testing.T) { tests := []struct { name string - filter *gomatrixserverlib.StateFilter + filter *synctypes.StateFilter want bool }{ { @@ -22,42 +22,42 @@ func Test_isStatefilterEmpty(t *testing.T) { }, { name: "Empty filter is empty", - filter: &gomatrixserverlib.StateFilter{}, + filter: &synctypes.StateFilter{}, want: true, }, { name: "NotTypes is set", - filter: &gomatrixserverlib.StateFilter{ + filter: &synctypes.StateFilter{ NotTypes: &filterSet, }, }, { name: "Types is set", - filter: &gomatrixserverlib.StateFilter{ + filter: &synctypes.StateFilter{ Types: &filterSet, }, }, { name: "Senders is set", - filter: &gomatrixserverlib.StateFilter{ + filter: &synctypes.StateFilter{ Senders: &filterSet, }, }, { name: "NotSenders is set", - filter: &gomatrixserverlib.StateFilter{ + filter: &synctypes.StateFilter{ NotSenders: &filterSet, }, }, { name: "NotRooms is set", - filter: &gomatrixserverlib.StateFilter{ + filter: &synctypes.StateFilter{ NotRooms: &filterSet, }, }, { name: "ContainsURL is set", - filter: &gomatrixserverlib.StateFilter{ + filter: &synctypes.StateFilter{ ContainsURL: &boolValue, }, }, diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index de0e72dbd..b49c2f701 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -66,16 +66,11 @@ func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables if err != nil { return nil, err } - if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return nil, err - } - if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { - return nil, err - } - if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertAccountDataStmt, insertAccountDataSQL}, + {&s.selectMaxAccountDataIDStmt, selectMaxAccountDataIDSQL}, + {&s.selectAccountDataInRangeStmt, selectAccountDataInRangeSQL}, + }.Prepare(db) } func (s *accountDataStatements) InsertAccountData( @@ -94,7 +89,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( ctx context.Context, txn *sql.Tx, userID string, r types.Range, - filter *gomatrixserverlib.EventFilter, + filter *synctypes.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) stmt, params, err := prepareWithFilters( diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 35b746c5c..01dcff693 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -27,8 +27,10 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const currentRoomStateSchema = ` @@ -264,14 +266,14 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithAnyMembership( // CurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, - stateFilter *gomatrixserverlib.StateFilter, + stateFilter *synctypes.StateFilter, excludeEventIDs []string, ) ([]*gomatrixserverlib.HeaderedEvent, error) { // We're going to query members later, so remove them from this request if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers { - notTypes := &[]string{gomatrixserverlib.MRoomMember} + notTypes := &[]string{spec.MRoomMember} if stateFilter.NotTypes != nil { - *stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember) + *stateFilter.NotTypes = append(*stateFilter.NotTypes, spec.MRoomMember) } else { stateFilter.NotTypes = notTypes } diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index 5f1e980eb..8da4f299c 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -65,20 +66,15 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { s := &filterStatements{ db: db, } - if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { - return nil, err - } - if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { - return nil, err - } - if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.selectFilterStmt, selectFilterSQL}, + {&s.selectFilterIDByContentStmt, selectFilterIDByContentSQL}, + {&s.insertFilterStmt, insertFilterSQL}, + }.Prepare(db) } func (s *filterStatements) SelectFilter( - ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string, + ctx context.Context, txn *sql.Tx, target *synctypes.Filter, localpart string, filterID string, ) error { // Retrieve filter from database (stored as canonical JSON) var filterData []byte @@ -95,7 +91,7 @@ func (s *filterStatements) SelectFilter( } func (s *filterStatements) InsertFilter( - ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string, + ctx context.Context, txn *sql.Tx, filter *synctypes.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string diff --git a/syncapi/storage/sqlite3/ignores_table.go b/syncapi/storage/sqlite3/ignores_table.go index 5ee1a9fa0..bff5780b0 100644 --- a/syncapi/storage/sqlite3/ignores_table.go +++ b/syncapi/storage/sqlite3/ignores_table.go @@ -52,13 +52,10 @@ func NewSqliteIgnoresTable(db *sql.DB) (tables.Ignores, error) { return nil, err } s := &ignoresStatements{} - if s.selectIgnoresStmt, err = db.Prepare(selectIgnoresSQL); err != nil { - return nil, err - } - if s.upsertIgnoresStmt, err = db.Prepare(upsertIgnoresSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.selectIgnoresStmt, selectIgnoresSQL}, + {&s.upsertIgnoresStmt, upsertIgnoresSQL}, + }.Prepare(db) } func (s *ignoresStatements) SelectIgnores( diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 23bc68a41..b5b6ea883 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -27,8 +27,8 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -81,12 +81,6 @@ const selectRecentEventsForSyncSQL = "" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters -const selectEarlyEventsSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3" - -// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters - const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" @@ -118,7 +112,7 @@ const selectContextAfterEventSQL = "" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters -const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC" +const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type IN ($2)" const purgeEventsSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" @@ -186,7 +180,7 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *s // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, - stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, + stateFilter *synctypes.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmtSQL := strings.Replace(selectStateInRangeSQL, "($3)", sqlutil.QueryVariadicOffset(len(roomIDs), 2), 1) inputParams := []interface{}{ @@ -368,7 +362,7 @@ func (s *outputRoomEventsStatements) InsertEvent( func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, - roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + roomIDs []string, r types.Range, eventFilter *synctypes.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool, ) (map[string]types.RecentEvents, error) { var query string @@ -429,46 +423,10 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( return result, nil } -func (s *outputRoomEventsStatements) SelectEarlyEvents( - ctx context.Context, txn *sql.Tx, - roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, -) ([]types.StreamEvent, error) { - stmt, params, err := prepareWithFilters( - s.db, txn, selectEarlyEventsSQL, - []interface{}{ - roomID, r.Low(), r.High(), - }, - eventFilter.Senders, eventFilter.NotSenders, - eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc, - ) - if err != nil { - return nil, fmt.Errorf("s.prepareWithFilters: %w", err) - } - defer internal.CloseAndLogIfError(ctx, stmt, "SelectEarlyEvents: stmt.close() failed") - - rows, err := stmt.QueryContext(ctx, params...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, err - } - // The events need to be returned from oldest to latest, which isn't - // necessarily the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - return events, nil -} - // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool, + ctx context.Context, txn *sql.Tx, eventIDs []string, filter *synctypes.RoomEventFilter, preserveOrder bool, ) ([]types.StreamEvent, error) { iEventIDs := make([]interface{}, len(eventIDs)) for i := range eventIDs { @@ -477,7 +435,7 @@ func (s *outputRoomEventsStatements) SelectEvents( selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) if filter == nil { - filter = &gomatrixserverlib.RoomEventFilter{Limit: 20} + filter = &synctypes.RoomEventFilter{Limit: 20} } stmt, params, err := prepareWithFilters( s.db, txn, selectSQL, iEventIDs, @@ -581,7 +539,7 @@ func (s *outputRoomEventsStatements) SelectContextEvent( } func (s *outputRoomEventsStatements) SelectContextBeforeEvent( - ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, + ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter, ) (evts []*gomatrixserverlib.HeaderedEvent, err error) { stmt, params, err := prepareWithFilters( s.db, txn, selectContextBeforeEventSQL, @@ -623,7 +581,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( } func (s *outputRoomEventsStatements) SelectContextAfterEvent( - ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, + ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter, ) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) { stmt, params, err := prepareWithFilters( s.db, txn, selectContextAfterEventSQL, @@ -685,18 +643,18 @@ func (s *outputRoomEventsStatements) PurgeEvents( } func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { - params := make([]interface{}, len(types)) + params := make([]interface{}, len(types)+1) + params[0] = afterID for i := range types { - params[i] = types[i] + params[i+1] = types[i] } - params = append(params, afterID) - params = append(params, limit) - selectSQL := strings.Replace(selectSearchSQL, "($1)", sqlutil.QueryVariadic(len(types)), 1) - stmt, err := s.db.Prepare(selectSQL) + selectSQL := strings.Replace(selectSearchSQL, "($2)", sqlutil.QueryVariadicOffset(len(types), 1), 1) + stmt, params, err := prepareWithFilters(s.db, txn, selectSQL, params, nil, nil, nil, nil, nil, nil, int(limit), FilterOrderAsc) if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, stmt, "selectEvents: stmt.close() failed") rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) if err != nil { diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index 7641de92f..573fbad6c 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -20,10 +20,11 @@ import ( "strings" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -112,7 +113,7 @@ func (p *presenceStatements) UpsertPresence( userID string, statusMsg *string, presence types.Presence, - lastActiveTS gomatrixserverlib.Timestamp, + lastActiveTS spec.Timestamp, fromSync bool, ) (pos types.StreamPosition, err error) { pos, err = p.streamIDStatements.nextPresenceID(ctx, txn) @@ -180,11 +181,11 @@ func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) // GetPresenceAfter returns the changes presences after a given stream id func (p *presenceStatements) GetPresenceAfter( ctx context.Context, txn *sql.Tx, - after types.StreamPosition, filter gomatrixserverlib.EventFilter, + after types.StreamPosition, filter synctypes.EventFilter, ) (presences map[string]*types.PresenceInternal, err error) { presences = make(map[string]*types.PresenceInternal) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) - afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5)) + afterTS := spec.AsTimestamp(time.Now().Add(time.Minute * -5)) rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index ca3d80fb4..b973903bd 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -25,7 +25,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const receiptsSchema = ` @@ -97,7 +97,7 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re } // UpsertReceipt creates new user receipts -func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { +func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) { pos, err = r.streamIDStatements.nextReceiptID(ctx, txn) if err != nil { return diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 0da06506c..998a063cd 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -88,19 +88,12 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { if err != nil { return nil, err } - if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { - return nil, err - } - if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { - return nil, err - } - if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { - return nil, err - } - if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertSendToDeviceMessageStmt, insertSendToDeviceMessageSQL}, + {&s.selectSendToDeviceMessagesStmt, selectSendToDeviceMessagesSQL}, + {&s.deleteSendToDeviceMessagesStmt, deleteSendToDeviceMessagesSQL}, + {&s.selectMaxSendToDeviceIDStmt, selectMaxSendToDeviceIDSQL}, + }.Prepare(db) } func (s *sendToDeviceStatements) InsertSendToDeviceMessage( diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index a4bba508e..b51eccf55 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -47,10 +47,9 @@ func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) { if err != nil { return } - if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.increaseStreamIDStmt, increaseStreamIDStmt}, + }.Prepare(db) } func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 510546909..3f1ca355e 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" @@ -37,13 +36,14 @@ type SyncServerDatasource struct { // NewDatabase creates a new sync server database // nolint: gocyclo -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + + if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { return nil, err } - if err = d.prepare(base.Context()); err != nil { + if err = d.prepare(ctx); err != nil { return nil, err } return &d, nil diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go index 5b20c6cc2..8714ec5e2 100644 --- a/syncapi/storage/storage.go +++ b/syncapi/storage/storage.go @@ -18,21 +18,22 @@ package storage import ( + "context" "fmt" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/postgres" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" ) // NewSyncServerDatasource opens a database connection. -func NewSyncServerDatasource(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { +func NewSyncServerDatasource(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) + return sqlite3.NewDatabase(ctx, conMan, dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties) + return postgres.NewDatabase(ctx, conMan, dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 05d498bc2..2cc1378bf 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -9,27 +9,29 @@ import ( "reflect" "testing" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) var ctx = context.Background() -func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func(), func()) { +func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - base, closeBase := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.NewSyncServerDatasource(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.NewSyncServerDatasource(context.Background(), cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { t.Fatalf("NewSyncServerDatasource returned %s", err) } - return db, close, closeBase + return db, close } func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { @@ -55,9 +57,8 @@ func TestWriteEvents(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { alice := test.NewUser(t) r := test.NewRoom(t, alice) - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() MustWriteEvents(t, db, r.Events()) }) } @@ -76,9 +77,8 @@ func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.Dat // These tests assert basic functionality of RecentEvents for PDUs func TestRecentEventsPDU(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() alice := test.NewUser(t) // dummy room to make sure SQL queries are filtering on room ID MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) @@ -155,7 +155,7 @@ func TestRecentEventsPDU(t *testing.T) { for i := range testCases { tc := testCases[i] t.Run(tc.Name, func(st *testing.T) { - var filter gomatrixserverlib.RoomEventFilter + var filter synctypes.RoomEventFilter var gotEvents map[string]types.RecentEvents var limited bool filter.Limit = tc.Limit @@ -191,9 +191,8 @@ func TestRecentEventsPDU(t *testing.T) { // The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token func TestGetEventsInRangeWithTopologyToken(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() alice := test.NewUser(t) r := test.NewRoom(t, alice) for i := 0; i < 10; i++ { @@ -209,7 +208,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { to := types.TopologyToken{} // backpaginate 5 messages starting at the latest position. - filter := &gomatrixserverlib.RoomEventFilter{Limit: 5} + filter := &synctypes.RoomEventFilter{Limit: 5} paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) if err != nil { t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) @@ -276,9 +275,8 @@ func TestStreamToTopologicalPosition(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() txn, err := db.NewDatabaseTransaction(ctx) if err != nil { @@ -514,9 +512,8 @@ func TestSendToDeviceBehaviour(t *testing.T) { bob := test.NewUser(t) deviceID := "one" test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() // At this point there should be no messages. We haven't sent anything // yet. @@ -781,7 +778,7 @@ func TestRoomSummary(t *testing.T) { name: "invited user", wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) }, @@ -790,10 +787,10 @@ func TestRoomSummary(t *testing.T) { name: "invited user, but declined", wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(bob.ID)) }, @@ -802,10 +799,10 @@ func TestRoomSummary(t *testing.T) { name: "joined user after invitation", wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) }, @@ -814,10 +811,10 @@ func TestRoomSummary(t *testing.T) { name: "multiple joined user", wantSummary: &types.Summary{JoinedMemberCount: pointer(3), InvitedMemberCount: pointer(0), Heroes: []string{charlie.ID, bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(charlie.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) }, @@ -826,10 +823,10 @@ func TestRoomSummary(t *testing.T) { name: "multiple joined/invited user", wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID, bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(charlie.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) }, @@ -838,13 +835,13 @@ func TestRoomSummary(t *testing.T) { name: "multiple joined/invited/left user", wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(charlie.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(bob.ID)) }, @@ -853,10 +850,10 @@ func TestRoomSummary(t *testing.T) { name: "leaving user after joining", wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(bob.ID)) }, @@ -866,7 +863,7 @@ func TestRoomSummary(t *testing.T) { wantSummary: &types.Summary{JoinedMemberCount: pointer(len(moreUserIDs) + 1), InvitedMemberCount: pointer(0), Heroes: moreUserIDs[:5]}, additionalEvents: func(t *testing.T, room *test.Room) { for _, x := range moreUsers { - room.CreateAndInsert(t, x, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, x, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(x.ID)) } @@ -876,10 +873,10 @@ func TestRoomSummary(t *testing.T) { name: "canonical alias set", wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomCanonicalAlias, map[string]interface{}{ "alias": "myalias", }, test.WithStateKey("")) }, @@ -888,10 +885,10 @@ func TestRoomSummary(t *testing.T) { name: "room name set", wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomName, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomName, map[string]interface{}{ "name": "my room name", }, test.WithStateKey("")) }, @@ -899,9 +896,8 @@ func TestRoomSummary(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -938,12 +934,9 @@ func TestRecentEvents(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - filter := gomatrixserverlib.DefaultRoomEventFilter() - db, close, closeBase := MustCreateDatabase(t, dbType) - t.Cleanup(func() { - close() - closeBase() - }) + filter := synctypes.DefaultRoomEventFilter() + db, close := MustCreateDatabase(t, dbType) + t.Cleanup(close) MustWriteEvents(t, db, room1.Events()) MustWriteEvents(t, db, room2.Events()) diff --git a/syncapi/storage/storage_wasm.go b/syncapi/storage/storage_wasm.go index c15444743..db0b173bb 100644 --- a/syncapi/storage/storage_wasm.go +++ b/syncapi/storage/storage_wasm.go @@ -15,18 +15,19 @@ package storage import ( + "context" "fmt" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" ) // NewPublicRoomsServerDatabase opens a database connection. -func NewSyncServerDatasource(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { +func NewSyncServerDatasource(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) + return sqlite3.NewDatabase(ctx, conMan, dbProperties) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/syncapi/storage/tables/current_room_state_test.go b/syncapi/storage/tables/current_room_state_test.go index c7af4f977..7d4ec812c 100644 --- a/syncapi/storage/tables/current_room_state_test.go +++ b/syncapi/storage/tables/current_room_state_test.go @@ -11,9 +11,10 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/postgres" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) { @@ -94,7 +95,7 @@ func TestCurrentRoomStateTable(t *testing.T) { func testCurrentState(t *testing.T, ctx context.Context, txn *sql.Tx, tab tables.CurrentRoomState, room *test.Room) { t.Run("test currentState", func(t *testing.T) { // returns the complete state of the room with a default filter - filter := gomatrixserverlib.DefaultStateFilter() + filter := synctypes.DefaultStateFilter() evs, err := tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) if err != nil { t.Fatal(err) @@ -114,7 +115,7 @@ func testCurrentState(t *testing.T, ctx context.Context, txn *sql.Tx, tab tables t.Fatalf("expected %d state events, got %d", expectCount, gotCount) } // same as above, but with existing NotTypes defined - notTypes := []string{gomatrixserverlib.MRoomMember} + notTypes := []string{spec.MRoomMember} filter.NotTypes = ¬Types evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) if err != nil { diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 727d6bf2c..6384fb914 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -19,16 +19,18 @@ import ( "database/sql" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" ) type AccountData interface { InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error) // SelectAccountDataInRange returns a map of room ID to a list of `dataType`. - SelectAccountDataInRange(ctx context.Context, txn *sql.Tx, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error) + SelectAccountDataInRange(ctx context.Context, txn *sql.Tx, userID string, r types.Range, accountDataEventFilter *synctypes.EventFilter) (data map[string][]string, pos types.StreamPosition, err error) SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error) } @@ -53,7 +55,7 @@ type Peeks interface { } type Events interface { - SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string) (map[string]map[string]bool, map[string]types.StreamEvent, error) + SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *synctypes.StateFilter, roomIDs []string) (map[string]map[string]bool, map[string]types.StreamEvent, error) SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error) InsertEvent( ctx context.Context, txn *sql.Tx, @@ -66,17 +68,15 @@ type Events interface { // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`. - SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) - // SelectEarlyEvents returns the earliest events in the given room. - SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) - SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) + SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomIDs []string, r types.Range, eventFilter *synctypes.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) + SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *synctypes.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) - SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) - SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) + SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) PurgeEvents(ctx context.Context, txn *sql.Tx, roomID string) error ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) @@ -107,7 +107,7 @@ type CurrentRoomState interface { DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error DeleteRoomStateForRoom(ctx context.Context, txn *sql.Tx, roomID string) error // SelectCurrentState returns all the current state events for the given room. - SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) + SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *synctypes.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) // SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. @@ -179,12 +179,12 @@ type SendToDevice interface { } type Filter interface { - SelectFilter(ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string) error - InsertFilter(ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) + SelectFilter(ctx context.Context, txn *sql.Tx, target *synctypes.Filter, localpart string, filterID string) error + InsertFilter(ctx context.Context, txn *sql.Tx, filter *synctypes.Filter, localpart string) (filterID string, err error) } type Receipts interface { - UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) + UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) PurgeReceipts(ctx context.Context, txn *sql.Tx, roomID string) error @@ -215,10 +215,10 @@ type Ignores interface { } type Presence interface { - UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) + UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS spec.Timestamp, fromSync bool) (pos types.StreamPosition, err error) GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) - GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) + GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter synctypes.EventFilter) (presences map[string]*types.PresenceInternal, err error) } type Relations interface { diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go index df593ae78..2e8375c3d 100644 --- a/syncapi/storage/tables/memberships_test.go +++ b/syncapi/storage/tables/memberships_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -65,7 +66,7 @@ func TestMembershipsTable(t *testing.T) { u := test.NewUser(t) users = append(users, u.ID) - ev := room.CreateAndInsert(t, u, gomatrixserverlib.MRoomMember, map[string]interface{}{ + ev := room.CreateAndInsert(t, u, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(u.ID)) userEvents = append(userEvents, ev) @@ -92,7 +93,7 @@ func TestMembershipsTable(t *testing.T) { func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) { t.Run("membership counts are correct", func(t *testing.T) { // After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users) - count, err := table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 10) + count, err := table.SelectMembershipCount(ctx, nil, room.ID, spec.Join, 10) if err != nil { t.Fatalf("failed to get membership count: %s", err) } @@ -102,7 +103,7 @@ func testMembershipCount(t *testing.T, ctx context.Context, table tables.Members } // After 100 events, we should have all 11 users - count, err = table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 100) + count, err = table.SelectMembershipCount(ctx, nil, room.ID, spec.Join, 100) if err != nil { t.Fatalf("failed to get membership count: %s", err) } @@ -126,12 +127,12 @@ func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, mem if pos != expectedPos { t.Fatalf("expected pos to be %d, got %d", expectedPos, pos) } - if membership != gomatrixserverlib.Join { + if membership != spec.Join { t.Fatalf("expected membership to be join, got %s", membership) } // Create a new event which gets upserted and should not cause issues - ev := room.CreateAndInsert(t, user, gomatrixserverlib.MRoomMember, map[string]interface{}{ - "membership": gomatrixserverlib.Join, + ev := room.CreateAndInsert(t, user, spec.MRoomMember, map[string]interface{}{ + "membership": spec.Join, }, test.WithStateKey(user.ID)) // Insert the same event again, but with different positions, which should get updated if err = table.UpsertMembership(ctx, nil, ev, 2, 2); err != nil { @@ -147,7 +148,7 @@ func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, mem if pos != expectedPos { t.Fatalf("expected pos to be %d, got %d", expectedPos, pos) } - if membership != gomatrixserverlib.Join { + if membership != spec.Join { t.Fatalf("expected membership to be join, got %s", membership) } @@ -155,7 +156,7 @@ func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, mem if membership, _, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1); err != nil { t.Fatalf("failed to select membership: %s", err) } - if membership != gomatrixserverlib.Leave { + if membership != spec.Leave { t.Fatalf("expected membership to be leave, got %s", membership) } }) diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go index bdb17ae20..9b755dc85 100644 --- a/syncapi/storage/tables/output_room_events_test.go +++ b/syncapi/storage/tables/output_room_events_test.go @@ -12,8 +12,10 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/postgres" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) { @@ -84,7 +86,7 @@ func TestOutputRoomEventsTable(t *testing.T) { } wantEventID := []string{urlEv.EventID()} t := true - gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true) + gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &synctypes.RoomEventFilter{Limit: 1, ContainsURL: &t}, true) if err != nil { return fmt.Errorf("failed to SelectEvents: %s", err) } @@ -103,3 +105,53 @@ func TestOutputRoomEventsTable(t *testing.T) { } }) } + +func TestReindex(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + room.CreateAndInsert(t, alice, spec.MRoomName, map[string]interface{}{ + "name": "my new room name", + }, test.WithStateKey("")) + + room.CreateAndInsert(t, alice, spec.MRoomTopic, map[string]interface{}{ + "topic": "my new room topic", + }, test.WithStateKey("")) + + room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{ + "msgbody": "my room message", + "type": "m.text", + }) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := newOutputRoomEventsTable(t, dbType) + defer close() + err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + for _, ev := range room.Events() { + _, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false, gomatrixserverlib.HistoryVisibilityShared) + if err != nil { + return fmt.Errorf("failed to InsertEvent: %s", err) + } + } + + return nil + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + events, err := tab.ReIndex(ctx, nil, 10, 0, []string{ + spec.MRoomName, + spec.MRoomTopic, + "m.room.message"}) + if err != nil { + t.Fatal(err) + } + + wantEventCount := 3 + if len(events) != wantEventCount { + t.Fatalf("expected %d events, got %d", wantEventCount, len(events)) + } + }) +} diff --git a/syncapi/storage/tables/presence_table_test.go b/syncapi/storage/tables/presence_table_test.go index dce0c695a..d8161836b 100644 --- a/syncapi/storage/tables/presence_table_test.go +++ b/syncapi/storage/tables/presence_table_test.go @@ -7,15 +7,15 @@ import ( "testing" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/postgres" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib/spec" ) func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) { @@ -51,7 +51,7 @@ func TestPresence(t *testing.T) { ctx := context.Background() statusMsg := "Hello World!" - timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + timestamp := spec.AsTimestamp(time.Now()) var txn *sql.Tx test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -96,7 +96,7 @@ func TestPresence(t *testing.T) { } // This should return only Bobs status - presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10}) + presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, synctypes.EventFilter{Limit: 10}) if err != nil { t.Error(err) } diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 3593a6563..51f2a3d30 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -3,11 +3,11 @@ package streams import ( "context" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type AccountDataStreamProvider struct { @@ -82,9 +82,9 @@ func (p *AccountDataStreamProvider) IncrementalSync( if globalData, ok := dataRes.GlobalAccountData[dataType]; ok { req.Response.AccountData.Events = append( req.Response.AccountData.Events, - gomatrixserverlib.ClientEvent{ + synctypes.ClientEvent{ Type: dataType, - Content: gomatrixserverlib.RawJSON(globalData), + Content: spec.RawJSON(globalData), }, ) } @@ -96,9 +96,9 @@ func (p *AccountDataStreamProvider) IncrementalSync( } joinData.AccountData.Events = append( joinData.AccountData.Events, - gomatrixserverlib.ClientEvent{ + synctypes.ClientEvent{ Type: dataType, - Content: gomatrixserverlib.RawJSON(roomData), + Content: spec.RawJSON(roomData), }, ) req.Response.Rooms.Join[roomID] = joinData diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index e4de30e1c..becd863a9 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -8,9 +8,10 @@ import ( "strconv" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -78,22 +79,22 @@ func (p *InviteStreamProvider) IncrementalSync( membership, _, err := snapshot.SelectMembershipForUser(ctx, roomID, req.Device.UserID, math.MaxInt64) // Skip if the user is an existing member of the room. // Otherwise, the NewLeaveResponse will eject the user from the room unintentionally - if membership == gomatrixserverlib.Join || + if membership == spec.Join || err != nil { continue } lr := types.NewLeaveResponse() h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) - lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ + lr.Timeline.Events = append(lr.Timeline.Events, synctypes.ClientEvent{ // fake event ID which muxes in the to position EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + OriginServerTS: spec.AsTimestamp(time.Now()), RoomID: roomID, Sender: req.Device.UserID, StateKey: &req.Device.UserID, Type: "m.room.member", - Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), + Content: spec.RawJSON(`{"membership":"leave"}`), }) req.Response.Rooms.Leave[roomID] = lr } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 6af25c028..41c584812 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -10,8 +10,10 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/gomatrixserverlib" @@ -69,7 +71,7 @@ func (p *PDUStreamProvider) CompleteSync( } // Extract room state and recent events for all rooms the user is joined to. - joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join) + joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, spec.Join) if err != nil { req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed") return from @@ -109,7 +111,7 @@ func (p *PDUStreamProvider) CompleteSync( continue } req.Response.Rooms.Join[roomID] = jr - req.Rooms[roomID] = gomatrixserverlib.Join + req.Rooms[roomID] = spec.Join } // Add peeked rooms. @@ -184,7 +186,7 @@ func (p *PDUStreamProvider) IncrementalSync( } for _, roomID := range syncJoinedRooms { - req.Rooms[roomID] = gomatrixserverlib.Join + req.Rooms[roomID] = spec.Join } if len(stateDeltas) == 0 { @@ -242,8 +244,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( device *userapi.Device, r types.Range, delta types.StateDelta, - eventFilter *gomatrixserverlib.RoomEventFilter, - stateFilter *gomatrixserverlib.StateFilter, + eventFilter *synctypes.RoomEventFilter, + stateFilter *synctypes.StateFilter, req *types.SyncRequest, ) (types.StreamPosition, error) { var err error @@ -311,8 +313,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( hasMembershipChange := false for _, recentEvent := range recentStreamEvents { - if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil { - if membership, _ := recentEvent.Membership(); membership == gomatrixserverlib.Join { + if recentEvent.Type() == spec.MRoomMember && recentEvent.StateKey() != nil { + if membership, _ := recentEvent.Membership(); membership == spec.Join { req.MembershipChanges[*recentEvent.StateKey()] = struct{}{} } hasMembershipChange = true @@ -356,7 +358,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } switch delta.Membership { - case gomatrixserverlib.Join: + case spec.Join: jr := types.NewJoinResponse() if hasMembershipChange { jr.Summary, err = snapshot.GetRoomSummary(ctx, delta.RoomID, device.UserID) @@ -365,33 +367,33 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) + jr.Timeline.Events = synctypes.HeaderedToClientEvents(events, synctypes.FormatSync) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) + jr.State.Events = synctypes.HeaderedToClientEvents(delta.StateEvents, synctypes.FormatSync) req.Response.Rooms.Join[delta.RoomID] = jr - case gomatrixserverlib.Peek: + case spec.Peek: jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) + jr.Timeline.Events = synctypes.HeaderedToClientEvents(recentEvents, synctypes.FormatSync) jr.Timeline.Limited = limited - jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) + jr.State.Events = synctypes.HeaderedToClientEvents(delta.StateEvents, synctypes.FormatSync) req.Response.Rooms.Peek[delta.RoomID] = jr - case gomatrixserverlib.Leave: + case spec.Leave: fallthrough // transitions to leave are the same as ban - case gomatrixserverlib.Ban: + case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) + lr.Timeline.Events = synctypes.HeaderedToClientEvents(events, synctypes.FormatSync) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) + lr.State.Events = synctypes.HeaderedToClientEvents(delta.StateEvents, synctypes.FormatSync) req.Response.Rooms.Leave[delta.RoomID] = lr } @@ -420,7 +422,7 @@ func applyHistoryVisibilityFilter( // Only get the state again if there are state events in the timeline if len(stateTypes) > 0 { - filter := gomatrixserverlib.DefaultStateFilter() + filter := synctypes.DefaultStateFilter() filter.Types = &stateTypes filter.Senders = &senders stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) @@ -451,7 +453,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, snapshot storage.DatabaseTransaction, roomID string, - stateFilter *gomatrixserverlib.StateFilter, + stateFilter *synctypes.StateFilter, wantFullState bool, device *userapi.Device, isPeek bool, @@ -526,7 +528,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( event := events[0] // If this is the beginning of the room, we can't go back further. We're going to return // the TopologyToken from the last event instead. (Synapse returns the /sync next_Batch) - if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + if event.Type() == spec.MRoomCreate && event.StateKeyEquals("") { event = events[len(events)-1] } backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, event.EventID()) @@ -541,17 +543,17 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) + jr.Timeline.Events = synctypes.HeaderedToClientEvents(events, synctypes.FormatSync) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) + jr.State.Events = synctypes.HeaderedToClientEvents(stateEvents, synctypes.FormatSync) return jr, nil } func (p *PDUStreamProvider) lazyLoadMembers( ctx context.Context, snapshot storage.DatabaseTransaction, roomID string, - incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter, + incremental, limited bool, stateFilter *synctypes.StateFilter, device *userapi.Device, timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, ) ([]*gomatrixserverlib.HeaderedEvent, error) { @@ -574,7 +576,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( newStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEvents)) // Remove existing membership events we don't care about, e.g. users not in the timeline.events for _, event := range stateEvents { - if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil { + if event.Type() == spec.MRoomMember && event.StateKey() != nil { // If this is a gapped incremental sync, we still want this membership isGappedIncremental := limited && incremental // We want this users membership event, keep it in the list @@ -595,9 +597,9 @@ func (p *PDUStreamProvider) lazyLoadMembers( wantUsers = append(wantUsers, userID) } // Query missing membership events - filter := gomatrixserverlib.DefaultStateFilter() + filter := synctypes.DefaultStateFilter() filter.Senders = &wantUsers - filter.Types = &[]string{gomatrixserverlib.MRoomMember} + filter.Types = &[]string{spec.MRoomMember} memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter) if err != nil { return stateEvents, err @@ -612,7 +614,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( // addIgnoredUsersToFilter adds ignored users to the eventfilter and // the syncreq itself for further use in streams. -func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { +func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, eventFilter *synctypes.RoomEventFilter) error { ignores, err := snapshot.IgnoresForUser(ctx, req.Device.UserID) if err != nil { if err == sql.ErrNoRows { diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 445e46b3a..324240667 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -20,12 +20,13 @@ import ( "fmt" "sync" - "github.com/matrix-org/gomatrixserverlib" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) type PresenceStreamProvider struct { @@ -65,7 +66,7 @@ func (p *PresenceStreamProvider) IncrementalSync( from, to types.StreamPosition, ) types.StreamPosition { // We pull out a larger number than the filter asks for, since we're filtering out events later - presences, err := snapshot.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000}) + presences, err := snapshot.PresenceAfter(ctx, from, synctypes.EventFilter{Limit: 1000}) if err != nil { req.Log.WithError(err).Error("p.DB.PresenceAfter failed") return from @@ -130,10 +131,10 @@ func (p *PresenceStreamProvider) IncrementalSync( return from } - req.Response.Presence.Events = append(req.Response.Presence.Events, gomatrixserverlib.ClientEvent{ + req.Response.Presence.Events = append(req.Response.Presence.Events, synctypes.ClientEvent{ Content: content, Sender: presence.UserID, - Type: gomatrixserverlib.MPresence, + Type: spec.MPresence, }) if presence.StreamPos > lastPos { lastPos = presence.StreamPos @@ -202,11 +203,11 @@ func joinedRooms(res *types.Response, userID string) []string { return roomIDs } -func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID string) bool { +func membershipEventPresent(events []synctypes.ClientEvent, userID string) bool { for _, ev := range events { // it's enough to know that we have our member event here, don't need to check membership content // as it's implied by being in the respective section of the sync response. - if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { + if ev.Type == spec.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { // ignore e.g. join -> join changes if gjson.GetBytes(ev.Unsigned, "prev_content.membership").Str == gjson.GetBytes(ev.Content, "membership").Str { continue diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index 16a81e833..ed52dc5c7 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -4,9 +4,10 @@ import ( "context" "encoding/json" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -45,7 +46,7 @@ func (p *ReceiptStreamProvider) IncrementalSync( ) types.StreamPosition { var joinedRooms []string for roomID, membership := range req.Rooms { - if membership == gomatrixserverlib.Join { + if membership == spec.Join { joinedRooms = append(joinedRooms, roomID) } } @@ -86,8 +87,8 @@ func (p *ReceiptStreamProvider) IncrementalSync( jr = types.NewJoinResponse() } - ev := gomatrixserverlib.ClientEvent{ - Type: gomatrixserverlib.MReceipt, + ev := synctypes.ClientEvent{ + Type: spec.MReceipt, } content := make(map[string]ReceiptMRead) for _, receipt := range receipts { @@ -118,5 +119,5 @@ type ReceiptMRead struct { } type ReceiptTS struct { - TS gomatrixserverlib.Timestamp `json:"ts"` + TS spec.Timestamp `json:"ts"` } diff --git a/syncapi/streams/stream_typing.go b/syncapi/streams/stream_typing.go index 84c199b39..15500a470 100644 --- a/syncapi/streams/stream_typing.go +++ b/syncapi/streams/stream_typing.go @@ -4,11 +4,11 @@ import ( "context" "encoding/json" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) type TypingStreamProvider struct { @@ -32,7 +32,7 @@ func (p *TypingStreamProvider) IncrementalSync( ) types.StreamPosition { var err error for roomID, membership := range req.Rooms { - if membership != gomatrixserverlib.Join { + if membership != spec.Join { continue } @@ -51,8 +51,8 @@ func (p *TypingStreamProvider) IncrementalSync( typingUsers = append(typingUsers, users[i]) } } - ev := gomatrixserverlib.ClientEvent{ - Type: gomatrixserverlib.MTyping, + ev := synctypes.ClientEvent{ + Type: spec.MTyping, } ev.Content, err = json.Marshal(map[string]interface{}{ "user_ids": typingUsers, diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index e5e5fdb5b..617342331 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -28,6 +28,7 @@ import ( "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -49,7 +50,7 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat } // Create a default filter and apply a stored filter on top of it (if specified) - filter := gomatrixserverlib.DefaultFilter() + filter := synctypes.DefaultFilter() filterQuery := req.URL.Query().Get("filter") if filterQuery != "" { if filterQuery[0] == '{' { diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 68f913871..6baaff3c8 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -25,7 +25,7 @@ import ( "sync" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" @@ -61,7 +61,7 @@ type PresencePublisher interface { } type PresenceConsumer interface { - EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts gomatrixserverlib.Timestamp, fromSync bool) + EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts spec.Timestamp, fromSync bool) } // NewRequestPool makes a new RequestPool @@ -138,7 +138,7 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user newPresence := types.PresenceInternal{ Presence: presenceID, UserID: userID, - LastActiveTS: gomatrixserverlib.AsTimestamp(time.Now()), + LastActiveTS: spec.AsTimestamp(time.Now()), } // ensure we also send the current status_msg to federated servers and not nil @@ -170,7 +170,7 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user // the /sync response else we may not return presence: online immediately. rp.consumer.EmitPresence( context.Background(), userID, presenceID, newPresence.ClientFields.StatusMsg, - gomatrixserverlib.AsTimestamp(time.Now()), true, + spec.AsTimestamp(time.Now()), true, ) } diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index faa0b49c6..93be46d01 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -7,8 +7,9 @@ import ( "time" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type dummyPublisher struct { @@ -25,7 +26,7 @@ func (d *dummyPublisher) SendPresence(userID string, presence types.Presence, st type dummyDB struct{} -func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { +func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS spec.Timestamp, fromSync bool) (types.StreamPosition, error) { return 0, nil } @@ -33,7 +34,7 @@ func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.Pr return []*types.PresenceInternal{}, nil } -func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { +func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter synctypes.EventFilter) (map[string]*types.PresenceInternal, error) { return map[string]*types.PresenceInternal{}, nil } @@ -43,7 +44,7 @@ func (d dummyDB) MaxStreamPositionForPresence(ctx context.Context) (types.Stream type dummyConsumer struct{} -func (d dummyConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts gomatrixserverlib.Timestamp, fromSync bool) { +func (d dummyConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts spec.Timestamp, fromSync bool) { } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 153f7af5d..ecbe05dd8 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -17,12 +17,16 @@ package syncapi import ( "context" + "github.com/matrix-org/dendrite/internal/fulltext" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -38,45 +42,57 @@ import ( // AddPublicRoutes sets up and registers HTTP handlers for the SyncAPI // component. func AddPublicRoutes( - base *base.BaseDendrite, + processContext *process.ProcessContext, + routers httputil.Routers, + dendriteCfg *config.Dendrite, + cm sqlutil.Connections, + natsInstance *jetstream.NATSInstance, userAPI userapi.SyncUserAPI, rsAPI api.SyncRoomserverAPI, + caches caching.LazyLoadCache, + enableMetrics bool, ) { - cfg := &base.Cfg.SyncAPI + js, natsClient := natsInstance.Prepare(processContext, &dendriteCfg.Global.JetStream) - js, natsClient := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - - syncDB, err := storage.NewSyncServerDatasource(base, &cfg.Database) + syncDB, err := storage.NewSyncServerDatasource(processContext.Context(), cm, &dendriteCfg.SyncAPI.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to sync db") } eduCache := caching.NewTypingCache() notifier := notifier.NewNotifier() - streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, base.Caches, notifier) + streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { logrus.WithError(err).Panicf("failed to load notifier ") } + var fts *fulltext.Search + if dendriteCfg.SyncAPI.Fulltext.Enabled { + fts, err = fulltext.New(processContext, dendriteCfg.SyncAPI.Fulltext) + if err != nil { + logrus.WithError(err).Panicf("failed to create full text") + } + } + federationPresenceProducer := &producers.FederationAPIPresenceProducer{ - Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), + Topic: dendriteCfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), JetStream: js, } presenceConsumer := consumers.NewPresenceConsumer( - base.ProcessContext, cfg, js, natsClient, syncDB, + processContext, &dendriteCfg.SyncAPI, js, natsClient, syncDB, notifier, streams.PresenceStreamProvider, userAPI, ) - requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics) + requestPool := sync.NewRequestPool(syncDB, &dendriteCfg.SyncAPI, userAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, enableMetrics) if err = presenceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start presence consumer") } keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( - base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + processContext, &dendriteCfg.SyncAPI, dendriteCfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), js, rsAPI, syncDB, notifier, streams.DeviceListStreamProvider, ) @@ -85,51 +101,51 @@ func AddPublicRoutes( } roomConsumer := consumers.NewOutputRoomEventConsumer( - base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider, - streams.InviteStreamProvider, rsAPI, base.Fulltext, + processContext, &dendriteCfg.SyncAPI, js, syncDB, notifier, streams.PDUStreamProvider, + streams.InviteStreamProvider, rsAPI, fts, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") } clientConsumer := consumers.NewOutputClientDataConsumer( - base.ProcessContext, cfg, js, natsClient, syncDB, notifier, - streams.AccountDataStreamProvider, base.Fulltext, + processContext, &dendriteCfg.SyncAPI, js, natsClient, syncDB, notifier, + streams.AccountDataStreamProvider, fts, ) if err = clientConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start client data consumer") } notificationConsumer := consumers.NewOutputNotificationDataConsumer( - base.ProcessContext, cfg, js, syncDB, notifier, streams.NotificationDataStreamProvider, + processContext, &dendriteCfg.SyncAPI, js, syncDB, notifier, streams.NotificationDataStreamProvider, ) if err = notificationConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start notification data consumer") } typingConsumer := consumers.NewOutputTypingEventConsumer( - base.ProcessContext, cfg, js, eduCache, notifier, streams.TypingStreamProvider, + processContext, &dendriteCfg.SyncAPI, js, eduCache, notifier, streams.TypingStreamProvider, ) if err = typingConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start typing consumer") } sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( - base.ProcessContext, cfg, js, syncDB, userAPI, notifier, streams.SendToDeviceStreamProvider, + processContext, &dendriteCfg.SyncAPI, js, syncDB, userAPI, notifier, streams.SendToDeviceStreamProvider, ) if err = sendToDeviceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start send-to-device consumer") } receiptConsumer := consumers.NewOutputReceiptEventConsumer( - base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, + processContext, &dendriteCfg.SyncAPI, js, syncDB, notifier, streams.ReceiptStreamProvider, ) if err = receiptConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start receipts consumer") } routing.Setup( - base.PublicClientAPIMux, requestPool, syncDB, userAPI, - rsAPI, cfg, base.Caches, base.Fulltext, + routers.Client, requestPool, syncDB, userAPI, + rsAPI, &dendriteCfg.SyncAPI, caches, fts, ) } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index e748660f7..c120f8828 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -10,16 +10,23 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/syncapi/routing" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/tidwall/gjson" + "github.com/matrix-org/dendrite/syncapi/routing" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" + "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" @@ -111,13 +118,17 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { AccountType: userapi.AccountTypeUser, } - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} defer close() - jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - msgs := toNATSMsgs(t, base, room.Events()...) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}) + jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) + msgs := toNATSMsgs(t, cfg, room.Events()...) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { @@ -152,7 +163,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { }, } - syncUntil(t, base, alice.AccessToken, false, func(syncBody string) bool { + syncUntil(t, routers, alice.AccessToken, false, func(syncBody string) bool { // wait for the last sent eventID to come down sync path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) return gjson.Get(syncBody, path).Exists() @@ -160,7 +171,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { for _, tc := range testCases { w := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(w, tc.req) + routers.Client.ServeHTTP(w, tc.req) if w.Code != tc.wantCode { t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) } @@ -203,25 +214,29 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { AccountType: userapi.AccountTypeUser, } - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) defer close() + natsInstance := jetstream.NATSInstance{} - jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) // order is: // m.room.create // m.room.member // m.room.power_levels // m.room.join_rules // m.room.history_visibility - msgs := toNATSMsgs(t, base, room.Events()...) + msgs := toNATSMsgs(t, cfg, room.Events()...) sinceTokens := make([]string, len(msgs)) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) for i, msg := range msgs { testrig.MustPublishMsgs(t, jsctx, msg) time.Sleep(100 * time.Millisecond) w := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ + routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, "timeout": "0", }))) @@ -251,7 +266,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { t.Logf("waited for events to be consumed; syncing with %v", sinceTokens) for i, since := range sinceTokens { w := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ + routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, "timeout": "0", "since": since, @@ -293,16 +308,20 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { AccountType: userapi.AccountTypeUser, } - base, close := testrig.CreateBaseDendrite(t, dbType) - base.Cfg.Global.Presence.EnableOutbound = true - base.Cfg.Global.Presence.EnableInbound = true + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + cfg.Global.Presence.EnableOutbound = true + cfg.Global.Presence.EnableInbound = true defer close() + natsInstance := jetstream.NATSInstance{} - jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}) + jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, caching.DisableMetrics) w := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ + routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, "timeout": "0", "set_presence": "online", @@ -408,17 +427,20 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { userType = "real user" } - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) defer close() + natsInstance := jetstream.NATSInstance{} - jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) // Use the actual internal roomserver API - rsAPI := roomserver.NewInternalAPI(base) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, caching.DisableMetrics) for _, tc := range testCases { testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) @@ -433,7 +455,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } - syncUntil(t, base, aliceDev.AccessToken, false, + syncUntil(t, routers, aliceDev.AccessToken, false, func(syncBody string) bool { path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(content.body=="%s")`, room.ID, beforeJoinBody) return gjson.Get(syncBody, path).Exists() @@ -442,7 +464,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { // There is only one event, we expect only to be able to see this, if the room is world_readable w := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ + routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ "access_token": bobDev.AccessToken, "dir": "b", "filter": `{"lazy_load_members":true}`, // check that lazy loading doesn't break history visibility @@ -453,7 +475,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { } // We only care about the returned events at this point var res struct { - Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` + Chunk []synctypes.ClientEvent `json:"chunk"` } if err := json.NewDecoder(w.Body).Decode(&res); err != nil { t.Errorf("failed to decode response body: %s", err) @@ -473,7 +495,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } - syncUntil(t, base, aliceDev.AccessToken, false, + syncUntil(t, routers, aliceDev.AccessToken, false, func(syncBody string) bool { path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(content.body=="%s")`, room.ID, afterJoinBody) return gjson.Get(syncBody, path).Exists() @@ -482,7 +504,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { // Verify the messages after/before invite are visible or not w = httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ + routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ "access_token": bobDev.AccessToken, "dir": "b", }))) @@ -501,7 +523,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { } } -func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatrixserverlib.HeaderedEvent, chunk []gomatrixserverlib.ClientEvent) { +func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatrixserverlib.HeaderedEvent, chunk []synctypes.ClientEvent) { t.Helper() if wantVisible { for _, ev := range chunk { @@ -591,10 +613,10 @@ func TestGetMembership(t *testing.T) { })) }, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(alice.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) }, @@ -610,10 +632,10 @@ func TestGetMembership(t *testing.T) { })) }, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(alice.ID)) }, @@ -629,7 +651,7 @@ func TestGetMembership(t *testing.T) { })) }, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(alice.ID)) }, @@ -645,7 +667,7 @@ func TestGetMembership(t *testing.T) { })) }, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) }, @@ -685,10 +707,10 @@ func TestGetMembership(t *testing.T) { })) }, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(bob.ID)) }, @@ -708,17 +730,20 @@ func TestGetMembership(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, close := testrig.CreateBaseDendrite(t, dbType) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) defer close() - - jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + natsInstance := jetstream.NATSInstance{} + jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) // Use an actual roomserver for this - rsAPI := roomserver.NewInternalAPI(base) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, caching.DisableMetrics) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -738,7 +763,7 @@ func TestGetMembership(t *testing.T) { if tc.useSleep { time.Sleep(time.Millisecond * 100) } else { - syncUntil(t, base, aliceDev.AccessToken, false, func(syncBody string) bool { + syncUntil(t, routers, aliceDev.AccessToken, false, func(syncBody string) bool { // wait for the last sent eventID to come down sync path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) return gjson.Get(syncBody, path).Exists() @@ -746,7 +771,7 @@ func TestGetMembership(t *testing.T) { } w := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(w, tc.request(t, room)) + routers.Client.ServeHTTP(w, tc.request(t, room)) if w.Code != 200 && tc.wantOK { t.Logf("%s", w.Body.String()) t.Fatalf("got HTTP %d want %d", w.Code, 200) @@ -779,16 +804,19 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { AccountType: userapi.AccountTypeUser, } - base, baseClose := testrig.CreateBaseDendrite(t, dbType) - defer baseClose() + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + defer close() + natsInstance := jetstream.NATSInstance{} - jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}) + jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, caching.DisableMetrics) producer := producers.SyncAPIProducer{ - TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), JetStream: jsctx, } @@ -874,7 +902,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { } } - syncUntil(t, base, alice.AccessToken, + syncUntil(t, routers, alice.AccessToken, len(tc.want) == 0, func(body string) bool { return gjson.Get(body, fmt.Sprintf(`to_device.events.#(content.dummy=="message %d")`, msgCounter)).Exists() @@ -883,7 +911,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { // Execute a /sync request, recording the response w := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ + routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, "since": tc.since, }))) @@ -997,14 +1025,18 @@ func testContext(t *testing.T, dbType test.DBType) { AccountType: userapi.AccountTypeUser, } - base, baseClose := testrig.CreateBaseDendrite(t, dbType) - defer baseClose() + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + defer close() // Use an actual roomserver for this - rsAPI := roomserver.NewInternalAPI(base) + natsInstance := jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, caching.DisableMetrics) room := test.NewRoom(t, user) @@ -1017,10 +1049,10 @@ func testContext(t *testing.T, dbType test.DBType) { t.Fatalf("failed to send events: %v", err) } - jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) - syncUntil(t, base, alice.AccessToken, false, func(syncBody string) bool { + syncUntil(t, routers, alice.AccessToken, false, func(syncBody string) bool { // wait for the last sent eventID to come down sync path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, thirdMsg.EventID()) return gjson.Get(syncBody, path).Exists() @@ -1047,7 +1079,7 @@ func testContext(t *testing.T, dbType test.DBType) { params[k] = v } } - base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", requestPath, test.WithQueryParams(params))) + routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", requestPath, test.WithQueryParams(params))) if tc.wantError && w.Code == 200 { t.Fatalf("Expected an error, but got none") @@ -1074,11 +1106,97 @@ func testContext(t *testing.T, dbType test.DBType) { } } +func TestUpdateRelations(t *testing.T) { + testCases := []struct { + name string + eventContent map[string]interface{} + eventType string + }{ + { + name: "empty event content should not error", + }, + { + name: "unable to unmarshal event should not error", + eventContent: map[string]interface{}{ + "m.relates_to": map[string]interface{}{ + "event_id": map[string]interface{}{}, // this should be a string and not struct + }, + }, + }, + { + name: "empty event ID is ignored", + eventContent: map[string]interface{}{ + "m.relates_to": map[string]interface{}{ + "event_id": "", + }, + }, + }, + { + name: "empty rel_type is ignored", + eventContent: map[string]interface{}{ + "m.relates_to": map[string]interface{}{ + "event_id": "$randomEventID", + "rel_type": "", + }, + }, + }, + { + name: "redactions are ignored", + eventType: spec.MRoomRedaction, + eventContent: map[string]interface{}{ + "m.relates_to": map[string]interface{}{ + "event_id": "$randomEventID", + "rel_type": "m.replace", + }, + }, + }, + { + name: "valid event is correctly written", + eventContent: map[string]interface{}{ + "m.relates_to": map[string]interface{}{ + "event_id": "$randomEventID", + "rel_type": "m.replace", + }, + }, + }, + } + + ctx := context.Background() + + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + t.Cleanup(close) + db, err := storage.NewSyncServerDatasource(processCtx.Context(), cm, &cfg.SyncAPI.Database) + if err != nil { + t.Fatal(err) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + evType := "m.room.message" + if tc.eventType != "" { + evType = tc.eventType + } + ev := room.CreateEvent(t, alice, evType, tc.eventContent) + err = db.UpdateRelations(ctx, ev) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} + func syncUntil(t *testing.T, - base *base.BaseDendrite, accessToken string, + routers httputil.Routers, accessToken string, skip bool, checkFunc func(syncBody string) bool, ) { + t.Helper() if checkFunc == nil { t.Fatalf("No checkFunc defined") } @@ -1092,7 +1210,7 @@ func syncUntil(t *testing.T, go func() { for { w := httptest.NewRecorder() - base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ + routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": accessToken, "timeout": "1000", }))) @@ -1110,14 +1228,14 @@ func syncUntil(t *testing.T, } } -func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg { +func toNATSMsgs(t *testing.T, cfg *config.Dendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg { result := make([]*nats.Msg, len(input)) for i, ev := range input { var addsStateIDs []string if ev.StateKey() != nil { addsStateIDs = append(addsStateIDs, ev.EventID()) } - result[i] = testrig.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{ + result[i] = testrig.NewOutputEventMsg(t, cfg, ev.RoomID(), api.OutputEvent{ Type: rsapi.OutputTypeNewRoomEvent, NewRoomEvent: &rsapi.OutputNewRoomEvent{ Event: ev, diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go new file mode 100644 index 000000000..0bf04cdb9 --- /dev/null +++ b/syncapi/synctypes/clientevent.go @@ -0,0 +1,91 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package synctypes + +import ( + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +type ClientEventFormat int + +const ( + // FormatAll will include all client event keys + FormatAll ClientEventFormat = iota + // FormatSync will include only the event keys required by the /sync API. Notably, this + // means the 'room_id' will be missing from the events. + FormatSync +) + +// ClientEvent is an event which is fit for consumption by clients, in accordance with the specification. +type ClientEvent struct { + Content spec.RawJSON `json:"content"` + EventID string `json:"event_id,omitempty"` // EventID is omitted on receipt events + OriginServerTS spec.Timestamp `json:"origin_server_ts,omitempty"` // OriginServerTS is omitted on receipt events + RoomID string `json:"room_id,omitempty"` // RoomID is omitted on /sync responses + Sender string `json:"sender,omitempty"` // Sender is omitted on receipt events + StateKey *string `json:"state_key,omitempty"` + Type string `json:"type"` + Unsigned spec.RawJSON `json:"unsigned,omitempty"` + Redacts string `json:"redacts,omitempty"` +} + +// ToClientEvents converts server events to client events. +func ToClientEvents(serverEvs []*gomatrixserverlib.Event, format ClientEventFormat) []ClientEvent { + evs := make([]ClientEvent, 0, len(serverEvs)) + for _, se := range serverEvs { + if se == nil { + continue // TODO: shouldn't happen? + } + evs = append(evs, ToClientEvent(se, format)) + } + return evs +} + +// HeaderedToClientEvents converts headered server events to client events. +func HeaderedToClientEvents(serverEvs []*gomatrixserverlib.HeaderedEvent, format ClientEventFormat) []ClientEvent { + evs := make([]ClientEvent, 0, len(serverEvs)) + for _, se := range serverEvs { + if se == nil { + continue // TODO: shouldn't happen? + } + evs = append(evs, HeaderedToClientEvent(se, format)) + } + return evs +} + +// ToClientEvent converts a single server event to a client event. +func ToClientEvent(se *gomatrixserverlib.Event, format ClientEventFormat) ClientEvent { + ce := ClientEvent{ + Content: spec.RawJSON(se.Content()), + Sender: se.Sender(), + Type: se.Type(), + StateKey: se.StateKey(), + Unsigned: spec.RawJSON(se.Unsigned()), + OriginServerTS: se.OriginServerTS(), + EventID: se.EventID(), + Redacts: se.Redacts(), + } + if format == FormatAll { + ce.RoomID = se.RoomID() + } + return ce +} + +// HeaderedToClientEvent converts a single headered server event to a client event. +func HeaderedToClientEvent(se *gomatrixserverlib.HeaderedEvent, format ClientEventFormat) ClientEvent { + return ToClientEvent(se.Event, format) +} diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go new file mode 100644 index 000000000..ac07917ab --- /dev/null +++ b/syncapi/synctypes/clientevent_test.go @@ -0,0 +1,105 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package synctypes + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/matrix-org/gomatrixserverlib" +) + +func TestToClientEvent(t *testing.T) { // nolint: gocyclo + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.name", + "state_key": "", + "event_id": "$test:localhost", + "room_id": "!test:localhost", + "sender": "@test:localhost", + "content": { + "name": "Hello World" + }, + "origin_server_ts": 123456, + "unsigned": { + "prev_content": { + "name": "Goodbye World" + } + } + }`), false, gomatrixserverlib.RoomVersionV1) + if err != nil { + t.Fatalf("failed to create Event: %s", err) + } + ce := ToClientEvent(ev, FormatAll) + if ce.EventID != ev.EventID() { + t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) + } + if ce.OriginServerTS != ev.OriginServerTS() { + t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", ev.OriginServerTS(), ce.OriginServerTS) + } + if ce.StateKey == nil || *ce.StateKey != "" { + t.Errorf("ClientEvent.StateKey: wanted '', got %v", ce.StateKey) + } + if ce.Type != ev.Type() { + t.Errorf("ClientEvent.Type: wanted %s, got %s", ev.Type(), ce.Type) + } + if !bytes.Equal(ce.Content, ev.Content()) { + t.Errorf("ClientEvent.Content: wanted %s, got %s", string(ev.Content()), string(ce.Content)) + } + if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { + t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) + } + if ce.Sender != ev.Sender() { + t.Errorf("ClientEvent.Sender: wanted %s, got %s", ev.Sender(), ce.Sender) + } + j, err := json.Marshal(ce) + if err != nil { + t.Fatalf("failed to Marshal ClientEvent: %s", err) + } + // Marshal sorts keys in structs by the order they are defined in the struct, which is alphabetical + out := `{"content":{"name":"Hello World"},"event_id":"$test:localhost","origin_server_ts":123456,` + + `"room_id":"!test:localhost","sender":"@test:localhost","state_key":"","type":"m.room.name",` + + `"unsigned":{"prev_content":{"name":"Goodbye World"}}}` + if !bytes.Equal([]byte(out), j) { + t.Errorf("ClientEvent marshalled to wrong bytes: wanted %s, got %s", out, string(j)) + } +} + +func TestToClientFormatSync(t *testing.T) { + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.name", + "state_key": "", + "event_id": "$test:localhost", + "room_id": "!test:localhost", + "sender": "@test:localhost", + "content": { + "name": "Hello World" + }, + "origin_server_ts": 123456, + "unsigned": { + "prev_content": { + "name": "Goodbye World" + } + } + }`), false, gomatrixserverlib.RoomVersionV1) + if err != nil { + t.Fatalf("failed to create Event: %s", err) + } + ce := ToClientEvent(ev, FormatSync) + if ce.RoomID != "" { + t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) + } +} diff --git a/syncapi/synctypes/filter.go b/syncapi/synctypes/filter.go new file mode 100644 index 000000000..c994ddb96 --- /dev/null +++ b/syncapi/synctypes/filter.go @@ -0,0 +1,152 @@ +// Copyright 2017 Jan Christian Grünhage +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synctypes + +import ( + "errors" +) + +// Filter is used by clients to specify how the server should filter responses to e.g. sync requests +// Specified by: https://spec.matrix.org/v1.6/client-server-api/#filtering +type Filter struct { + EventFields []string `json:"event_fields,omitempty"` + EventFormat string `json:"event_format,omitempty"` + Presence EventFilter `json:"presence,omitempty"` + AccountData EventFilter `json:"account_data,omitempty"` + Room RoomFilter `json:"room,omitempty"` +} + +// EventFilter is used to define filtering rules for events +type EventFilter struct { + Limit int `json:"limit,omitempty"` + NotSenders *[]string `json:"not_senders,omitempty"` + NotTypes *[]string `json:"not_types,omitempty"` + Senders *[]string `json:"senders,omitempty"` + Types *[]string `json:"types,omitempty"` +} + +// RoomFilter is used to define filtering rules for room-related events +type RoomFilter struct { + NotRooms *[]string `json:"not_rooms,omitempty"` + Rooms *[]string `json:"rooms,omitempty"` + Ephemeral RoomEventFilter `json:"ephemeral,omitempty"` + IncludeLeave bool `json:"include_leave,omitempty"` + State StateFilter `json:"state,omitempty"` + Timeline RoomEventFilter `json:"timeline,omitempty"` + AccountData RoomEventFilter `json:"account_data,omitempty"` +} + +// StateFilter is used to define filtering rules for state events +type StateFilter struct { + NotSenders *[]string `json:"not_senders,omitempty"` + NotTypes *[]string `json:"not_types,omitempty"` + Senders *[]string `json:"senders,omitempty"` + Types *[]string `json:"types,omitempty"` + LazyLoadMembers bool `json:"lazy_load_members,omitempty"` + IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` + NotRooms *[]string `json:"not_rooms,omitempty"` + Rooms *[]string `json:"rooms,omitempty"` + Limit int `json:"limit,omitempty"` + UnreadThreadNotifications bool `json:"unread_thread_notifications,omitempty"` + ContainsURL *bool `json:"contains_url,omitempty"` +} + +// RoomEventFilter is used to define filtering rules for events in rooms +type RoomEventFilter struct { + Limit int `json:"limit,omitempty"` + NotSenders *[]string `json:"not_senders,omitempty"` + NotTypes *[]string `json:"not_types,omitempty"` + Senders *[]string `json:"senders,omitempty"` + Types *[]string `json:"types,omitempty"` + LazyLoadMembers bool `json:"lazy_load_members,omitempty"` + IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` + NotRooms *[]string `json:"not_rooms,omitempty"` + Rooms *[]string `json:"rooms,omitempty"` + UnreadThreadNotifications bool `json:"unread_thread_notifications,omitempty"` + ContainsURL *bool `json:"contains_url,omitempty"` +} + +// Validate checks if the filter contains valid property values +func (filter *Filter) Validate() error { + if filter.EventFormat != "" && filter.EventFormat != "client" && filter.EventFormat != "federation" { + return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]") + } + return nil +} + +// DefaultFilter returns the default filter used by the Matrix server if no filter is provided in +// the request +func DefaultFilter() Filter { + return Filter{ + AccountData: DefaultEventFilter(), + EventFields: nil, + EventFormat: "client", + Presence: DefaultEventFilter(), + Room: RoomFilter{ + AccountData: DefaultRoomEventFilter(), + Ephemeral: DefaultRoomEventFilter(), + IncludeLeave: false, + NotRooms: nil, + Rooms: nil, + State: DefaultStateFilter(), + Timeline: DefaultRoomEventFilter(), + }, + } +} + +// DefaultEventFilter returns the default event filter used by the Matrix server if no filter is +// provided in the request +func DefaultEventFilter() EventFilter { + return EventFilter{ + // parity with synapse: https://github.com/matrix-org/synapse/blob/v1.80.0/synapse/api/filtering.py#L336 + Limit: 10, + NotSenders: nil, + NotTypes: nil, + Senders: nil, + Types: nil, + } +} + +// DefaultStateFilter returns the default state event filter used by the Matrix server if no filter +// is provided in the request +func DefaultStateFilter() StateFilter { + return StateFilter{ + NotSenders: nil, + NotTypes: nil, + Senders: nil, + Types: nil, + LazyLoadMembers: false, + IncludeRedundantMembers: false, + NotRooms: nil, + Rooms: nil, + ContainsURL: nil, + } +} + +// DefaultRoomEventFilter returns the default room event filter used by the Matrix server if no +// filter is provided in the request +func DefaultRoomEventFilter() RoomEventFilter { + return RoomEventFilter{ + // parity with synapse: https://github.com/matrix-org/synapse/blob/v1.80.0/synapse/api/filtering.py#L336 + Limit: 10, + NotSenders: nil, + NotTypes: nil, + Senders: nil, + Types: nil, + NotRooms: nil, + Rooms: nil, + ContainsURL: nil, + } +} diff --git a/syncapi/synctypes/filter_test.go b/syncapi/synctypes/filter_test.go new file mode 100644 index 000000000..d00ee69d6 --- /dev/null +++ b/syncapi/synctypes/filter_test.go @@ -0,0 +1,60 @@ +package synctypes + +import ( + "encoding/json" + "reflect" + "testing" +) + +func Test_Filter(t *testing.T) { + tests := []struct { + name string + input []byte + want RoomEventFilter + }{ + { + name: "empty types filter", + input: []byte(`{ "types": [] }`), + want: RoomEventFilter{ + Limit: 0, + NotSenders: nil, + NotTypes: nil, + Senders: nil, + Types: &[]string{}, + LazyLoadMembers: false, + IncludeRedundantMembers: false, + NotRooms: nil, + Rooms: nil, + ContainsURL: nil, + }, + }, + { + name: "absent types filter", + input: []byte(`{}`), + want: RoomEventFilter{ + Limit: 0, + NotSenders: nil, + NotTypes: nil, + Senders: nil, + Types: nil, + LazyLoadMembers: false, + IncludeRedundantMembers: false, + NotRooms: nil, + Rooms: nil, + ContainsURL: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var f RoomEventFilter + if err := json.Unmarshal(tt.input, &f); err != nil { + t.Fatalf("unable to parse filter: %v", err) + } + if !reflect.DeepEqual(f, tt.want) { + t.Fatalf("Expected %+v\ngot %+v", tt.want, f) + } + }) + } + +} diff --git a/syncapi/types/presence.go b/syncapi/types/presence.go index 30e025b9f..a9c5e3a42 100644 --- a/syncapi/types/presence.go +++ b/syncapi/types/presence.go @@ -18,7 +18,7 @@ import ( "strings" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Presence uint8 @@ -60,10 +60,10 @@ func PresenceFromString(input string) (Presence, bool) { type PresenceInternal struct { ClientFields PresenceClientResponse - StreamPos StreamPosition `json:"-"` - UserID string `json:"-"` - LastActiveTS gomatrixserverlib.Timestamp `json:"-"` - Presence Presence `json:"-"` + StreamPos StreamPosition `json:"-"` + UserID string `json:"-"` + LastActiveTS spec.Timestamp `json:"-"` + Presence Presence `json:"-"` } // Equals compares p1 with p2. diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index 9a533002b..a0fcec0f6 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -4,10 +4,11 @@ import ( "context" "time" - "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/syncapi/synctypes" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type SyncRequest struct { @@ -15,7 +16,7 @@ type SyncRequest struct { Log *logrus.Entry Device *userapi.Device Response *Response - Filter gomatrixserverlib.Filter + Filter synctypes.Filter Since StreamingToken Timeout time.Duration WantFullState bool @@ -34,11 +35,11 @@ func (r *SyncRequest) IsRoomPresent(roomID string) bool { return false } switch membership { - case gomatrixserverlib.Join: + case spec.Join: return true - case gomatrixserverlib.Invite: + case spec.Invite: return true - case gomatrixserverlib.Peek: + case spec.Peek: return true default: return false diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 6495dd535..ed17eb560 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -22,9 +22,11 @@ import ( "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/synctypes" ) var ( @@ -451,13 +453,13 @@ type UnreadNotifications struct { } type ClientEvents struct { - Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"` + Events []synctypes.ClientEvent `json:"events,omitempty"` } type Timeline struct { - Events []gomatrixserverlib.ClientEvent `json:"events"` - Limited bool `json:"limited"` - PrevBatch *TopologyToken `json:"prev_batch,omitempty"` + Events []synctypes.ClientEvent `json:"events"` + Limited bool `json:"limited"` + PrevBatch *TopologyToken `json:"prev_batch,omitempty"` } type Summary struct { @@ -549,7 +551,7 @@ func NewInviteResponse(event *gomatrixserverlib.HeaderedEvent) *InviteResponse { // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := gomatrixserverlib.ToClientEvent(event.Unwrap(), gomatrixserverlib.FormatSync) + inviteEvent := synctypes.ToClientEvent(event.Unwrap(), synctypes.FormatSync) inviteEvent.Unsigned = nil if ev, err := json.Marshal(inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) @@ -605,11 +607,11 @@ type Peek struct { // OutputReceiptEvent is an entry in the receipt output kafka log type OutputReceiptEvent struct { - UserID string `json:"user_id"` - RoomID string `json:"room_id"` - EventID string `json:"event_id"` - Type string `json:"type"` - Timestamp gomatrixserverlib.Timestamp `json:"timestamp"` + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + Type string `json:"type"` + Timestamp spec.Timestamp `json:"timestamp"` } // OutputSendToDeviceEvent is an entry in the send-to-device output kafka log. diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 74246d964..8bb887ffe 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -5,6 +5,7 @@ import ( "reflect" "testing" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -125,7 +126,7 @@ func TestJoinResponse_MarshalJSON(t *testing.T) { { name: "unread notifications are NOT removed, if state is set", fields: fields{ - State: &ClientEvents{Events: []gomatrixserverlib.ClientEvent{{Content: []byte("{}")}}}, + State: &ClientEvents{Events: []synctypes.ClientEvent{{Content: []byte("{}")}}}, UnreadNotifications: &UnreadNotifications{NotificationCount: 1}, }, want: []byte(`{"state":{"events":[{"content":{},"type":""}]},"unread_notifications":{"highlight_count":0,"notification_count":1}}`), @@ -134,7 +135,7 @@ func TestJoinResponse_MarshalJSON(t *testing.T) { name: "roomID is removed from EDUs", fields: fields{ Ephemeral: &ClientEvents{ - Events: []gomatrixserverlib.ClientEvent{ + Events: []synctypes.ClientEvent{ {RoomID: "!someRandomRoomID:test", Content: []byte("{}")}, }, }, diff --git a/test/event.go b/test/event.go index 0c7bf4355..77c44700c 100644 --- a/test/event.go +++ b/test/event.go @@ -21,11 +21,12 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type eventMods struct { originServerTS time.Time - origin gomatrixserverlib.ServerName + origin spec.ServerName stateKey *string unsigned interface{} keyID gomatrixserverlib.KeyID @@ -71,7 +72,7 @@ func WithPrivateKey(pkey ed25519.PrivateKey) eventModifier { } } -func WithOrigin(origin gomatrixserverlib.ServerName) eventModifier { +func WithOrigin(origin spec.ServerName) eventModifier { return func(e *eventMods) { e.origin = origin } diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go index de0dc54eb..c29162fe1 100644 --- a/test/memory_federation_db.go +++ b/test/memory_federation_db.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) var nidMutex sync.Mutex @@ -31,28 +32,28 @@ var nid = int64(0) type InMemoryFederationDatabase struct { dbMutex sync.Mutex - pendingPDUServers map[gomatrixserverlib.ServerName]struct{} - pendingEDUServers map[gomatrixserverlib.ServerName]struct{} - blacklistedServers map[gomatrixserverlib.ServerName]struct{} - assumedOffline map[gomatrixserverlib.ServerName]struct{} + pendingPDUServers map[spec.ServerName]struct{} + pendingEDUServers map[spec.ServerName]struct{} + blacklistedServers map[spec.ServerName]struct{} + assumedOffline map[spec.ServerName]struct{} pendingPDUs map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent pendingEDUs map[*receipt.Receipt]*gomatrixserverlib.EDU - associatedPDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} - associatedEDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} - relayServers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName + associatedPDUs map[spec.ServerName]map[*receipt.Receipt]struct{} + associatedEDUs map[spec.ServerName]map[*receipt.Receipt]struct{} + relayServers map[spec.ServerName][]spec.ServerName } func NewInMemoryFederationDatabase() *InMemoryFederationDatabase { return &InMemoryFederationDatabase{ - pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), - assumedOffline: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUServers: make(map[spec.ServerName]struct{}), + pendingEDUServers: make(map[spec.ServerName]struct{}), + blacklistedServers: make(map[spec.ServerName]struct{}), + assumedOffline: make(map[spec.ServerName]struct{}), pendingPDUs: make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent), pendingEDUs: make(map[*receipt.Receipt]*gomatrixserverlib.EDU), - associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), - associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), - relayServers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName), + associatedPDUs: make(map[spec.ServerName]map[*receipt.Receipt]struct{}), + associatedEDUs: make(map[spec.ServerName]map[*receipt.Receipt]struct{}), + relayServers: make(map[spec.ServerName][]spec.ServerName), } } @@ -88,7 +89,7 @@ func (d *InMemoryFederationDatabase) StoreJSON( func (d *InMemoryFederationDatabase) GetPendingPDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { d.dbMutex.Lock() @@ -112,7 +113,7 @@ func (d *InMemoryFederationDatabase) GetPendingPDUs( func (d *InMemoryFederationDatabase) GetPendingEDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) { d.dbMutex.Lock() @@ -136,7 +137,7 @@ func (d *InMemoryFederationDatabase) GetPendingEDUs( func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( ctx context.Context, - destinations map[gomatrixserverlib.ServerName]struct{}, + destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt, ) error { d.dbMutex.Lock() @@ -158,7 +159,7 @@ func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations( ctx context.Context, - destinations map[gomatrixserverlib.ServerName]struct{}, + destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, @@ -182,7 +183,7 @@ func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations( func (d *InMemoryFederationDatabase) CleanPDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, receipts []*receipt.Receipt, ) error { d.dbMutex.Lock() @@ -199,7 +200,7 @@ func (d *InMemoryFederationDatabase) CleanPDUs( func (d *InMemoryFederationDatabase) CleanEDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, receipts []*receipt.Receipt, ) error { d.dbMutex.Lock() @@ -216,7 +217,7 @@ func (d *InMemoryFederationDatabase) CleanEDUs( func (d *InMemoryFederationDatabase) GetPendingPDUCount( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (int64, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -230,7 +231,7 @@ func (d *InMemoryFederationDatabase) GetPendingPDUCount( func (d *InMemoryFederationDatabase) GetPendingEDUCount( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (int64, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -244,11 +245,11 @@ func (d *InMemoryFederationDatabase) GetPendingEDUCount( func (d *InMemoryFederationDatabase) GetPendingPDUServerNames( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() - servers := []gomatrixserverlib.ServerName{} + servers := []spec.ServerName{} for server := range d.pendingPDUServers { servers = append(servers, server) } @@ -257,11 +258,11 @@ func (d *InMemoryFederationDatabase) GetPendingPDUServerNames( func (d *InMemoryFederationDatabase) GetPendingEDUServerNames( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() - servers := []gomatrixserverlib.ServerName{} + servers := []spec.ServerName{} for server := range d.pendingEDUServers { servers = append(servers, server) } @@ -269,7 +270,7 @@ func (d *InMemoryFederationDatabase) GetPendingEDUServerNames( } func (d *InMemoryFederationDatabase) AddServerToBlacklist( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -279,7 +280,7 @@ func (d *InMemoryFederationDatabase) AddServerToBlacklist( } func (d *InMemoryFederationDatabase) RemoveServerFromBlacklist( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -292,12 +293,12 @@ func (d *InMemoryFederationDatabase) RemoveAllServersFromBlacklist() error { d.dbMutex.Lock() defer d.dbMutex.Unlock() - d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + d.blacklistedServers = make(map[spec.ServerName]struct{}) return nil } func (d *InMemoryFederationDatabase) IsServerBlacklisted( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (bool, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -312,7 +313,7 @@ func (d *InMemoryFederationDatabase) IsServerBlacklisted( func (d *InMemoryFederationDatabase) SetServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -323,7 +324,7 @@ func (d *InMemoryFederationDatabase) SetServerAssumedOffline( func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -338,13 +339,13 @@ func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine( d.dbMutex.Lock() defer d.dbMutex.Unlock() - d.assumedOffline = make(map[gomatrixserverlib.ServerName]struct{}) + d.assumedOffline = make(map[spec.ServerName]struct{}) return nil } func (d *InMemoryFederationDatabase) IsServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (bool, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -359,12 +360,12 @@ func (d *InMemoryFederationDatabase) IsServerAssumedOffline( func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, -) ([]gomatrixserverlib.ServerName, error) { + serverName spec.ServerName, +) ([]spec.ServerName, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() - knownRelayServers := []gomatrixserverlib.ServerName{} + knownRelayServers := []spec.ServerName{} if relayServers, ok := d.relayServers[serverName]; ok { knownRelayServers = relayServers } @@ -374,8 +375,8 @@ func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer( func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -401,8 +402,8 @@ func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -426,7 +427,7 @@ func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer( return nil } -func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { +func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { return nil, nil } @@ -446,11 +447,11 @@ func (d *InMemoryFederationDatabase) GetJoinedHosts(ctx context.Context, roomID return nil, nil } -func (d *InMemoryFederationDatabase) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { +func (d *InMemoryFederationDatabase) GetAllJoinedHosts(ctx context.Context) ([]spec.ServerName, error) { return nil, nil } -func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { +func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]spec.ServerName, error) { return nil, nil } @@ -458,19 +459,19 @@ func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context. return nil } -func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { +func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName spec.ServerName) error { return nil } -func (d *InMemoryFederationDatabase) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *InMemoryFederationDatabase) AddOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error { return nil } -func (d *InMemoryFederationDatabase) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *InMemoryFederationDatabase) RenewOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error { return nil } -func (d *InMemoryFederationDatabase) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { +func (d *InMemoryFederationDatabase) GetOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { return nil, nil } @@ -478,15 +479,15 @@ func (d *InMemoryFederationDatabase) GetOutboundPeeks(ctx context.Context, roomI return nil, nil } -func (d *InMemoryFederationDatabase) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *InMemoryFederationDatabase) AddInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error { return nil } -func (d *InMemoryFederationDatabase) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *InMemoryFederationDatabase) RenewInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error { return nil } -func (d *InMemoryFederationDatabase) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { +func (d *InMemoryFederationDatabase) GetInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string) (*types.InboundPeek, error) { return nil, nil } @@ -494,11 +495,11 @@ func (d *InMemoryFederationDatabase) GetInboundPeeks(ctx context.Context, roomID return nil, nil } -func (d *InMemoryFederationDatabase) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { +func (d *InMemoryFederationDatabase) UpdateNotaryKeys(ctx context.Context, serverName spec.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { return nil } -func (d *InMemoryFederationDatabase) GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { +func (d *InMemoryFederationDatabase) GetNotaryKeys(ctx context.Context, serverName spec.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { return nil, nil } diff --git a/test/memory_relay_db.go b/test/memory_relay_db.go index db93919df..eecc23fe7 100644 --- a/test/memory_relay_db.go +++ b/test/memory_relay_db.go @@ -21,13 +21,14 @@ import ( "sync" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type InMemoryRelayDatabase struct { nid int64 nidMutex sync.Mutex transactions map[int64]json.RawMessage - associations map[gomatrixserverlib.ServerName][]int64 + associations map[spec.ServerName][]int64 } func NewInMemoryRelayDatabase() *InMemoryRelayDatabase { @@ -35,7 +36,7 @@ func NewInMemoryRelayDatabase() *InMemoryRelayDatabase { nid: 1, nidMutex: sync.Mutex{}, transactions: make(map[int64]json.RawMessage), - associations: make(map[gomatrixserverlib.ServerName][]int64), + associations: make(map[spec.ServerName][]int64), } } @@ -43,7 +44,7 @@ func (d *InMemoryRelayDatabase) InsertQueueEntry( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, ) error { if _, ok := d.associations[serverName]; !ok { @@ -56,7 +57,7 @@ func (d *InMemoryRelayDatabase) InsertQueueEntry( func (d *InMemoryRelayDatabase) DeleteQueueEntries( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { for _, nid := range jsonNIDs { @@ -72,7 +73,7 @@ func (d *InMemoryRelayDatabase) DeleteQueueEntries( func (d *InMemoryRelayDatabase) SelectQueueEntries( ctx context.Context, - txn *sql.Tx, serverName gomatrixserverlib.ServerName, + txn *sql.Tx, serverName spec.ServerName, limit int, ) ([]int64, error) { results := []int64{} @@ -92,7 +93,7 @@ func (d *InMemoryRelayDatabase) SelectQueueEntries( func (d *InMemoryRelayDatabase) SelectQueueEntryCount( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (int64, error) { return int64(len(d.associations[serverName])), nil } diff --git a/test/room.go b/test/room.go index 685876cb0..0d80a5940 100644 --- a/test/room.go +++ b/test/room.go @@ -22,6 +22,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/eventutil" ) @@ -111,18 +112,18 @@ func (r *Room) insertCreateEvents(t *testing.T) { hisVis.HistoryVisibility = r.visibility } - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{ + r.CreateAndInsert(t, r.creator, spec.MRoomCreate, map[string]interface{}{ "creator": r.creator.ID, "room_version": r.Version, }, WithStateKey("")) - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomMember, map[string]interface{}{ + r.CreateAndInsert(t, r.creator, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, WithStateKey(r.creator.ID)) - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) + r.CreateAndInsert(t, r.creator, spec.MRoomPowerLevels, plContent, WithStateKey("")) + r.CreateAndInsert(t, r.creator, spec.MRoomJoinRules, joinRule, WithStateKey("")) + r.CreateAndInsert(t, r.creator, spec.MRoomHistoryVisibility, hisVis, WithStateKey("")) if r.guestCanJoin { - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{ + r.CreateAndInsert(t, r.creator, spec.MRoomGuestAccess, map[string]string{ "guest_access": "can_join", }, WithStateKey("")) } @@ -152,7 +153,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten mod.origin = creator.srvName } - var unsigned gomatrixserverlib.RawJSON + var unsigned spec.RawJSON var err error if mod.unsigned != nil { unsigned, err = json.Marshal(mod.unsigned) diff --git a/test/testrig/base.go b/test/testrig/base.go index bb8fded21..953704595 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -19,13 +19,12 @@ import ( "path/filepath" "testing" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" - "github.com/nats-io/nats.go" ) -func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) { +func CreateConfig(t *testing.T, dbType test.DBType) (*config.Dendrite, *process.ProcessContext, func()) { var cfg config.Dendrite cfg.Defaults(config.DefaultOpts{ Generate: false, @@ -33,6 +32,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f }) cfg.Global.JetStream.InMemory = true cfg.FederationAPI.KeyPerspectives = nil + ctx := process.NewProcessContext() switch dbType { case test.DBTypePostgres: cfg.Global.Defaults(config.DefaultOpts{ // autogen a signing key @@ -51,18 +51,19 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) - connStr, close := test.PrepareDBConnectionString(t, dbType) + cfg.SyncAPI.Fulltext.InMemory = true + + connStr, closeDb := test.PrepareDBConnectionString(t, dbType) cfg.Global.DatabaseOptions = config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), MaxOpenConnections: 10, MaxIdleConnections: 2, ConnMaxLifetimeSeconds: 60, } - base := base.NewBaseDendrite(&cfg, base.DisableMetrics) - return base, func() { - base.ShutdownDendrite() - base.WaitForShutdown() - close() + return &cfg, ctx, func() { + ctx.ShutdownDendrite() + ctx.WaitForShutdown() + closeDb() } case test.DBTypeSQLite: cfg.Defaults(config.DefaultOpts{ @@ -70,7 +71,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f SingleDatabase: false, }) cfg.Global.ServerName = "test" - + cfg.SyncAPI.Fulltext.InMemory = true // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) @@ -86,30 +87,13 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "userapi.db")) cfg.RelayAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "relayapi.db")) - base := base.NewBaseDendrite(&cfg, base.DisableMetrics) - return base, func() { - base.ShutdownDendrite() - base.WaitForShutdown() + return &cfg, ctx, func() { + ctx.ShutdownDendrite() + ctx.WaitForShutdown() t.Cleanup(func() {}) // removes t.TempDir, where all database files are created } default: t.Fatalf("unknown db type: %v", dbType) } - return nil, nil -} - -func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nats.Conn) { - if cfg == nil { - cfg = &config.Dendrite{} - cfg.Defaults(config.DefaultOpts{ - Generate: true, - SingleDatabase: false, - }) - } - cfg.Global.JetStream.InMemory = true - cfg.SyncAPI.Fulltext.InMemory = true - cfg.FederationAPI.KeyPerspectives = nil - base := base.NewBaseDendrite(cfg, base.DisableMetrics) - js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) - return base, js, jc + return &config.Dendrite{}, nil, func() {} } diff --git a/test/testrig/jetstream.go b/test/testrig/jetstream.go index b880eea43..5f15cfb3e 100644 --- a/test/testrig/jetstream.go +++ b/test/testrig/jetstream.go @@ -4,10 +4,10 @@ import ( "encoding/json" "testing" + "github.com/matrix-org/dendrite/setup/config" "github.com/nats-io/nats.go" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" ) @@ -20,9 +20,9 @@ func MustPublishMsgs(t *testing.T, jsctx nats.JetStreamContext, msgs ...*nats.Ms } } -func NewOutputEventMsg(t *testing.T, base *base.BaseDendrite, roomID string, update api.OutputEvent) *nats.Msg { +func NewOutputEventMsg(t *testing.T, cfg *config.Dendrite, roomID string, update api.OutputEvent) *nats.Msg { t.Helper() - msg := nats.NewMsg(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)) + msg := nats.NewMsg(cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)) msg.Header.Set(jetstream.RoomEventType, string(update.Type)) msg.Header.Set(jetstream.RoomID, roomID) var err error diff --git a/test/user.go b/test/user.go index 95a8f83e6..9509b95a6 100644 --- a/test/user.go +++ b/test/user.go @@ -17,17 +17,19 @@ package test import ( "crypto/ed25519" "fmt" + "strconv" "sync/atomic" "testing" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) var ( userIDCounter = int64(0) - serverName = gomatrixserverlib.ServerName("test") + serverName = spec.ServerName("test") keyID = gomatrixserverlib.KeyID("ed25519:test") privateKey = ed25519.NewKeyFromSeed([]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -47,16 +49,17 @@ var ( type User struct { ID string + Localpart string AccountType api.AccountType // key ID and private key of the server who has this user, if known. keyID gomatrixserverlib.KeyID privKey ed25519.PrivateKey - srvName gomatrixserverlib.ServerName + srvName spec.ServerName } type UserOpt func(*User) -func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt { +func WithSigningServer(srvName spec.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt { return func(u *User) { u.keyID = keyID u.privKey = privKey @@ -81,6 +84,7 @@ func NewUser(t *testing.T, opts ...UserOpt) *User { WithSigningServer(serverName, keyID, privateKey)(&u) } u.ID = fmt.Sprintf("@%d:%s", counter, u.srvName) + u.Localpart = strconv.Itoa(int(counter)) t.Logf("NewUser: created user %s", u.ID) return &u } diff --git a/userapi/api/api.go b/userapi/api/api.go index fa297f773..3fe6a383d 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -21,8 +21,11 @@ import ( "strings" "time" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" @@ -58,7 +61,7 @@ type MediaUserAPI interface { type FederationUserAPI interface { UploadDeviceKeysAPI QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error - QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error + QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error) QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error @@ -83,12 +86,12 @@ type ClientUserAPI interface { LoginTokenInternalAPI UserLoginAPI ClientKeyAPI + ProfileAPI QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error - QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error - QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error + QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error @@ -97,11 +100,9 @@ type ClientUserAPI interface { PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error - PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error + PerformPushRulesPut(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error - SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error - SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error @@ -113,6 +114,12 @@ type ClientUserAPI interface { PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error } +type ProfileAPI interface { + QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, serverName spec.ServerName, displayName string) (*authtypes.Profile, bool, error) +} + // custom api functions required by pinecone / p2p demos type QuerySearchProfilesAPI interface { QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error @@ -224,7 +231,6 @@ type PerformDeviceUpdateRequest struct { } type PerformDeviceUpdateResponse struct { DeviceExists bool - Forbidden bool } type PerformDeviceDeletionRequest struct { @@ -291,22 +297,6 @@ type QueryDevicesResponse struct { Devices []Device } -// QueryProfileRequest is the request for QueryProfile -type QueryProfileRequest struct { - // The user ID to query - UserID string -} - -// QueryProfileResponse is the response for QueryProfile -type QueryProfileResponse struct { - // True if the user exists. Querying for a profile does not create them. - UserExists bool - // The current display name if set. - DisplayName string - // The current avatar URL if set. - AvatarURL string -} - // QuerySearchProfilesRequest is the request for QueryProfile type QuerySearchProfilesRequest struct { // The search string to match @@ -323,9 +313,9 @@ type QuerySearchProfilesResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformAccountCreationRequest struct { - AccountType AccountType // Required: whether this is a guest or user account - Localpart string // Required: The localpart for this account. Ignored if account type is guest. - ServerName gomatrixserverlib.ServerName // optional: if not specified, default server name used instead + AccountType AccountType // Required: whether this is a guest or user account + Localpart string // Required: The localpart for this account. Ignored if account type is guest. + ServerName spec.ServerName // optional: if not specified, default server name used instead AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. Password string // optional: if missing then this account will be a passwordless account @@ -340,10 +330,10 @@ type PerformAccountCreationResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformPasswordUpdateRequest struct { - Localpart string // Required: The localpart for this account. - ServerName gomatrixserverlib.ServerName // Required: The domain for this account. - Password string // Required: The new password to set. - LogoutDevices bool // Optional: Whether to log out all user devices. + Localpart string // Required: The localpart for this account. + ServerName spec.ServerName // Required: The domain for this account. + Password string // Required: The new password to set. + LogoutDevices bool // Optional: Whether to log out all user devices. } // PerformAccountCreationResponse is the response for PerformAccountCreation @@ -367,8 +357,8 @@ type PerformLastSeenUpdateResponse struct { // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used - AccessToken string // optional: if blank one will be made on your behalf + ServerName spec.ServerName // optional: if blank, default server name used + AccessToken string // optional: if blank one will be made on your behalf // optional: if nil an ID is generated for you. If set, replaces any existing device session, // which will generate a new access token and invalidate the old one. DeviceID *string @@ -393,7 +383,7 @@ type PerformDeviceCreationResponse struct { // PerformAccountDeactivationRequest is the request for PerformAccountDeactivation type PerformAccountDeactivationRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used + ServerName spec.ServerName // optional: if blank, default server name used } // PerformAccountDeactivationResponse is the response for PerformAccountDeactivation @@ -443,7 +433,7 @@ type Device struct { AccountType AccountType } -func (d *Device) UserDomain() gomatrixserverlib.ServerName { +func (d *Device) UserDomain() spec.ServerName { _, domain, err := gomatrixserverlib.SplitID('@', d.UserID) if err != nil { // This really is catastrophic because it means that someone @@ -459,7 +449,7 @@ func (d *Device) UserDomain() gomatrixserverlib.ServerName { type Account struct { UserID string Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName AppServiceID string AccountType AccountType // TODO: Associations (e.g. with application services) @@ -525,7 +515,7 @@ const ( type QueryPushersRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryPushersResponse struct { @@ -535,13 +525,13 @@ type QueryPushersResponse struct { type PerformPusherSetRequest struct { Pusher // Anonymous field because that's how clientapi unmarshals it. Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName Append bool `json:"append"` } type PerformPusherDeletionRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName SessionID int64 } @@ -566,25 +556,12 @@ const ( HTTPKind PusherKind = "http" ) -type PerformPushRulesPutRequest struct { - UserID string `json:"user_id"` - RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` -} - -type QueryPushRulesRequest struct { - UserID string `json:"user_id"` -} - -type QueryPushRulesResponse struct { - RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` -} - type QueryNotificationsRequest struct { - Localpart string `json:"localpart"` // Required. - ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required. - From string `json:"from,omitempty"` - Limit int `json:"limit,omitempty"` - Only string `json:"only,omitempty"` + Localpart string `json:"localpart"` // Required. + ServerName spec.ServerName `json:"server_name"` // Required. + From string `json:"from,omitempty"` + Limit int `json:"limit,omitempty"` + Only string `json:"only,omitempty"` } type QueryNotificationsResponse struct { @@ -593,26 +570,16 @@ type QueryNotificationsResponse struct { } type Notification struct { - Actions []*pushrules.Action `json:"actions"` // Required. - Event gomatrixserverlib.ClientEvent `json:"event"` // Required. - ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional. - Read bool `json:"read"` // Required. - RoomID string `json:"room_id"` // Required. - TS gomatrixserverlib.Timestamp `json:"ts"` // Required. -} - -type PerformSetAvatarURLRequest struct { - Localpart string - ServerName gomatrixserverlib.ServerName - AvatarURL string -} -type PerformSetAvatarURLResponse struct { - Profile *authtypes.Profile `json:"profile"` - Changed bool `json:"changed"` + Actions []*pushrules.Action `json:"actions"` // Required. + Event synctypes.ClientEvent `json:"event"` // Required. + ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional. + Read bool `json:"read"` // Required. + RoomID string `json:"room_id"` // Required. + TS spec.Timestamp `json:"ts"` // Required. } type QueryNumericLocalpartRequest struct { - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryNumericLocalpartResponse struct { @@ -621,7 +588,7 @@ type QueryNumericLocalpartResponse struct { type QueryAccountAvailabilityRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryAccountAvailabilityResponse struct { @@ -630,7 +597,7 @@ type QueryAccountAvailabilityResponse struct { type QueryAccountByPasswordRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName PlaintextPassword string } @@ -639,29 +606,18 @@ type QueryAccountByPasswordResponse struct { Exists bool } -type PerformUpdateDisplayNameRequest struct { - Localpart string - ServerName gomatrixserverlib.ServerName - DisplayName string -} - -type PerformUpdateDisplayNameResponse struct { - Profile *authtypes.Profile `json:"profile"` - Changed bool `json:"changed"` -} - type QueryLocalpartForThreePIDRequest struct { ThreePID, Medium string } type QueryLocalpartForThreePIDResponse struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryThreePIDsForLocalpartRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryThreePIDsForLocalpartResponse struct { @@ -673,13 +629,13 @@ type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest type PerformSaveThreePIDAssociationRequest struct { ThreePID string Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName Medium string } type QueryAccountByLocalpartRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryAccountByLocalpartResponse struct { @@ -752,9 +708,9 @@ type OutputCrossSigningKeyUpdate struct { } type CrossSigningKeyUpdate struct { - MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` - SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` - UserID string `json:"user_id"` + MasterKey *fclient.CrossSigningKey `json:"master_key,omitempty"` + SelfSigningKey *fclient.CrossSigningKey `json:"self_signing_key,omitempty"` + UserID string `json:"user_id"` } // DeviceKeysEqual returns true if the device keys updates contain the @@ -887,7 +843,7 @@ type PerformClaimKeysResponse struct { } type PerformUploadDeviceKeysRequest struct { - gomatrixserverlib.CrossSigningKeys + fclient.CrossSigningKeys // The user that uploaded the key, should be populated by the clientapi. UserID string } @@ -897,7 +853,7 @@ type PerformUploadDeviceKeysResponse struct { } type PerformUploadDeviceSignaturesRequest struct { - Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice + Signatures map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice // The user that uploaded the sig, should be populated by the clientapi. UserID string } @@ -921,9 +877,9 @@ type QueryKeysResponse struct { // Map of user_id to device_id to device_key DeviceKeys map[string]map[string]json.RawMessage // Maps of user_id to cross signing key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + MasterKeys map[string]fclient.CrossSigningKey + SelfSigningKeys map[string]fclient.CrossSigningKey + UserSigningKeys map[string]fclient.CrossSigningKey // Set if there was a fatal error processing this query Error *KeyError } @@ -978,17 +934,17 @@ type QuerySignaturesResponse struct { // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap // A map of target user ID -> cross-signing master key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey + MasterKeys map[string]fclient.CrossSigningKey // A map of target user ID -> cross-signing self-signing key - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + SelfSigningKeys map[string]fclient.CrossSigningKey // A map of target user ID -> cross-signing user-signing key - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + UserSigningKeys map[string]fclient.CrossSigningKey // The request error, if any Error *KeyError } type PerformMarkAsStaleRequest struct { UserID string - Domain gomatrixserverlib.ServerName + Domain spec.ServerName DeviceID string } diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 51bd2753a..ba72ff350 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -18,6 +18,7 @@ import ( "context" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -38,7 +39,7 @@ type OutputReceiptEventConsumer struct { durable string topic string db storage.UserDatabase - serverName gomatrixserverlib.ServerName + serverName spec.ServerName syncProducer *producers.SyncAPI pgClient pushgateway.Client } @@ -104,7 +105,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return false } - updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) + updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(spec.AsTimestamp(metadata.Timestamp)), true) if err != nil { log.WithError(err).Error("userapi EDU consumer") return false diff --git a/userapi/consumers/devicelistupdate.go b/userapi/consumers/devicelistupdate.go index a65889fcc..3389bb808 100644 --- a/userapi/consumers/devicelistupdate.go +++ b/userapi/consumers/devicelistupdate.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -35,7 +36,7 @@ type DeviceListUpdateConsumer struct { durable string topic string updater *internal.DeviceListUpdater - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool } // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. @@ -72,7 +73,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M logrus.WithError(err).Errorf("Failed to read from device list update input topic") return true } - origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) + origin := spec.ServerName(msg.Header.Get("origin")) if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { return true } else if t.isLocalServerName(serverName) { diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 47d330959..834964ce8 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -13,6 +13,7 @@ import ( "github.com/tidwall/gjson" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -23,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/producers" @@ -42,11 +44,11 @@ type OutputRoomEventConsumer struct { topic string pgClient pushgateway.Client syncProducer *producers.SyncAPI - msgCounts map[gomatrixserverlib.ServerName]userAPITypes.MessageStats - roomCounts map[gomatrixserverlib.ServerName]map[string]bool // map from serverName to map from rommID to "isEncrypted" + msgCounts map[spec.ServerName]userAPITypes.MessageStats + roomCounts map[spec.ServerName]map[string]bool // map from serverName to map from rommID to "isEncrypted" lastUpdate time.Time countsLock sync.Mutex - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } func NewOutputRoomEventConsumer( @@ -68,8 +70,8 @@ func NewOutputRoomEventConsumer( pgClient: pgClient, rsAPI: rsAPI, syncProducer: syncProducer, - msgCounts: map[gomatrixserverlib.ServerName]userAPITypes.MessageStats{}, - roomCounts: map[gomatrixserverlib.ServerName]map[string]bool{}, + msgCounts: map[spec.ServerName]userAPITypes.MessageStats{}, + roomCounts: map[spec.ServerName]map[string]bool{}, lastUpdate: time.Now(), countsLock: sync.Mutex{}, serverName: cfg.Matrix.ServerName, @@ -118,7 +120,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return true } - if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil { + if err := s.processMessage(ctx, event, uint64(spec.AsTimestamp(metadata.Timestamp))); err != nil { log.WithFields(log.Fields{ "event_id": event.EventID(), }).WithError(err).Errorf("userapi consumer: process room event failure") @@ -209,7 +211,7 @@ func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoom return nil } -func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error { +func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName spec.ServerName) error { pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName) if err != nil { return fmt.Errorf("failed to query pushrules for user: %w", err) @@ -237,7 +239,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, } // updateMDirect copies the "is_direct" flag from oldRoomID to newROomID -func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error { +func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName, roomSize int) error { // this is most likely not a DM, so skip updating m.direct state if roomSize > 2 { return nil @@ -279,7 +281,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, return nil } -func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error { +func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName) error { tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag") if err != nil && !errors.Is(err, sql.ErrNoRows) { return err @@ -297,14 +299,14 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gom } switch { - case event.Type() == gomatrixserverlib.MRoomMember: - cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll) + case event.Type() == spec.MRoomMember: + cevent := synctypes.HeaderedToClientEvent(event, synctypes.FormatAll) var member *localMembership member, err = newLocalMembership(&cevent) if err != nil { return fmt.Errorf("newLocalMembership: %w", err) } - if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName { + if member.Membership == spec.Invite && member.Domain == s.cfg.Matrix.ServerName { // localRoomMembers only adds joined members. An invite // should also be pushed to the target user. members = append(members, member) @@ -355,10 +357,10 @@ type localMembership struct { gomatrixserverlib.MemberContent UserID string Localpart string - Domain gomatrixserverlib.ServerName + Domain spec.ServerName } -func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) { +func newLocalMembership(event *synctypes.ClientEvent) (*localMembership, error) { if event.StateKey == nil { return nil, fmt.Errorf("missing state_key") } @@ -417,7 +419,7 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s log.WithError(err).Errorf("Parsing MemberContent") continue } - if member.Membership != gomatrixserverlib.Join { + if member.Membership != spec.Join { continue } if member.Domain != s.cfg.Matrix.ServerName { @@ -435,7 +437,7 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s // m.room.canonical_alias is consulted. Returns an empty string if the // room has no name. func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { - if event.Type() == gomatrixserverlib.MRoomName { + if event.Type() == spec.MRoomName { name, err := unmarshalRoomName(event) if err != nil { return "", err @@ -460,7 +462,7 @@ func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixs return unmarshalRoomName(eventS) } - if event.Type() == gomatrixserverlib.MRoomCanonicalAlias { + if event.Type() == spec.MRoomCanonicalAlias { alias, err := unmarshalCanonicalAlias(event) if err != nil { return "", err @@ -479,8 +481,8 @@ func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixs } var ( - canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias} - roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName} + canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCanonicalAlias} + roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomName} ) func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) { @@ -531,14 +533,14 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync), + Event: synctypes.HeaderedToClientEvent(event, synctypes.FormatSync), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it // to "work", but they only use a single device. ProfileTag: profileTag, RoomID: event.RoomID(), - TS: gomatrixserverlib.AsTimestamp(time.Now()), + TS: spec.AsTimestamp(time.Now()), } if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil { return fmt.Errorf("s.db.InsertNotification: %w", err) @@ -683,7 +685,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err req := &rsapi.QueryLatestEventsAndStateRequest{ RoomID: rse.roomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomPowerLevels}, + {EventType: spec.MRoomPowerLevels}, }, } var res rsapi.QueryLatestEventsAndStateResponse @@ -691,7 +693,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err return false, err } for _, ev := range res.StateEvents { - if ev.Type() != gomatrixserverlib.MRoomPowerLevels { + if ev.Type() != spec.MRoomPowerLevels { continue } @@ -706,7 +708,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err // localPushDevices pushes to the configured devices of a local // user. The map keys are [url][format]. -func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { +func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName spec.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db) if err != nil { return nil, "", fmt.Errorf("util.GetPushDevices: %w", err) @@ -804,7 +806,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri } // deleteRejectedPushers deletes the pushers associated with the given devices. -func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) { +func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName spec.ServerName) { log.WithFields(log.Fields{ "localpart": localpart, "app_id0": devices[0].AppID, diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index bc5ae652d..e233bc66e 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -7,22 +7,23 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/storage" userAPITypes "github.com/matrix-org/dendrite/userapi/types" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { - base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.NewUserDatabase(context.Background(), cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") if err != nil { @@ -30,7 +31,6 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, } return db, func() { close() - baseclose() } } @@ -140,9 +140,9 @@ func TestMessageStats(t *testing.T) { tests := []struct { name string args args - ourServer gomatrixserverlib.ServerName + ourServer spec.ServerName lastUpdate time.Time - initRoomCounts map[gomatrixserverlib.ServerName]map[string]bool + initRoomCounts map[spec.ServerName]map[string]bool wantStats userAPITypes.MessageStats }{ { @@ -198,7 +198,7 @@ func TestMessageStats(t *testing.T) { name: "day change creates a new room map", ourServer: "localhost", lastUpdate: time.Now().Add(-time.Hour * 24), - initRoomCounts: map[gomatrixserverlib.ServerName]map[string]bool{ + initRoomCounts: map[spec.ServerName]map[string]bool{ "localhost": {"encryptedRoom": true}, }, args: args{ @@ -220,11 +220,11 @@ func TestMessageStats(t *testing.T) { tt.lastUpdate = time.Now() } if tt.initRoomCounts == nil { - tt.initRoomCounts = map[gomatrixserverlib.ServerName]map[string]bool{} + tt.initRoomCounts = map[spec.ServerName]map[string]bool{} } s := &OutputRoomEventConsumer{ db: db, - msgCounts: map[gomatrixserverlib.ServerName]userAPITypes.MessageStats{}, + msgCounts: map[spec.ServerName]userAPITypes.MessageStats{}, roomCounts: tt.initRoomCounts, countsLock: sync.Mutex{}, lastUpdate: tt.lastUpdate, diff --git a/userapi/consumers/signingkeyupdate.go b/userapi/consumers/signingkeyupdate.go index f4ff017db..457a61838 100644 --- a/userapi/consumers/signingkeyupdate.go +++ b/userapi/consumers/signingkeyupdate.go @@ -19,6 +19,8 @@ import ( "encoding/json" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -36,7 +38,7 @@ type SigningKeyUpdateConsumer struct { topic string userAPI api.UploadDeviceKeysAPI cfg *config.UserAPI - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool } // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. @@ -74,7 +76,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M logrus.WithError(err).Errorf("Failed to read from signing key update input topic") return true } - origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) + origin := spec.ServerName(msg.Header.Get("origin")) if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { logrus.WithError(err).Error("failed to split user id") return true @@ -86,7 +88,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M return true } - keys := gomatrixserverlib.CrossSigningKeys{} + keys := fclient.CrossSigningKeys{} if updatePayload.MasterKey != nil { keys.MasterKey = *updatePayload.MasterKey } diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go index 8b9704d1b..ea7b84f6b 100644 --- a/userapi/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -25,11 +25,13 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "golang.org/x/crypto/curve25519" ) -func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpose gomatrixserverlib.CrossSigningKeyPurpose) error { +func sanityCheckKey(key fclient.CrossSigningKey, userID string, purpose fclient.CrossSigningKeyPurpose) error { // Is there exactly one key? if len(key.Keys) != 1 { return fmt.Errorf("should contain exactly one key") @@ -105,12 +107,12 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos // nolint:gocyclo func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { // Find the keys to store. - byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + byPurpose := map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey{} toStore := types.CrossSigningKeyMap{} hasMasterKey := false if len(req.MasterKey.Keys) > 0 { - if err := sanityCheckKey(req.MasterKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err != nil { + if err := sanityCheckKey(req.MasterKey, req.UserID, fclient.CrossSigningKeyPurposeMaster); err != nil { res.Error = &api.KeyError{ Err: "Master key sanity check failed: " + err.Error(), IsInvalidParam: true, @@ -118,15 +120,15 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. return nil } - byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey + byPurpose[fclient.CrossSigningKeyPurposeMaster] = req.MasterKey for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey - toStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = key + toStore[fclient.CrossSigningKeyPurposeMaster] = key } hasMasterKey = true } if len(req.SelfSigningKey.Keys) > 0 { - if err := sanityCheckKey(req.SelfSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err != nil { + if err := sanityCheckKey(req.SelfSigningKey, req.UserID, fclient.CrossSigningKeyPurposeSelfSigning); err != nil { res.Error = &api.KeyError{ Err: "Self-signing key sanity check failed: " + err.Error(), IsInvalidParam: true, @@ -134,14 +136,14 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. return nil } - byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey + byPurpose[fclient.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey - toStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = key + toStore[fclient.CrossSigningKeyPurposeSelfSigning] = key } } if len(req.UserSigningKey.Keys) > 0 { - if err := sanityCheckKey(req.UserSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeUserSigning); err != nil { + if err := sanityCheckKey(req.UserSigningKey, req.UserID, fclient.CrossSigningKeyPurposeUserSigning); err != nil { res.Error = &api.KeyError{ Err: "User-signing key sanity check failed: " + err.Error(), IsInvalidParam: true, @@ -149,9 +151,9 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. return nil } - byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey + byPurpose[fclient.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey - toStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = key + toStore[fclient.CrossSigningKeyPurposeUserSigning] = key } } @@ -180,7 +182,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. // If we still can't find a master key for the user then stop the upload. // This satisfies the "Fails to upload self-signing key without master key" test. if !hasMasterKey { - if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey { + if _, hasMasterKey = existingKeys[fclient.CrossSigningKeyPurposeMaster]; !hasMasterKey { res.Error = &api.KeyError{ Err: "No master key was found", IsMissingParam: true, @@ -191,10 +193,10 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. // Check if anything actually changed compared to what we have in the database. changed := false - for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{ - gomatrixserverlib.CrossSigningKeyPurposeMaster, - gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, - gomatrixserverlib.CrossSigningKeyPurposeUserSigning, + for _, purpose := range []fclient.CrossSigningKeyPurpose{ + fclient.CrossSigningKeyPurposeMaster, + fclient.CrossSigningKeyPurposeSelfSigning, + fclient.CrossSigningKeyPurposeUserSigning, } { old, gotOld := existingKeys[purpose] new, gotNew := toStore[purpose] @@ -248,10 +250,10 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. update := api.CrossSigningKeyUpdate{ UserID: req.UserID, } - if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok { + if mk, ok := byPurpose[fclient.CrossSigningKeyPurposeMaster]; ok { update.MasterKey = &mk } - if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok { + if ssk, ok := byPurpose[fclient.CrossSigningKeyPurposeSelfSigning]; ok { update.SelfSigningKey = &ssk } if update.MasterKey == nil && update.SelfSigningKey == nil { @@ -279,36 +281,36 @@ func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req } _ = a.QueryKeys(ctx, queryReq, queryRes) - selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + selfSignatures := map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice{} + otherSignatures := map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice{} // Sort signatures into two groups: one where people have signed their own // keys and one where people have signed someone elses for userID, forUserID := range req.Signatures { for keyID, keyOrDevice := range forUserID { switch key := keyOrDevice.CrossSigningBody.(type) { - case *gomatrixserverlib.CrossSigningKey: + case *fclient.CrossSigningKey: if key.UserID == req.UserID { if _, ok := selfSignatures[userID]; !ok { - selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + selfSignatures[userID] = map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice{} } selfSignatures[userID][keyID] = keyOrDevice } else { if _, ok := otherSignatures[userID]; !ok { - otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + otherSignatures[userID] = map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice{} } otherSignatures[userID][keyID] = keyOrDevice } - case *gomatrixserverlib.DeviceKeys: + case *fclient.DeviceKeys: if key.UserID == req.UserID { if _, ok := selfSignatures[userID]; !ok { - selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + selfSignatures[userID] = map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice{} } selfSignatures[userID][keyID] = keyOrDevice } else { if _, ok := otherSignatures[userID]; !ok { - otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + otherSignatures[userID] = map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice{} } otherSignatures[userID][keyID] = keyOrDevice } @@ -354,7 +356,7 @@ func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req func (a *UserInternalAPI) processSelfSignatures( ctx context.Context, - signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, + signatures map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice, ) error { // Here we will process: // * The user signing their own devices using their self-signing key @@ -363,7 +365,7 @@ func (a *UserInternalAPI) processSelfSignatures( for targetUserID, forTargetUserID := range signatures { for targetKeyID, signature := range forTargetUserID { switch sig := signature.CrossSigningBody.(type) { - case *gomatrixserverlib.CrossSigningKey: + case *fclient.CrossSigningKey: for keyID := range sig.Keys { split := strings.SplitN(string(keyID), ":", 2) if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID { @@ -381,7 +383,7 @@ func (a *UserInternalAPI) processSelfSignatures( } } - case *gomatrixserverlib.DeviceKeys: + case *fclient.DeviceKeys: for originUserID, forOriginUserID := range sig.Signatures { for originKeyID, originSig := range forOriginUserID { if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( @@ -403,7 +405,7 @@ func (a *UserInternalAPI) processSelfSignatures( func (a *UserInternalAPI) processOtherSignatures( ctx context.Context, userID string, queryRes *api.QueryKeysResponse, - signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, + signatures map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice, ) error { // Here we will process: // * A user signing someone else's master keys using their user-signing keys @@ -411,7 +413,7 @@ func (a *UserInternalAPI) processOtherSignatures( for targetUserID, forTargetUserID := range signatures { for _, signature := range forTargetUserID { switch sig := signature.CrossSigningBody.(type) { - case *gomatrixserverlib.CrossSigningKey: + case *fclient.CrossSigningKey: // Find the local copy of the master key. We'll use this to be // sure that the supplied stanza matches the key that we think it // should be. @@ -484,12 +486,12 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase( continue } - appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) { + appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature spec.Base64Bytes) { if key.Signatures == nil { key.Signatures = types.CrossSigningSigMap{} } if _, ok := key.Signatures[originUserID]; !ok { - key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes) + key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]spec.Base64Bytes) } key.Signatures[originUserID][originKeyID] = signature } @@ -509,13 +511,13 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase( } switch keyType { - case gomatrixserverlib.CrossSigningKeyPurposeMaster: + case fclient.CrossSigningKeyPurposeMaster: res.MasterKeys[targetUserID] = key - case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: + case fclient.CrossSigningKeyPurposeSelfSigning: res.SelfSigningKeys[targetUserID] = key - case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: + case fclient.CrossSigningKeyPurposeUserSigning: res.UserSigningKeys[targetUserID] = key } } @@ -534,21 +536,21 @@ func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySig for targetPurpose, targetKey := range keyMap { switch targetPurpose { - case gomatrixserverlib.CrossSigningKeyPurposeMaster: + case fclient.CrossSigningKeyPurposeMaster: if res.MasterKeys == nil { - res.MasterKeys = map[string]gomatrixserverlib.CrossSigningKey{} + res.MasterKeys = map[string]fclient.CrossSigningKey{} } res.MasterKeys[targetUserID] = targetKey - case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: + case fclient.CrossSigningKeyPurposeSelfSigning: if res.SelfSigningKeys == nil { - res.SelfSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{} + res.SelfSigningKeys = map[string]fclient.CrossSigningKey{} } res.SelfSigningKeys[targetUserID] = targetKey - case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: + case fclient.CrossSigningKeyPurposeUserSigning: if res.UserSigningKeys == nil { - res.UserSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{} + res.UserSigningKeys = map[string]fclient.CrossSigningKey{} } res.UserSigningKeys[targetUserID] = targetKey } @@ -576,7 +578,7 @@ func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySig res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{} } if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok { - res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig } diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go index 3b4dcf98e..d60e522e8 100644 --- a/userapi/internal/device_list_update.go +++ b/userapi/internal/device_list_update.go @@ -25,6 +25,8 @@ import ( "time" rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" @@ -97,8 +99,8 @@ type DeviceListUpdater struct { api DeviceListUpdaterAPI producer KeyChangeProducer fedClient fedsenderapi.KeyserverFederationAPI - workerChans []chan gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName + workerChans []chan spec.ServerName + thisServer spec.ServerName // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // block on or timeout via a select. @@ -112,7 +114,7 @@ type DeviceListUpdater struct { type DeviceListUpdaterDatabase interface { // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) // MarkDeviceListStale sets the stale bit for this user to isStale. MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error @@ -145,7 +147,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, - rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, + rsAPI rsapi.KeyserverRoomserverAPI, thisServer spec.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -156,7 +158,7 @@ func NewDeviceListUpdater( producer: producer, fedClient: fedClient, thisServer: thisServer, - workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), + workerChans: make([]chan spec.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, rsAPI: rsAPI, @@ -169,12 +171,12 @@ func (u *DeviceListUpdater) Start() error { // Allocate a small buffer per channel. // If the buffer limit is reached, backpressure will cause the processing of EDUs // to stop (in this transaction) until key requests can be made. - ch := make(chan gomatrixserverlib.ServerName, 10) + ch := make(chan spec.ServerName, 10) u.workerChans[i] = ch go u.worker(ch) } - staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []spec.ServerName{}) if err != nil { return err } @@ -194,7 +196,7 @@ func (u *DeviceListUpdater) Start() error { // CleanUp removes stale device entries for users we don't share a room with anymore func (u *DeviceListUpdater) CleanUp() error { - staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []spec.ServerName{}) if err != nil { return err } @@ -222,7 +224,7 @@ func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { // ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it. // Blocks until the device list is synced or the timeout is reached. -func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error { +func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName spec.ServerName, userID string) error { mu := u.mutex(userID) mu.Lock() err := u.db.MarkDeviceListStale(ctx, userID, true) @@ -368,12 +370,12 @@ func (u *DeviceListUpdater) clearChannel(userID string) { } } -func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { - retries := make(map[gomatrixserverlib.ServerName]time.Time) +func (u *DeviceListUpdater) worker(ch chan spec.ServerName) { + retries := make(map[spec.ServerName]time.Time) retriesMu := &sync.Mutex{} // restarter goroutine which will inject failed servers into ch when it is time go func() { - var serversToRetry []gomatrixserverlib.ServerName + var serversToRetry []spec.ServerName for { serversToRetry = serversToRetry[:0] // reuse memory time.Sleep(time.Second) @@ -412,7 +414,7 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { } } -func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { +func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Duration, bool) { ctx := u.process.Context() logger := util.GetLogger(ctx).WithField("server_name", serverName) deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() @@ -420,7 +422,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam waitTime := defaultWaitTime // How long should we wait to try again? successCount := 0 // How many user requests failed? - userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) + userIDs, err := u.db.StaleDeviceLists(ctx, []spec.ServerName{serverName}) if err != nil { logger.WithError(err).Error("Failed to load stale device lists") return waitTime, true @@ -456,7 +458,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam return waitTime, !allUsersSucceeded } -func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) { +func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName spec.ServerName, userID string) (time.Duration, error) { ctx, cancel := context.WithTimeout(ctx, requestTimeout) defer cancel() logger := util.GetLogger(ctx).WithFields(logrus.Fields{ @@ -508,12 +510,12 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go } uploadRes := &api.PerformUploadDeviceKeysResponse{} if res.MasterKey != nil { - if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil { + if err = sanityCheckKey(*res.MasterKey, userID, fclient.CrossSigningKeyPurposeMaster); err == nil { uploadReq.MasterKey = *res.MasterKey } } if res.SelfSigningKey != nil { - if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil { + if err = sanityCheckKey(*res.SelfSigningKey, userID, fclient.CrossSigningKeyPurposeSelfSigning); err == nil { uploadReq.SelfSigningKey = *res.SelfSigningKey } } @@ -527,7 +529,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go return defaultWaitTime, nil } -func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { +func (u *DeviceListUpdater) updateDeviceList(res *fclient.RespUserDevices) error { ctx := context.Background() // we've got the keys, don't time out when persisting them to the database. keys := make([]api.DeviceMessage, len(res.Devices)) existingKeys := make([]api.DeviceMessage, len(res.Devices)) diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go index 868fc9be8..3a269801c 100644 --- a/userapi/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -27,13 +27,15 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" ) @@ -64,7 +66,7 @@ func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Conte // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. -func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { +func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) { d.mu.Lock() defer d.mu.Unlock() var result []string @@ -135,19 +137,19 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return t.fn(req) } -func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { +func newFedClient(tripper func(*http.Request) (*http.Response, error)) *fclient.FederationClient { _, pkey, _ := ed25519.GenerateKey(nil) - fedClient := gomatrixserverlib.NewFederationClient( - []*gomatrixserverlib.SigningIdentity{ + fedClient := fclient.NewFederationClient( + []*fclient.SigningIdentity{ { - ServerName: gomatrixserverlib.ServerName("example.test"), + ServerName: spec.ServerName("example.test"), KeyID: gomatrixserverlib.KeyID("ed25519:test"), PrivateKey: pkey, }, }, ) - fedClient.Client = *gomatrixserverlib.NewClient( - gomatrixserverlib.WithTransport(&roundTripper{tripper}), + fedClient.Client = *fclient.NewClient( + fclient.WithTransport(&roundTripper{tripper}), ) return fedClient } @@ -293,7 +295,7 @@ func TestDebounce(t *testing.T) { ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} fedCh := make(chan *http.Response, 1) - srv := gomatrixserverlib.ServerName("example.com") + srv := spec.ServerName("example.com") userID := "@alice:example.com" keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` incomingFedReq := make(chan struct{}) @@ -363,9 +365,9 @@ func TestDebounce(t *testing.T) { func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() - base, _, _ := testrig.Base(nil) connStr, clearDB := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.NewKeyDatabase(cm, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) if err != nil { t.Fatal(err) } @@ -413,7 +415,7 @@ func TestDeviceListUpdater_CleanUp(t *testing.T) { } // check that we still have Alice in our stale list - staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + staleUsers, err := db.StaleDeviceLists(ctx, []spec.ServerName{"test"}) if err != nil { t.Error(err) } diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index be816fe5d..0b188b091 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -24,6 +24,8 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -79,7 +81,7 @@ func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.Perform domainToDeviceKeys[string(serverName)] = nested } for domain, local := range domainToDeviceKeys { - if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if !a.Config.Matrix.IsLocalServerName(spec.ServerName(domain)) { continue } // claim local keys @@ -128,7 +130,7 @@ func (a *UserInternalAPI) claimRemoteKeys( defer cancel() defer wg.Done() - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, spec.ServerName(domain), keysToClaim) mu.Lock() defer mu.Unlock() @@ -229,9 +231,9 @@ func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *a func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { var respMu sync.Mutex res.DeviceKeys = make(map[string]map[string]json.RawMessage) - res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) - res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) - res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) + res.MasterKeys = make(map[string]fclient.CrossSigningKey) + res.SelfSigningKeys = make(map[string]fclient.CrossSigningKey) + res.UserSigningKeys = make(map[string]fclient.CrossSigningKey) res.Failures = make(map[string]interface{}) // make a map from domain to device keys @@ -320,7 +322,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque for targetUserID, masterKey := range res.MasterKeys { if masterKey.Signatures == nil { - masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } for targetKeyID := range masterKey.Keys { sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) @@ -339,7 +341,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque for sourceUserID, forSourceUser := range sigMap { for sourceKeyID, sourceSig := range forSourceUser { if _, ok := masterKey.Signatures[sourceUserID]; !ok { - masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig } @@ -362,17 +364,17 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque if len(sigMap) == 0 { continue } - var deviceKey gomatrixserverlib.DeviceKeys + var deviceKey fclient.DeviceKeys if err = json.Unmarshal(key, &deviceKey); err != nil { continue } if deviceKey.Signatures == nil { - deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } for sourceUserID, forSourceUser := range sigMap { for sourceKeyID, sourceSig := range forSourceUser { if _, ok := deviceKey.Signatures[sourceUserID]; !ok { - deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig } @@ -415,7 +417,7 @@ func (a *UserInternalAPI) queryRemoteKeys( ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{}, ) { - resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys)) + resultCh := make(chan *fclient.RespQueryKeys, len(domainToDeviceKeys)) // allows us to wait until all federation servers have been poked var wg sync.WaitGroup // mutex for writing directly to res (e.g failures) @@ -423,13 +425,13 @@ func (a *UserInternalAPI) queryRemoteKeys( domains := map[string]struct{}{} for domain := range domainToDeviceKeys { - if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if a.Config.Matrix.IsLocalServerName(spec.ServerName(domain)) { continue } domains[domain] = struct{}{} } for domain := range domainToCrossSigningKeys { - if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if a.Config.Matrix.IsLocalServerName(spec.ServerName(domain)) { continue } domains[domain] = struct{}{} @@ -450,7 +452,7 @@ func (a *UserInternalAPI) queryRemoteKeys( close(resultCh) }() - processResult := func(result *gomatrixserverlib.RespQueryKeys) { + processResult := func(result *fclient.RespQueryKeys) { respMu.Lock() defer respMu.Unlock() for userID, nest := range result.DeviceKeys { @@ -483,7 +485,7 @@ func (a *UserInternalAPI) queryRemoteKeys( func (a *UserInternalAPI) queryRemoteKeysOnServer( ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{}, - wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, + wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *fclient.RespQueryKeys, res *api.QueryKeysResponse, ) { defer wg.Done() @@ -513,7 +515,7 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer( } } for userID := range userIDsForAllDevices { - err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID) + err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID) if err != nil { logrus.WithFields(logrus.Fields{ logrus.ErrorKey: err, @@ -541,7 +543,7 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer( if len(devKeys) == 0 { return } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, spec.ServerName(serverName), devKeys) if err == nil { resultCh <- &queryKeysResp return @@ -670,7 +672,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe } else { // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID) if err != nil { continue // ignore invalid users diff --git a/userapi/internal/key_api_test.go b/userapi/internal/key_api_test.go index fc7e7e0df..de2a6d2c8 100644 --- a/userapi/internal/key_api_test.go +++ b/userapi/internal/key_api_test.go @@ -5,9 +5,9 @@ import ( "reflect" "testing" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/storage" @@ -16,15 +16,14 @@ import ( func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - base, _, _ := testrig.Base(nil) - db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.NewKeyDatabase(cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { t.Fatalf("failed to create new user db: %v", err) } return db, func() { - base.Close() close() } } diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 8977697b5..1b6a4ebfa 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -23,8 +23,12 @@ import ( "strconv" "time" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" @@ -110,7 +114,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } - deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) + deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(spec.AsTimestamp(time.Now()))) if err != nil { logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") return err @@ -386,11 +390,6 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf } res.DeviceExists = true - if dev.UserID != req.RequestingUserID { - res.Forbidden = true - return nil - } - err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName) if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") @@ -423,25 +422,26 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf return nil } -func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { - local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) +var ( + ErrIsRemoteServer = errors.New("cannot query profile of remote users") +) + +func (a *UserInternalAPI) QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error) { + local, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { - return err + return nil, err } if !a.Config.Matrix.IsLocalServerName(domain) { - return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) + return nil, ErrIsRemoteServer } prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain) if err != nil { if err == sql.ErrNoRows { - return nil + return nil, appserviceAPI.ErrProfileNotExists } - return err + return nil, err } - res.UserExists = true - res.AvatarURL = prof.AvatarURL - res.DisplayName = prof.DisplayName - return nil + return prof, nil } func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { @@ -874,43 +874,32 @@ func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPusher func (a *UserInternalAPI) PerformPushRulesPut( ctx context.Context, - req *api.PerformPushRulesPutRequest, - _ *struct{}, + userID string, + ruleSets *pushrules.AccountRuleSets, ) error { - bs, err := json.Marshal(&req.RuleSets) + bs, err := json.Marshal(ruleSets) if err != nil { return err } userReq := api.InputAccountDataRequest{ - UserID: req.UserID, + UserID: userID, DataType: pushRulesAccountDataType, AccountData: json.RawMessage(bs), } var userRes api.InputAccountDataResponse // empty - if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { - return err - } - return nil + return a.InputAccountData(ctx, &userReq, &userRes) } -func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { - localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) +func (a *UserInternalAPI) QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) { + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { - return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) + return nil, fmt.Errorf("failed to split user ID %q for push rules", userID) } - pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain) - if err != nil { - return fmt.Errorf("failed to query push rules: %w", err) - } - res.RuleSets = pushRules - return nil + return a.DB.QueryPushRules(ctx, localpart, domain) } -func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL) - res.Profile = profile - res.Changed = changed - return err +func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error) { + return a.DB.SetAvatarURL(ctx, localpart, serverName, avatarURL) } func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error { @@ -944,11 +933,8 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } } -func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { - profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName) - res.Profile = profile - res.Changed = changed - return err +func (a *UserInternalAPI) SetDisplayName(ctx context.Context, localpart string, serverName spec.ServerName, displayName string) (*authtypes.Profile, bool, error) { + return a.DB.SetDisplayName(ctx, localpart, serverName, displayName) } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 278378861..4f5e99a8a 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -20,6 +20,8 @@ import ( "errors" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" @@ -29,40 +31,40 @@ import ( ) type Profile interface { - GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) + GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) + SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, serverName spec.ServerName, displayName string) (*authtypes.Profile, bool, error) } type Account interface { // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) - GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error) - GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) - GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) - DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) - SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error + CreateAccount(ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword string) (*api.Account, error) + GetNewNumericLocalpart(ctx context.Context, serverName spec.ServerName) (int64, error) + CheckAccountAvailability(ctx context.Context, localpart string, serverName spec.ServerName) (bool, error) + GetAccountByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*api.Account, error) + DeactivateAccount(ctx context.Context, localpart string, serverName spec.ServerName) (err error) + SetPassword(ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword string) error } type AccountData interface { - SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error - GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) + SaveAccountData(ctx context.Context, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error + GetAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given // localpart, room ID and type. // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval - GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error) - QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error) + GetAccountDataByType(ctx context.Context, localpart string, serverName spec.ServerName, roomID, dataType string) (data json.RawMessage, err error) + QueryPushRules(ctx context.Context, localpart string, serverName spec.ServerName) (*pushrules.AccountRuleSets, error) } type Device interface { GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) - GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error) + GetDeviceByID(ctx context.Context, localpart string, serverName spec.ServerName, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) ([]api.Device, error) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked @@ -70,12 +72,12 @@ type Device interface { // an error will be returned. // If no device ID is given one is generated. // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) - UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error - RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error + CreateDevice(ctx context.Context, localpart string, serverName spec.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart string, serverName spec.ServerName, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error + RemoveDevices(ctx context.Context, localpart string, serverName spec.ServerName, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error) + RemoveAllDevices(ctx context.Context, localpart string, serverName spec.ServerName, exceptDeviceID string) (devices []api.Device, err error) } type KeyBackup interface { @@ -107,26 +109,26 @@ type OpenID interface { } type Pusher interface { - UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error - GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error) - RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error + UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName spec.ServerName) error + GetPushers(ctx context.Context, localpart string, serverName spec.ServerName) ([]api.Pusher, error) + RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName spec.ServerName) error RemovePushers(ctx context.Context, appid, pushkey string) error } type ThreePID interface { - SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error) + SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName spec.ServerName, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) - GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) - GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName spec.ServerName, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error) } type Notification interface { - InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error - DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) - SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error) - GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) - GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) - GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) + InsertNotification(ctx context.Context, localpart string, serverName spec.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart string, serverName spec.ServerName, roomID string, pos uint64, read bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, serverName spec.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, serverName spec.ServerName, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, _ error) DeleteOldNotifications(ctx context.Context) error } @@ -198,17 +200,17 @@ type KeyDatabase interface { // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) // MarkDeviceListStale sets the stale bit for this user to isStale. MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error - CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) + CrossSigningKeysForUser(ctx context.Context, userID string) (map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey, error) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error - StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature spec.Base64Bytes) error DeleteStaleDeviceLists( ctx context.Context, @@ -218,8 +220,8 @@ type KeyDatabase interface { type Statistics interface { UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) - DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) - UpsertDailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error + DailyRoomsMessages(ctx context.Context, serverName spec.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) + UpsertDailyRoomsMessages(ctx context.Context, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 057160374..6ffda340e 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const accountDataSchema = ` @@ -74,7 +74,7 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) @@ -90,7 +90,7 @@ func (s *accountDataStatements) InsertAccountData( func (s *accountDataStatements) SelectAccountData( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, @@ -129,7 +129,7 @@ func (s *accountDataStatements) SelectAccountData( func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 31a996527..5b38c5f4b 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -20,13 +20,12 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -79,10 +78,10 @@ type accountsStatements struct { selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { +func NewPostgresAccountsTable(db *sql.DB, serverName spec.ServerName) (tables.AccountsTable, error) { s := &accountsStatements{ serverName: serverName, } @@ -122,7 +121,7 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam // on success. func (s *accountsStatements) InsertAccount( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -148,7 +147,7 @@ func (s *accountsStatements) InsertAccount( } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) @@ -156,21 +155,21 @@ func (s *accountsStatements) UpdatePassword( } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account @@ -192,7 +191,7 @@ func (s *accountsStatements) SelectAccountByLocalpart( } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index c0ecbd303..138b629d7 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -23,7 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) var crossSigningKeysSchema = ` @@ -75,7 +76,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( r = types.CrossSigningKeyMap{} for rows.Next() { var keyTypeInt int16 - var keyData gomatrixserverlib.Base64Bytes + var keyData spec.Base64Bytes if err := rows.Scan(&keyTypeInt, &keyData); err != nil { return nil, err } @@ -89,7 +90,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( } func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, ) error { keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] if !ok { diff --git a/userapi/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go index b0117145c..61a381184 100644 --- a/userapi/storage/postgres/cross_signing_sigs_table.go +++ b/userapi/storage/postgres/cross_signing_sigs_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) var crossSigningSigsSchema = ` @@ -96,12 +97,12 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( for rows.Next() { var userID string var keyID gomatrixserverlib.KeyID - var signature gomatrixserverlib.Base64Bytes + var signature spec.Base64Bytes if err := rows.Scan(&userID, &keyID, &signature); err != nil { return nil, err } if _, ok := r[userID]; !ok { - r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + r[userID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } r[userID][keyID] = signature } @@ -112,7 +113,7 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, + signature spec.Base64Bytes, ) error { if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) diff --git a/userapi/storage/postgres/deltas/2022110411000000_server_names.go b/userapi/storage/postgres/deltas/2022110411000000_server_names.go index 279e1e5f1..e9d39d062 100644 --- a/userapi/storage/postgres/deltas/2022110411000000_server_names.go +++ b/userapi/storage/postgres/deltas/2022110411000000_server_names.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) var serverNamesTables = []string{ @@ -42,7 +42,7 @@ var serverNamesDropIndex = []string{ // PostgreSQL doesn't expect the table name to be specified as a substituted // argument in that way so it results in a syntax error in the query. -func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName spec.ServerName) error { for _, table := range serverNamesTables { q := fmt.Sprintf( "ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';", diff --git a/userapi/storage/postgres/deltas/2022110411000001_server_names.go b/userapi/storage/postgres/deltas/2022110411000001_server_names.go index 04a47fa7b..f83859dfa 100644 --- a/userapi/storage/postgres/deltas/2022110411000001_server_names.go +++ b/userapi/storage/postgres/deltas/2022110411000001_server_names.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // I know what you're thinking: you're wondering "why doesn't this use $1 @@ -14,7 +14,7 @@ import ( // PostgreSQL doesn't expect the table name to be specified as a substituted // argument in that way so it results in a syntax error in the query. -func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName spec.ServerName) error { for _, table := range serverNamesTables { q := fmt.Sprintf( "UPDATE %s SET server_name = %s WHERE server_name = '';", diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 88f8839c5..0335f8266 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -27,7 +27,7 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const devicesSchema = ` @@ -112,10 +112,10 @@ type devicesStatements struct { deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt deleteDevicesStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { +func NewPostgresDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.DevicesTable, error) { s := &devicesStatements{ serverName: serverName, } @@ -151,7 +151,7 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName // Returns the device on success. func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id string, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -176,7 +176,7 @@ func (s *devicesStatements) InsertDevice( } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64, ) (*api.Device, error) { @@ -186,7 +186,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * // deleteDevice removes a single device by id and user localpart. func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id string, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) _, err := stmt.ExecContext(ctx, id, localpart, serverName) @@ -197,7 +197,7 @@ func (s *devicesStatements) DeleteDevice( // Returns an error if the execution failed. func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) @@ -209,7 +209,7 @@ func (s *devicesStatements) DeleteDevices( // given user localpart. func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -219,7 +219,7 @@ func (s *devicesStatements) DeleteDevicesByLocalpart( func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -232,7 +232,7 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName stmt := s.selectDeviceByTokenStmt err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { @@ -246,7 +246,7 @@ func (s *devicesStatements) SelectDeviceByToken( // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -279,7 +279,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName var lastseents sql.NullInt64 var displayName sql.NullString for rows.Next() { @@ -300,7 +300,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -342,7 +342,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index dc64b1e79..acb9e42bc 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -20,13 +20,13 @@ import ( "encoding/json" "time" - "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib/spec" ) type notificationsStatements struct { @@ -112,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -128,7 +128,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err @@ -142,7 +142,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err @@ -155,7 +155,7 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { @@ -168,7 +168,7 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local for rows.Next() { var id int64 var roomID string - var ts gomatrixserverlib.Timestamp + var ts spec.Timestamp var read bool var jsonStr string err = rows.Scan( @@ -198,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, filter tables.NotificationFilter) (count int64, err error) { err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, err error) { err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 68d87f007..345877d11 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -34,10 +34,10 @@ const selectOpenIDTokenSQL = "" + type openIDTokenStatements struct { insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { +func NewPostgresOpenIDTable(db *sql.DB, serverName spec.ServerName) (tables.OpenIDTable, error) { s := &openIDTokenStatements{ serverName: serverName, } @@ -56,7 +56,7 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, serverName gomatrixserverlib.ServerName, + token, localpart string, serverName spec.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) @@ -72,7 +72,7 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index df4e0db63..e404c32f2 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const profilesSchema = ` @@ -92,7 +92,7 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (err error) { _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return @@ -100,7 +100,7 @@ func (s *profilesStatements) InsertProfile( func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( @@ -114,7 +114,7 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ @@ -130,7 +130,7 @@ func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index 707b3bd2b..2e88aa8e9 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -25,7 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -98,7 +98,7 @@ type pushersStatements struct { func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err @@ -106,7 +106,7 @@ func (s *pushersStatements) InsertPusher( func (s *pushersStatements) SelectPushers( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName) @@ -147,7 +147,7 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( ctx context.Context, txn *sql.Tx, appid, pushkey, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err diff --git a/userapi/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go index c823b58c6..e2086dc99 100644 --- a/userapi/storage/postgres/stale_device_lists.go +++ b/userapi/storage/postgres/stale_device_lists.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -81,11 +82,11 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, spec.AsTimestamp(time.Now())) return err } -func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) { // we only query for 1 domain or all domains so optimise for those use cases if len(domains) == 0 { rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) diff --git a/userapi/storage/postgres/stats_table.go b/userapi/storage/postgres/stats_table.go index f62467fa4..a7949e4ba 100644 --- a/userapi/storage/postgres/stats_table.go +++ b/userapi/storage/postgres/stats_table.go @@ -20,7 +20,7 @@ import ( "time" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" @@ -191,7 +191,7 @@ ON CONFLICT (localpart, device_id, timestamp) DO NOTHING const queryDBEngineVersion = "SHOW server_version;" type statsStatements struct { - serverName gomatrixserverlib.ServerName + serverName spec.ServerName lastUpdate time.Time countUsersLastSeenAfterStmt *sql.Stmt countR30UsersStmt *sql.Stmt @@ -204,7 +204,7 @@ type statsStatements struct { selectDailyMessagesStmt *sql.Stmt } -func NewPostgresStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.StatsTable, error) { +func NewPostgresStatsTable(db *sql.DB, serverName spec.ServerName) (tables.StatsTable, error) { s := &statsStatements{ serverName: serverName, lastUpdate: time.Now(), @@ -280,7 +280,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) int64(api.AccountTypeAppService), }, api.AccountTypeGuest, - gomatrixserverlib.AsTimestamp(registeredAfter), + spec.AsTimestamp(registeredAfter), ) if err != nil { return nil, err @@ -304,7 +304,7 @@ func (s *statsStatements) dailyUsers(ctx context.Context, txn *sql.Tx) (result i stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) lastSeenAfter := time.Now().AddDate(0, 0, -1) err = stmt.QueryRowContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), ).Scan(&result) return } @@ -313,7 +313,7 @@ func (s *statsStatements) monthlyUsers(ctx context.Context, txn *sql.Tx) (result stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) lastSeenAfter := time.Now().AddDate(0, 0, -30) err = stmt.QueryRowContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), ).Scan(&result) return } @@ -330,7 +330,7 @@ func (s *statsStatements) r30Users(ctx context.Context, txn *sql.Tx) (map[string diff := time.Hour * 24 * 30 rows, err := stmt.QueryContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), diff.Milliseconds(), ) if err != nil { @@ -367,8 +367,8 @@ func (s *statsStatements) r30UsersV2(ctx context.Context, txn *sql.Tx) (map[stri tomorrow := time.Now().Add(time.Hour * 24) rows, err := stmt.QueryContext(ctx, - gomatrixserverlib.AsTimestamp(sixtyDaysAgo), - gomatrixserverlib.AsTimestamp(tomorrow), + spec.AsTimestamp(sixtyDaysAgo), + spec.AsTimestamp(tomorrow), diff.Milliseconds(), ) if err != nil { @@ -464,9 +464,9 @@ func (s *statsStatements) UpdateUserDailyVisits( startTime = startTime.AddDate(0, 0, -1) } _, err := stmt.ExecContext(ctx, - gomatrixserverlib.AsTimestamp(startTime), - gomatrixserverlib.AsTimestamp(lastUpdate), - gomatrixserverlib.AsTimestamp(time.Now()), + spec.AsTimestamp(startTime), + spec.AsTimestamp(lastUpdate), + spec.AsTimestamp(time.Now()), ) if err == nil { s.lastUpdate = time.Now() @@ -476,13 +476,13 @@ func (s *statsStatements) UpdateUserDailyVisits( func (s *statsStatements) UpsertDailyStats( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, stats types.MessageStats, + serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64, ) error { stmt := sqlutil.TxStmt(txn, s.upsertMessagesStmt) timestamp := time.Now().Truncate(time.Hour * 24) _, err := stmt.ExecContext(ctx, - gomatrixserverlib.AsTimestamp(timestamp), + spec.AsTimestamp(timestamp), serverName, stats.Messages, stats.SentMessages, stats.MessagesE2EE, stats.SentMessagesE2EE, activeRooms, activeE2EERooms, @@ -492,12 +492,12 @@ func (s *statsStatements) UpsertDailyStats( func (s *statsStatements) DailyRoomsMessages( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { stmt := sqlutil.TxStmt(txn, s.selectDailyMessagesStmt) timestamp := time.Now().Truncate(time.Hour * 24) - err = stmt.QueryRowContext(ctx, serverName, gomatrixserverlib.AsTimestamp(timestamp)). + err = stmt.QueryRowContext(ctx, serverName, spec.AsTimestamp(timestamp)). Scan(&msgStats.Messages, &msgStats.SentMessages, &msgStats.MessagesE2EE, &msgStats.SentMessagesE2EE, &activeRooms, &activeE2EERooms) if err != nil && err != sql.ErrNoRows { return msgStats, 0, 0, err diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 673d123ba..72e7c9cd9 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -20,21 +20,19 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/shared" + "github.com/matrix-org/gomatrixserverlib/spec" // Import the postgres database driver. _ "github.com/lib/pq" ) // NewDatabase creates a new accounts and profiles database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, serverName spec.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } @@ -51,7 +49,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return deltas.UpServerNames(ctx, txn, serverName) }, }) - if err = m.Up(base.Context()); err != nil { + if err = m.Up(ctx); err != nil { return nil, err } @@ -111,7 +109,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return deltas.UpServerNamesPopulate(ctx, txn, serverName) }, }) - if err = m.Up(base.Context()); err != nil { + if err = m.Up(ctx); err != nil { return nil, err } @@ -137,8 +135,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, }, nil } -func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) +func NewKeyDatabase(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index f41c43122..15b42a0a6 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -77,7 +77,7 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, serverName gomatrixserverlib.ServerName, err error) { +) (localpart string, serverName spec.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { @@ -88,7 +88,7 @@ func (s *threepidStatements) SelectLocalpartForThreePID( func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { @@ -113,7 +113,7 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index d3272a032..705707571 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -27,6 +27,8 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -54,7 +56,7 @@ type Database struct { Pushers tables.PusherTable Stats tables.StatsTable LoginTokenLifetime time.Duration - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName BcryptCost int OpenIDTokenLifetimeMS int64 } @@ -79,7 +81,7 @@ const ( // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword string, ) (*api.Account, error) { hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName) @@ -99,7 +101,7 @@ func (d *Database) GetAccountByPassword( // Returns sql.ErrNoRows if no profile exists which matches the given localpart. func (d *Database) GetProfileByLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (*authtypes.Profile, error) { return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName) } @@ -108,7 +110,7 @@ func (d *Database) GetProfileByLocalpart( // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, avatarURL string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -122,7 +124,7 @@ func (d *Database) SetAvatarURL( // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, displayName string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -134,7 +136,7 @@ func (d *Database) SetDisplayName( // SetPassword sets the account password to the given hash. func (d *Database) SetPassword( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword string, ) error { hash, err := d.hashPassword(plaintextPassword) @@ -150,7 +152,7 @@ func (d *Database) SetPassword( // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword, appserviceID string, accountType api.AccountType, ) (acc *api.Account, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -175,7 +177,7 @@ func (d *Database) CreateAccount( // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, plaintextPassword, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { var err error @@ -207,7 +209,7 @@ func (d *Database) createAccount( func (d *Database) QueryPushRules( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (*pushrules.AccountRuleSets, error) { data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules") if err != nil { @@ -246,7 +248,7 @@ func (d *Database) QueryPushRules( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -257,7 +259,7 @@ func (d *Database) SaveAccountData( // GetAccountData returns account data related to a given localpart // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ( +func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName spec.ServerName) ( global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error, @@ -270,7 +272,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string, serverN // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, roomID, dataType string, ) (data json.RawMessage, err error) { return d.AccountDatas.SelectAccountDataByType( @@ -280,7 +282,7 @@ func (d *Database) GetAccountDataByType( // GetNewNumericLocalpart generates and returns a new unused numeric localpart func (d *Database) GetNewNumericLocalpart( - ctx context.Context, serverName gomatrixserverlib.ServerName, + ctx context.Context, serverName spec.ServerName, ) (int64, error) { return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName) } @@ -300,7 +302,7 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use") // Returns an error if there was a problem talking to the database. func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid string, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, medium string, ) (err error) { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -338,7 +340,7 @@ func (d *Database) RemoveThreePIDAssociation( // Returns an error if there was a problem talking to the database. func (d *Database) GetLocalpartForThreePID( ctx context.Context, threepid string, medium string, -) (localpart string, serverName gomatrixserverlib.ServerName, err error) { +) (localpart string, serverName spec.ServerName, err error) { return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) } @@ -348,7 +350,7 @@ func (d *Database) GetLocalpartForThreePID( // Returns an error if there was an issue talking to the database. func (d *Database) GetThreePIDsForLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (threepids []authtypes.ThreePID, err error) { return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName) } @@ -356,7 +358,7 @@ func (d *Database) GetThreePIDsForLocalpart( // CheckAccountAvailability checks if the username/localpart is already present // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) { +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName spec.ServerName) (bool, error) { _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) if err == sql.ErrNoRows { return true, nil @@ -367,7 +369,7 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // GetAccountByLocalpart returns the account associated with the given localpart. // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName, ) (*api.Account, error) { // try to get the account with lowercase localpart (majority) acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName) @@ -385,7 +387,7 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi } // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) { +func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName spec.ServerName) (err error) { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Accounts.DeactivateAccount(ctx, localpart, serverName) }) @@ -570,7 +572,7 @@ func (d *Database) GetDeviceByAccessToken( // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByID( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, ) (*api.Device, error) { return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID) @@ -579,7 +581,7 @@ func (d *Database) GetDeviceByID( // GetDevicesByLocalpart returns the devices matching the given localpart. func (d *Database) GetDevicesByLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) ([]api.Device, error) { return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "") } @@ -595,7 +597,7 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap // If no device ID is given one is generated. // Returns the device on success. func (d *Database) CreateDevice( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string, ) (dev *api.Device, returnErr error) { if deviceID != nil { @@ -674,7 +676,7 @@ func generateDeviceID() (string, error) { // Returns SQL error if there are problems and nil on success. func (d *Database) UpdateDevice( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, displayName *string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -688,7 +690,7 @@ func (d *Database) UpdateDevice( // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveDevices( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, devices []string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -704,7 +706,7 @@ func (d *Database) RemoveDevices( // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, exceptDeviceID string, ) (devices []api.Device, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -721,7 +723,7 @@ func (d *Database) RemoveAllDevices( } // UpdateDeviceLastSeen updates a last seen timestamp and the ip address. -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent) }) @@ -771,13 +773,13 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( return d.LoginTokens.SelectLoginToken(ctx, token) } -func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { +func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName spec.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) }) } -func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) { +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos) return err @@ -785,7 +787,7 @@ func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string return } -func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) { +func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName spec.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b) return err @@ -793,15 +795,15 @@ func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, s return } -func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { +func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName spec.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter) } -func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) { +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName spec.ServerName, filter tables.NotificationFilter) (int64, error) { return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter) } -func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) { +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, _ error) { return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID) } @@ -813,7 +815,7 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error { func (d *Database) UpsertPusher( ctx context.Context, p api.Pusher, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { data, err := json.Marshal(p.Data) if err != nil { @@ -839,7 +841,7 @@ func (d *Database) UpsertPusher( // GetPushers returns the pushers matching the given localpart. func (d *Database) GetPushers( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) ([]api.Pusher, error) { return d.Pushers.SelectPushers(ctx, nil, localpart, serverName) } @@ -848,7 +850,7 @@ func (d *Database) GetPushers( // Invoked when `append` is true and `kind` is null in // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set func (d *Database) RemovePusher( - ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, appid, pushkey, localpart string, serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName) @@ -875,14 +877,14 @@ func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, * return d.Stats.UserStatistics(ctx, nil) } -func (d *Database) UpsertDailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error { +func (d *Database) UpsertDailyRoomsMessages(ctx context.Context, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Stats.UpsertDailyStats(ctx, txn, serverName, stats, activeRooms, activeE2EERooms) }) } func (d *Database) DailyRoomsMessages( - ctx context.Context, serverName gomatrixserverlib.ServerName, + ctx context.Context, serverName spec.ServerName, ) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { return d.Stats.DailyRoomsMessages(ctx, nil, serverName) } @@ -995,7 +997,7 @@ func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64 // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. -func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { +func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) { return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) } @@ -1026,18 +1028,18 @@ func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, devic } // CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { +func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey, error) { keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) if err != nil { return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) } - results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + results := map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey{} for purpose, key := range keyMap { keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) - result := gomatrixserverlib.CrossSigningKey{ + result := fclient.CrossSigningKey{ UserID: userID, - Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, - Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + Usage: []fclient.CrossSigningKeyPurpose{purpose}, + Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{ keyID: key, }, } @@ -1050,10 +1052,10 @@ func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string continue } if result.Signatures == nil { - result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + result.Signatures = map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } if _, ok := result.Signatures[sigUserID]; !ok { - result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } for sigKeyID, sigBytes := range forSigUserID { result.Signatures[sigUserID][sigKeyID] = sigBytes @@ -1091,7 +1093,7 @@ func (d *KeyDatabase) StoreCrossSigningSigsForTarget( ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, + signature spec.Base64Bytes, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index 2fbdc5732..3a6367c45 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -21,7 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const accountDataSchema = ` @@ -76,7 +76,7 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage, ) error { _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content) @@ -85,7 +85,7 @@ func (s *accountDataStatements) InsertAccountData( func (s *accountDataStatements) SelectAccountData( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, @@ -123,7 +123,7 @@ func (s *accountDataStatements) SelectAccountData( func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index f4ebe2158..d01915a76 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -19,13 +19,12 @@ import ( "database/sql" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -79,10 +78,10 @@ type accountsStatements struct { selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { +func NewSQLiteAccountsTable(db *sql.DB, serverName spec.ServerName) (tables.AccountsTable, error) { s := &accountsStatements{ db: db, serverName: serverName, @@ -122,7 +121,7 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -148,7 +147,7 @@ func (s *accountsStatements) InsertAccount( } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) @@ -156,21 +155,21 @@ func (s *accountsStatements) UpdatePassword( } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account @@ -192,7 +191,7 @@ func (s *accountsStatements) SelectAccountByLocalpart( } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index 10721fcc8..5c2ce7039 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -23,7 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) var crossSigningKeysSchema = ` @@ -74,7 +75,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( r = types.CrossSigningKeyMap{} for rows.Next() { var keyTypeInt int16 - var keyData gomatrixserverlib.Base64Bytes + var keyData spec.Base64Bytes if err := rows.Scan(&keyTypeInt, &keyData); err != nil { return nil, err } @@ -88,7 +89,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( } func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, ) error { keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] if !ok { diff --git a/userapi/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go index 2be00c9c1..657264115 100644 --- a/userapi/storage/sqlite3/cross_signing_sigs_table.go +++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) var crossSigningSigsSchema = ` @@ -94,12 +95,12 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( for rows.Next() { var userID string var keyID gomatrixserverlib.KeyID - var signature gomatrixserverlib.Base64Bytes + var signature spec.Base64Bytes if err := rows.Scan(&userID, &keyID, &signature); err != nil { return nil, err } if _, ok := r[userID]; !ok { - r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + r[userID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } r[userID][keyID] = signature } @@ -110,7 +111,7 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, + signature spec.Base64Bytes, ) error { if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) diff --git a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go index c11ea6844..76f39a908 100644 --- a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go +++ b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -42,7 +42,7 @@ var serverNamesDropIndex = []string{ // PostgreSQL doesn't expect the table name to be specified as a substituted // argument in that way so it results in a syntax error in the query. -func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName spec.ServerName) error { for _, table := range serverNamesTables { q := fmt.Sprintf( "SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;", diff --git a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go index 04a47fa7b..f83859dfa 100644 --- a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go +++ b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // I know what you're thinking: you're wondering "why doesn't this use $1 @@ -14,7 +14,7 @@ import ( // PostgreSQL doesn't expect the table name to be specified as a substituted // argument in that way so it results in a syntax error in the query. -func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName spec.ServerName) error { for _, table := range serverNamesTables { q := fmt.Sprintf( "UPDATE %s SET server_name = %s WHERE server_name = '';", diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 65e17527d..23e823116 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -25,9 +25,9 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/gomatrixserverlib" ) const devicesSchema = ` @@ -97,10 +97,10 @@ type devicesStatements struct { updateDeviceLastSeenStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { +func NewSQLiteDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.DevicesTable, error) { s := &devicesStatements{ db: db, serverName: serverName, @@ -137,7 +137,7 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // Returns the device on success. func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id string, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -167,7 +167,7 @@ func (s *devicesStatements) InsertDevice( } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64, ) (*api.Device, error) { @@ -193,7 +193,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id string, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) _, err := stmt.ExecContext(ctx, id, localpart, serverName) @@ -202,7 +202,7 @@ func (s *devicesStatements) DeleteDevice( func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, devices []string, ) error { orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1) @@ -224,7 +224,7 @@ func (s *devicesStatements) DeleteDevices( func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -234,7 +234,7 @@ func (s *devicesStatements) DeleteDevicesByLocalpart( func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -247,7 +247,7 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName stmt := s.selectDeviceByTokenStmt err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { @@ -261,7 +261,7 @@ func (s *devicesStatements) SelectDeviceByToken( // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -287,7 +287,7 @@ func (s *devicesStatements) SelectDeviceByID( func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -343,7 +343,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName var displayName sql.NullString var lastseents sql.NullInt64 for rows.Next() { @@ -362,7 +362,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index ef39d027c..94d6c7294 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -20,13 +20,13 @@ import ( "encoding/json" "time" - "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib/spec" ) type notificationsStatements struct { @@ -112,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -128,7 +128,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err @@ -142,7 +142,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err @@ -155,7 +155,7 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { @@ -168,7 +168,7 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local for rows.Next() { var id int64 var roomID string - var ts gomatrixserverlib.Timestamp + var ts spec.Timestamp var read bool var jsonStr string err = rows.Scan( @@ -198,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, filter tables.NotificationFilter) (count int64, err error) { err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, err error) { err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index f06429741..def0074d2 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -35,10 +35,10 @@ type openIDTokenStatements struct { db *sql.DB insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { +func NewSQLiteOpenIDTable(db *sql.DB, serverName spec.ServerName) (tables.OpenIDTable, error) { s := &openIDTokenStatements{ db: db, serverName: serverName, @@ -58,7 +58,7 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) ( func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, serverName gomatrixserverlib.ServerName, + token, localpart string, serverName spec.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) @@ -74,7 +74,7 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 867026d7a..a20d7e848 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const profilesSchema = ` @@ -88,7 +88,7 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return err @@ -96,7 +96,7 @@ func (s *profilesStatements) InsertProfile( func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( @@ -110,7 +110,7 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ @@ -132,7 +132,7 @@ func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index c9d451dc5..e09f9c78f 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -25,7 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -98,7 +98,7 @@ type pushersStatements struct { func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err @@ -106,7 +106,7 @@ func (s *pushersStatements) InsertPusher( func (s *pushersStatements) SelectPushers( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName) @@ -147,7 +147,7 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( ctx context.Context, txn *sql.Tx, appid, pushkey, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err diff --git a/userapi/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go index f078fc99f..5302899f4 100644 --- a/userapi/storage/sqlite3/stale_device_lists.go +++ b/userapi/storage/sqlite3/stale_device_lists.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -83,11 +84,11 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, spec.AsTimestamp(time.Now())) return err } -func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) { // we only query for 1 domain or all domains so optimise for those use cases if len(domains) == 0 { rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index 72b3ba49d..71d80d4d4 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -20,7 +20,7 @@ import ( "strings" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" @@ -195,7 +195,7 @@ ON CONFLICT (localpart, device_id, timestamp) DO NOTHING const queryDBEngineVersion = "select sqlite_version();" type statsStatements struct { - serverName gomatrixserverlib.ServerName + serverName spec.ServerName db *sql.DB lastUpdate time.Time countUsersLastSeenAfterStmt *sql.Stmt @@ -209,7 +209,7 @@ type statsStatements struct { selectDailyMessagesStmt *sql.Stmt } -func NewSQLiteStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.StatsTable, error) { +func NewSQLiteStatsTable(db *sql.DB, serverName spec.ServerName) (tables.StatsTable, error) { s := &statsStatements{ serverName: serverName, lastUpdate: time.Now(), @@ -298,8 +298,8 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) params[i] = v // i: 0 1 2 => ($1, $2, $3) params[i+1+len(nonGuests)] = v // i: 4 5 6 => ($5, $6, $7) } - params[3] = api.AccountTypeGuest // $4 - params[7] = gomatrixserverlib.AsTimestamp(registeredAfter) // $8 + params[3] = api.AccountTypeGuest // $4 + params[7] = spec.AsTimestamp(registeredAfter) // $8 rows, err := stmt.QueryContext(ctx, params...) if err != nil { @@ -324,7 +324,7 @@ func (s *statsStatements) dailyUsers(ctx context.Context, txn *sql.Tx) (result i stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) lastSeenAfter := time.Now().AddDate(0, 0, -1) err = stmt.QueryRowContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), ).Scan(&result) return } @@ -333,7 +333,7 @@ func (s *statsStatements) monthlyUsers(ctx context.Context, txn *sql.Tx) (result stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) lastSeenAfter := time.Now().AddDate(0, 0, -30) err = stmt.QueryRowContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), ).Scan(&result) return } @@ -348,8 +348,8 @@ func (s *statsStatements) r30Users(ctx context.Context, txn *sql.Tx) (map[string diff := time.Hour * 24 * 30 rows, err := stmt.QueryContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), diff.Milliseconds(), ) if err != nil { @@ -386,8 +386,8 @@ func (s *statsStatements) r30UsersV2(ctx context.Context, txn *sql.Tx) (map[stri tomorrow := time.Now().Add(time.Hour * 24) rows, err := stmt.QueryContext(ctx, - gomatrixserverlib.AsTimestamp(sixtyDaysAgo), - gomatrixserverlib.AsTimestamp(tomorrow), + spec.AsTimestamp(sixtyDaysAgo), + spec.AsTimestamp(tomorrow), diff.Milliseconds(), ) if err != nil { @@ -482,9 +482,9 @@ func (s *statsStatements) UpdateUserDailyVisits( startTime = startTime.AddDate(0, 0, -1) } _, err := stmt.ExecContext(ctx, - gomatrixserverlib.AsTimestamp(startTime), - gomatrixserverlib.AsTimestamp(lastUpdate), - gomatrixserverlib.AsTimestamp(time.Now()), + spec.AsTimestamp(startTime), + spec.AsTimestamp(lastUpdate), + spec.AsTimestamp(time.Now()), ) if err == nil { s.lastUpdate = time.Now() @@ -494,13 +494,13 @@ func (s *statsStatements) UpdateUserDailyVisits( func (s *statsStatements) UpsertDailyStats( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, stats types.MessageStats, + serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64, ) error { stmt := sqlutil.TxStmt(txn, s.upsertMessagesStmt) timestamp := time.Now().Truncate(time.Hour * 24) _, err := stmt.ExecContext(ctx, - gomatrixserverlib.AsTimestamp(timestamp), + spec.AsTimestamp(timestamp), serverName, stats.Messages, stats.SentMessages, stats.MessagesE2EE, stats.SentMessagesE2EE, activeRooms, activeE2EERooms, @@ -510,12 +510,12 @@ func (s *statsStatements) UpsertDailyStats( func (s *statsStatements) DailyRoomsMessages( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { stmt := sqlutil.TxStmt(txn, s.selectDailyMessagesStmt) timestamp := time.Now().Truncate(time.Hour * 24) - err = stmt.QueryRowContext(ctx, serverName, gomatrixserverlib.AsTimestamp(timestamp)). + err = stmt.QueryRowContext(ctx, serverName, spec.AsTimestamp(timestamp)). Scan(&msgStats.Messages, &msgStats.SentMessages, &msgStats.MessagesE2EE, &msgStats.SentMessagesE2EE, &activeRooms, &activeE2EERooms) if err != nil && err != sql.ErrNoRows { return msgStats, 0, 0, err diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 0f3eeed1b..acd9678f2 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -20,19 +20,17 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/userapi/storage/shared" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) // NewUserDatabase creates a new accounts and profiles database -func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) +func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, serverName spec.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } @@ -49,7 +47,7 @@ func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptio return deltas.UpServerNames(ctx, txn, serverName) }, }) - if err = m.Up(base.Context()); err != nil { + if err = m.Up(ctx); err != nil { return nil, err } @@ -109,7 +107,7 @@ func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptio return deltas.UpServerNamesPopulate(ctx, txn, serverName) }, }) - if err = m.Up(base.Context()); err != nil { + if err = m.Up(ctx); err != nil { return nil, err } @@ -135,8 +133,8 @@ func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptio }, nil } -func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) +func NewKeyDatabase(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index 2db7d5887..a83f80423 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -21,7 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -81,7 +81,7 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, serverName gomatrixserverlib.ServerName, err error) { +) (localpart string, serverName spec.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { @@ -92,7 +92,7 @@ func (s *threepidStatements) SelectLocalpartForThreePID( func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { @@ -117,7 +117,7 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index 0329fb46a..39231b224 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -18,12 +18,13 @@ package storage import ( + "context" "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/postgres" "github.com/matrix-org/dendrite/userapi/storage/sqlite3" @@ -32,9 +33,10 @@ import ( // NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters func NewUserDatabase( - base *base.BaseDendrite, + ctx context.Context, + conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, @@ -42,9 +44,9 @@ func NewUserDatabase( ) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(ctx, conMan, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return postgres.NewDatabase(ctx, conMan, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) default: return nil, fmt.Errorf("unexpected database type") } @@ -52,12 +54,12 @@ func NewUserDatabase( // NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme) // and sets postgres connection parameters. -func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { +func NewKeyDatabase(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewKeyDatabase(base, dbProperties) + return sqlite3.NewKeyDatabase(conMan, dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewKeyDatabase(base, dbProperties) + return postgres.NewKeyDatabase(conMan, dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index f52e7e17d..a46ee9ebb 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -9,8 +9,11 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" "golang.org/x/crypto/bcrypt" @@ -33,9 +36,9 @@ var ( ) func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { - base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.NewUserDatabase(context.Background(), cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { @@ -43,7 +46,6 @@ func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatab } return db, func() { close() - baseclose() } } @@ -356,12 +358,12 @@ func Test_OpenID(t *testing.T) { expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS expires, err := db.CreateOpenIDToken(ctx, token, alice.ID) assert.NoError(t, err, "unable to create OpenID token") - assert.Equal(t, expiresAtMS, expires) + assert.InDelta(t, expiresAtMS, expires, 2) // 2ms leeway attributes, err := db.GetOpenIDTokenAttributes(ctx, token) assert.NoError(t, err, "unable to get OpenID token attributes") assert.Equal(t, alice.ID, attributes.UserID) - assert.Equal(t, expiresAtMS, attributes.ExpiresAtMS) + assert.InDelta(t, expiresAtMS, attributes.ExpiresAtMS, 2) // 2ms leeway }) } @@ -526,12 +528,12 @@ func Test_Notification(t *testing.T) { Actions: []*pushrules.Action{ {}, }, - Event: gomatrixserverlib.ClientEvent{ - Content: gomatrixserverlib.RawJSON("{}"), + Event: synctypes.ClientEvent{ + Content: spec.RawJSON("{}"), }, Read: false, RoomID: roomID, - TS: gomatrixserverlib.AsTimestamp(ts), + TS: spec.AsTimestamp(ts), } err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") @@ -576,8 +578,9 @@ func Test_Notification(t *testing.T) { } func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { - base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + db, err := storage.NewKeyDatabase(cm, &cfg.KeyServer.Database) if err != nil { t.Fatalf("failed to create new database: %v", err) } diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 163e3e173..cbadd98e9 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -15,19 +15,21 @@ package storage import ( + "context" "fmt" "time" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) -func NewUserAPIDatabase( - base *base.BaseDendrite, +func NewUserDatabase( + ctx context.Context, + conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, @@ -35,7 +37,20 @@ func NewUserAPIDatabase( ) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(ctx, conMan, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + case dbProperties.ConnectionString.IsPostgres(): + return nil, fmt.Errorf("can't use Postgres implementation") + default: + return nil, fmt.Errorf("unexpected database type") + } +} + +// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme) +// and sets postgres connection parameters. +func NewKeyDatabase(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewKeyDatabase(conMan, dbProperties) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 693e73038..3c6214e7c 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -22,38 +22,40 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/types" ) type AccountDataTable interface { - InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error - SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) - SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error) + InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error + SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) + SelectAccountDataByType(ctx context.Context, localpart string, serverName spec.ServerName, roomID, dataType string) (data json.RawMessage, err error) } type AccountsTable interface { - InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) - UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error) - DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) - SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error) - SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) - SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error) + InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + UpdatePassword(ctx context.Context, localpart string, serverName spec.ServerName, passwordHash string) (err error) + DeactivateAccount(ctx context.Context, localpart string, serverName spec.ServerName) (err error) + SelectPasswordHash(ctx context.Context, localpart string, serverName spec.ServerName) (hash string, err error) + SelectAccountByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*api.Account, error) + SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (id int64, err error) } type DevicesTable interface { - InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) - InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64) (*api.Device, error) - DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error - DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error - DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error - UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error + InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64) (*api.Device, error) + DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName spec.ServerName) error + DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, devices []string) error + DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, exceptDeviceID string) error + UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, deviceID string, displayName *string) error SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) - SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) - SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error) + SelectDeviceByID(ctx context.Context, localpart string, serverName spec.ServerName, deviceID string) (*api.Device, error) + SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, exceptDeviceID string) ([]api.Device, error) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error } type KeyBackupTable interface { @@ -80,47 +82,47 @@ type LoginTokenTable interface { } type OpenIDTable interface { - InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error) + InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName spec.ServerName, expiresAtMS int64) (err error) SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) } type ProfileTable interface { - InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error - SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName) error + SelectProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, displayName string) (*authtypes.Profile, bool, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } type ThreePIDTable interface { - SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) - SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) - InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error) + SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName spec.ServerName, err error) + SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error) + InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName spec.ServerName) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) } type PusherTable interface { - InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error - SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error) - DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName spec.ServerName) error + SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName) ([]api.Pusher, error) + DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName spec.ServerName) error DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error } type NotificationTable interface { Clean(ctx context.Context, txn *sql.Tx) error - Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error - DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) - UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) - Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) - SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error) - SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) + Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) + Select(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, _ error) } type StatsTable interface { UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error) - DailyRoomsMessages(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) + DailyRoomsMessages(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) UpdateUserDailyVisits(ctx context.Context, txn *sql.Tx, startTime, lastUpdate time.Time) error - UpsertDailyStats(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error + UpsertDailyStats(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error } type NotificationFilter uint32 @@ -175,17 +177,17 @@ type KeyChanges interface { type StaleDeviceLists interface { InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error - SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error } type CrossSigningKeys interface { SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) - UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error } type CrossSigningSigs interface { SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) - UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature spec.Base64Bytes) error DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error } diff --git a/userapi/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go index b9bdafdaa..09924eb08 100644 --- a/userapi/storage/tables/stale_device_lists_test.go +++ b/userapi/storage/tables/stale_device_lists_test.go @@ -6,7 +6,7 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/postgres" "github.com/matrix-org/dendrite/userapi/storage/sqlite3" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -57,7 +57,7 @@ func TestStaleDeviceLists(t *testing.T) { // Query one server wantStaleUsers := []string{alice.ID, bob.ID} - gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []spec.ServerName{"test"}) if err != nil { t.Fatalf("failed to query stale device lists: %s", err) } @@ -67,7 +67,7 @@ func TestStaleDeviceLists(t *testing.T) { // Query all servers wantStaleUsers = []string{alice.ID, bob.ID, charlie} - gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []spec.ServerName{}) if err != nil { t.Fatalf("failed to query stale device lists: %s", err) } @@ -82,7 +82,7 @@ func TestStaleDeviceLists(t *testing.T) { } // Verify we don't get anything back after deleting - gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []spec.ServerName{"test"}) if err != nil { t.Fatalf("failed to query stale device lists: %s", err) } diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index 969bc5303..61fe02663 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -79,7 +79,7 @@ func mustMakeAccountAndDevice( accDB tables.AccountsTable, devDB tables.DevicesTable, localpart string, - serverName gomatrixserverlib.ServerName, // nolint:unparam + serverName spec.ServerName, // nolint:unparam accType api.AccountType, userAgent string, ) { @@ -108,7 +108,7 @@ func mustUpdateDeviceLastSeen( timestamp time.Time, ) { t.Helper() - _, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + _, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", spec.AsTimestamp(timestamp), localpart) if err != nil { t.Fatalf("unable to update device last seen") } @@ -121,7 +121,7 @@ func mustUserUpdateRegistered( localpart string, timestamp time.Time, ) { - _, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + _, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", spec.AsTimestamp(timestamp), localpart) if err != nil { t.Fatalf("unable to update device last seen") } diff --git a/userapi/types/storage.go b/userapi/types/storage.go index 7fb90454e..2c918847d 100644 --- a/userapi/types/storage.go +++ b/userapi/types/storage.go @@ -18,6 +18,8 @@ import ( "math" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) const ( @@ -29,22 +31,22 @@ const ( // KeyTypePurposeToInt maps a purpose to an integer, which is used in the // database to reduce the amount of space taken up by this column. -var KeyTypePurposeToInt = map[gomatrixserverlib.CrossSigningKeyPurpose]int16{ - gomatrixserverlib.CrossSigningKeyPurposeMaster: 1, - gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: 2, - gomatrixserverlib.CrossSigningKeyPurposeUserSigning: 3, +var KeyTypePurposeToInt = map[fclient.CrossSigningKeyPurpose]int16{ + fclient.CrossSigningKeyPurposeMaster: 1, + fclient.CrossSigningKeyPurposeSelfSigning: 2, + fclient.CrossSigningKeyPurposeUserSigning: 3, } // KeyTypeIntToPurpose maps an integer to a purpose, which is used in the // database to reduce the amount of space taken up by this column. -var KeyTypeIntToPurpose = map[int16]gomatrixserverlib.CrossSigningKeyPurpose{ - 1: gomatrixserverlib.CrossSigningKeyPurposeMaster, - 2: gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, - 3: gomatrixserverlib.CrossSigningKeyPurposeUserSigning, +var KeyTypeIntToPurpose = map[int16]fclient.CrossSigningKeyPurpose{ + 1: fclient.CrossSigningKeyPurposeMaster, + 2: fclient.CrossSigningKeyPurposeSelfSigning, + 3: fclient.CrossSigningKeyPurposeUserSigning, } // Map of purpose -> public key -type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes +type CrossSigningKeyMap map[fclient.CrossSigningKeyPurpose]spec.Base64Bytes // Map of user ID -> key ID -> signature -type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes +type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes diff --git a/userapi/userapi.go b/userapi/userapi.go index 826bd7213..6dcbc121f 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -18,10 +18,13 @@ import ( "time" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" "github.com/sirupsen/logrus" rsapi "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/consumers" @@ -34,30 +37,33 @@ import ( // NewInternalAPI returns a concrete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, + processContext *process.ProcessContext, + dendriteCfg *config.Dendrite, + cm sqlutil.Connections, + natsInstance *jetstream.NATSInstance, rsAPI rsapi.UserRoomserverAPI, fedClient fedsenderapi.KeyserverFederationAPI, ) *internal.UserInternalAPI { - cfg := &base.Cfg.UserAPI - js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - appServices := base.Cfg.Derived.ApplicationServices + js, _ := natsInstance.Prepare(processContext, &dendriteCfg.Global.JetStream) + appServices := dendriteCfg.Derived.ApplicationServices - pgClient := base.PushGatewayHTTPClient() + pgClient := pushgateway.NewHTTPClient(dendriteCfg.UserAPI.PushGatewayDisableTLSValidation) db, err := storage.NewUserDatabase( - base, - &cfg.AccountDatabase, - cfg.Matrix.ServerName, - cfg.BCryptCost, - cfg.OpenIDTokenLifetimeMS, + processContext.Context(), + cm, + &dendriteCfg.UserAPI.AccountDatabase, + dendriteCfg.Global.ServerName, + dendriteCfg.UserAPI.BCryptCost, + dendriteCfg.UserAPI.OpenIDTokenLifetimeMS, api.DefaultLoginTokenLifetime, - cfg.Matrix.ServerNotices.LocalPart, + dendriteCfg.UserAPI.Matrix.ServerNotices.LocalPart, ) if err != nil { logrus.WithError(err).Panicf("failed to connect to accounts db") } - keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + keyDB, err := storage.NewKeyDatabase(cm, &dendriteCfg.KeyServer.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to key db") } @@ -68,11 +74,11 @@ func NewInternalAPI( // it's handled by clientapi, and hence uses its topic. When user // API handles it for all account data, we can remove it from // here. - cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), - cfg.Matrix.JetStream.Prefixed(jetstream.OutputNotificationData), + dendriteCfg.Global.JetStream.Prefixed(jetstream.OutputClientData), + dendriteCfg.Global.JetStream.Prefixed(jetstream.OutputNotificationData), ) keyChangeProducer := &producers.KeyChange{ - Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + Topic: dendriteCfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), JetStream: js, DB: keyDB, } @@ -82,15 +88,15 @@ func NewInternalAPI( KeyDatabase: keyDB, SyncProducer: syncProducer, KeyChangeProducer: keyChangeProducer, - Config: cfg, + Config: &dendriteCfg.UserAPI, AppServices: appServices, RSAPI: rsAPI, - DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, + DisableTLSValidation: dendriteCfg.UserAPI.PushGatewayDisableTLSValidation, PgClient: pgClient, FedClient: fedClient, } - updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(processContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, dendriteCfg.Global.ServerName) // 8 workers TODO: configurable userAPI.Updater = updater // Remove users which we don't share a room with anymore if err := updater.CleanUp(); err != nil { @@ -104,28 +110,28 @@ func NewInternalAPI( }() dlConsumer := consumers.NewDeviceListUpdateConsumer( - base.ProcessContext, cfg, js, updater, + processContext, &dendriteCfg.UserAPI, js, updater, ) if err := dlConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start device list consumer") } sigConsumer := consumers.NewSigningKeyUpdateConsumer( - base.ProcessContext, cfg, js, userAPI, + processContext, &dendriteCfg.UserAPI, js, userAPI, ) if err := sigConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start signing key consumer") } receiptConsumer := consumers.NewOutputReceiptEventConsumer( - base.ProcessContext, cfg, js, db, syncProducer, pgClient, + processContext, &dendriteCfg.UserAPI, js, db, syncProducer, pgClient, ) if err := receiptConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start user API receipt consumer") } eventConsumer := consumers.NewOutputRoomEventConsumer( - base.ProcessContext, cfg, js, db, pgClient, rsAPI, syncProducer, + processContext, &dendriteCfg.UserAPI, js, db, pgClient, rsAPI, syncProducer, ) if err := eventConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start user API streamed event consumer") @@ -134,15 +140,15 @@ func NewInternalAPI( var cleanOldNotifs func() cleanOldNotifs = func() { logrus.Infof("Cleaning old notifications") - if err := db.DeleteOldNotifications(base.Context()); err != nil { + if err := db.DeleteOldNotifications(processContext.Context()); err != nil { logrus.WithError(err).Error("Failed to clean old notifications") } time.AfterFunc(time.Hour, cleanOldNotifs) } time.AfterFunc(time.Minute, cleanOldNotifs) - if base.Cfg.Global.ReportStats.Enabled { - go util.StartPhoneHomeCollector(time.Now(), base.Cfg, db) + if dendriteCfg.Global.ReportStats.Enabled { + go util.StartPhoneHomeCollector(time.Now(), dendriteCfg, db) } return userAPI diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 01e491cb6..45762a7d8 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -22,8 +22,13 @@ import ( "testing" "time" + api2 "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" "golang.org/x/crypto/bcrypt" @@ -37,7 +42,7 @@ import ( ) const ( - serverName = gomatrixserverlib.ServerName("example.com") + serverName = spec.ServerName("example.com") ) type apiTestOpts struct { @@ -67,32 +72,25 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, pub if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } - base, baseclose := testrig.CreateBaseDendrite(t, dbType) - connStr, close := test.PrepareDBConnectionString(t, dbType) + cfg, ctx, close := testrig.CreateConfig(t, dbType) sName := serverName if opts.serverName != "" { - sName = gomatrixserverlib.ServerName(opts.serverName) + sName = spec.ServerName(opts.serverName) } - accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ - ConnectionString: config.DataSource(connStr), - }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") + cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions) + + accountDB, err := storage.NewUserDatabase(ctx.Context(), cm, &cfg.UserAPI.AccountDatabase, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { t.Fatalf("failed to create account DB: %s", err) } - keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ - ConnectionString: config.DataSource(connStr), - }) + keyDB, err := storage.NewKeyDatabase(cm, &cfg.KeyServer.Database) if err != nil { t.Fatalf("failed to create key DB: %s", err) } - cfg := &config.UserAPI{ - Matrix: &config.Global{ - SigningIdentity: gomatrixserverlib.SigningIdentity{ - ServerName: sName, - }, - }, + cfg.Global.SigningIdentity = fclient.SigningIdentity{ + ServerName: sName, } if publisher == nil { @@ -104,12 +102,11 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, pub return &internal.UserInternalAPI{ DB: accountDB, KeyDatabase: keyDB, - Config: cfg, + Config: &cfg.UserAPI, SyncProducer: syncProducer, KeyChangeProducer: keyChangeProducer, }, accountDB, func() { close() - baseclose() } } @@ -118,33 +115,26 @@ func TestQueryProfile(t *testing.T) { aliceDisplayName := "Alice" testCases := []struct { - req api.QueryProfileRequest - wantRes api.QueryProfileResponse + userID string + wantRes *authtypes.Profile wantErr error }{ { - req: api.QueryProfileRequest{ - UserID: fmt.Sprintf("@alice:%s", serverName), - }, - wantRes: api.QueryProfileResponse{ - UserExists: true, - AvatarURL: aliceAvatarURL, + userID: fmt.Sprintf("@alice:%s", serverName), + wantRes: &authtypes.Profile{ + Localpart: "alice", DisplayName: aliceDisplayName, + AvatarURL: aliceAvatarURL, + ServerName: string(serverName), }, }, { - req: api.QueryProfileRequest{ - UserID: fmt.Sprintf("@bob:%s", serverName), - }, - wantRes: api.QueryProfileResponse{ - UserExists: false, - }, + userID: fmt.Sprintf("@bob:%s", serverName), + wantErr: api2.ErrProfileNotExists, }, { - req: api.QueryProfileRequest{ - UserID: "@alice:wrongdomain.com", - }, - wantErr: fmt.Errorf("wrong domain"), + userID: "@alice:wrongdomain.com", + wantErr: api2.ErrProfileNotExists, }, } @@ -154,14 +144,14 @@ func TestQueryProfile(t *testing.T) { mode = "HTTP" } for _, tc := range testCases { - var gotRes api.QueryProfileResponse - gotErr := testAPI.QueryProfile(context.TODO(), &tc.req, &gotRes) + + profile, gotErr := testAPI.QueryProfile(context.TODO(), tc.userID) if tc.wantErr == nil && gotErr != nil || tc.wantErr != nil && gotErr == nil { t.Errorf("QueryProfile %s error, got %s want %s", mode, gotErr, tc.wantErr) continue } - if !reflect.DeepEqual(tc.wantRes, gotRes) { - t.Errorf("QueryProfile %s response got %+v want %+v", mode, gotRes, tc.wantRes) + if !reflect.DeepEqual(tc.wantRes, profile) { + t.Errorf("QueryProfile %s response got %+v want %+v", mode, profile, tc.wantRes) } } } diff --git a/userapi/util/devices.go b/userapi/util/devices.go index 31617d8c1..117da08ea 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -7,7 +7,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -19,7 +19,7 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) { +func GetPushDevices(ctx context.Context, localpart string, serverName spec.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) { pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { return nil, fmt.Errorf("db.GetPushers: %w", err) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index 08d1371d6..45d37525c 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -17,7 +17,7 @@ import ( // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error { +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName spec.ServerName, db storage.UserDatabase) error { pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 421852d3f..d6cbad7db 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "golang.org/x/crypto/bcrypt" @@ -15,7 +17,6 @@ import ( "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" userUtil "github.com/matrix-org/dendrite/userapi/util" @@ -77,9 +78,8 @@ func TestNotifyUserCountsAsync(t *testing.T) { // Create DB and Dendrite base connStr, close := test.PrepareDBConnectionString(t, dbType) defer close() - base, _, _ := testrig.Base(nil) - defer base.Close() - db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.NewUserDatabase(ctx, cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "test", bcrypt.MinCost, 0, 0, "") if err != nil { @@ -100,7 +100,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { // Insert a dummy event if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: gomatrixserverlib.HeaderedToClientEvent(dummyEvent, gomatrixserverlib.FormatAll), + Event: synctypes.HeaderedToClientEvent(dummyEvent, synctypes.FormatAll), }); err != nil { t.Error(err) } diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go index 21035e045..4bf9a5d88 100644 --- a/userapi/util/phonehomestats.go +++ b/userapi/util/phonehomestats.go @@ -24,18 +24,18 @@ import ( "syscall" "time" - "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib/spec" ) type phoneHomeStats struct { prevData timestampToRUUsage stats map[string]interface{} - serverName gomatrixserverlib.ServerName + serverName spec.ServerName startTime time.Time cfg *config.Dendrite db storage.Statistics diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go index 5f626b5bc..191a35c04 100644 --- a/userapi/util/phonehomestats_test.go +++ b/userapi/util/phonehomestats_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/storage" @@ -18,12 +18,10 @@ import ( func TestCollect(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - b, _, _ := testrig.Base(nil) - connStr, closeDB := test.PrepareDBConnectionString(t, dbType) + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) defer closeDB() - db, err := storage.NewUserDatabase(b, &config.DatabaseOptions{ - ConnectionString: config.DataSource(connStr), - }, "localhost", bcrypt.MinCost, 1000, 1000, "") + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + db, err := storage.NewUserDatabase(processCtx.Context(), cm, &cfg.UserAPI.AccountDatabase, "localhost", bcrypt.MinCost, 1000, 1000, "") if err != nil { t.Error(err) } @@ -62,12 +60,12 @@ func TestCollect(t *testing.T) { })) defer srv.Close() - b.Cfg.Global.ReportStats.Endpoint = srv.URL + cfg.Global.ReportStats.Endpoint = srv.URL stats := phoneHomeStats{ prevData: timestampToRUUsage{}, serverName: "localhost", startTime: time.Now(), - cfg: b.Cfg, + cfg: cfg, db: db, isMonolith: false, client: &http.Client{Timeout: time.Second},