Merge branch 'main' of github.com:matrix-org/dendrite into gh-pages

This commit is contained in:
Till Faelligen 2023-03-03 10:22:28 +01:00
commit bb838a17c2
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
51 changed files with 687 additions and 471 deletions

View file

@ -4,7 +4,13 @@ on:
push: push:
branches: branches:
- main - main
paths:
- '**.go' # only execute on changes to go files
- '.github/workflows/**' # or workflow changes
pull_request: pull_request:
paths:
- '**.go'
- '.github/workflows/**'
release: release:
types: [published] types: [published]
workflow_dispatch: workflow_dispatch:

View file

@ -71,10 +71,10 @@ $ ./bin/generate-keys --tls-cert server.crt --tls-key server.key
# Copy and modify the config file - you'll need to set a server name and paths to the keys # Copy and modify the config file - you'll need to set a server name and paths to the keys
# at the very least, along with setting up the database connection strings. # at the very least, along with setting up the database connection strings.
$ cp dendrite-sample.monolith.yaml dendrite.yaml $ cp dendrite-sample.yaml dendrite.yaml
# Build and run the server: # Build and run the server:
$ ./bin/dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml $ ./bin/dendrite --tls-cert server.crt --tls-key server.key --config dendrite.yaml
# Create an user account (add -admin for an admin user). # Create an user account (add -admin for an admin user).
# Specify the localpart only, e.g. 'alice' for '@alice:domain.com' # Specify the localpart only, e.g. 'alice' for '@alice:domain.com'

View file

@ -122,6 +122,7 @@ func (s *OutputRoomEventConsumer) onMessage(
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
newEventID := output.NewRoomEvent.Event.EventID() newEventID := output.NewRoomEvent.Event.EventID()
eventsReq := &api.QueryEventsByIDRequest{ eventsReq := &api.QueryEventsByIDRequest{
RoomID: output.NewRoomEvent.Event.RoomID(),
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
} }
eventsRes := &api.QueryEventsByIDResponse{} eventsRes := &api.QueryEventsByIDResponse{}

View file

@ -57,7 +57,7 @@ func SendRedaction(
} }
} }
ev := roomserverAPI.GetEvent(req.Context(), rsAPI, eventID) ev := roomserverAPI.GetEvent(req.Context(), rsAPI, roomID, eventID)
if ev == nil { if ev == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,

View file

@ -16,6 +16,7 @@ package main
import ( import (
"flag" "flag"
"io/fs"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -30,6 +31,12 @@ import (
) )
var ( var (
unixSocket = flag.String("unix-socket", "",
"EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)",
)
unixSocketPermission = flag.Int("unix-socket-permission", 0755,
"EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server",
)
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
@ -38,8 +45,23 @@ var (
func main() { func main() {
cfg := setup.ParseFlags(true) cfg := setup.ParseFlags(true)
httpAddr := config.HTTPAddress("http://" + *httpBindAddr) httpAddr := config.ServerAddress{}
httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr) httpsAddr := config.ServerAddress{}
if *unixSocket == "" {
http, err := config.HTTPAddress("http://" + *httpBindAddr)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse http address")
}
httpAddr = http
https, err := config.HTTPAddress("https://" + *httpsBindAddr)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse https address")
}
httpsAddr = https
} else {
httpAddr = config.UnixSocketAddress(*unixSocket, fs.FileMode(*unixSocketPermission))
}
options := []basepkg.BaseDendriteOptions{} options := []basepkg.BaseDendriteOptions{}
base := basepkg.NewBaseDendrite(cfg, options...) base := basepkg.NewBaseDendrite(cfg, options...)
@ -92,7 +114,7 @@ func main() {
base.SetupAndServeHTTP(httpAddr, nil, nil) base.SetupAndServeHTTP(httpAddr, nil, nil)
}() }()
// Handle HTTPS if certificate and key are provided // Handle HTTPS if certificate and key are provided
if *certFile != "" && *keyFile != "" { if *unixSocket == "" && *certFile != "" && *keyFile != "" {
go func() { go func() {
base.SetupAndServeHTTP(httpsAddr, certFile, keyFile) base.SetupAndServeHTTP(httpsAddr, certFile, keyFile)
}() }()

View file

@ -9,7 +9,7 @@ import (
) )
// This is an instrumented main, used when running integration tests (sytest) with code coverage. // This is an instrumented main, used when running integration tests (sytest) with code coverage.
// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server // Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite
// Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml // Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml
// Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html // Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html
// Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc // Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc

View file

@ -62,9 +62,10 @@ func main() {
panic(err) panic(err)
} }
stateres := state.NewStateResolution(roomserverDB, &types.RoomInfo{ roomInfo := &types.RoomInfo{
RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
}) }
stateres := state.NewStateResolution(roomserverDB, roomInfo)
if *difference { if *difference {
if len(snapshotNIDs) != 2 { if len(snapshotNIDs) != 2 {
@ -87,7 +88,7 @@ func main() {
} }
var eventEntries []types.Event var eventEntries []types.Event
eventEntries, err = roomserverDB.Events(ctx, 0, eventNIDs) eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -145,7 +146,7 @@ func main() {
} }
fmt.Println("Fetching", len(eventNIDMap), "state events") fmt.Println("Fetching", len(eventNIDMap), "state events")
eventEntries, err := roomserverDB.Events(ctx, 0, eventNIDs) eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -165,7 +166,7 @@ func main() {
} }
fmt.Println("Fetching", len(authEventIDs), "auth events") fmt.Println("Fetching", len(authEventIDs), "auth events")
authEventEntries, err := roomserverDB.EventsFromIDs(ctx, 0, authEventIDs) authEventEntries, err := roomserverDB.EventsFromIDs(ctx, roomInfo, authEventIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -15,7 +15,7 @@ Dendrite contains an embedded profiler called `pprof`, which is a part of the st
To enable the profiler, start Dendrite with the `PPROFLISTEN` environment variable. This variable specifies which address and port to listen on, e.g. To enable the profiler, start Dendrite with the `PPROFLISTEN` environment variable. This variable specifies which address and port to listen on, e.g.
``` ```
PPROFLISTEN=localhost:65432 ./bin/dendrite-monolith-server ... PPROFLISTEN=localhost:65432 ./bin/dendrite ...
``` ```
If pprof has been enabled successfully, a log line at startup will show that pprof is listening: If pprof has been enabled successfully, a log line at startup will show that pprof is listening:

View file

@ -14,8 +14,8 @@ index 8f0e209c..ad057e52 100644
$output->diag( "Starting monolith server" ); $output->diag( "Starting monolith server" );
my @command = ( my @command = (
- $self->{bindir} . '/dendrite-monolith-server', - $self->{bindir} . '/dendrite',
+ $self->{bindir} . '/dendrite-monolith-server', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL", + $self->{bindir} . '/dendrite', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL",
'--config', $self->{paths}{config}, '--config', $self->{paths}{config},
'--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port, '--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port,
'--https-bind-address', $self->{bind_host} . ':' . $self->secure_port, '--https-bind-address', $self->{bind_host} . ':' . $self->secure_port,
@ -27,9 +27,9 @@ index f009332b..7ea79869 100755
echo >&2 "--- Building dendrite from source" echo >&2 "--- Building dendrite from source"
cd /src cd /src
mkdir -p $GOBIN mkdir -p $GOBIN
-go install -v ./cmd/dendrite-monolith-server -go install -v ./cmd/dendrite
+# go install -v ./cmd/dendrite-monolith-server +# go install -v ./cmd/dendrite
+go test -c -cover -covermode=atomic -o $GOBIN/dendrite-monolith-server -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server +go test -c -cover -covermode=atomic -o $GOBIN/dendrite -coverpkg "github.com/matrix-org/..." ./cmd/dendrite
go install -v ./cmd/generate-keys go install -v ./cmd/generate-keys
cd - cd -
``` ```

View file

@ -49,7 +49,7 @@ tracing:
then run the monolith server: then run the monolith server:
``` ```
./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml ./dendrite --tls-cert server.crt --tls-key server.key --config dendrite.yaml
``` ```
## Checking traces ## Checking traces

View file

@ -28,11 +28,11 @@ The resulting binaries will be placed in the `bin` subfolder.
You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`: You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`:
```sh ```sh
go install ./cmd/dendrite-monolith-server go install ./cmd/dendrite
``` ```
Alternatively, you can specify a custom path for the binary to be written to using `go build`: Alternatively, you can specify a custom path for the binary to be written to using `go build`:
```sh ```sh
go build -o /usr/local/bin/ ./cmd/dendrite-monolith-server go build -o /usr/local/bin/ ./cmd/dendrite
``` ```

View file

@ -11,11 +11,11 @@ permalink: /installation/install/monolith
You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`: You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`:
```sh ```sh
go install ./cmd/dendrite-monolith-server go install ./cmd/dendrite
``` ```
Alternatively, you can specify a custom path for the binary to be written to using `go build`: Alternatively, you can specify a custom path for the binary to be written to using `go build`:
```sh ```sh
go build -o /usr/local/bin/ ./cmd/dendrite-monolith-server go build -o /usr/local/bin/ ./cmd/dendrite
``` ```

View file

@ -9,10 +9,10 @@ permalink: /installation/start/monolith
# Starting the monolith # Starting the monolith
Once you have completed all of the preparation and installation steps, Once you have completed all of the preparation and installation steps,
you can start your Dendrite monolith deployment by starting the `dendrite-monolith-server`: you can start your Dendrite monolith deployment by starting `dendrite`:
```bash ```bash
./dendrite-monolith-server -config /path/to/dendrite.yaml ./dendrite -config /path/to/dendrite.yaml
``` ```
By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses
@ -20,7 +20,7 @@ or ports that Dendrite listens on, you can use the `-http-bind-address` and
`-https-bind-address` command line arguments: `-https-bind-address` command line arguments:
```bash ```bash
./dendrite-monolith-server -config /path/to/dendrite.yaml \ ./dendrite -config /path/to/dendrite.yaml \
-http-bind-address 1.2.3.4:12345 \ -http-bind-address 1.2.3.4:12345 \
-https-bind-address 1.2.3.4:54321 -https-bind-address 1.2.3.4:54321
``` ```

View file

@ -11,7 +11,7 @@ Type=simple
User=dendrite User=dendrite
Group=dendrite Group=dendrite
WorkingDirectory=/opt/dendrite/ WorkingDirectory=/opt/dendrite/
ExecStart=/opt/dendrite/bin/dendrite-monolith-server ExecStart=/opt/dendrite/bin/dendrite
Restart=always Restart=always
LimitNOFILE=65535 LimitNOFILE=65535

View file

@ -173,6 +173,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
// Finally, work out if there are any more events missing. // Finally, work out if there are any more events missing.
if len(missingEventIDs) > 0 { if len(missingEventIDs) > 0 {
eventsReq := &api.QueryEventsByIDRequest{ eventsReq := &api.QueryEventsByIDRequest{
RoomID: ore.Event.RoomID(),
EventIDs: missingEventIDs, EventIDs: missingEventIDs,
} }
eventsRes := &api.QueryEventsByIDResponse{} eventsRes := &api.QueryEventsByIDResponse{}
@ -483,7 +484,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents(
// At this point the missing events are neither the event itself nor are // At this point the missing events are neither the event itself nor are
// they present in our local database. Our only option is to fetch them // they present in our local database. Our only option is to fetch them
// from the roomserver using the query API. // from the roomserver using the query API.
eventReq := api.QueryEventsByIDRequest{EventIDs: missing} eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID()}
var eventResp api.QueryEventsByIDResponse var eventResp api.QueryEventsByIDResponse
if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil { if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil {
return nil, err return nil, err

View file

@ -36,7 +36,7 @@ func GetEventAuth(
return *err return *err
} }
event, resErr := fetchEvent(ctx, rsAPI, eventID) event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -20,10 +20,11 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api"
) )
// GetEvent returns the requested event // GetEvent returns the requested event
@ -38,7 +39,9 @@ func GetEvent(
if err != nil { if err != nil {
return *err return *err
} }
event, err := fetchEvent(ctx, rsAPI, eventID) // /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string,
// which results in `QueryEventsByID` to first get the event and use that to determine the roomID.
event, err := fetchEvent(ctx, rsAPI, "", eventID)
if err != nil { if err != nil {
return *err return *err
} }
@ -60,21 +63,13 @@ func allowedToSeeEvent(
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
eventID string, eventID string,
) *util.JSONResponse { ) *util.JSONResponse {
var authResponse api.QueryServerAllowedToSeeEventResponse allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID)
err := rsAPI.QueryServerAllowedToSeeEvent(
ctx,
&api.QueryServerAllowedToSeeEventRequest{
EventID: eventID,
ServerName: origin,
},
&authResponse,
)
if err != nil { if err != nil {
resErr := util.ErrorResponse(err) resErr := util.ErrorResponse(err)
return &resErr return &resErr
} }
if !authResponse.AllowedToSeeEvent { if !allowed {
resErr := util.MessageResponse(http.StatusForbidden, "server not allowed to see event") resErr := util.MessageResponse(http.StatusForbidden, "server not allowed to see event")
return &resErr return &resErr
} }
@ -83,11 +78,11 @@ func allowedToSeeEvent(
} }
// fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found. // fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found.
func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) {
var eventsResponse api.QueryEventsByIDResponse var eventsResponse api.QueryEventsByIDResponse
err := rsAPI.QueryEventsByID( err := rsAPI.QueryEventsByID(
ctx, ctx,
&api.QueryEventsByIDRequest{EventIDs: []string{eventID}}, &api.QueryEventsByIDRequest{EventIDs: []string{eventID}, RoomID: roomID},
&eventsResponse, &eventsResponse,
) )
if err != nil { if err != nil {

View file

@ -107,7 +107,7 @@ func getState(
return nil, nil, err return nil, nil, err
} }
event, resErr := fetchEvent(ctx, rsAPI, eventID) event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID)
if resErr != nil { if resErr != nil {
return nil, nil, resErr return nil, nil, resErr
} }

View file

@ -16,7 +16,9 @@
// Hooks can only be run in monolith mode. // Hooks can only be run in monolith mode.
package hooks package hooks
import "sync" import (
"sync"
)
const ( const (
// KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent

View file

@ -54,7 +54,8 @@ type QueryBulkStateContentAPI interface {
} }
type QueryEventsAPI interface { type QueryEventsAPI interface {
// Query a list of events by event ID. // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID.
QueryEventsByID( QueryEventsByID(
ctx context.Context, ctx context.Context,
req *QueryEventsByIDRequest, req *QueryEventsByIDRequest,
@ -71,7 +72,8 @@ type SyncRoomserverAPI interface {
QueryBulkStateContentAPI QueryBulkStateContentAPI
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
// Query a list of events by event ID. // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID.
QueryEventsByID( QueryEventsByID(
ctx context.Context, ctx context.Context,
req *QueryEventsByIDRequest, req *QueryEventsByIDRequest,
@ -108,7 +110,8 @@ type SyncRoomserverAPI interface {
} }
type AppserviceRoomserverAPI interface { type AppserviceRoomserverAPI interface {
// Query a list of events by event ID. // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID.
QueryEventsByID( QueryEventsByID(
ctx context.Context, ctx context.Context,
req *QueryEventsByIDRequest, req *QueryEventsByIDRequest,
@ -182,6 +185,8 @@ type FederationRoomserverAPI interface {
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID.
QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error
// Query to get state and auth chain for a (potentially hypothetical) event. // Query to get state and auth chain for a (potentially hypothetical) event.
// Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate
@ -193,7 +198,7 @@ type FederationRoomserverAPI interface {
// Query missing events for a room from roomserver // Query missing events for a room from roomserver
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error QueryServerAllowedToSeeEvent(ctx context.Context, serverName gomatrixserverlib.ServerName, eventID string) (allowed bool, err error)
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error

View file

@ -86,6 +86,9 @@ type QueryStateAfterEventsResponse struct {
// QueryEventsByIDRequest is a request to QueryEventsByID // QueryEventsByIDRequest is a request to QueryEventsByID
type QueryEventsByIDRequest struct { type QueryEventsByIDRequest struct {
// The roomID to query events for. If this is empty, we first try to fetch the roomID from the database
// as this is needed for further processing/parsing events.
RoomID string `json:"room_id"`
// The event IDs to look up. // The event IDs to look up.
EventIDs []string `json:"event_ids"` EventIDs []string `json:"event_ids"`
} }

View file

@ -108,9 +108,10 @@ func SendInputRoomEvents(
} }
// GetEvent returns the event or nil, even on errors. // GetEvent returns the event or nil, even on errors.
func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, eventID string) *gomatrixserverlib.HeaderedEvent { func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *gomatrixserverlib.HeaderedEvent {
var res QueryEventsByIDResponse var res QueryEventsByIDResponse
err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{ err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{
RoomID: roomID,
EventIDs: []string{eventID}, EventIDs: []string{eventID},
}, &res) }, &res)
if err != nil { if err != nil {

View file

@ -67,7 +67,7 @@ func CheckForSoftFail(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
if err != nil { if err != nil {
return true, fmt.Errorf("loadAuthEvents: %w", err) return true, fmt.Errorf("loadAuthEvents: %w", err)
} }
@ -85,7 +85,7 @@ func CheckForSoftFail(
func CheckAuthEvents( func CheckAuthEvents(
ctx context.Context, ctx context.Context,
db storage.RoomDatabase, db storage.RoomDatabase,
roomNID types.RoomNID, roomInfo *types.RoomInfo,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
authEventIDs []string, authEventIDs []string,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
@ -100,7 +100,7 @@ func CheckAuthEvents(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
if err != nil { if err != nil {
return nil, fmt.Errorf("loadAuthEvents: %w", err) return nil, fmt.Errorf("loadAuthEvents: %w", err)
} }
@ -193,7 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
func loadAuthEvents( func loadAuthEvents(
ctx context.Context, ctx context.Context,
db state.StateResolutionStorage, db state.StateResolutionStorage,
roomNID types.RoomNID, roomInfo *types.RoomInfo,
needed gomatrixserverlib.StateNeeded, needed gomatrixserverlib.StateNeeded,
state []types.StateEntry, state []types.StateEntry,
) (result authEvents, err error) { ) (result authEvents, err error) {
@ -216,7 +216,7 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID) eventNIDs = append(eventNIDs, eventNID)
} }
} }
if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil { if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil {
return return
} }
roomID := "" roomID := ""

View file

@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
return false, err return false, err
} }
events, err := db.Events(ctx, info.RoomNID, eventNIDs) events, err := db.Events(ctx, info, eventNIDs)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -157,7 +157,7 @@ func IsInvitePending(
// only keep the "m.room.member" events with a "join" membership. These events are returned. // only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events. // Returns an error if there was an issue fetching the events.
func GetMembershipsAtState( func GetMembershipsAtState(
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool, ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) { ) ([]types.Event, error) {
var eventNIDs types.EventNIDs var eventNIDs types.EventNIDs
@ -177,7 +177,7 @@ func GetMembershipsAtState(
util.Unique(eventNIDs) util.Unique(eventNIDs)
// Get all of the events in this state // Get all of the events in this state
stateEvents, err := db.Events(ctx, roomNID, eventNIDs) stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -227,9 +227,9 @@ func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types
} }
func LoadEvents( func LoadEvents(
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID, ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) { ) ([]*gomatrixserverlib.Event, error) {
stateEvents, err := db.Events(ctx, roomNID, eventNIDs) stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,13 +242,13 @@ func LoadEvents(
} }
func LoadStateEvents( func LoadStateEvents(
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
) ([]*gomatrixserverlib.Event, error) { ) ([]*gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries)) eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries { for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID eventNIDs[i] = stateEntries[i].EventNID
} }
return LoadEvents(ctx, db, roomNID, eventNIDs) return LoadEvents(ctx, db, roomInfo, eventNIDs)
} }
func CheckServerAllowedToSeeEvent( func CheckServerAllowedToSeeEvent(
@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState(
return nil, nil return nil, nil
} }
return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries) return LoadStateEvents(ctx, db, info, filteredEntries)
} }
// TODO: Remove this when we have tests to assert correctness of this function // TODO: Remove this when we have tests to assert correctness of this function
@ -366,7 +366,7 @@ BFSLoop:
next = make([]string, 0) next = make([]string, 0)
} }
// Retrieve the events to process from the database. // Retrieve the events to process from the database.
events, err = db.EventsFromIDs(ctx, info.RoomNID, front) events, err = db.EventsFromIDs(ctx, info, front)
if err != nil { if err != nil {
return resultNIDs, redactEventIDs, err return resultNIDs, redactEventIDs, err
} }
@ -467,7 +467,7 @@ func QueryLatestEventsAndState(
return err return err
} }
stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries) stateEvents, err := LoadStateEvents(ctx, db, roomInfo, stateEntries)
if err != nil { if err != nil {
return err return err
} }

View file

@ -4,9 +4,10 @@ import (
"context" "context"
"testing" "testing"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
@ -38,9 +39,9 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
var authNIDs []types.EventNID var authNIDs []types.EventNID
for _, x := range room.Events() { for _, x := range room.Events() {
roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap()) roomInfo, err := db.GetOrCreateRoomInfo(context.Background(), x.Unwrap())
assert.NoError(t, err) assert.NoError(t, err)
assert.Greater(t, roomNID, types.RoomNID(0)) assert.NotNil(t, roomInfo)
eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type()) eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type())
assert.NoError(t, err) assert.NoError(t, err)
@ -49,7 +50,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey()) eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey())
assert.NoError(t, err) assert.NoError(t, err)
evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false) evNID, _, err := db.StoreEvent(context.Background(), x.Event, roomInfo, eventTypeNID, eventStateKeyNID, authNIDs, false)
assert.NoError(t, err) assert.NoError(t, err)
authNIDs = append(authNIDs, evNID) authNIDs = append(authNIDs, evNID)
} }

View file

@ -24,9 +24,10 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
@ -274,8 +275,10 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then // Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted. // we consider the event to be "rejected" — it will still be persisted.
redactAllowed := true
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
isRejected = true isRejected = true
redactAllowed = false
rejectionErr = err rejectionErr = err
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
} }
@ -323,7 +326,7 @@ func (r *Inputer) processRoomEvent(
// burning CPU time. // burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent { if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent {
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo.RoomNID, input, missingPrev) historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo, input, missingPrev)
if err != nil { if err != nil {
return fmt.Errorf("r.processStateBefore: %w", err) return fmt.Errorf("r.processStateBefore: %w", err)
} }
@ -332,9 +335,11 @@ func (r *Inputer) processRoomEvent(
} }
} }
roomNID, err := r.DB.GetOrCreateRoomNID(ctx, event) if roomInfo == nil {
roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, event)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err)
}
} }
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type()) eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type())
@ -348,15 +353,24 @@ func (r *Inputer) processRoomEvent(
} }
// Store the event. // Store the event.
_, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) eventNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
// if storing this event results in it being redacted then do so. // if storing this event results in it being redacted then do so.
if !isRejected && redactedEventID == event.EventID() { var (
if err = eventutil.RedactEvent(redactionEvent, event); err != nil { redactedEventID string
return fmt.Errorf("eventutil.RedactEvent: %w", rerr) redactionEvent *gomatrixserverlib.Event
redactedEvent *gomatrixserverlib.Event
)
if !isRejected && !isCreateEvent {
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, redactAllowed)
if err != nil {
return err
}
if redactedEvent != nil {
redactedEventID = redactedEvent.EventID()
} }
} }
@ -489,7 +503,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse
// nolint:nakedret // nolint:nakedret
func (r *Inputer) processStateBefore( func (r *Inputer) processStateBefore(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, roomInfo *types.RoomInfo,
input *api.InputRoomEvent, input *api.InputRoomEvent,
missingPrev bool, missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
@ -505,7 +519,7 @@ func (r *Inputer) processStateBefore(
case input.HasState: case input.HasState:
// If we're overriding the state then we need to go and retrieve // If we're overriding the state then we need to go and retrieve
// them from the database. It's a hard error if they are missing. // them from the database. It's a hard error if they are missing.
stateEvents, err := r.DB.EventsFromIDs(ctx, roomNID, input.StateEventIDs) stateEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, input.StateEventIDs)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err) return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err)
} }
@ -604,7 +618,7 @@ func (r *Inputer) fetchAuthEvents(
} }
for _, authEventID := range authEventIDs { for _, authEventID := range authEventIDs {
authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo.RoomNID, []string{authEventID}) authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
unknown[authEventID] = struct{}{} unknown[authEventID] = struct{}{}
continue continue
@ -690,9 +704,11 @@ nextAuthEvent:
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
} }
roomNID, err := r.DB.GetOrCreateRoomNID(ctx, authEvent) if roomInfo == nil {
roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, authEvent)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err)
}
} }
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type()) eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type())
@ -706,7 +722,7 @@ nextAuthEvent:
} }
// Finally, store the event in the database. // Finally, store the event in the database.
eventNID, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) eventNID, _, err := r.DB.StoreEvent(ctx, authEvent, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
@ -782,7 +798,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event
return err return err
} }
memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, membershipNIDs) memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs)
if err != nil { if err != nil {
return err return err
} }

View file

@ -53,7 +53,7 @@ func (r *Inputer) updateMemberships(
// Load the event JSON so we can look up the "membership" key. // Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that // TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON? // key without having to load the entire event JSON?
events, err := updater.Events(ctx, 0, eventNIDs) events, err := updater.Events(ctx, nil, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
for _, entry := range stateEntries { for _, entry := range stateEntries {
stateEventNIDs = append(stateEventNIDs, entry.EventNID) stateEventNIDs = append(stateEventNIDs, entry.EventNID)
} }
stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomNID, stateEventNIDs) stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs)
if err != nil { if err != nil {
t.log.WithError(err).Warnf("failed to load state events locally") t.log.WithError(err).Warnf("failed to load state events locally")
return nil return nil
@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
missingEventList = append(missingEventList, evID) missingEventList = append(missingEventList, evID)
} }
t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events")
events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList)
if err != nil { if err != nil {
return nil return nil
} }
@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
} }
t.haveEventsMutex.Unlock() t.haveEventsMutex.Unlock()
events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList)
if err != nil { if err != nil {
return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err)
} }
@ -844,7 +844,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
if localFirst { if localFirst {
// fetch from the roomserver // fetch from the roomserver
events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, []string{missingEventID}) events, err := t.db.EventsFromIDs(ctx, t.roomInfo, []string{missingEventID})
if err != nil { if err != nil {
t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
} else if len(events) == 1 { } else if len(events) == 1 {

View file

@ -70,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil return nil
} }
memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, memberNIDs) memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,

View file

@ -23,7 +23,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/auth"
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
@ -86,7 +85,7 @@ func (r *Backfiller) PerformBackfill(
// Retrieve events from the list that was filled previously. If we fail to get // Retrieve events from the list that was filled previously. If we fail to get
// events from the database then attempt once to get them from federation instead. // events from the database then attempt once to get them from federation instead.
var loadedEvents []*gomatrixserverlib.Event var loadedEvents []*gomatrixserverlib.Event
loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
if err != nil { if err != nil {
if _, ok := err.(types.MissingEventError); ok { if _, ok := err.(types.MissingEventError); ok {
return r.backfillViaFederation(ctx, request, response) return r.backfillViaFederation(ctx, request, response)
@ -473,7 +472,7 @@ FindSuccessor:
// Retrieve all "m.room.member" state events of "join" membership, which // Retrieve all "m.room.member" state events of "join" membership, which
// contains the list of users in the room before the event, therefore all // contains the list of users in the room before the event, therefore all
// the servers in it at that moment. // the servers in it at that moment.
memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info.RoomNID, stateEntries, true) memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info, stateEntries, true)
if err != nil { if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
return nil return nil
@ -532,7 +531,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
roomNID = nid.RoomNID roomNID = nid.RoomNID
} }
} }
eventsWithNids, err := b.db.Events(ctx, roomNID, eventNIDs) eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err return nil, err
@ -562,7 +561,7 @@ func joinEventsFromHistoryVisibility(
} }
// Get all of the events in this state // Get all of the events in this state
stateEvents, err := db.Events(ctx, roomInfo.RoomNID, eventNIDs) stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if err != nil { if err != nil {
// even though the default should be shared, restricting the visibility to joined // even though the default should be shared, restricting the visibility to joined
// feels more secure here. // feels more secure here.
@ -585,7 +584,7 @@ func joinEventsFromHistoryVisibility(
if err != nil { if err != nil {
return nil, visibility, err return nil, visibility, err
} }
evs, err := db.Events(ctx, roomInfo.RoomNID, joinEventNIDs) evs, err := db.Events(ctx, roomInfo, joinEventNIDs)
return evs, visibility, err return evs, visibility, err
} }
@ -606,7 +605,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
i++ i++
} }
roomNID, err = db.GetOrCreateRoomNID(ctx, ev.Unwrap()) roomInfo, err := db.GetOrCreateRoomInfo(ctx, ev.Unwrap())
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to get or create roomNID") logrus.WithError(err).Error("failed to get or create roomNID")
continue continue
@ -624,23 +623,22 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
continue continue
} }
var redactedEventID string eventNID, _, err = db.StoreEvent(ctx, ev.Unwrap(), roomInfo, eventTypeNID, eventStateKeyNID, authNids, false)
var redactionEvent *gomatrixserverlib.Event
eventNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), roomNID, eventTypeNID, eventStateKeyNID, authNids, false)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue continue
} }
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Unwrap(), true)
if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
continue
}
// If storing this event results in it being redacted, then do so. // If storing this event results in it being redacted, then do so.
// It's also possible for this event to be a redaction which results in another event being // It's also possible for this event to be a redaction which results in another event being
// redacted, which we don't care about since we aren't returning it in this backfill. // redacted, which we don't care about since we aren't returning it in this backfill.
if redactedEventID == ev.EventID() { if redactedEvent != nil && redactedEvent.EventID() == ev.EventID() {
eventToRedact := ev.Unwrap() ev = redactedEvent.Headered(ev.RoomVersion)
if err := eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
continue
}
ev = eventToRedact.Headered(ev.RoomVersion)
events[j] = ev events[j] = ev
} }
backfilledEventMap[ev.EventID()] = types.Event{ backfilledEventMap[ev.EventID()] = types.Event{

View file

@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil { if err != nil {
return err return err
} }
latestEvents, err := r.DB.EventsFromIDs(ctx, info.RoomNID, []string{latestEventRefs[0].EventID}) latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0].EventID})
if err != nil { if err != nil {
return err return err
} }
@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil { if err != nil {
return err return err
} }
stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
if err != nil { if err != nil {
return err return err
} }
@ -100,7 +100,7 @@ func (r *InboundPeeker) PerformInboundPeek(
} }
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil { if err != nil {
return err return err
} }

View file

@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite(
// try and see if the user is allowed to make this invite. We can't do // try and see if the user is allowed to make this invite. We can't do
// this for invites coming in over federation - we have to take those on // this for invites coming in over federation - we have to take those on
// trust. // trust.
_, err = helpers.CheckAuthEvents(ctx, r.DB, info.RoomNID, event, event.AuthEventIDs()) _, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs())
if err != nil { if err != nil {
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
"processInviteEvent.checkAuthEvents failed for event", "processInviteEvent.checkAuthEvents failed for event",
@ -291,7 +291,7 @@ func buildInviteStrippedState(
for _, stateNID := range stateEntries { for _, stateNID := range stateEntries {
stateNIDs = append(stateNIDs, stateNID.EventNID) stateNIDs = append(stateNIDs, stateNID.EventNID)
} }
stateEvents, err := db.Events(ctx, info.RoomNID, stateNIDs) stateEvents, err := db.Events(ctx, info, stateNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -21,11 +21,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/acls"
@ -102,7 +103,7 @@ func (r *Queryer) QueryStateAfterEvents(
return err return err
} }
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
if err != nil { if err != nil {
return err return err
} }
@ -114,7 +115,7 @@ func (r *Queryer) QueryStateAfterEvents(
} }
authEventIDs = util.UniqueStrings(authEventIDs) authEventIDs = util.UniqueStrings(authEventIDs)
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil { if err != nil {
return fmt.Errorf("getAuthChain: %w", err) return fmt.Errorf("getAuthChain: %w", err)
} }
@ -132,24 +133,46 @@ func (r *Queryer) QueryStateAfterEvents(
return nil return nil
} }
// QueryEventsByID implements api.RoomserverInternalAPI // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID.
func (r *Queryer) QueryEventsByID( func (r *Queryer) QueryEventsByID(
ctx context.Context, ctx context.Context,
request *api.QueryEventsByIDRequest, request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse, response *api.QueryEventsByIDResponse,
) error { ) error {
events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs) if len(request.EventIDs) == 0 {
return nil
}
var err error
// We didn't receive a room ID, we need to fetch it first before we can continue.
// This happens for e.g. ` /_matrix/federation/v1/event/{eventId}`
var roomInfo *types.RoomInfo
if request.RoomID == "" {
var eventNIDs map[string]types.EventMetadata
eventNIDs, err = r.DB.EventNIDs(ctx, []string{request.EventIDs[0]})
if err != nil {
return err
}
if len(eventNIDs) == 0 {
return nil
}
roomInfo, err = r.DB.RoomInfoByNID(ctx, eventNIDs[request.EventIDs[0]].RoomNID)
} else {
roomInfo, err = r.DB.RoomInfo(ctx, request.RoomID)
}
if err != nil {
return err
}
if roomInfo == nil {
return nil
}
events, err := r.DB.EventsFromIDs(ctx, roomInfo, request.EventIDs)
if err != nil { if err != nil {
return err return err
} }
for _, event := range events { for _, event := range events {
roomVersion, verr := r.roomVersion(event.RoomID()) response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion))
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
} }
return nil return nil
@ -186,7 +209,7 @@ func (r *Queryer) QueryMembershipForUser(
response.IsInRoom = stillInRoom response.IsInRoom = stillInRoom
response.HasBeenInRoom = true response.HasBeenInRoom = true
evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID}) evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID})
if err != nil { if err != nil {
return err return err
} }
@ -268,10 +291,10 @@ func (r *Queryer) QueryMembershipAtEvent(
// once. If we have more than one membership event, we need to get the state for each state entry. // once. If we have more than one membership event, we need to get the state for each state entry.
if canShortCircuit { if canShortCircuit {
if len(memberships) == 0 { if len(memberships) == 0 {
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
} }
} else { } else {
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
} }
if err != nil { if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err) return fmt.Errorf("unable to get memberships at state: %w", err)
@ -318,7 +341,7 @@ func (r *Queryer) QueryMembershipsForRoom(
} }
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
} }
events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) events, err = r.DB.Events(ctx, info, eventNIDs)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.Events: %w", err) return fmt.Errorf("r.DB.Events: %w", err)
} }
@ -357,14 +380,14 @@ func (r *Queryer) QueryMembershipsForRoom(
return err return err
} }
events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) events, err = r.DB.Events(ctx, info, eventNIDs)
} else { } else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil { if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err return err
} }
events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly) events, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntries, request.JoinedOnly)
} }
if err != nil { if err != nil {
@ -412,39 +435,39 @@ func (r *Queryer) QueryServerJoinedToRoom(
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
func (r *Queryer) QueryServerAllowedToSeeEvent( func (r *Queryer) QueryServerAllowedToSeeEvent(
ctx context.Context, ctx context.Context,
request *api.QueryServerAllowedToSeeEventRequest, serverName gomatrixserverlib.ServerName,
response *api.QueryServerAllowedToSeeEventResponse, eventID string,
) (err error) { ) (allowed bool, err error) {
events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID}) events, err := r.DB.EventNIDs(ctx, []string{eventID})
if err != nil { if err != nil {
return return
} }
if len(events) == 0 { if len(events) == 0 {
response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see return allowed, nil
return
} }
roomID := events[0].RoomID() info, err := r.DB.RoomInfoByNID(ctx, events[eventID].RoomNID)
inRoomReq := &api.QueryServerJoinedToRoomRequest{
RoomID: roomID,
ServerName: request.ServerName,
}
inRoomRes := &api.QueryServerJoinedToRoomResponse{}
if err = r.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil {
return fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err)
}
info, err := r.DB.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
return err return allowed, err
} }
if info == nil || info.IsStub() { if info == nil || info.IsStub() {
return nil return allowed, nil
} }
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( var isInRoom bool
ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom, if r.IsLocalServerName(serverName) || serverName == "" {
isInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID)
if err != nil {
return allowed, fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err)
}
} else {
isInRoom, err = r.DB.GetServerInRoom(ctx, info.RoomNID, serverName)
if err != nil {
return allowed, fmt.Errorf("r.DB.GetServerInRoom: %w", err)
}
}
return helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, info, eventID, serverName, isInRoom,
) )
return
} }
// QueryMissingEvents implements api.RoomserverInternalAPI // QueryMissingEvents implements api.RoomserverInternalAPI
@ -466,19 +489,22 @@ func (r *Queryer) QueryMissingEvents(
eventsToFilter[id] = true eventsToFilter[id] = true
} }
} }
events, err := r.DB.EventsFromIDs(ctx, 0, front) if len(front) == 0 {
return nil // no events to query, give up.
}
events, err := r.DB.EventNIDs(ctx, []string{front[0]})
if err != nil { if err != nil {
return err return err
} }
if len(events) == 0 { if len(events) == 0 {
return nil // we are missing the events being asked to search from, give up. return nil // we are missing the events being asked to search from, give up.
} }
info, err := r.DB.RoomInfo(ctx, events[0].RoomID()) info, err := r.DB.RoomInfoByNID(ctx, events[front[0]].RoomNID)
if err != nil { if err != nil {
return err return err
} }
if info == nil || info.IsStub() { if info == nil || info.IsStub() {
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
} }
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
@ -486,7 +512,7 @@ func (r *Queryer) QueryMissingEvents(
return err return err
} }
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
if err != nil { if err != nil {
return err return err
} }
@ -529,7 +555,7 @@ func (r *Queryer) QueryStateAndAuthChain(
// TODO: this probably means it should be a different query operation... // TODO: this probably means it should be a different query operation...
if request.OnlyFetchAuthChain { if request.OnlyFetchAuthChain {
var authEvents []*gomatrixserverlib.Event var authEvents []*gomatrixserverlib.Event
authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, request.AuthEventIDs) authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs)
if err != nil { if err != nil {
return err return err
} }
@ -556,7 +582,7 @@ func (r *Queryer) QueryStateAndAuthChain(
} }
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil { if err != nil {
return err return err
} }
@ -611,18 +637,18 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
return nil, rejected, false, err return nil, rejected, false, err
} }
events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries) events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo, stateEntries)
return events, rejected, false, err return events, rejected, false, err
} }
type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error) type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Event, error)
// GetAuthChain fetches the auth chain for the given auth events. An auth chain // GetAuthChain fetches the auth chain for the given auth events. An auth chain
// is the list of all events that are referenced in the auth_events section, and // is the list of all events that are referenced in the auth_events section, and
// all their auth_events, recursively. The returned set of events contain the // all their auth_events, recursively. The returned set of events contain the
// given events. Will *not* error if we don't have all auth events. // given events. Will *not* error if we don't have all auth events.
func GetAuthChain( func GetAuthChain(
ctx context.Context, fn eventsFromIDs, authEventIDs []string, ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string,
) ([]*gomatrixserverlib.Event, error) { ) ([]*gomatrixserverlib.Event, error) {
// List of event IDs to fetch. On each pass, these events will be requested // List of event IDs to fetch. On each pass, these events will be requested
// from the database and the `eventsToFetch` will be updated with any new // from the database and the `eventsToFetch` will be updated with any new
@ -633,7 +659,7 @@ func GetAuthChain(
for len(eventsToFetch) > 0 { for len(eventsToFetch) > 0 {
// Try to retrieve the events from the database. // Try to retrieve the events from the database.
events, err := fn(ctx, 0, eventsToFetch) events, err := fn(ctx, roomInfo, eventsToFetch)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -852,7 +878,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS
} }
func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error { func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error {
chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs) chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, nil, req.EventIDs)
if err != nil { if err != nil {
return err return err
} }
@ -971,7 +997,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// For each of the joined users, let's see if we can get a valid // For each of the joined users, let's see if we can get a valid
// membership event. // membership event.
for _, joinNID := range joinNIDs { for _, joinNID := range joinNIDs {
events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID}) events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID})
if err != nil || len(events) != 1 { if err != nil || len(events) != 1 {
continue continue
} }

View file

@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
} }
// EventsFromIDs implements RoomserverInternalAPIEventDB // EventsFromIDs implements RoomserverInternalAPIEventDB
func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) { func (db *getEventDB) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) (res []types.Event, err error) {
for _, evID := range eventIDs { for _, evID := range eventIDs {
res = append(res, types.Event{ res = append(res, types.Event{
EventNID: 0, EventNID: 0,
@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) {
t.Fatalf("Failed to add events to db: %v", err) t.Fatalf("Failed to add events to db: %v", err)
} }
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"}) result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e"})
if err != nil { if err != nil {
t.Fatalf("getAuthChain failed: %v", err) t.Fatalf("getAuthChain failed: %v", err)
} }
@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) {
t.Fatalf("Failed to add events to db: %v", err) t.Fatalf("Failed to add events to db: %v", err)
} }
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"}) result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e", "f"})
if err != nil { if err != nil {
t.Fatalf("getAuthChain failed: %v", err) t.Fatalf("getAuthChain failed: %v", err)
} }

View file

@ -278,6 +278,16 @@ func TestPurgeRoom(t *testing.T) {
if roomInfo == nil { if roomInfo == nil {
t.Fatalf("room does not exist") t.Fatalf("room does not exist")
} }
//
roomInfo2, err := db.RoomInfoByNID(ctx, roomInfo.RoomNID)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(roomInfo, roomInfo2) {
t.Fatalf("expected roomInfos to be the same, but they aren't")
}
// remember the roomInfo before purging // remember the roomInfo before purging
existingRoomInfo := roomInfo existingRoomInfo := roomInfo
@ -333,6 +343,10 @@ func TestPurgeRoom(t *testing.T) {
if roomInfo != nil { if roomInfo != nil {
t.Fatalf("room should not exist after purging: %+v", roomInfo) t.Fatalf("room should not exist after purging: %+v", roomInfo)
} }
roomInfo2, err = db.RoomInfoByNID(ctx, existingRoomInfo.RoomNID)
if err == nil {
t.Fatalf("expected room to not exist, but it does: %#v", roomInfo2)
}
// validation below // validation below

View file

@ -41,8 +41,8 @@ type StateResolutionStorage interface {
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
} }
type StateResolution struct { type StateResolution struct {
@ -975,7 +975,7 @@ func (v *StateResolution) resolveConflictsV2(
// Store the newly found auth events in the auth set for this event. // Store the newly found auth events in the auth set for this event.
var authEventMap map[string]types.StateEntry var authEventMap map[string]types.StateEntry
authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo.RoomNID, conflictedEvent, knownAuthEvents) authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo, conflictedEvent, knownAuthEvents)
if err != nil { if err != nil {
return err return err
} }
@ -1091,7 +1091,7 @@ func (v *StateResolution) loadStateEvents(
eventNIDs = append(eventNIDs, entry.EventNID) eventNIDs = append(eventNIDs, entry.EventNID)
} }
} }
events, err := v.db.Events(ctx, v.roomInfo.RoomNID, eventNIDs) events, err := v.db.Events(ctx, v.roomInfo, eventNIDs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -1120,7 +1120,7 @@ type authEventLoader struct {
// loadAuthEvents loads all of the auth events for a given event recursively, // loadAuthEvents loads all of the auth events for a given event recursively,
// along with a map that contains state entries for all of the auth events. // along with a map that contains state entries for all of the auth events.
func (l *authEventLoader) loadAuthEvents( func (l *authEventLoader) loadAuthEvents(
ctx context.Context, roomNID types.RoomNID, event *gomatrixserverlib.Event, eventMap map[string]types.Event, ctx context.Context, roomInfo *types.RoomInfo, event *gomatrixserverlib.Event, eventMap map[string]types.Event,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
@ -1155,7 +1155,7 @@ func (l *authEventLoader) loadAuthEvents(
// If we need to get events from the database, go and fetch // If we need to get events from the database, go and fetch
// those now. // those now.
if len(l.lookupFromDB) > 0 { if len(l.lookupFromDB) > 0 {
eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomNID, l.lookupFromDB) eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomInfo, l.lookupFromDB)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
} }

View file

@ -29,6 +29,7 @@ type Database interface {
SupportsConcurrentRoomInputs() bool SupportsConcurrentRoomInputs() bool
// RoomInfo returns room information for the given room ID, or nil if there is no room. // RoomInfo returns room information for the given room ID, or nil if there is no room.
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
// Store the room state at an event in the database // Store the room state at an event in the database
AddState( AddState(
ctx context.Context, ctx context.Context,
@ -69,12 +70,12 @@ type Database interface {
) ([]types.StateEntryList, error) ) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs. // Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events. // Returns a sorted list of events.
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string // Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. // Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error.
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
// Look up the state entries for a list of string event IDs // Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database // Returns an error if the there is an error talking to the database
// Returns a types.MissingEventError if the event IDs aren't in the database. // Returns a types.MissingEventError if the event IDs aren't in the database.
@ -135,7 +136,7 @@ type Database interface {
// EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was
// not found. // not found.
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
// Publish or unpublish a room from the room directory. // Publish or unpublish a room from the room directory.
PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
// Returns a list of room IDs for rooms which are published. // Returns a list of room IDs for rooms which are published.
@ -179,36 +180,53 @@ type Database interface {
GetMembershipForHistoryVisibility( GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) ) (map[string]*gomatrixserverlib.HeaderedEvent, error)
GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
MaybeRedactEvent(
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error)
} }
type RoomDatabase interface { type RoomDatabase interface {
EventDatabase
// RoomInfo returns room information for the given room ID, or nil if there is no room. // RoomInfo returns room information for the given room ID, or nil if there is no room.
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
// IsEventRejected returns true if the event is known and rejected. // IsEventRejected returns true if the event is known and rejected.
IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error) IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error)
MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error) MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error)
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
}
type EventDatabase interface {
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error)
SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
// MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error
// (nil if there was nothing to do)
MaybeRedactEvent(
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error)
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
} }

View file

@ -194,22 +194,27 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
return err return err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: db,
EventDatabase: shared.EventDatabase{
DB: db, DB: db,
Cache: cache, Cache: cache,
Writer: writer, Writer: writer,
EventsTable: events,
EventJSONTable: eventJSON,
EventTypesTable: eventTypes, EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys, EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON, PrevEventsTable: prevEvents,
EventsTable: events, RedactionsTable: redactions,
},
Cache: cache,
Writer: writer,
RoomsTable: rooms, RoomsTable: rooms,
StateBlockTable: stateBlock, StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot, StateSnapshotTable: stateSnapshot,
PrevEventsTable: prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: invites, InvitesTable: invites,
MembershipTable: membership, MembershipTable: membership,
PublishedTable: published, PublishedTable: published,
RedactionsTable: redactions,
Purge: purge, Purge: purge,
} }
return nil return nil

View file

@ -116,8 +116,8 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent
}) })
} }
func (u *RoomUpdater) Events(ctx context.Context, _ types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) { func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
return u.d.events(ctx, u.txn, u.roomInfo.RoomNID, eventNIDs) return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs)
} }
func (u *RoomUpdater) SnapshotNIDFromEventID( func (u *RoomUpdater) SnapshotNIDFromEventID(
@ -195,8 +195,8 @@ func (u *RoomUpdater) StateAtEventIDs(
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
} }
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) { func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, roomNID, eventIDs, NoFilter) return u.d.eventsFromIDs(ctx, u.txn, u.roomInfo, eventIDs, NoFilter)
} }
// IsReferenced implements types.RoomRecentEventsUpdater // IsReferenced implements types.RoomRecentEventsUpdater

View file

@ -9,7 +9,6 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
@ -28,6 +27,23 @@ import (
const redactionsArePermanent = true const redactionsArePermanent = true
type Database struct { type Database struct {
DB *sql.DB
EventDatabase
Cache caching.RoomServerCaches
Writer sqlutil.Writer
RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
InvitesTable tables.Invites
MembershipTable tables.Membership
PublishedTable tables.Published
Purge tables.Purge
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
}
// EventDatabase contains all tables needed to work with events
type EventDatabase struct {
DB *sql.DB DB *sql.DB
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
Writer sqlutil.Writer Writer sqlutil.Writer
@ -35,17 +51,8 @@ type Database struct {
EventJSONTable tables.EventJSON EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites
MembershipTable tables.Membership
PublishedTable tables.Published
RedactionsTable tables.Redactions RedactionsTable tables.Redactions
Purge tables.Purge
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
} }
func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) SupportsConcurrentRoomInputs() bool {
@ -58,13 +65,13 @@ func (d *Database) GetMembershipForHistoryVisibility(
return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...) return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...)
} }
func (d *Database) EventTypeNIDs( func (d *EventDatabase) EventTypeNIDs(
ctx context.Context, eventTypes []string, ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {
return d.eventTypeNIDs(ctx, nil, eventTypes) return d.eventTypeNIDs(ctx, nil, eventTypes)
} }
func (d *Database) eventTypeNIDs( func (d *EventDatabase) eventTypeNIDs(
ctx context.Context, txn *sql.Tx, eventTypes []string, ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID) result := make(map[string]types.EventTypeNID)
@ -91,7 +98,7 @@ func (d *Database) eventTypeNIDs(
return result, nil return result, nil
} }
func (d *Database) EventStateKeys( func (d *EventDatabase) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) { ) (map[types.EventStateKeyNID]string, error) {
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
@ -116,13 +123,13 @@ func (d *Database) EventStateKeys(
return result, nil return result, nil
} }
func (d *Database) EventStateKeyNIDs( func (d *EventDatabase) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string, ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
return d.eventStateKeyNIDs(ctx, nil, eventStateKeys) return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
} }
func (d *Database) eventStateKeyNIDs( func (d *EventDatabase) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string, ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID) result := make(map[string]types.EventStateKeyNID)
@ -174,7 +181,7 @@ func (d *Database) eventStateKeyNIDs(
return result, nil return result, nil
} }
func (d *Database) StateEntriesForEventIDs( func (d *EventDatabase) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string, excludeRejected bool, ctx context.Context, eventIDs []string, excludeRejected bool,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected) return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected)
@ -213,6 +220,17 @@ func (d *Database) stateEntriesForTuples(
return lists, nil return lists, nil
} }
func (d *Database) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) {
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{roomNID})
if err != nil {
return nil, err
}
if len(roomIDs) == 0 {
return nil, fmt.Errorf("room does not exist")
}
return d.roomInfo(ctx, nil, roomIDs[0])
}
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.roomInfo(ctx, nil, roomID) return d.roomInfo(ctx, nil, roomID)
} }
@ -292,7 +310,7 @@ func (d *Database) addState(
return return
} }
func (d *Database) EventNIDs( func (d *EventDatabase) EventNIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) (map[string]types.EventMetadata, error) { ) (map[string]types.EventMetadata, error) {
return d.eventNIDs(ctx, nil, eventIDs, NoFilter) return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
@ -305,7 +323,7 @@ const (
FilterUnsentOnly UnsentFilter = true FilterUnsentOnly UnsentFilter = true
) )
func (d *Database) eventNIDs( func (d *EventDatabase) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter, ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
) (map[string]types.EventMetadata, error) { ) (map[string]types.EventMetadata, error) {
switch filter { switch filter {
@ -318,7 +336,7 @@ func (d *Database) eventNIDs(
} }
} }
func (d *Database) SetState( func (d *EventDatabase) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -326,19 +344,19 @@ func (d *Database) SetState(
}) })
} }
func (d *Database) StateAtEventIDs( func (d *EventDatabase) StateAtEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs) return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
} }
func (d *Database) SnapshotNIDFromEventID( func (d *EventDatabase) SnapshotNIDFromEventID(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
return d.snapshotNIDFromEventID(ctx, nil, eventID) return d.snapshotNIDFromEventID(ctx, nil, eventID)
} }
func (d *Database) snapshotNIDFromEventID( func (d *EventDatabase) snapshotNIDFromEventID(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID) _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
@ -351,17 +369,17 @@ func (d *Database) snapshotNIDFromEventID(
return stateNID, err return stateNID, err
} }
func (d *Database) EventIDs( func (d *EventDatabase) EventIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) { ) (map[types.EventNID]string, error) {
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
} }
func (d *Database) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) { func (d *EventDatabase) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) {
return d.eventsFromIDs(ctx, nil, roomNID, eventIDs, NoFilter) return d.eventsFromIDs(ctx, nil, roomInfo, eventIDs, NoFilter)
} }
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter) nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
if err != nil { if err != nil {
return nil, err return nil, err
@ -370,15 +388,9 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types
var nids []types.EventNID var nids []types.EventNID
for _, nid := range nidMap { for _, nid := range nidMap {
nids = append(nids, nid.EventNID) nids = append(nids, nid.EventNID)
if roomNID != 0 && roomNID != nid.RoomNID {
logrus.Errorf("expected events from room %d, but also found %d", roomNID, nid.RoomNID)
}
if roomNID == 0 {
roomNID = nid.RoomNID
}
} }
return d.events(ctx, txn, roomNID, nids) return d.events(ctx, txn, roomInfo, nids)
} }
func (d *Database) LatestEventIDs( func (d *Database) LatestEventIDs(
@ -517,19 +529,17 @@ func (d *Database) GetInvitesForUser(
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
} }
func (d *Database) Events( func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID, return d.events(ctx, nil, roomInfo, eventNIDs)
) ([]types.Event, error) {
return d.events(ctx, nil, roomNID, eventNIDs)
} }
func (d *Database) events( func (d *EventDatabase) events(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, inputEventNIDs types.EventNIDs, ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs,
) ([]types.Event, error) { ) ([]types.Event, error) {
if roomNID == 0 { if roomInfo == nil { // this should never happen
// No need to go further, as we won't find any events for this room. return nil, fmt.Errorf("unable to parse events without roomInfo")
return nil, nil
} }
sort.Sort(inputEventNIDs) sort.Sort(inputEventNIDs)
events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs)) events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs))
eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs)) eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
@ -566,31 +576,9 @@ func (d *Database) events(
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
var roomVersion gomatrixserverlib.RoomVersion
var fetchRoomVersion bool
var ok bool
var roomID string
if roomID, ok = d.Cache.GetRoomServerRoomID(roomNID); ok {
roomVersion, ok = d.Cache.GetRoomVersion(roomID)
if !ok {
fetchRoomVersion = true
}
}
if roomVersion == "" || fetchRoomVersion {
var dbRoomVersions map[types.RoomNID]gomatrixserverlib.RoomVersion
dbRoomVersions, err = d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, []types.RoomNID{roomNID})
if err != nil {
return nil, err
}
if roomVersion, ok = dbRoomVersions[roomNID]; !ok {
return nil, fmt.Errorf("unable to find roomversion for room %d", roomNID)
}
}
for _, eventJSON := range eventJSONs { for _, eventJSON := range eventJSONs {
events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion, eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomInfo.RoomVersion,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -660,8 +648,8 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e
return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID)
} }
// GetOrCreateRoomNID gets or creates a new roomNID for the given event // GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID.
func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (roomNID types.RoomNID, err error) { func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (roomInfo *types.RoomInfo, err error) {
// Get the default room version. If the client doesn't supply a room_version // Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room. // then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
@ -670,8 +658,9 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver
// room. // room.
var roomVersion gomatrixserverlib.RoomVersion var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return 0, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
} }
var roomNID types.RoomNID
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion)
if err != nil { if err != nil {
@ -679,7 +668,10 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver
} }
return nil return nil
}) })
return roomNID, err return &types.RoomInfo{
RoomVersion: roomVersion,
RoomNID: roomNID,
}, err
} }
func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
@ -710,25 +702,22 @@ func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKe
return eventStateKeyNID, nil return eventStateKeyNID, nil
} }
func (d *Database) StoreEvent( func (d *EventDatabase) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event, ctx context.Context, event *gomatrixserverlib.Event,
roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
authEventNIDs []types.EventNID, isRejected bool, authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { ) (types.EventNID, types.StateAtEvent, error) {
var ( var (
eventNID types.EventNID eventNID types.EventNID
stateNID types.StateSnapshotNID stateNID types.StateSnapshotNID
redactionEvent *gomatrixserverlib.Event
redactedEventID string
err error err error
) )
// Second writer is using the database-provided transaction, probably from the
// room updater, for easy roll-back if required.
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if eventNID, stateNID, err = d.EventsTable.InsertEvent( if eventNID, stateNID, err = d.EventsTable.InsertEvent(
ctx, ctx,
txn, txn,
roomNID, roomInfo.RoomNID,
eventTypeNID, eventTypeNID,
eventStateKeyNID, eventStateKeyNID,
event.EventID(), event.EventID(),
@ -751,16 +740,26 @@ func (d *Database) StoreEvent(
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
} }
if !isRejected { // ignore rejected redaction events
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, roomNID, eventNID, event) if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
if err != nil { // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
return fmt.Errorf("d.handleRedactions: %w", err) // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
// function only does SELECTs though so the created txn (at this point) is just a read txn like
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`.
// The following is a copy of RoomUpdater.StorePreviousEvents
for _, ref := range prevEvents {
if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
} }
} }
}
return nil return nil
}) })
if err != nil { if err != nil {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) return 0, types.StateAtEvent{}, fmt.Errorf("d.Writer.Do: %w", err)
} }
// We should attempt to update the previous events table with any // We should attempt to update the previous events table with any
@ -768,33 +767,6 @@ func (d *Database) StoreEvent(
// events updater because it somewhat works as a mutex, ensuring // events updater because it somewhat works as a mutex, ensuring
// that there's a row-level lock on the latest room events (well, // that there's a row-level lock on the latest room events (well,
// on Postgres at least). // on Postgres at least).
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
// function only does SELECTs though so the created txn (at this point) is just a read txn like
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`.
succeeded := false
var roomInfo *types.RoomInfo
roomInfo, err = d.roomInfo(ctx, nil, event.RoomID())
if err != nil {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
if roomInfo == nil && len(prevEvents) > 0 {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
}
var updater *RoomUpdater
updater, err = d.GetRoomUpdater(ctx, roomInfo)
if err != nil {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
}
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
}
succeeded = true
}
return eventNID, types.StateAtEvent{ return eventNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID, BeforeStateSnapshotNID: stateNID,
@ -805,7 +777,7 @@ func (d *Database) StoreEvent(
}, },
EventNID: eventNID, EventNID: eventNID,
}, },
}, redactionEvent, redactedEventID, err }, err
} }
func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error { func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error {
@ -893,7 +865,7 @@ func (d *Database) assignEventTypeNID(
return eventTypeNID, nil return eventTypeNID, nil
} }
func (d *Database) assignStateKeyNID( func (d *EventDatabase) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string, ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey) eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey)
@ -937,7 +909,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
return roomVersion, err return roomVersion, err
} }
// handleRedactions manages the redacted status of events. There's two cases to consider in order to comply with the spec: // MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec:
// "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid." // "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid."
// https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events // https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
// These cases are: // These cases are:
@ -952,16 +924,23 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
// when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need // when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need
// to cross-reference with other tables when loading. // to cross-reference with other tables when loading.
// //
// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. // Returns the redaction event and the redacted event if this call resulted in a redaction.
func (d *Database) handleRedactions( func (d *EventDatabase) MaybeRedactEvent(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event, ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
) (*gomatrixserverlib.Event, string, error) { ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) {
var err error var (
redactionEvent, redactedEvent *types.Event
err error
validated bool
ignoreRedaction bool
)
wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil
if isRedactionEvent { if isRedactionEvent {
// an event which redacts itself should be ignored // an event which redacts itself should be ignored
if event.EventID() == event.Redacts() { if event.EventID() == event.Redacts() {
return nil, "", nil return nil
} }
err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{ err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{
@ -970,44 +949,30 @@ func (d *Database) handleRedactions(
RedactsEventID: event.Redacts(), RedactsEventID: event.Redacts(),
}) })
if err != nil { if err != nil {
return nil, "", fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err) return fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err)
} }
} }
redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, roomNID, eventNID, event) redactionEvent, redactedEvent, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event)
if err != nil {
return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err)
}
if validated || redactedEvent == nil || redactionEvent == nil {
// we've seen this redaction before or there is nothing to redact
return nil, "", nil
}
if redactedEvent.RoomID() != redactionEvent.RoomID() {
// redactions across rooms aren't allowed
return nil, "", nil
}
// Get the power level from the database, so we can verify the user is allowed to redact the event
powerLevels, err := d.GetStateEvent(ctx, event.RoomID(), gomatrixserverlib.MRoomPowerLevels, "")
if err != nil {
return nil, "", fmt.Errorf("d.GetStateEvent: %w", err)
}
if powerLevels == nil {
return nil, "", fmt.Errorf("unable to fetch m.room.power_levels event from database for room %s", event.RoomID())
}
pl, err := powerLevels.PowerLevels()
if err != nil {
return nil, "", fmt.Errorf("unable to get powerlevels for room: %w", err)
}
redactUser := pl.UserLevel(redactionEvent.Sender())
switch { switch {
case redactUser >= pl.Redact: case err != nil:
// The power level of the redaction events sender is greater than or equal to the redact level. return fmt.Errorf("d.loadRedactionPair: %w", err)
case redactedEvent.Sender() == redactionEvent.Sender(): case validated || redactedEvent == nil || redactionEvent == nil:
// The domain of the redaction events sender matches that of the original events sender. // we've seen this redaction before or there is nothing to redact
default: return nil
return nil, "", nil case redactedEvent.RoomID() != redactionEvent.RoomID():
// redactions across rooms aren't allowed
ignoreRedaction = true
return nil
}
// 1. The power level of the redaction events sender is greater than or equal to the redact level. (redactAllowed)
// 2. The domain of the redaction events sender matches that of the original events sender.
_, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender())
_, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender())
if !redactAllowed || sender1 != sender2 {
ignoreRedaction = true
return nil
} }
// mark the event as redacted // mark the event as redacted
@ -1017,30 +982,37 @@ func (d *Database) handleRedactions(
err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
} }
// NOTSPEC: sytest relies on this unspecced field existing :( // NOTSPEC: sytest relies on this unspecced field existing :(
err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID()) err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID())
if err != nil { if err != nil {
return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
} }
// overwrite the eventJSON table // overwrite the eventJSON table
err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON())
if err != nil { if err != nil {
return nil, "", fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
} }
err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true)
if err != nil { if err != nil {
err = fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err)
} }
return nil
return redactionEvent.Event, redactedEvent.EventID(), err })
if wErr != nil {
return nil, nil, err
}
if ignoreRedaction || redactionEvent == nil || redactedEvent == nil {
return nil, nil, nil
}
return redactionEvent.Event, redactedEvent.Event, nil
} }
// loadRedactionPair returns both the redaction event and the redacted event, else nil. // loadRedactionPair returns both the redaction event and the redacted event, else nil.
func (d *Database) loadRedactionPair( func (d *EventDatabase) loadRedactionPair(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event, ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event,
) (*types.Event, *types.Event, bool, error) { ) (*types.Event, *types.Event, bool, error) {
var redactionEvent, redactedEvent *types.Event var redactionEvent, redactedEvent *types.Event
var info *tables.RedactionInfo var info *tables.RedactionInfo
@ -1072,16 +1044,16 @@ func (d *Database) loadRedactionPair(
} }
if isRedactionEvent { if isRedactionEvent {
redactedEvent = d.loadEvent(ctx, roomNID, info.RedactsEventID) redactedEvent = d.loadEvent(ctx, roomInfo, info.RedactsEventID)
} else { } else {
redactionEvent = d.loadEvent(ctx, roomNID, info.RedactionEventID) redactionEvent = d.loadEvent(ctx, roomInfo, info.RedactionEventID)
} }
return redactionEvent, redactedEvent, info.Validated, nil return redactionEvent, redactedEvent, info.Validated, nil
} }
// applyRedactions will redact events that have an `unsigned.redacted_because` field. // applyRedactions will redact events that have an `unsigned.redacted_because` field.
func (d *Database) applyRedactions(events []types.Event) { func (d *EventDatabase) applyRedactions(events []types.Event) {
for i := range events { for i := range events {
if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() { if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() {
events[i].Redact() events[i].Redact()
@ -1090,7 +1062,7 @@ func (d *Database) applyRedactions(events []types.Event) {
} }
// loadEvent loads a single event or returns nil on any problems/missing event // loadEvent loads a single event or returns nil on any problems/missing event
func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID string) *types.Event { func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo, eventID string) *types.Event {
nids, err := d.EventNIDs(ctx, []string{eventID}) nids, err := d.EventNIDs(ctx, []string{eventID})
if err != nil { if err != nil {
return nil return nil
@ -1098,7 +1070,7 @@ func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID
if len(nids) == 0 { if len(nids) == 0 {
return nil return nil
} }
evs, err := d.Events(ctx, roomNID, []types.EventNID{nids[eventID].EventNID}) evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID})
if err != nil { if err != nil {
return nil return nil
} }
@ -1144,7 +1116,7 @@ func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *type
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
roomInfo, err := d.RoomInfo(ctx, roomID) roomInfo, err := d.roomInfo(ctx, nil, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1209,7 +1181,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error // Same as GetStateEvent but returns all matching state events with this event type. Returns no error
// if there are no events with this event type. // if there are no events with this event type.
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
roomInfo, err := d.RoomInfo(ctx, roomID) roomInfo, err := d.roomInfo(ctx, nil, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1340,7 +1312,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion) eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion)
// TODO: This feels like this is going to be really slow... // TODO: This feels like this is going to be really slow...
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
roomInfo, err2 := d.RoomInfo(ctx, roomID) roomInfo, err2 := d.roomInfo(ctx, nil, roomID)
if err2 != nil { if err2 != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2) return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2)
} }

View file

@ -52,9 +52,11 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat
cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache}
return &shared.Database{ return &shared.Database{
DB: db, DB: db,
EventStateKeysTable: stateKeyTable, EventDatabase: evDb,
MembershipTable: membershipTable, MembershipTable: membershipTable,
Writer: sqlutil.NewExclusiveWriter(), Writer: sqlutil.NewExclusiveWriter(),
Cache: cache, Cache: cache,

View file

@ -203,6 +203,8 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: db,
EventDatabase: shared.EventDatabase{
DB: db, DB: db,
Cache: cache, Cache: cache,
Writer: writer, Writer: writer,
@ -210,15 +212,18 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
EventTypesTable: eventTypes, EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys, EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON, EventJSONTable: eventJSON,
PrevEventsTable: prevEvents,
RedactionsTable: redactions,
},
Cache: cache,
Writer: writer,
RoomsTable: rooms, RoomsTable: rooms,
StateBlockTable: stateBlock, StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot, StateSnapshotTable: stateSnapshot,
PrevEventsTable: prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: invites, InvitesTable: invites,
MembershipTable: membership, MembershipTable: membership,
PublishedTable: published, PublishedTable: published,
RedactionsTable: redactions,
GetRoomUpdaterFn: d.GetRoomUpdater, GetRoomUpdaterFn: d.GetRoomUpdater,
Purge: purge, Purge: purge,
} }

View file

@ -20,9 +20,11 @@ import (
"database/sql" "database/sql"
"embed" "embed"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"html/template" "html/template"
"io" "io"
"io/fs"
"net" "net"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
@ -85,8 +87,6 @@ type BaseDendrite struct {
startupLock sync.Mutex startupLock sync.Mutex
} }
const NoListener = ""
const HTTPServerTimeout = time.Minute * 5 const HTTPServerTimeout = time.Minute * 5
type BaseDendriteOptions int type BaseDendriteOptions int
@ -345,18 +345,17 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() {
// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs // SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs
// and adds a prometheus handler under /_dendrite/metrics. // and adds a prometheus handler under /_dendrite/metrics.
func (b *BaseDendrite) SetupAndServeHTTP( func (b *BaseDendrite) SetupAndServeHTTP(
externalHTTPAddr config.HTTPAddress, externalHTTPAddr config.ServerAddress,
certFile, keyFile *string, certFile, keyFile *string,
) { ) {
// Manually unlocked right before actually serving requests, // Manually unlocked right before actually serving requests,
// as we don't return from this method (defer doesn't work). // as we don't return from this method (defer doesn't work).
b.startupLock.Lock() b.startupLock.Lock()
externalAddr, _ := externalHTTPAddr.Address()
externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
externalServ := &http.Server{ externalServ := &http.Server{
Addr: string(externalAddr), Addr: externalHTTPAddr.Address,
WriteTimeout: HTTPServerTimeout, WriteTimeout: HTTPServerTimeout,
Handler: externalRouter, Handler: externalRouter,
BaseContext: func(_ net.Listener) context.Context { BaseContext: func(_ net.Listener) context.Context {
@ -419,7 +418,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
b.startupLock.Unlock() b.startupLock.Unlock()
if externalAddr != NoListener { if externalHTTPAddr.Enabled() {
go func() { go func() {
var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
logrus.Infof("Starting external listener on %s", externalServ.Addr) logrus.Infof("Starting external listener on %s", externalServ.Addr)
@ -436,6 +435,26 @@ func (b *BaseDendrite) SetupAndServeHTTP(
logrus.WithError(err).Fatal("failed to serve HTTPS") logrus.WithError(err).Fatal("failed to serve HTTPS")
} }
} }
} else {
if externalHTTPAddr.IsUnixSocket() {
err := os.Remove(externalHTTPAddr.Address)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
logrus.WithError(err).Fatal("failed to remove existing unix socket")
}
listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address)
if err != nil {
logrus.WithError(err).Fatal("failed to serve unix socket")
}
err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission)
if err != nil {
logrus.WithError(err).Fatal("failed to set unix socket permissions")
}
if err := externalServ.Serve(listener); err != nil {
if err != http.ErrServerClosed {
logrus.WithError(err).Fatal("failed to serve unix socket")
}
}
} else { } else {
if err := externalServ.ListenAndServe(); err != nil { if err := externalServ.ListenAndServe(); err != nil {
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
@ -443,6 +462,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
} }
} }
} }
}
logrus.Infof("Stopped external listener on %s", externalServ.Addr) logrus.Infof("Stopped external listener on %s", externalServ.Addr)
}() }()
} }

View file

@ -2,10 +2,13 @@ package base_test
import ( import (
"bytes" "bytes"
"context"
"embed" "embed"
"html/template" "html/template"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path"
"testing" "testing"
"time" "time"
@ -18,7 +21,7 @@ import (
//go:embed static/*.gotmpl //go:embed static/*.gotmpl
var staticContent embed.FS var staticContent embed.FS
func TestLandingPage(t *testing.T) { func TestLandingPage_Tcp(t *testing.T) {
// generate the expected result // generate the expected result
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
expectedRes := &bytes.Buffer{} expectedRes := &bytes.Buffer{}
@ -35,7 +38,9 @@ func TestLandingPage(t *testing.T) {
s.Close() s.Close()
// start base with the listener and wait for it to be started // start base with the listener and wait for it to be started
go b.SetupAndServeHTTP(config.HTTPAddress(s.URL), nil, nil) address, err := config.HTTPAddress(s.URL)
assert.NoError(t, err)
go b.SetupAndServeHTTP(address, nil, nil)
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)
// When hitting /, we should be redirected to /_matrix/static, which should contain the landing page // When hitting /, we should be redirected to /_matrix/static, which should contain the landing page
@ -55,3 +60,43 @@ func TestLandingPage(t *testing.T) {
// Using .String() for user friendly output // Using .String() for user friendly output
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
} }
func TestLandingPage_UnixSocket(t *testing.T) {
// generate the expected result
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
expectedRes := &bytes.Buffer{}
err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{
"Version": internal.VersionString(),
})
assert.NoError(t, err)
b, _, _ := testrig.Base(nil)
defer b.Close()
tempDir := t.TempDir()
socket := path.Join(tempDir, "socket")
// start base with the listener and wait for it to be started
address := config.UnixSocketAddress(socket, 0755)
assert.NoError(t, err)
go b.SetupAndServeHTTP(address, nil, nil)
time.Sleep(time.Millisecond * 100)
client := &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", socket)
},
},
}
resp, err := client.Get("http://unix/")
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// read the response
buf := &bytes.Buffer{}
_, err = buf.ReadFrom(resp.Body)
assert.NoError(t, err)
// Using .String() for user friendly output
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
}

View file

@ -19,7 +19,6 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -131,20 +130,6 @@ func (d DataSource) IsPostgres() bool {
// A Topic in kafka. // A Topic in kafka.
type Topic string type Topic string
// An Address to listen on.
type Address string
// An HTTPAddress to listen on, starting with either http:// or https://.
type HTTPAddress string
func (h HTTPAddress) Address() (Address, error) {
url, err := url.Parse(string(h))
if err != nil {
return "", err
}
return Address(url.Host), nil
}
// FileSizeBytes is a file size in bytes // FileSizeBytes is a file size in bytes
type FileSizeBytes int64 type FileSizeBytes int64

View file

@ -0,0 +1,45 @@
package config
import (
"io/fs"
"net/url"
)
const (
NetworkTCP = "tcp"
NetworkUnix = "unix"
)
type ServerAddress struct {
Address string
Scheme string
UnixSocketPermission fs.FileMode
}
func (s ServerAddress) Enabled() bool {
return s.Address != ""
}
func (s ServerAddress) IsUnixSocket() bool {
return s.Scheme == NetworkUnix
}
func (s ServerAddress) Network() string {
if s.Scheme == NetworkUnix {
return NetworkUnix
} else {
return NetworkTCP
}
}
func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress {
return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm}
}
func HTTPAddress(urlAddress string) (ServerAddress, error) {
parsedUrl, err := url.Parse(urlAddress)
if err != nil {
return ServerAddress{}, err
}
return ServerAddress{parsedUrl.Host, parsedUrl.Scheme, 0}, nil
}

View file

@ -0,0 +1,25 @@
package config
import (
"io/fs"
"testing"
"github.com/stretchr/testify/assert"
)
func TestHttpAddress_ParseGood(t *testing.T) {
address, err := HTTPAddress("http://localhost:123")
assert.NoError(t, err)
assert.Equal(t, "localhost:123", address.Address)
assert.Equal(t, "tcp", address.Network())
}
func TestHttpAddress_ParseBad(t *testing.T) {
_, err := HTTPAddress(":")
assert.Error(t, err)
}
func TestUnixSocketAddress_Network(t *testing.T) {
address := UnixSocketAddress("/tmp", fs.FileMode(0755))
assert.Equal(t, "unix", address.Network())
}

View file

@ -253,7 +253,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo
var res MSC2836EventRelationshipsResponse var res MSC2836EventRelationshipsResponse
var returnEvents []*gomatrixserverlib.HeaderedEvent var returnEvents []*gomatrixserverlib.HeaderedEvent
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue. // Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
event := rc.getLocalEvent(rc.req.EventID) event := rc.getLocalEvent(rc.req.RoomID, rc.req.EventID)
if event == nil { if event == nil {
event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID) event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID)
} }
@ -592,7 +592,7 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelation
// lookForEvent returns the event for the event ID given, by trying to query remote servers // lookForEvent returns the event for the event ID given, by trying to query remote servers
// if the event ID is unknown via /event_relationships. // if the event ID is unknown via /event_relationships.
func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent { func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
event := rc.getLocalEvent(eventID) event := rc.getLocalEvent(rc.req.RoomID, eventID)
if event == nil { if event == nil {
queryRes := rc.remoteEventRelationships(eventID) queryRes := rc.remoteEventRelationships(eventID)
if queryRes != nil { if queryRes != nil {
@ -622,9 +622,10 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent
return nil return nil
} }
func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent { func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.HeaderedEvent {
var queryEventsRes roomserver.QueryEventsByIDResponse var queryEventsRes roomserver.QueryEventsByIDResponse
err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{ err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
RoomID: roomID,
EventIDs: []string{eventID}, EventIDs: []string{eventID},
}, &queryEventsRes) }, &queryEventsRes)
if err != nil { if err != nil {

View file

@ -212,6 +212,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
// Finally, work out if there are any more events missing. // Finally, work out if there are any more events missing.
if len(missingEventIDs) > 0 { if len(missingEventIDs) > 0 {
eventsReq := &api.QueryEventsByIDRequest{ eventsReq := &api.QueryEventsByIDRequest{
RoomID: ev.RoomID(),
EventIDs: missingEventIDs, EventIDs: missingEventIDs,
} }
eventsRes := &api.QueryEventsByIDResponse{} eventsRes := &api.QueryEventsByIDResponse{}

View file

@ -109,7 +109,7 @@ func GetMemberships(
} }
qryRes := &api.QueryEventsByIDResponse{} qryRes := &api.QueryEventsByIDResponse{}
if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs}, qryRes); err != nil { if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs, RoomID: roomID}, qryRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed") util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -187,8 +187,8 @@ func Test_UserStatistics(t *testing.T) {
}) })
t.Run("Users not active for one/two month", func(t *testing.T) { t.Run("Users not active for one/two month", func(t *testing.T) {
mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, -2, 0)) mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, 0, -60))
mustUpdateDeviceLastSeen(t, ctx, db, "user2", time.Now().AddDate(0, -1, 0)) mustUpdateDeviceLastSeen(t, ctx, db, "user2", time.Now().AddDate(0, 0, -30))
gotStats, _, err := statsDB.UserStatistics(ctx, nil) gotStats, _, err := statsDB.UserStatistics(ctx, nil)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
@ -224,9 +224,9 @@ func Test_UserStatistics(t *testing.T) {
- Where account creation and last_seen are > 30 days apart - Where account creation and last_seen are > 30 days apart
*/ */
t.Run("R30Users tests", func(t *testing.T) { t.Run("R30Users tests", func(t *testing.T) {
mustUserUpdateRegistered(t, ctx, db, "user1", time.Now().AddDate(0, -2, 0)) mustUserUpdateRegistered(t, ctx, db, "user1", time.Now().AddDate(0, 0, -60))
mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now()) mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now())
mustUserUpdateRegistered(t, ctx, db, "user4", time.Now().AddDate(0, -2, 0)) mustUserUpdateRegistered(t, ctx, db, "user4", time.Now().AddDate(0, 0, -60))
mustUpdateDeviceLastSeen(t, ctx, db, "user4", time.Now()) mustUpdateDeviceLastSeen(t, ctx, db, "user4", time.Now())
startTime := time.Now().AddDate(0, 0, -2) startTime := time.Now().AddDate(0, 0, -2)
err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24)) err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24))