diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 2f615a6a4..033c5864b 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -4,7 +4,13 @@ on: push: branches: - main + paths: + - '**.go' # only execute on changes to go files + - '.github/workflows/**' # or workflow changes pull_request: + paths: + - '**.go' + - '.github/workflows/**' release: types: [published] workflow_dispatch: diff --git a/README.md b/README.md index ba6960f35..295203eb4 100644 --- a/README.md +++ b/README.md @@ -71,10 +71,10 @@ $ ./bin/generate-keys --tls-cert server.crt --tls-key server.key # Copy and modify the config file - you'll need to set a server name and paths to the keys # at the very least, along with setting up the database connection strings. -$ cp dendrite-sample.monolith.yaml dendrite.yaml +$ cp dendrite-sample.yaml dendrite.yaml # Build and run the server: -$ ./bin/dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml +$ ./bin/dendrite --tls-cert server.crt --tls-key server.key --config dendrite.yaml # Create an user account (add -admin for an admin user). # Specify the localpart only, e.g. 'alice' for '@alice:domain.com' diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index ac68f4bd4..528de63e8 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -122,6 +122,7 @@ func (s *OutputRoomEventConsumer) onMessage( if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { newEventID := output.NewRoomEvent.Event.EventID() eventsReq := &api.QueryEventsByIDRequest{ + RoomID: output.NewRoomEvent.Event.RoomID(), EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), } eventsRes := &api.QueryEventsByIDResponse{} diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 7841b3b07..f86bbc8fd 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -57,7 +57,7 @@ func SendRedaction( } } - ev := roomserverAPI.GetEvent(req.Context(), rsAPI, eventID) + ev := roomserverAPI.GetEvent(req.Context(), rsAPI, roomID, eventID) if ev == nil { return util.JSONResponse{ Code: 400, diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go index e8ff0a478..1ae348cfa 100644 --- a/cmd/dendrite/main.go +++ b/cmd/dendrite/main.go @@ -16,6 +16,7 @@ package main import ( "flag" + "io/fs" "github.com/sirupsen/logrus" @@ -30,6 +31,12 @@ import ( ) var ( + unixSocket = flag.String("unix-socket", "", + "EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)", + ) + unixSocketPermission = flag.Int("unix-socket-permission", 0755, + "EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server", + ) httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") @@ -38,8 +45,23 @@ var ( func main() { cfg := setup.ParseFlags(true) - httpAddr := config.HTTPAddress("http://" + *httpBindAddr) - httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr) + httpAddr := config.ServerAddress{} + httpsAddr := config.ServerAddress{} + if *unixSocket == "" { + http, err := config.HTTPAddress("http://" + *httpBindAddr) + if err != nil { + logrus.WithError(err).Fatalf("Failed to parse http address") + } + httpAddr = http + https, err := config.HTTPAddress("https://" + *httpsBindAddr) + if err != nil { + logrus.WithError(err).Fatalf("Failed to parse https address") + } + httpsAddr = https + } else { + httpAddr = config.UnixSocketAddress(*unixSocket, fs.FileMode(*unixSocketPermission)) + } + options := []basepkg.BaseDendriteOptions{} base := basepkg.NewBaseDendrite(cfg, options...) @@ -92,7 +114,7 @@ func main() { base.SetupAndServeHTTP(httpAddr, nil, nil) }() // Handle HTTPS if certificate and key are provided - if *certFile != "" && *keyFile != "" { + if *unixSocket == "" && *certFile != "" && *keyFile != "" { go func() { base.SetupAndServeHTTP(httpsAddr, certFile, keyFile) }() diff --git a/cmd/dendrite/main_test.go b/cmd/dendrite/main_test.go index efa1a926c..d51bc7434 100644 --- a/cmd/dendrite/main_test.go +++ b/cmd/dendrite/main_test.go @@ -9,7 +9,7 @@ import ( ) // This is an instrumented main, used when running integration tests (sytest) with code coverage. -// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server +// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite // Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml // Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html // Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index e3840bbcf..a9cc80cb7 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -62,9 +62,10 @@ func main() { panic(err) } - stateres := state.NewStateResolution(roomserverDB, &types.RoomInfo{ + roomInfo := &types.RoomInfo{ RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), - }) + } + stateres := state.NewStateResolution(roomserverDB, roomInfo) if *difference { if len(snapshotNIDs) != 2 { @@ -87,7 +88,7 @@ func main() { } var eventEntries []types.Event - eventEntries, err = roomserverDB.Events(ctx, 0, eventNIDs) + eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs) if err != nil { panic(err) } @@ -145,7 +146,7 @@ func main() { } fmt.Println("Fetching", len(eventNIDMap), "state events") - eventEntries, err := roomserverDB.Events(ctx, 0, eventNIDs) + eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs) if err != nil { panic(err) } @@ -165,7 +166,7 @@ func main() { } fmt.Println("Fetching", len(authEventIDs), "auth events") - authEventEntries, err := roomserverDB.EventsFromIDs(ctx, 0, authEventIDs) + authEventEntries, err := roomserverDB.EventsFromIDs(ctx, roomInfo, authEventIDs) if err != nil { panic(err) } diff --git a/docs/development/PROFILING.md b/docs/development/PROFILING.md index f3b573472..57c37a900 100644 --- a/docs/development/PROFILING.md +++ b/docs/development/PROFILING.md @@ -15,7 +15,7 @@ Dendrite contains an embedded profiler called `pprof`, which is a part of the st To enable the profiler, start Dendrite with the `PPROFLISTEN` environment variable. This variable specifies which address and port to listen on, e.g. ``` -PPROFLISTEN=localhost:65432 ./bin/dendrite-monolith-server ... +PPROFLISTEN=localhost:65432 ./bin/dendrite ... ``` If pprof has been enabled successfully, a log line at startup will show that pprof is listening: diff --git a/docs/development/coverage.md b/docs/development/coverage.md index f3e39ddd7..c4a8a1174 100644 --- a/docs/development/coverage.md +++ b/docs/development/coverage.md @@ -14,8 +14,8 @@ index 8f0e209c..ad057e52 100644 $output->diag( "Starting monolith server" ); my @command = ( -- $self->{bindir} . '/dendrite-monolith-server', -+ $self->{bindir} . '/dendrite-monolith-server', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL", +- $self->{bindir} . '/dendrite', ++ $self->{bindir} . '/dendrite', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL", '--config', $self->{paths}{config}, '--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port, '--https-bind-address', $self->{bind_host} . ':' . $self->secure_port, @@ -27,9 +27,9 @@ index f009332b..7ea79869 100755 echo >&2 "--- Building dendrite from source" cd /src mkdir -p $GOBIN --go install -v ./cmd/dendrite-monolith-server -+# go install -v ./cmd/dendrite-monolith-server -+go test -c -cover -covermode=atomic -o $GOBIN/dendrite-monolith-server -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server +-go install -v ./cmd/dendrite ++# go install -v ./cmd/dendrite ++go test -c -cover -covermode=atomic -o $GOBIN/dendrite -coverpkg "github.com/matrix-org/..." ./cmd/dendrite go install -v ./cmd/generate-keys cd - ``` diff --git a/docs/development/tracing/setup.md b/docs/development/tracing/setup.md index a9e90c643..cef1089e4 100644 --- a/docs/development/tracing/setup.md +++ b/docs/development/tracing/setup.md @@ -49,7 +49,7 @@ tracing: then run the monolith server: ``` -./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml +./dendrite --tls-cert server.crt --tls-key server.key --config dendrite.yaml ``` ## Checking traces diff --git a/docs/installation/3_build.md b/docs/installation/3_build.md index aed2080db..824c81d37 100644 --- a/docs/installation/3_build.md +++ b/docs/installation/3_build.md @@ -28,11 +28,11 @@ The resulting binaries will be placed in the `bin` subfolder. You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`: ```sh -go install ./cmd/dendrite-monolith-server +go install ./cmd/dendrite ``` Alternatively, you can specify a custom path for the binary to be written to using `go build`: ```sh -go build -o /usr/local/bin/ ./cmd/dendrite-monolith-server +go build -o /usr/local/bin/ ./cmd/dendrite ``` diff --git a/docs/installation/5_install_monolith.md b/docs/installation/5_install_monolith.md index 7de066cf7..901975a65 100644 --- a/docs/installation/5_install_monolith.md +++ b/docs/installation/5_install_monolith.md @@ -11,11 +11,11 @@ permalink: /installation/install/monolith You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`: ```sh -go install ./cmd/dendrite-monolith-server +go install ./cmd/dendrite ``` Alternatively, you can specify a custom path for the binary to be written to using `go build`: ```sh -go build -o /usr/local/bin/ ./cmd/dendrite-monolith-server +go build -o /usr/local/bin/ ./cmd/dendrite ``` diff --git a/docs/installation/9_starting_monolith.md b/docs/installation/9_starting_monolith.md index 124477e73..d7e8c0b8b 100644 --- a/docs/installation/9_starting_monolith.md +++ b/docs/installation/9_starting_monolith.md @@ -9,10 +9,10 @@ permalink: /installation/start/monolith # Starting the monolith Once you have completed all of the preparation and installation steps, -you can start your Dendrite monolith deployment by starting the `dendrite-monolith-server`: +you can start your Dendrite monolith deployment by starting `dendrite`: ```bash -./dendrite-monolith-server -config /path/to/dendrite.yaml +./dendrite -config /path/to/dendrite.yaml ``` By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses @@ -20,7 +20,7 @@ or ports that Dendrite listens on, you can use the `-http-bind-address` and `-https-bind-address` command line arguments: ```bash -./dendrite-monolith-server -config /path/to/dendrite.yaml \ +./dendrite -config /path/to/dendrite.yaml \ -http-bind-address 1.2.3.4:12345 \ -https-bind-address 1.2.3.4:54321 ``` diff --git a/docs/systemd/monolith-example.service b/docs/systemd/monolith-example.service index 237120ffb..8a948a3fa 100644 --- a/docs/systemd/monolith-example.service +++ b/docs/systemd/monolith-example.service @@ -11,7 +11,7 @@ Type=simple User=dendrite Group=dendrite WorkingDirectory=/opt/dendrite/ -ExecStart=/opt/dendrite/bin/dendrite-monolith-server +ExecStart=/opt/dendrite/bin/dendrite Restart=always LimitNOFILE=65535 diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 82a4db3f7..378b96ba0 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -173,6 +173,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew // Finally, work out if there are any more events missing. if len(missingEventIDs) > 0 { eventsReq := &api.QueryEventsByIDRequest{ + RoomID: ore.Event.RoomID(), EventIDs: missingEventIDs, } eventsRes := &api.QueryEventsByIDResponse{} @@ -483,7 +484,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( // At this point the missing events are neither the event itself nor are // they present in our local database. Our only option is to fetch them // from the roomserver using the query API. - eventReq := api.QueryEventsByIDRequest{EventIDs: missing} + eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID()} var eventResp api.QueryEventsByIDResponse if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil { return nil, err diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index 868785a9b..2f1f3baf6 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -36,7 +36,7 @@ func GetEventAuth( return *err } - event, resErr := fetchEvent(ctx, rsAPI, eventID) + event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID) if resErr != nil { return *resErr } diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index 6168912bd..b41292415 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -20,10 +20,11 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/roomserver/api" ) // GetEvent returns the requested event @@ -38,7 +39,9 @@ func GetEvent( if err != nil { return *err } - event, err := fetchEvent(ctx, rsAPI, eventID) + // /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string, + // which results in `QueryEventsByID` to first get the event and use that to determine the roomID. + event, err := fetchEvent(ctx, rsAPI, "", eventID) if err != nil { return *err } @@ -60,21 +63,13 @@ func allowedToSeeEvent( rsAPI api.FederationRoomserverAPI, eventID string, ) *util.JSONResponse { - var authResponse api.QueryServerAllowedToSeeEventResponse - err := rsAPI.QueryServerAllowedToSeeEvent( - ctx, - &api.QueryServerAllowedToSeeEventRequest{ - EventID: eventID, - ServerName: origin, - }, - &authResponse, - ) + allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID) if err != nil { resErr := util.ErrorResponse(err) return &resErr } - if !authResponse.AllowedToSeeEvent { + if !allowed { resErr := util.MessageResponse(http.StatusForbidden, "server not allowed to see event") return &resErr } @@ -83,11 +78,11 @@ func allowedToSeeEvent( } // fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found. -func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { +func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { var eventsResponse api.QueryEventsByIDResponse err := rsAPI.QueryEventsByID( ctx, - &api.QueryEventsByIDRequest{EventIDs: []string{eventID}}, + &api.QueryEventsByIDRequest{EventIDs: []string{eventID}, RoomID: roomID}, &eventsResponse, ) if err != nil { diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 1d08d0a82..1120cf260 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -107,7 +107,7 @@ func getState( return nil, nil, err } - event, resErr := fetchEvent(ctx, rsAPI, eventID) + event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID) if resErr != nil { return nil, nil, resErr } diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index 223282a25..d6c79e989 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -16,7 +16,9 @@ // Hooks can only be run in monolith mode. package hooks -import "sync" +import ( + "sync" +) const ( // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 73732ae32..f6d003a44 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -54,7 +54,8 @@ type QueryBulkStateContentAPI interface { } type QueryEventsAPI interface { - // Query a list of events by event ID. + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, @@ -71,7 +72,8 @@ type SyncRoomserverAPI interface { QueryBulkStateContentAPI // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error - // Query a list of events by event ID. + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, @@ -108,7 +110,8 @@ type SyncRoomserverAPI interface { } type AppserviceRoomserverAPI interface { - // Query a list of events by event ID. + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, @@ -182,6 +185,8 @@ type FederationRoomserverAPI interface { QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error // Query to get state and auth chain for a (potentially hypothetical) event. // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate @@ -193,7 +198,7 @@ type FederationRoomserverAPI interface { // Query missing events for a room from roomserver QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error // Query whether a server is allowed to see an event - QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error + QueryServerAllowedToSeeEvent(ctx context.Context, serverName gomatrixserverlib.ServerName, eventID string) (allowed bool, err error) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 4ef548e19..24722db0b 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -86,6 +86,9 @@ type QueryStateAfterEventsResponse struct { // QueryEventsByIDRequest is a request to QueryEventsByID type QueryEventsByIDRequest struct { + // The roomID to query events for. If this is empty, we first try to fetch the roomID from the database + // as this is needed for further processing/parsing events. + RoomID string `json:"room_id"` // The event IDs to look up. EventIDs []string `json:"event_ids"` } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 252be557f..f220560ed 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -108,9 +108,10 @@ func SendInputRoomEvents( } // GetEvent returns the event or nil, even on errors. -func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, eventID string) *gomatrixserverlib.HeaderedEvent { +func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *gomatrixserverlib.HeaderedEvent { var res QueryEventsByIDResponse err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{ + RoomID: roomID, EventIDs: []string{eventID}, }, &res) if err != nil { diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 27c8dd8fa..9defe7945 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -67,7 +67,7 @@ func CheckForSoftFail( stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) if err != nil { return true, fmt.Errorf("loadAuthEvents: %w", err) } @@ -85,7 +85,7 @@ func CheckForSoftFail( func CheckAuthEvents( ctx context.Context, db storage.RoomDatabase, - roomNID types.RoomNID, + roomInfo *types.RoomInfo, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -100,7 +100,7 @@ func CheckAuthEvents( stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) if err != nil { return nil, fmt.Errorf("loadAuthEvents: %w", err) } @@ -193,7 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * func loadAuthEvents( ctx context.Context, db state.StateResolutionStorage, - roomNID types.RoomNID, + roomInfo *types.RoomInfo, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { @@ -216,7 +216,7 @@ func loadAuthEvents( eventNIDs = append(eventNIDs, eventNID) } } - if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil { + if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil { return } roomID := "" diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index ee1610cf2..9a70bcc9c 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam return false, err } - events, err := db.Events(ctx, info.RoomNID, eventNIDs) + events, err := db.Events(ctx, info, eventNIDs) if err != nil { return false, err } @@ -157,7 +157,7 @@ func IsInvitePending( // only keep the "m.room.member" events with a "join" membership. These events are returned. // Returns an error if there was an issue fetching the events. func GetMembershipsAtState( - ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, joinedOnly bool, ) ([]types.Event, error) { var eventNIDs types.EventNIDs @@ -177,7 +177,7 @@ func GetMembershipsAtState( util.Unique(eventNIDs) // Get all of the events in this state - stateEvents, err := db.Events(ctx, roomNID, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if err != nil { return nil, err } @@ -227,9 +227,9 @@ func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types } func LoadEvents( - ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID, ) ([]*gomatrixserverlib.Event, error) { - stateEvents, err := db.Events(ctx, roomNID, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if err != nil { return nil, err } @@ -242,13 +242,13 @@ func LoadEvents( } func LoadStateEvents( - ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, ) ([]*gomatrixserverlib.Event, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) for i := range stateEntries { eventNIDs[i] = stateEntries[i].EventNID } - return LoadEvents(ctx, db, roomNID, eventNIDs) + return LoadEvents(ctx, db, roomInfo, eventNIDs) } func CheckServerAllowedToSeeEvent( @@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState( return nil, nil } - return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries) + return LoadStateEvents(ctx, db, info, filteredEntries) } // TODO: Remove this when we have tests to assert correctness of this function @@ -366,7 +366,7 @@ BFSLoop: next = make([]string, 0) } // Retrieve the events to process from the database. - events, err = db.EventsFromIDs(ctx, info.RoomNID, front) + events, err = db.EventsFromIDs(ctx, info, front) if err != nil { return resultNIDs, redactEventIDs, err } @@ -467,7 +467,7 @@ func QueryLatestEventsAndState( return err } - stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries) + stateEvents, err := LoadStateEvents(ctx, db, roomInfo, stateEntries) if err != nil { return err } diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index 62730df1f..c056e704c 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -4,9 +4,10 @@ import ( "context" "testing" - "github.com/matrix-org/dendrite/roomserver/types" "github.com/stretchr/testify/assert" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" @@ -38,9 +39,9 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { var authNIDs []types.EventNID for _, x := range room.Events() { - roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap()) + roomInfo, err := db.GetOrCreateRoomInfo(context.Background(), x.Unwrap()) assert.NoError(t, err) - assert.Greater(t, roomNID, types.RoomNID(0)) + assert.NotNil(t, roomInfo) eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type()) assert.NoError(t, err) @@ -49,7 +50,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey()) assert.NoError(t, err) - evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false) + evNID, _, err := db.StoreEvent(context.Background(), x.Event, roomInfo, eventTypeNID, eventStateKeyNID, authNIDs, false) assert.NoError(t, err) authNIDs = append(authNIDs, evNID) } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index fe35efb27..ede345a93 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -24,9 +24,10 @@ import ( "fmt" "time" - "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/tidwall/gjson" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" @@ -274,8 +275,10 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. + redactAllowed := true if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { isRejected = true + redactAllowed = false rejectionErr = err logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) } @@ -323,7 +326,7 @@ func (r *Inputer) processRoomEvent( // burning CPU time. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent { - historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo.RoomNID, input, missingPrev) + historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo, input, missingPrev) if err != nil { return fmt.Errorf("r.processStateBefore: %w", err) } @@ -332,9 +335,11 @@ func (r *Inputer) processRoomEvent( } } - roomNID, err := r.DB.GetOrCreateRoomNID(ctx, event) - if err != nil { - return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) + if roomInfo == nil { + roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, event) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err) + } } eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type()) @@ -348,15 +353,24 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) + eventNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } // if storing this event results in it being redacted then do so. - if !isRejected && redactedEventID == event.EventID() { - if err = eventutil.RedactEvent(redactionEvent, event); err != nil { - return fmt.Errorf("eventutil.RedactEvent: %w", rerr) + var ( + redactedEventID string + redactionEvent *gomatrixserverlib.Event + redactedEvent *gomatrixserverlib.Event + ) + if !isRejected && !isCreateEvent { + redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, redactAllowed) + if err != nil { + return err + } + if redactedEvent != nil { + redactedEventID = redactedEvent.EventID() } } @@ -489,7 +503,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse // nolint:nakedret func (r *Inputer) processStateBefore( ctx context.Context, - roomNID types.RoomNID, + roomInfo *types.RoomInfo, input *api.InputRoomEvent, missingPrev bool, ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { @@ -505,7 +519,7 @@ func (r *Inputer) processStateBefore( case input.HasState: // If we're overriding the state then we need to go and retrieve // them from the database. It's a hard error if they are missing. - stateEvents, err := r.DB.EventsFromIDs(ctx, roomNID, input.StateEventIDs) + stateEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, input.StateEventIDs) if err != nil { return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err) } @@ -604,7 +618,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo.RoomNID, []string{authEventID}) + authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -690,9 +704,11 @@ nextAuthEvent: logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } - roomNID, err := r.DB.GetOrCreateRoomNID(ctx, authEvent) - if err != nil { - return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) + if roomInfo == nil { + roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, authEvent) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err) + } } eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type()) @@ -706,7 +722,7 @@ nextAuthEvent: } // Finally, store the event in the database. - eventNID, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) + eventNID, _, err := r.DB.StoreEvent(ctx, authEvent, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -782,7 +798,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event return err } - memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, membershipNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs) if err != nil { return err } diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 99a012551..e1dfa6cfa 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -53,7 +53,7 @@ func (r *Inputer) updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := updater.Events(ctx, 0, eventNIDs) + events, err := updater.Events(ctx, nil, eventNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index c8b7d31dd..9627f15ac 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even for _, entry := range stateEntries { stateEventNIDs = append(stateEventNIDs, entry.EventNID) } - stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomNID, stateEventNIDs) + stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs) if err != nil { t.log.WithError(err).Warnf("failed to load state events locally") return nil @@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even missingEventList = append(missingEventList, evID) } t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") - events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList) if err != nil { return nil } @@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo } t.haveEventsMutex.Unlock() - events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList) if err != nil { return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) } @@ -844,7 +844,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs if localFirst { // fetch from the roomserver - events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, []string{missingEventID}) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo, []string{missingEventID}) if err != nil { t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) } else if len(events) == 1 { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 2efe2255f..45089bdd1 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -70,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom( return nil } - memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, memberNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 3a3a049db..411f4202b 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -23,7 +23,6 @@ import ( "github.com/sirupsen/logrus" federationAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/internal/helpers" @@ -86,7 +85,7 @@ func (r *Backfiller) PerformBackfill( // Retrieve events from the list that was filled previously. If we fail to get // events from the database then attempt once to get them from federation instead. var loadedEvents []*gomatrixserverlib.Event - loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) + loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info, resultNIDs) if err != nil { if _, ok := err.(types.MissingEventError); ok { return r.backfillViaFederation(ctx, request, response) @@ -473,7 +472,7 @@ FindSuccessor: // Retrieve all "m.room.member" state events of "join" membership, which // contains the list of users in the room before the event, therefore all // the servers in it at that moment. - memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info.RoomNID, stateEntries, true) + memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info, stateEntries, true) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") return nil @@ -532,7 +531,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, roomNID = nid.RoomNID } } - eventsWithNids, err := b.db.Events(ctx, roomNID, eventNIDs) + eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs) if err != nil { logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") return nil, err @@ -562,7 +561,7 @@ func joinEventsFromHistoryVisibility( } // Get all of the events in this state - stateEvents, err := db.Events(ctx, roomInfo.RoomNID, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if err != nil { // even though the default should be shared, restricting the visibility to joined // feels more secure here. @@ -585,7 +584,7 @@ func joinEventsFromHistoryVisibility( if err != nil { return nil, visibility, err } - evs, err := db.Events(ctx, roomInfo.RoomNID, joinEventNIDs) + evs, err := db.Events(ctx, roomInfo, joinEventNIDs) return evs, visibility, err } @@ -606,7 +605,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs i++ } - roomNID, err = db.GetOrCreateRoomNID(ctx, ev.Unwrap()) + roomInfo, err := db.GetOrCreateRoomInfo(ctx, ev.Unwrap()) if err != nil { logrus.WithError(err).Error("failed to get or create roomNID") continue @@ -624,23 +623,22 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs continue } - var redactedEventID string - var redactionEvent *gomatrixserverlib.Event - eventNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), roomNID, eventTypeNID, eventStateKeyNID, authNids, false) + eventNID, _, err = db.StoreEvent(ctx, ev.Unwrap(), roomInfo, eventTypeNID, eventStateKeyNID, authNids, false) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") continue } + + _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Unwrap(), true) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") + continue + } // If storing this event results in it being redacted, then do so. // It's also possible for this event to be a redaction which results in another event being // redacted, which we don't care about since we aren't returning it in this backfill. - if redactedEventID == ev.EventID() { - eventToRedact := ev.Unwrap() - if err := eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") - continue - } - ev = eventToRedact.Headered(ev.RoomVersion) + if redactedEvent != nil && redactedEvent.EventID() == ev.EventID() { + ev = redactedEvent.Headered(ev.RoomVersion) events[j] = ev } backfilledEventMap[ev.EventID()] = types.Event{ diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 9ac9edc4c..1fb6eb43a 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - latestEvents, err := r.DB.EventsFromIDs(ctx, info.RoomNID, []string{latestEventRefs[0].EventID}) + latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0].EventID}) if err != nil { return err } @@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) + stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info, stateEntries) if err != nil { return err } @@ -100,7 +100,7 @@ func (r *InboundPeeker) PerformInboundPeek( } authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe - authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 118e1b879..13d13f7b5 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite( // try and see if the user is allowed to make this invite. We can't do // this for invites coming in over federation - we have to take those on // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, info.RoomNID, event, event.AuthEventIDs()) + _, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs()) if err != nil { logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( "processInviteEvent.checkAuthEvents failed for event", @@ -291,7 +291,7 @@ func buildInviteStrippedState( for _, stateNID := range stateEntries { stateNIDs = append(stateNIDs, stateNID.EventNID) } - stateEvents, err := db.Events(ctx, info.RoomNID, stateNIDs) + stateEvents, err := db.Events(ctx, info, stateNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index ac34e0ff0..c5b74422f 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -21,11 +21,12 @@ import ( "errors" "fmt" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/roomserver/acls" @@ -102,7 +103,7 @@ func (r *Queryer) QueryStateAfterEvents( return err } - stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) + stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info, stateEntries) if err != nil { return err } @@ -114,7 +115,7 @@ func (r *Queryer) QueryStateAfterEvents( } authEventIDs = util.UniqueStrings(authEventIDs) - authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs) if err != nil { return fmt.Errorf("getAuthChain: %w", err) } @@ -132,24 +133,46 @@ func (r *Queryer) QueryStateAfterEvents( return nil } -// QueryEventsByID implements api.RoomserverInternalAPI +// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine +// which room to use by querying the first events roomID. func (r *Queryer) QueryEventsByID( ctx context.Context, request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { - events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs) + if len(request.EventIDs) == 0 { + return nil + } + var err error + // We didn't receive a room ID, we need to fetch it first before we can continue. + // This happens for e.g. ` /_matrix/federation/v1/event/{eventId}` + var roomInfo *types.RoomInfo + if request.RoomID == "" { + var eventNIDs map[string]types.EventMetadata + eventNIDs, err = r.DB.EventNIDs(ctx, []string{request.EventIDs[0]}) + if err != nil { + return err + } + if len(eventNIDs) == 0 { + return nil + } + roomInfo, err = r.DB.RoomInfoByNID(ctx, eventNIDs[request.EventIDs[0]].RoomNID) + } else { + roomInfo, err = r.DB.RoomInfo(ctx, request.RoomID) + } + if err != nil { + return err + } + if roomInfo == nil { + return nil + } + events, err := r.DB.EventsFromIDs(ctx, roomInfo, request.EventIDs) if err != nil { return err } for _, event := range events { - roomVersion, verr := r.roomVersion(event.RoomID()) - if verr != nil { - return verr - } - - response.Events = append(response.Events, event.Headered(roomVersion)) + response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion)) } return nil @@ -186,7 +209,7 @@ func (r *Queryer) QueryMembershipForUser( response.IsInRoom = stillInRoom response.HasBeenInRoom = true - evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID}) + evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID}) if err != nil { return err } @@ -268,10 +291,10 @@ func (r *Queryer) QueryMembershipAtEvent( // once. If we have more than one membership event, we need to get the state for each state entry. if canShortCircuit { if len(memberships) == 0 { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false) } } else { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false) } if err != nil { return fmt.Errorf("unable to get memberships at state: %w", err) @@ -318,7 +341,7 @@ func (r *Queryer) QueryMembershipsForRoom( } return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } - events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) + events, err = r.DB.Events(ctx, info, eventNIDs) if err != nil { return fmt.Errorf("r.DB.Events: %w", err) } @@ -357,14 +380,14 @@ func (r *Queryer) QueryMembershipsForRoom( return err } - events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) + events, err = r.DB.Events(ctx, info, eventNIDs) } else { stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err } - events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly) + events, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntries, request.JoinedOnly) } if err != nil { @@ -412,39 +435,39 @@ func (r *Queryer) QueryServerJoinedToRoom( // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI func (r *Queryer) QueryServerAllowedToSeeEvent( ctx context.Context, - request *api.QueryServerAllowedToSeeEventRequest, - response *api.QueryServerAllowedToSeeEventResponse, -) (err error) { - events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID}) + serverName gomatrixserverlib.ServerName, + eventID string, +) (allowed bool, err error) { + events, err := r.DB.EventNIDs(ctx, []string{eventID}) if err != nil { return } if len(events) == 0 { - response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see - return + return allowed, nil } - roomID := events[0].RoomID() - - inRoomReq := &api.QueryServerJoinedToRoomRequest{ - RoomID: roomID, - ServerName: request.ServerName, - } - inRoomRes := &api.QueryServerJoinedToRoomResponse{} - if err = r.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil { - return fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err) - } - - info, err := r.DB.RoomInfo(ctx, roomID) + info, err := r.DB.RoomInfoByNID(ctx, events[eventID].RoomNID) if err != nil { - return err + return allowed, err } if info == nil || info.IsStub() { - return nil + return allowed, nil } - response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom, + var isInRoom bool + if r.IsLocalServerName(serverName) || serverName == "" { + isInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) + if err != nil { + return allowed, fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) + } + } else { + isInRoom, err = r.DB.GetServerInRoom(ctx, info.RoomNID, serverName) + if err != nil { + return allowed, fmt.Errorf("r.DB.GetServerInRoom: %w", err) + } + } + + return helpers.CheckServerAllowedToSeeEvent( + ctx, r.DB, info, eventID, serverName, isInRoom, ) - return } // QueryMissingEvents implements api.RoomserverInternalAPI @@ -466,19 +489,22 @@ func (r *Queryer) QueryMissingEvents( eventsToFilter[id] = true } } - events, err := r.DB.EventsFromIDs(ctx, 0, front) + if len(front) == 0 { + return nil // no events to query, give up. + } + events, err := r.DB.EventNIDs(ctx, []string{front[0]}) if err != nil { return err } if len(events) == 0 { return nil // we are missing the events being asked to search from, give up. } - info, err := r.DB.RoomInfo(ctx, events[0].RoomID()) + info, err := r.DB.RoomInfoByNID(ctx, events[front[0]].RoomNID) if err != nil { return err } if info == nil || info.IsStub() { - return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) + return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID) } resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) @@ -486,7 +512,7 @@ func (r *Queryer) QueryMissingEvents( return err } - loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) + loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info, resultNIDs) if err != nil { return err } @@ -529,7 +555,7 @@ func (r *Queryer) QueryStateAndAuthChain( // TODO: this probably means it should be a different query operation... if request.OnlyFetchAuthChain { var authEvents []*gomatrixserverlib.Event - authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, request.AuthEventIDs) + authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs) if err != nil { return err } @@ -556,7 +582,7 @@ func (r *Queryer) QueryStateAndAuthChain( } authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe - authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs) if err != nil { return err } @@ -611,18 +637,18 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI return nil, rejected, false, err } - events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries) + events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo, stateEntries) return events, rejected, false, err } -type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error) +type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Event, error) // GetAuthChain fetches the auth chain for the given auth events. An auth chain // is the list of all events that are referenced in the auth_events section, and // all their auth_events, recursively. The returned set of events contain the // given events. Will *not* error if we don't have all auth events. func GetAuthChain( - ctx context.Context, fn eventsFromIDs, authEventIDs []string, + ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string, ) ([]*gomatrixserverlib.Event, error) { // List of event IDs to fetch. On each pass, these events will be requested // from the database and the `eventsToFetch` will be updated with any new @@ -633,7 +659,7 @@ func GetAuthChain( for len(eventsToFetch) > 0 { // Try to retrieve the events from the database. - events, err := fn(ctx, 0, eventsToFetch) + events, err := fn(ctx, roomInfo, eventsToFetch) if err != nil { return nil, err } @@ -852,7 +878,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS } func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error { - chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs) + chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, nil, req.EventIDs) if err != nil { return err } @@ -971,7 +997,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query // For each of the joined users, let's see if we can get a valid // membership event. for _, joinNID := range joinNIDs { - events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID}) + events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID}) if err != nil || len(events) != 1 { continue } diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index 167611575..265f326d4 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error { } // EventsFromIDs implements RoomserverInternalAPIEventDB -func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) { +func (db *getEventDB) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) (res []types.Event, err error) { for _, evID := range eventIDs { res = append(res, types.Event{ EventNID: 0, @@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) { t.Fatalf("Failed to add events to db: %v", err) } - result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"}) + result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e"}) if err != nil { t.Fatalf("getAuthChain failed: %v", err) } @@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) { t.Fatalf("Failed to add events to db: %v", err) } - result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"}) + result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e", "f"}) if err != nil { t.Fatalf("getAuthChain failed: %v", err) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 304311c4c..cfa27e541 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -278,6 +278,16 @@ func TestPurgeRoom(t *testing.T) { if roomInfo == nil { t.Fatalf("room does not exist") } + + // + roomInfo2, err := db.RoomInfoByNID(ctx, roomInfo.RoomNID) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(roomInfo, roomInfo2) { + t.Fatalf("expected roomInfos to be the same, but they aren't") + } + // remember the roomInfo before purging existingRoomInfo := roomInfo @@ -333,6 +343,10 @@ func TestPurgeRoom(t *testing.T) { if roomInfo != nil { t.Fatalf("room should not exist after purging: %+v", roomInfo) } + roomInfo2, err = db.RoomInfoByNID(ctx, existingRoomInfo.RoomNID) + if err == nil { + t.Fatalf("expected room to not exist, but it does: %#v", roomInfo2) + } // validation below diff --git a/roomserver/state/state.go b/roomserver/state/state.go index cec542d7e..9af2f8577 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -41,8 +41,8 @@ type StateResolutionStorage interface { StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) - Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) - EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) + Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) } type StateResolution struct { @@ -975,7 +975,7 @@ func (v *StateResolution) resolveConflictsV2( // Store the newly found auth events in the auth set for this event. var authEventMap map[string]types.StateEntry - authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo.RoomNID, conflictedEvent, knownAuthEvents) + authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo, conflictedEvent, knownAuthEvents) if err != nil { return err } @@ -1091,7 +1091,7 @@ func (v *StateResolution) loadStateEvents( eventNIDs = append(eventNIDs, entry.EventNID) } } - events, err := v.db.Events(ctx, v.roomInfo.RoomNID, eventNIDs) + events, err := v.db.Events(ctx, v.roomInfo, eventNIDs) if err != nil { return nil, nil, err } @@ -1120,7 +1120,7 @@ type authEventLoader struct { // loadAuthEvents loads all of the auth events for a given event recursively, // along with a map that contains state entries for all of the auth events. func (l *authEventLoader) loadAuthEvents( - ctx context.Context, roomNID types.RoomNID, event *gomatrixserverlib.Event, eventMap map[string]types.Event, + ctx context.Context, roomInfo *types.RoomInfo, event *gomatrixserverlib.Event, eventMap map[string]types.Event, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { l.Lock() defer l.Unlock() @@ -1155,7 +1155,7 @@ func (l *authEventLoader) loadAuthEvents( // If we need to get events from the database, go and fetch // those now. if len(l.lookupFromDB) > 0 { - eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomNID, l.lookupFromDB) + eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomInfo, l.lookupFromDB) if err != nil { return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 88ec56670..a41a8a9b4 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -29,6 +29,7 @@ type Database interface { SupportsConcurrentRoomInputs() bool // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) // Store the room state at an event in the database AddState( ctx context.Context, @@ -69,12 +70,12 @@ type Database interface { ) ([]types.StateEntryList, error) // Look up the Events for a list of numeric event IDs. // Returns a sorted list of events. - Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) - // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. - StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) + // Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error. + StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database // Returns a types.MissingEventError if the event IDs aren't in the database. @@ -135,7 +136,7 @@ type Database interface { // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was // not found. // Returns an error if the retrieval went wrong. - EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) + EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) // Publish or unpublish a room from the room directory. PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error // Returns a list of room IDs for rooms which are published. @@ -179,36 +180,53 @@ type Database interface { GetMembershipForHistoryVisibility( ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ) (map[string]*gomatrixserverlib.HeaderedEvent, error) - GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error) + GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) + MaybeRedactEvent( + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool, + ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) } type RoomDatabase interface { + EventDatabase // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) // IsEventRejected returns true if the event is known and rejected. IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error) MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error) - // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. - StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) - StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) - SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) - StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) - Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) - EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) - EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) - EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) - GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error) + GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) +} + +type EventDatabase interface { + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error) + EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error) + SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) + Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + // MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error + // (nil if there was nothing to do) + MaybeRedactEvent( + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool, + ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) + StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 872084383..d98a5cf97 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -194,23 +194,28 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room return err } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: writer, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - EventsTable: events, - RoomsTable: rooms, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, - Purge: purge, + DB: db, + EventDatabase: shared.EventDatabase{ + DB: db, + Cache: cache, + Writer: writer, + EventsTable: events, + EventJSONTable: eventJSON, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + PrevEventsTable: prevEvents, + RedactionsTable: redactions, + }, + Cache: cache, + Writer: writer, + RoomsTable: rooms, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + Purge: purge, } return nil } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 5006c3c55..dc1db0825 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -116,8 +116,8 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent }) } -func (u *RoomUpdater) Events(ctx context.Context, _ types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) { - return u.d.events(ctx, u.txn, u.roomInfo.RoomNID, eventNIDs) +func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { + return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs) } func (u *RoomUpdater) SnapshotNIDFromEventID( @@ -195,8 +195,8 @@ func (u *RoomUpdater) StateAtEventIDs( return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) } -func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, roomNID, eventIDs, NoFilter) +func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, u.roomInfo, eventIDs, NoFilter) } // IsReferenced implements types.RoomRecentEventsUpdater diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index aac5bc365..be3f228d7 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -9,7 +9,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/caching" @@ -28,6 +27,23 @@ import ( const redactionsArePermanent = true type Database struct { + DB *sql.DB + EventDatabase + Cache caching.RoomServerCaches + Writer sqlutil.Writer + RoomsTable tables.Rooms + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + InvitesTable tables.Invites + MembershipTable tables.Membership + PublishedTable tables.Published + Purge tables.Purge + GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) +} + +// EventDatabase contains all tables needed to work with events +type EventDatabase struct { DB *sql.DB Cache caching.RoomServerCaches Writer sqlutil.Writer @@ -35,17 +51,8 @@ type Database struct { EventJSONTable tables.EventJSON EventTypesTable tables.EventTypes EventStateKeysTable tables.EventStateKeys - RoomsTable tables.Rooms - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock - RoomAliasesTable tables.RoomAliases PrevEventsTable tables.PreviousEvents - InvitesTable tables.Invites - MembershipTable tables.Membership - PublishedTable tables.Published RedactionsTable tables.Redactions - Purge tables.Purge - GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } func (d *Database) SupportsConcurrentRoomInputs() bool { @@ -58,13 +65,13 @@ func (d *Database) GetMembershipForHistoryVisibility( return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...) } -func (d *Database) EventTypeNIDs( +func (d *EventDatabase) EventTypeNIDs( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { return d.eventTypeNIDs(ctx, nil, eventTypes) } -func (d *Database) eventTypeNIDs( +func (d *EventDatabase) eventTypeNIDs( ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) @@ -91,7 +98,7 @@ func (d *Database) eventTypeNIDs( return result, nil } -func (d *Database) EventStateKeys( +func (d *EventDatabase) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) @@ -116,13 +123,13 @@ func (d *Database) EventStateKeys( return result, nil } -func (d *Database) EventStateKeyNIDs( +func (d *EventDatabase) EventStateKeyNIDs( ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { return d.eventStateKeyNIDs(ctx, nil, eventStateKeys) } -func (d *Database) eventStateKeyNIDs( +func (d *EventDatabase) eventStateKeyNIDs( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) @@ -174,7 +181,7 @@ func (d *Database) eventStateKeyNIDs( return result, nil } -func (d *Database) StateEntriesForEventIDs( +func (d *EventDatabase) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, excludeRejected bool, ) ([]types.StateEntry, error) { return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected) @@ -213,6 +220,17 @@ func (d *Database) stateEntriesForTuples( return lists, nil } +func (d *Database) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) { + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{roomNID}) + if err != nil { + return nil, err + } + if len(roomIDs) == 0 { + return nil, fmt.Errorf("room does not exist") + } + return d.roomInfo(ctx, nil, roomIDs[0]) +} + func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { return d.roomInfo(ctx, nil, roomID) } @@ -292,7 +310,7 @@ func (d *Database) addState( return } -func (d *Database) EventNIDs( +func (d *EventDatabase) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventMetadata, error) { return d.eventNIDs(ctx, nil, eventIDs, NoFilter) @@ -305,7 +323,7 @@ const ( FilterUnsentOnly UnsentFilter = true ) -func (d *Database) eventNIDs( +func (d *EventDatabase) eventNIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter, ) (map[string]types.EventMetadata, error) { switch filter { @@ -318,7 +336,7 @@ func (d *Database) eventNIDs( } } -func (d *Database) SetState( +func (d *EventDatabase) SetState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -326,19 +344,19 @@ func (d *Database) SetState( }) } -func (d *Database) StateAtEventIDs( +func (d *EventDatabase) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs) } -func (d *Database) SnapshotNIDFromEventID( +func (d *EventDatabase) SnapshotNIDFromEventID( ctx context.Context, eventID string, ) (types.StateSnapshotNID, error) { return d.snapshotNIDFromEventID(ctx, nil, eventID) } -func (d *Database) snapshotNIDFromEventID( +func (d *EventDatabase) snapshotNIDFromEventID( ctx context.Context, txn *sql.Tx, eventID string, ) (types.StateSnapshotNID, error) { _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID) @@ -351,17 +369,17 @@ func (d *Database) snapshotNIDFromEventID( return stateNID, err } -func (d *Database) EventIDs( +func (d *EventDatabase) EventIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]string, error) { return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) } -func (d *Database) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) { - return d.eventsFromIDs(ctx, nil, roomNID, eventIDs, NoFilter) +func (d *EventDatabase) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) { + return d.eventsFromIDs(ctx, nil, roomInfo, eventIDs, NoFilter) } -func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { +func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter) if err != nil { return nil, err @@ -370,15 +388,9 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types var nids []types.EventNID for _, nid := range nidMap { nids = append(nids, nid.EventNID) - if roomNID != 0 && roomNID != nid.RoomNID { - logrus.Errorf("expected events from room %d, but also found %d", roomNID, nid.RoomNID) - } - if roomNID == 0 { - roomNID = nid.RoomNID - } } - return d.events(ctx, txn, roomNID, nids) + return d.events(ctx, txn, roomInfo, nids) } func (d *Database) LatestEventIDs( @@ -517,19 +529,17 @@ func (d *Database) GetInvitesForUser( return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) } -func (d *Database) Events( - ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID, -) ([]types.Event, error) { - return d.events(ctx, nil, roomNID, eventNIDs) +func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { + return d.events(ctx, nil, roomInfo, eventNIDs) } -func (d *Database) events( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, inputEventNIDs types.EventNIDs, +func (d *EventDatabase) events( + ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs, ) ([]types.Event, error) { - if roomNID == 0 { - // No need to go further, as we won't find any events for this room. - return nil, nil + if roomInfo == nil { // this should never happen + return nil, fmt.Errorf("unable to parse events without roomInfo") } + sort.Sort(inputEventNIDs) events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs)) eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs)) @@ -566,31 +576,9 @@ func (d *Database) events( eventIDs = map[types.EventNID]string{} } - var roomVersion gomatrixserverlib.RoomVersion - var fetchRoomVersion bool - var ok bool - var roomID string - if roomID, ok = d.Cache.GetRoomServerRoomID(roomNID); ok { - roomVersion, ok = d.Cache.GetRoomVersion(roomID) - if !ok { - fetchRoomVersion = true - } - } - - if roomVersion == "" || fetchRoomVersion { - var dbRoomVersions map[types.RoomNID]gomatrixserverlib.RoomVersion - dbRoomVersions, err = d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, []types.RoomNID{roomNID}) - if err != nil { - return nil, err - } - if roomVersion, ok = dbRoomVersions[roomNID]; !ok { - return nil, fmt.Errorf("unable to find roomversion for room %d", roomNID) - } - } - for _, eventJSON := range eventJSONs { events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( - eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion, + eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomInfo.RoomVersion, ) if err != nil { return nil, err @@ -660,8 +648,8 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) } -// GetOrCreateRoomNID gets or creates a new roomNID for the given event -func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (roomNID types.RoomNID, err error) { +// GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID. +func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (roomInfo *types.RoomInfo, err error) { // Get the default room version. If the client doesn't supply a room_version // then we will use our configured default to create the room. // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom @@ -670,8 +658,9 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver // room. var roomVersion gomatrixserverlib.RoomVersion if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return 0, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) + return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) } + var roomNID types.RoomNID err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) if err != nil { @@ -679,7 +668,10 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver } return nil }) - return roomNID, err + return &types.RoomInfo{ + RoomVersion: roomVersion, + RoomNID: roomNID, + }, err } func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { @@ -710,25 +702,22 @@ func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKe return eventStateKeyNID, nil } -func (d *Database) StoreEvent( +func (d *EventDatabase) StoreEvent( ctx context.Context, event *gomatrixserverlib.Event, - roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, + roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool, -) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { +) (types.EventNID, types.StateAtEvent, error) { var ( - eventNID types.EventNID - stateNID types.StateSnapshotNID - redactionEvent *gomatrixserverlib.Event - redactedEventID string - err error + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error ) - // Second writer is using the database-provided transaction, probably from the - // room updater, for easy roll-back if required. + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if eventNID, stateNID, err = d.EventsTable.InsertEvent( ctx, txn, - roomNID, + roomInfo.RoomNID, eventTypeNID, eventStateKeyNID, event.EventID(), @@ -751,16 +740,26 @@ func (d *Database) StoreEvent( if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } - if !isRejected { // ignore rejected redaction events - redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, roomNID, eventNID, event) - if err != nil { - return fmt.Errorf("d.handleRedactions: %w", err) + + if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { + // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of + // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This + // function only does SELECTs though so the created txn (at this point) is just a read txn like + // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater + // to do writes however then this will need to go inside `Writer.Do`. + + // The following is a copy of RoomUpdater.StorePreviousEvents + for _, ref := range prevEvents { + if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) + } } } + return nil }) if err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) + return 0, types.StateAtEvent{}, fmt.Errorf("d.Writer.Do: %w", err) } // We should attempt to update the previous events table with any @@ -768,33 +767,6 @@ func (d *Database) StoreEvent( // events updater because it somewhat works as a mutex, ensuring // that there's a row-level lock on the latest room events (well, // on Postgres at least). - if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { - // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of - // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This - // function only does SELECTs though so the created txn (at this point) is just a read txn like - // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater - // to do writes however then this will need to go inside `Writer.Do`. - succeeded := false - var roomInfo *types.RoomInfo - roomInfo, err = d.roomInfo(ctx, nil, event.RoomID()) - if err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) - } - if roomInfo == nil && len(prevEvents) > 0 { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) - } - var updater *RoomUpdater - updater, err = d.GetRoomUpdater(ctx, roomInfo) - if err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) - } - defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - - if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) - } - succeeded = true - } return eventNID, types.StateAtEvent{ BeforeStateSnapshotNID: stateNID, @@ -805,7 +777,7 @@ func (d *Database) StoreEvent( }, EventNID: eventNID, }, - }, redactionEvent, redactedEventID, err + }, err } func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error { @@ -893,7 +865,7 @@ func (d *Database) assignEventTypeNID( return eventTypeNID, nil } -func (d *Database) assignStateKeyNID( +func (d *EventDatabase) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey) @@ -937,7 +909,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( return roomVersion, err } -// handleRedactions manages the redacted status of events. There's two cases to consider in order to comply with the spec: +// MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec: // "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid." // https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events // These cases are: @@ -952,95 +924,95 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( // when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need // to cross-reference with other tables when loading. // -// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. -func (d *Database) handleRedactions( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event, -) (*gomatrixserverlib.Event, string, error) { - var err error - isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil - if isRedactionEvent { - // an event which redacts itself should be ignored - if event.EventID() == event.Redacts() { - return nil, "", nil +// Returns the redaction event and the redacted event if this call resulted in a redaction. +func (d *EventDatabase) MaybeRedactEvent( + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool, +) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) { + var ( + redactionEvent, redactedEvent *types.Event + err error + validated bool + ignoreRedaction bool + ) + + wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil + if isRedactionEvent { + // an event which redacts itself should be ignored + if event.EventID() == event.Redacts() { + return nil + } + + err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{ + Validated: false, + RedactionEventID: event.EventID(), + RedactsEventID: event.Redacts(), + }) + if err != nil { + return fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err) + } } - err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{ - Validated: false, - RedactionEventID: event.EventID(), - RedactsEventID: event.Redacts(), - }) + redactionEvent, redactedEvent, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event) + switch { + case err != nil: + return fmt.Errorf("d.loadRedactionPair: %w", err) + case validated || redactedEvent == nil || redactionEvent == nil: + // we've seen this redaction before or there is nothing to redact + return nil + case redactedEvent.RoomID() != redactionEvent.RoomID(): + // redactions across rooms aren't allowed + ignoreRedaction = true + return nil + } + + // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. (redactAllowed) + // 2. The domain of the redaction event’s sender matches that of the original event’s sender. + _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender()) + _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender()) + if !redactAllowed || sender1 != sender2 { + ignoreRedaction = true + return nil + } + + // mark the event as redacted + if redactionsArePermanent { + redactedEvent.Redact() + } + + err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) if err != nil { - return nil, "", fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err) + return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) + } + // NOTSPEC: sytest relies on this unspecced field existing :( + err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID()) + if err != nil { + return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) + } + // overwrite the eventJSON table + err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) + if err != nil { + return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } - } - redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, roomNID, eventNID, event) - if err != nil { - return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err) + err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) + if err != nil { + return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) + } + return nil + }) + if wErr != nil { + return nil, nil, err } - if validated || redactedEvent == nil || redactionEvent == nil { - // we've seen this redaction before or there is nothing to redact - return nil, "", nil + if ignoreRedaction || redactionEvent == nil || redactedEvent == nil { + return nil, nil, nil } - if redactedEvent.RoomID() != redactionEvent.RoomID() { - // redactions across rooms aren't allowed - return nil, "", nil - } - - // Get the power level from the database, so we can verify the user is allowed to redact the event - powerLevels, err := d.GetStateEvent(ctx, event.RoomID(), gomatrixserverlib.MRoomPowerLevels, "") - if err != nil { - return nil, "", fmt.Errorf("d.GetStateEvent: %w", err) - } - if powerLevels == nil { - return nil, "", fmt.Errorf("unable to fetch m.room.power_levels event from database for room %s", event.RoomID()) - } - pl, err := powerLevels.PowerLevels() - if err != nil { - return nil, "", fmt.Errorf("unable to get powerlevels for room: %w", err) - } - - redactUser := pl.UserLevel(redactionEvent.Sender()) - switch { - case redactUser >= pl.Redact: - // The power level of the redaction event’s sender is greater than or equal to the redact level. - case redactedEvent.Sender() == redactionEvent.Sender(): - // The domain of the redaction event’s sender matches that of the original event’s sender. - default: - return nil, "", nil - } - - // mark the event as redacted - if redactionsArePermanent { - redactedEvent.Redact() - } - - err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) - if err != nil { - return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) - } - // NOTSPEC: sytest relies on this unspecced field existing :( - err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID()) - if err != nil { - return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) - } - // overwrite the eventJSON table - err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) - if err != nil { - return nil, "", fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) - } - - err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) - if err != nil { - err = fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) - } - - return redactionEvent.Event, redactedEvent.EventID(), err + return redactionEvent.Event, redactedEvent.Event, nil } // loadRedactionPair returns both the redaction event and the redacted event, else nil. -func (d *Database) loadRedactionPair( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event, +func (d *EventDatabase) loadRedactionPair( + ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, ) (*types.Event, *types.Event, bool, error) { var redactionEvent, redactedEvent *types.Event var info *tables.RedactionInfo @@ -1072,16 +1044,16 @@ func (d *Database) loadRedactionPair( } if isRedactionEvent { - redactedEvent = d.loadEvent(ctx, roomNID, info.RedactsEventID) + redactedEvent = d.loadEvent(ctx, roomInfo, info.RedactsEventID) } else { - redactionEvent = d.loadEvent(ctx, roomNID, info.RedactionEventID) + redactionEvent = d.loadEvent(ctx, roomInfo, info.RedactionEventID) } return redactionEvent, redactedEvent, info.Validated, nil } // applyRedactions will redact events that have an `unsigned.redacted_because` field. -func (d *Database) applyRedactions(events []types.Event) { +func (d *EventDatabase) applyRedactions(events []types.Event) { for i := range events { if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() { events[i].Redact() @@ -1090,7 +1062,7 @@ func (d *Database) applyRedactions(events []types.Event) { } // loadEvent loads a single event or returns nil on any problems/missing event -func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID string) *types.Event { +func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo, eventID string) *types.Event { nids, err := d.EventNIDs(ctx, []string{eventID}) if err != nil { return nil @@ -1098,7 +1070,7 @@ func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID if len(nids) == 0 { return nil } - evs, err := d.Events(ctx, roomNID, []types.EventNID{nids[eventID].EventNID}) + evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID}) if err != nil { return nil } @@ -1144,7 +1116,7 @@ func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *type // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { - roomInfo, err := d.RoomInfo(ctx, roomID) + roomInfo, err := d.roomInfo(ctx, nil, roomID) if err != nil { return nil, err } @@ -1209,7 +1181,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s // Same as GetStateEvent but returns all matching state events with this event type. Returns no error // if there are no events with this event type. func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { - roomInfo, err := d.RoomInfo(ctx, roomID) + roomInfo, err := d.roomInfo(ctx, nil, roomID) if err != nil { return nil, err } @@ -1340,7 +1312,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion) // TODO: This feels like this is going to be really slow... for _, roomID := range roomIDs { - roomInfo, err2 := d.RoomInfo(ctx, roomID) + roomInfo, err2 := d.roomInfo(ctx, nil, roomID) if err2 != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2) } diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 3acb55a3a..684e80b8f 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -52,12 +52,14 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) + evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache} + return &shared.Database{ - DB: db, - EventStateKeysTable: stateKeyTable, - MembershipTable: membershipTable, - Writer: sqlutil.NewExclusiveWriter(), - Cache: cache, + DB: db, + EventDatabase: evDb, + MembershipTable: membershipTable, + Writer: sqlutil.NewExclusiveWriter(), + Cache: cache, }, func() { err := base.Close() assert.NoError(t, err) diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 392edd289..2adedd2d8 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -203,24 +203,29 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: writer, - EventsTable: events, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - RoomsTable: rooms, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, - GetRoomUpdaterFn: d.GetRoomUpdater, - Purge: purge, + DB: db, + EventDatabase: shared.EventDatabase{ + DB: db, + Cache: cache, + Writer: writer, + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + PrevEventsTable: prevEvents, + RedactionsTable: redactions, + }, + Cache: cache, + Writer: writer, + RoomsTable: rooms, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + GetRoomUpdaterFn: d.GetRoomUpdater, + Purge: purge, } return nil } diff --git a/setup/base/base.go b/setup/base/base.go index aabdd7937..dfe48ff3c 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -20,9 +20,11 @@ import ( "database/sql" "embed" "encoding/json" + "errors" "fmt" "html/template" "io" + "io/fs" "net" "net/http" _ "net/http/pprof" @@ -85,8 +87,6 @@ type BaseDendrite struct { startupLock sync.Mutex } -const NoListener = "" - const HTTPServerTimeout = time.Minute * 5 type BaseDendriteOptions int @@ -345,18 +345,17 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() { // SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs // and adds a prometheus handler under /_dendrite/metrics. func (b *BaseDendrite) SetupAndServeHTTP( - externalHTTPAddr config.HTTPAddress, + externalHTTPAddr config.ServerAddress, certFile, keyFile *string, ) { // Manually unlocked right before actually serving requests, // as we don't return from this method (defer doesn't work). b.startupLock.Lock() - externalAddr, _ := externalHTTPAddr.Address() externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() externalServ := &http.Server{ - Addr: string(externalAddr), + Addr: externalHTTPAddr.Address, WriteTimeout: HTTPServerTimeout, Handler: externalRouter, BaseContext: func(_ net.Listener) context.Context { @@ -419,7 +418,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( b.startupLock.Unlock() - if externalAddr != NoListener { + if externalHTTPAddr.Enabled() { go func() { var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once logrus.Infof("Starting external listener on %s", externalServ.Addr) @@ -437,9 +436,30 @@ func (b *BaseDendrite) SetupAndServeHTTP( } } } else { - if err := externalServ.ListenAndServe(); err != nil { - if err != http.ErrServerClosed { - logrus.WithError(err).Fatal("failed to serve HTTP") + if externalHTTPAddr.IsUnixSocket() { + err := os.Remove(externalHTTPAddr.Address) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + logrus.WithError(err).Fatal("failed to remove existing unix socket") + } + listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address) + if err != nil { + logrus.WithError(err).Fatal("failed to serve unix socket") + } + err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission) + if err != nil { + logrus.WithError(err).Fatal("failed to set unix socket permissions") + } + if err := externalServ.Serve(listener); err != nil { + if err != http.ErrServerClosed { + logrus.WithError(err).Fatal("failed to serve unix socket") + } + } + + } else { + if err := externalServ.ListenAndServe(); err != nil { + if err != http.ErrServerClosed { + logrus.WithError(err).Fatal("failed to serve HTTP") + } } } } diff --git a/setup/base/base_test.go b/setup/base/base_test.go index d906294c0..658dc5b03 100644 --- a/setup/base/base_test.go +++ b/setup/base/base_test.go @@ -2,10 +2,13 @@ package base_test import ( "bytes" + "context" "embed" "html/template" + "net" "net/http" "net/http/httptest" + "path" "testing" "time" @@ -18,7 +21,7 @@ import ( //go:embed static/*.gotmpl var staticContent embed.FS -func TestLandingPage(t *testing.T) { +func TestLandingPage_Tcp(t *testing.T) { // generate the expected result tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) expectedRes := &bytes.Buffer{} @@ -35,7 +38,9 @@ func TestLandingPage(t *testing.T) { s.Close() // start base with the listener and wait for it to be started - go b.SetupAndServeHTTP(config.HTTPAddress(s.URL), nil, nil) + address, err := config.HTTPAddress(s.URL) + assert.NoError(t, err) + go b.SetupAndServeHTTP(address, nil, nil) time.Sleep(time.Millisecond * 10) // When hitting /, we should be redirected to /_matrix/static, which should contain the landing page @@ -55,3 +60,43 @@ func TestLandingPage(t *testing.T) { // Using .String() for user friendly output assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") } + +func TestLandingPage_UnixSocket(t *testing.T) { + // generate the expected result + tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) + expectedRes := &bytes.Buffer{} + err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{ + "Version": internal.VersionString(), + }) + assert.NoError(t, err) + + b, _, _ := testrig.Base(nil) + defer b.Close() + + tempDir := t.TempDir() + socket := path.Join(tempDir, "socket") + // start base with the listener and wait for it to be started + address := config.UnixSocketAddress(socket, 0755) + assert.NoError(t, err) + go b.SetupAndServeHTTP(address, nil, nil) + time.Sleep(time.Millisecond * 100) + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socket) + }, + }, + } + resp, err := client.Get("http://unix/") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // read the response + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(resp.Body) + assert.NoError(t, err) + + // Using .String() for user friendly output + assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") +} diff --git a/setup/config/config.go b/setup/config/config.go index 848766162..1a25f71eb 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -19,7 +19,6 @@ import ( "encoding/pem" "fmt" "io" - "net/url" "os" "path/filepath" "regexp" @@ -131,20 +130,6 @@ func (d DataSource) IsPostgres() bool { // A Topic in kafka. type Topic string -// An Address to listen on. -type Address string - -// An HTTPAddress to listen on, starting with either http:// or https://. -type HTTPAddress string - -func (h HTTPAddress) Address() (Address, error) { - url, err := url.Parse(string(h)) - if err != nil { - return "", err - } - return Address(url.Host), nil -} - // FileSizeBytes is a file size in bytes type FileSizeBytes int64 diff --git a/setup/config/config_address.go b/setup/config/config_address.go new file mode 100644 index 000000000..0e4f0296f --- /dev/null +++ b/setup/config/config_address.go @@ -0,0 +1,45 @@ +package config + +import ( + "io/fs" + "net/url" +) + +const ( + NetworkTCP = "tcp" + NetworkUnix = "unix" +) + +type ServerAddress struct { + Address string + Scheme string + UnixSocketPermission fs.FileMode +} + +func (s ServerAddress) Enabled() bool { + return s.Address != "" +} + +func (s ServerAddress) IsUnixSocket() bool { + return s.Scheme == NetworkUnix +} + +func (s ServerAddress) Network() string { + if s.Scheme == NetworkUnix { + return NetworkUnix + } else { + return NetworkTCP + } +} + +func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress { + return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm} +} + +func HTTPAddress(urlAddress string) (ServerAddress, error) { + parsedUrl, err := url.Parse(urlAddress) + if err != nil { + return ServerAddress{}, err + } + return ServerAddress{parsedUrl.Host, parsedUrl.Scheme, 0}, nil +} diff --git a/setup/config/config_address_test.go b/setup/config/config_address_test.go new file mode 100644 index 000000000..1be484fd5 --- /dev/null +++ b/setup/config/config_address_test.go @@ -0,0 +1,25 @@ +package config + +import ( + "io/fs" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHttpAddress_ParseGood(t *testing.T) { + address, err := HTTPAddress("http://localhost:123") + assert.NoError(t, err) + assert.Equal(t, "localhost:123", address.Address) + assert.Equal(t, "tcp", address.Network()) +} + +func TestHttpAddress_ParseBad(t *testing.T) { + _, err := HTTPAddress(":") + assert.Error(t, err) +} + +func TestUnixSocketAddress_Network(t *testing.T) { + address := UnixSocketAddress("/tmp", fs.FileMode(0755)) + assert.Equal(t, "unix", address.Network()) +} diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index bc369c166..4bb6a5eee 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -253,7 +253,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo var res MSC2836EventRelationshipsResponse var returnEvents []*gomatrixserverlib.HeaderedEvent // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. - event := rc.getLocalEvent(rc.req.EventID) + event := rc.getLocalEvent(rc.req.RoomID, rc.req.EventID) if event == nil { event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID) } @@ -592,7 +592,7 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelation // lookForEvent returns the event for the event ID given, by trying to query remote servers // if the event ID is unknown via /event_relationships. func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent { - event := rc.getLocalEvent(eventID) + event := rc.getLocalEvent(rc.req.RoomID, eventID) if event == nil { queryRes := rc.remoteEventRelationships(eventID) if queryRes != nil { @@ -622,9 +622,10 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent return nil } -func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent { +func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.HeaderedEvent { var queryEventsRes roomserver.QueryEventsByIDResponse err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{ + RoomID: roomID, EventIDs: []string{eventID}, }, &queryEventsRes) if err != nil { diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 21838039a..a8d4d2b2c 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -212,6 +212,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( // Finally, work out if there are any more events missing. if len(missingEventIDs) > 0 { eventsReq := &api.QueryEventsByIDRequest{ + RoomID: ev.RoomID(), EventIDs: missingEventIDs, } eventsRes := &api.QueryEventsByIDResponse{} diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 9ffdf513f..8efd77cef 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -109,7 +109,7 @@ func GetMemberships( } qryRes := &api.QueryEventsByIDResponse{} - if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs}, qryRes); err != nil { + if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs, RoomID: roomID}, qryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed") return jsonerror.InternalServerError() } diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index b088d15cd..969bc5303 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -187,8 +187,8 @@ func Test_UserStatistics(t *testing.T) { }) t.Run("Users not active for one/two month", func(t *testing.T) { - mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, -2, 0)) - mustUpdateDeviceLastSeen(t, ctx, db, "user2", time.Now().AddDate(0, -1, 0)) + mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, 0, -60)) + mustUpdateDeviceLastSeen(t, ctx, db, "user2", time.Now().AddDate(0, 0, -30)) gotStats, _, err := statsDB.UserStatistics(ctx, nil) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -224,9 +224,9 @@ func Test_UserStatistics(t *testing.T) { - Where account creation and last_seen are > 30 days apart */ t.Run("R30Users tests", func(t *testing.T) { - mustUserUpdateRegistered(t, ctx, db, "user1", time.Now().AddDate(0, -2, 0)) + mustUserUpdateRegistered(t, ctx, db, "user1", time.Now().AddDate(0, 0, -60)) mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now()) - mustUserUpdateRegistered(t, ctx, db, "user4", time.Now().AddDate(0, -2, 0)) + mustUserUpdateRegistered(t, ctx, db, "user4", time.Now().AddDate(0, 0, -60)) mustUpdateDeviceLastSeen(t, ctx, db, "user4", time.Now()) startTime := time.Now().AddDate(0, 0, -2) err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24))