mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-27 00:31:55 -06:00
Merge branch 'main' of github.com:matrix-org/dendrite into gh-pages
This commit is contained in:
commit
bb838a17c2
6
.github/workflows/dendrite.yml
vendored
6
.github/workflows/dendrite.yml
vendored
|
@ -4,7 +4,13 @@ on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
|
paths:
|
||||||
|
- '**.go' # only execute on changes to go files
|
||||||
|
- '.github/workflows/**' # or workflow changes
|
||||||
pull_request:
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- '**.go'
|
||||||
|
- '.github/workflows/**'
|
||||||
release:
|
release:
|
||||||
types: [published]
|
types: [published]
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
|
@ -71,10 +71,10 @@ $ ./bin/generate-keys --tls-cert server.crt --tls-key server.key
|
||||||
|
|
||||||
# Copy and modify the config file - you'll need to set a server name and paths to the keys
|
# Copy and modify the config file - you'll need to set a server name and paths to the keys
|
||||||
# at the very least, along with setting up the database connection strings.
|
# at the very least, along with setting up the database connection strings.
|
||||||
$ cp dendrite-sample.monolith.yaml dendrite.yaml
|
$ cp dendrite-sample.yaml dendrite.yaml
|
||||||
|
|
||||||
# Build and run the server:
|
# Build and run the server:
|
||||||
$ ./bin/dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml
|
$ ./bin/dendrite --tls-cert server.crt --tls-key server.key --config dendrite.yaml
|
||||||
|
|
||||||
# Create an user account (add -admin for an admin user).
|
# Create an user account (add -admin for an admin user).
|
||||||
# Specify the localpart only, e.g. 'alice' for '@alice:domain.com'
|
# Specify the localpart only, e.g. 'alice' for '@alice:domain.com'
|
||||||
|
|
|
@ -122,6 +122,7 @@ func (s *OutputRoomEventConsumer) onMessage(
|
||||||
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
|
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
|
||||||
newEventID := output.NewRoomEvent.Event.EventID()
|
newEventID := output.NewRoomEvent.Event.EventID()
|
||||||
eventsReq := &api.QueryEventsByIDRequest{
|
eventsReq := &api.QueryEventsByIDRequest{
|
||||||
|
RoomID: output.NewRoomEvent.Event.RoomID(),
|
||||||
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
|
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
|
||||||
}
|
}
|
||||||
eventsRes := &api.QueryEventsByIDResponse{}
|
eventsRes := &api.QueryEventsByIDResponse{}
|
||||||
|
|
|
@ -57,7 +57,7 @@ func SendRedaction(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ev := roomserverAPI.GetEvent(req.Context(), rsAPI, eventID)
|
ev := roomserverAPI.GetEvent(req.Context(), rsAPI, roomID, eventID)
|
||||||
if ev == nil {
|
if ev == nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 400,
|
Code: 400,
|
||||||
|
|
|
@ -16,6 +16,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
"io/fs"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
@ -30,6 +31,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
unixSocket = flag.String("unix-socket", "",
|
||||||
|
"EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)",
|
||||||
|
)
|
||||||
|
unixSocketPermission = flag.Int("unix-socket-permission", 0755,
|
||||||
|
"EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server",
|
||||||
|
)
|
||||||
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
|
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
|
||||||
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
|
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
|
||||||
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
|
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
|
||||||
|
@ -38,8 +45,23 @@ var (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
cfg := setup.ParseFlags(true)
|
cfg := setup.ParseFlags(true)
|
||||||
httpAddr := config.HTTPAddress("http://" + *httpBindAddr)
|
httpAddr := config.ServerAddress{}
|
||||||
httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr)
|
httpsAddr := config.ServerAddress{}
|
||||||
|
if *unixSocket == "" {
|
||||||
|
http, err := config.HTTPAddress("http://" + *httpBindAddr)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Fatalf("Failed to parse http address")
|
||||||
|
}
|
||||||
|
httpAddr = http
|
||||||
|
https, err := config.HTTPAddress("https://" + *httpsBindAddr)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Fatalf("Failed to parse https address")
|
||||||
|
}
|
||||||
|
httpsAddr = https
|
||||||
|
} else {
|
||||||
|
httpAddr = config.UnixSocketAddress(*unixSocket, fs.FileMode(*unixSocketPermission))
|
||||||
|
}
|
||||||
|
|
||||||
options := []basepkg.BaseDendriteOptions{}
|
options := []basepkg.BaseDendriteOptions{}
|
||||||
|
|
||||||
base := basepkg.NewBaseDendrite(cfg, options...)
|
base := basepkg.NewBaseDendrite(cfg, options...)
|
||||||
|
@ -92,7 +114,7 @@ func main() {
|
||||||
base.SetupAndServeHTTP(httpAddr, nil, nil)
|
base.SetupAndServeHTTP(httpAddr, nil, nil)
|
||||||
}()
|
}()
|
||||||
// Handle HTTPS if certificate and key are provided
|
// Handle HTTPS if certificate and key are provided
|
||||||
if *certFile != "" && *keyFile != "" {
|
if *unixSocket == "" && *certFile != "" && *keyFile != "" {
|
||||||
go func() {
|
go func() {
|
||||||
base.SetupAndServeHTTP(httpsAddr, certFile, keyFile)
|
base.SetupAndServeHTTP(httpsAddr, certFile, keyFile)
|
||||||
}()
|
}()
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// This is an instrumented main, used when running integration tests (sytest) with code coverage.
|
// This is an instrumented main, used when running integration tests (sytest) with code coverage.
|
||||||
// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server
|
// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite
|
||||||
// Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml
|
// Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml
|
||||||
// Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html
|
// Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html
|
||||||
// Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc
|
// Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc
|
||||||
|
|
|
@ -62,9 +62,10 @@ func main() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stateres := state.NewStateResolution(roomserverDB, &types.RoomInfo{
|
roomInfo := &types.RoomInfo{
|
||||||
RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
|
RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
|
||||||
})
|
}
|
||||||
|
stateres := state.NewStateResolution(roomserverDB, roomInfo)
|
||||||
|
|
||||||
if *difference {
|
if *difference {
|
||||||
if len(snapshotNIDs) != 2 {
|
if len(snapshotNIDs) != 2 {
|
||||||
|
@ -87,7 +88,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
var eventEntries []types.Event
|
var eventEntries []types.Event
|
||||||
eventEntries, err = roomserverDB.Events(ctx, 0, eventNIDs)
|
eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -145,7 +146,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("Fetching", len(eventNIDMap), "state events")
|
fmt.Println("Fetching", len(eventNIDMap), "state events")
|
||||||
eventEntries, err := roomserverDB.Events(ctx, 0, eventNIDs)
|
eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -165,7 +166,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("Fetching", len(authEventIDs), "auth events")
|
fmt.Println("Fetching", len(authEventIDs), "auth events")
|
||||||
authEventEntries, err := roomserverDB.EventsFromIDs(ctx, 0, authEventIDs)
|
authEventEntries, err := roomserverDB.EventsFromIDs(ctx, roomInfo, authEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ Dendrite contains an embedded profiler called `pprof`, which is a part of the st
|
||||||
To enable the profiler, start Dendrite with the `PPROFLISTEN` environment variable. This variable specifies which address and port to listen on, e.g.
|
To enable the profiler, start Dendrite with the `PPROFLISTEN` environment variable. This variable specifies which address and port to listen on, e.g.
|
||||||
|
|
||||||
```
|
```
|
||||||
PPROFLISTEN=localhost:65432 ./bin/dendrite-monolith-server ...
|
PPROFLISTEN=localhost:65432 ./bin/dendrite ...
|
||||||
```
|
```
|
||||||
|
|
||||||
If pprof has been enabled successfully, a log line at startup will show that pprof is listening:
|
If pprof has been enabled successfully, a log line at startup will show that pprof is listening:
|
||||||
|
|
|
@ -14,8 +14,8 @@ index 8f0e209c..ad057e52 100644
|
||||||
|
|
||||||
$output->diag( "Starting monolith server" );
|
$output->diag( "Starting monolith server" );
|
||||||
my @command = (
|
my @command = (
|
||||||
- $self->{bindir} . '/dendrite-monolith-server',
|
- $self->{bindir} . '/dendrite',
|
||||||
+ $self->{bindir} . '/dendrite-monolith-server', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL",
|
+ $self->{bindir} . '/dendrite', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL",
|
||||||
'--config', $self->{paths}{config},
|
'--config', $self->{paths}{config},
|
||||||
'--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port,
|
'--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port,
|
||||||
'--https-bind-address', $self->{bind_host} . ':' . $self->secure_port,
|
'--https-bind-address', $self->{bind_host} . ':' . $self->secure_port,
|
||||||
|
@ -27,9 +27,9 @@ index f009332b..7ea79869 100755
|
||||||
echo >&2 "--- Building dendrite from source"
|
echo >&2 "--- Building dendrite from source"
|
||||||
cd /src
|
cd /src
|
||||||
mkdir -p $GOBIN
|
mkdir -p $GOBIN
|
||||||
-go install -v ./cmd/dendrite-monolith-server
|
-go install -v ./cmd/dendrite
|
||||||
+# go install -v ./cmd/dendrite-monolith-server
|
+# go install -v ./cmd/dendrite
|
||||||
+go test -c -cover -covermode=atomic -o $GOBIN/dendrite-monolith-server -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server
|
+go test -c -cover -covermode=atomic -o $GOBIN/dendrite -coverpkg "github.com/matrix-org/..." ./cmd/dendrite
|
||||||
go install -v ./cmd/generate-keys
|
go install -v ./cmd/generate-keys
|
||||||
cd -
|
cd -
|
||||||
```
|
```
|
||||||
|
|
|
@ -49,7 +49,7 @@ tracing:
|
||||||
then run the monolith server:
|
then run the monolith server:
|
||||||
|
|
||||||
```
|
```
|
||||||
./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml
|
./dendrite --tls-cert server.crt --tls-key server.key --config dendrite.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
## Checking traces
|
## Checking traces
|
||||||
|
|
|
@ -28,11 +28,11 @@ The resulting binaries will be placed in the `bin` subfolder.
|
||||||
You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`:
|
You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
go install ./cmd/dendrite-monolith-server
|
go install ./cmd/dendrite
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, you can specify a custom path for the binary to be written to using `go build`:
|
Alternatively, you can specify a custom path for the binary to be written to using `go build`:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
go build -o /usr/local/bin/ ./cmd/dendrite-monolith-server
|
go build -o /usr/local/bin/ ./cmd/dendrite
|
||||||
```
|
```
|
||||||
|
|
|
@ -11,11 +11,11 @@ permalink: /installation/install/monolith
|
||||||
You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`:
|
You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
go install ./cmd/dendrite-monolith-server
|
go install ./cmd/dendrite
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, you can specify a custom path for the binary to be written to using `go build`:
|
Alternatively, you can specify a custom path for the binary to be written to using `go build`:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
go build -o /usr/local/bin/ ./cmd/dendrite-monolith-server
|
go build -o /usr/local/bin/ ./cmd/dendrite
|
||||||
```
|
```
|
||||||
|
|
|
@ -9,10 +9,10 @@ permalink: /installation/start/monolith
|
||||||
# Starting the monolith
|
# Starting the monolith
|
||||||
|
|
||||||
Once you have completed all of the preparation and installation steps,
|
Once you have completed all of the preparation and installation steps,
|
||||||
you can start your Dendrite monolith deployment by starting the `dendrite-monolith-server`:
|
you can start your Dendrite monolith deployment by starting `dendrite`:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./dendrite-monolith-server -config /path/to/dendrite.yaml
|
./dendrite -config /path/to/dendrite.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses
|
By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses
|
||||||
|
@ -20,7 +20,7 @@ or ports that Dendrite listens on, you can use the `-http-bind-address` and
|
||||||
`-https-bind-address` command line arguments:
|
`-https-bind-address` command line arguments:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./dendrite-monolith-server -config /path/to/dendrite.yaml \
|
./dendrite -config /path/to/dendrite.yaml \
|
||||||
-http-bind-address 1.2.3.4:12345 \
|
-http-bind-address 1.2.3.4:12345 \
|
||||||
-https-bind-address 1.2.3.4:54321
|
-https-bind-address 1.2.3.4:54321
|
||||||
```
|
```
|
||||||
|
|
|
@ -11,7 +11,7 @@ Type=simple
|
||||||
User=dendrite
|
User=dendrite
|
||||||
Group=dendrite
|
Group=dendrite
|
||||||
WorkingDirectory=/opt/dendrite/
|
WorkingDirectory=/opt/dendrite/
|
||||||
ExecStart=/opt/dendrite/bin/dendrite-monolith-server
|
ExecStart=/opt/dendrite/bin/dendrite
|
||||||
Restart=always
|
Restart=always
|
||||||
LimitNOFILE=65535
|
LimitNOFILE=65535
|
||||||
|
|
||||||
|
|
|
@ -173,6 +173,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
|
||||||
// Finally, work out if there are any more events missing.
|
// Finally, work out if there are any more events missing.
|
||||||
if len(missingEventIDs) > 0 {
|
if len(missingEventIDs) > 0 {
|
||||||
eventsReq := &api.QueryEventsByIDRequest{
|
eventsReq := &api.QueryEventsByIDRequest{
|
||||||
|
RoomID: ore.Event.RoomID(),
|
||||||
EventIDs: missingEventIDs,
|
EventIDs: missingEventIDs,
|
||||||
}
|
}
|
||||||
eventsRes := &api.QueryEventsByIDResponse{}
|
eventsRes := &api.QueryEventsByIDResponse{}
|
||||||
|
@ -483,7 +484,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents(
|
||||||
// At this point the missing events are neither the event itself nor are
|
// At this point the missing events are neither the event itself nor are
|
||||||
// they present in our local database. Our only option is to fetch them
|
// they present in our local database. Our only option is to fetch them
|
||||||
// from the roomserver using the query API.
|
// from the roomserver using the query API.
|
||||||
eventReq := api.QueryEventsByIDRequest{EventIDs: missing}
|
eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID()}
|
||||||
var eventResp api.QueryEventsByIDResponse
|
var eventResp api.QueryEventsByIDResponse
|
||||||
if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil {
|
if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -36,7 +36,7 @@ func GetEventAuth(
|
||||||
return *err
|
return *err
|
||||||
}
|
}
|
||||||
|
|
||||||
event, resErr := fetchEvent(ctx, rsAPI, eventID)
|
event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,10 +20,11 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetEvent returns the requested event
|
// GetEvent returns the requested event
|
||||||
|
@ -38,7 +39,9 @@ func GetEvent(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return *err
|
return *err
|
||||||
}
|
}
|
||||||
event, err := fetchEvent(ctx, rsAPI, eventID)
|
// /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string,
|
||||||
|
// which results in `QueryEventsByID` to first get the event and use that to determine the roomID.
|
||||||
|
event, err := fetchEvent(ctx, rsAPI, "", eventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return *err
|
return *err
|
||||||
}
|
}
|
||||||
|
@ -60,21 +63,13 @@ func allowedToSeeEvent(
|
||||||
rsAPI api.FederationRoomserverAPI,
|
rsAPI api.FederationRoomserverAPI,
|
||||||
eventID string,
|
eventID string,
|
||||||
) *util.JSONResponse {
|
) *util.JSONResponse {
|
||||||
var authResponse api.QueryServerAllowedToSeeEventResponse
|
allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID)
|
||||||
err := rsAPI.QueryServerAllowedToSeeEvent(
|
|
||||||
ctx,
|
|
||||||
&api.QueryServerAllowedToSeeEventRequest{
|
|
||||||
EventID: eventID,
|
|
||||||
ServerName: origin,
|
|
||||||
},
|
|
||||||
&authResponse,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resErr := util.ErrorResponse(err)
|
resErr := util.ErrorResponse(err)
|
||||||
return &resErr
|
return &resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if !authResponse.AllowedToSeeEvent {
|
if !allowed {
|
||||||
resErr := util.MessageResponse(http.StatusForbidden, "server not allowed to see event")
|
resErr := util.MessageResponse(http.StatusForbidden, "server not allowed to see event")
|
||||||
return &resErr
|
return &resErr
|
||||||
}
|
}
|
||||||
|
@ -83,11 +78,11 @@ func allowedToSeeEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found.
|
// fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found.
|
||||||
func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) {
|
func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) {
|
||||||
var eventsResponse api.QueryEventsByIDResponse
|
var eventsResponse api.QueryEventsByIDResponse
|
||||||
err := rsAPI.QueryEventsByID(
|
err := rsAPI.QueryEventsByID(
|
||||||
ctx,
|
ctx,
|
||||||
&api.QueryEventsByIDRequest{EventIDs: []string{eventID}},
|
&api.QueryEventsByIDRequest{EventIDs: []string{eventID}, RoomID: roomID},
|
||||||
&eventsResponse,
|
&eventsResponse,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -107,7 +107,7 @@ func getState(
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
event, resErr := fetchEvent(ctx, rsAPI, eventID)
|
event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return nil, nil, resErr
|
return nil, nil, resErr
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,9 @@
|
||||||
// Hooks can only be run in monolith mode.
|
// Hooks can only be run in monolith mode.
|
||||||
package hooks
|
package hooks
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent
|
// KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent
|
||||||
|
|
|
@ -54,7 +54,8 @@ type QueryBulkStateContentAPI interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryEventsAPI interface {
|
type QueryEventsAPI interface {
|
||||||
// Query a list of events by event ID.
|
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
|
||||||
|
// which room to use by querying the first events roomID.
|
||||||
QueryEventsByID(
|
QueryEventsByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *QueryEventsByIDRequest,
|
req *QueryEventsByIDRequest,
|
||||||
|
@ -71,7 +72,8 @@ type SyncRoomserverAPI interface {
|
||||||
QueryBulkStateContentAPI
|
QueryBulkStateContentAPI
|
||||||
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
||||||
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
|
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
|
||||||
// Query a list of events by event ID.
|
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
|
||||||
|
// which room to use by querying the first events roomID.
|
||||||
QueryEventsByID(
|
QueryEventsByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *QueryEventsByIDRequest,
|
req *QueryEventsByIDRequest,
|
||||||
|
@ -108,7 +110,8 @@ type SyncRoomserverAPI interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppserviceRoomserverAPI interface {
|
type AppserviceRoomserverAPI interface {
|
||||||
// Query a list of events by event ID.
|
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
|
||||||
|
// which room to use by querying the first events roomID.
|
||||||
QueryEventsByID(
|
QueryEventsByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *QueryEventsByIDRequest,
|
req *QueryEventsByIDRequest,
|
||||||
|
@ -182,6 +185,8 @@ type FederationRoomserverAPI interface {
|
||||||
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
|
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
|
||||||
QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error
|
QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error
|
||||||
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
|
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
|
||||||
|
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
|
||||||
|
// which room to use by querying the first events roomID.
|
||||||
QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error
|
QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error
|
||||||
// Query to get state and auth chain for a (potentially hypothetical) event.
|
// Query to get state and auth chain for a (potentially hypothetical) event.
|
||||||
// Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate
|
// Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate
|
||||||
|
@ -193,7 +198,7 @@ type FederationRoomserverAPI interface {
|
||||||
// Query missing events for a room from roomserver
|
// Query missing events for a room from roomserver
|
||||||
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
|
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
|
||||||
// Query whether a server is allowed to see an event
|
// Query whether a server is allowed to see an event
|
||||||
QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error
|
QueryServerAllowedToSeeEvent(ctx context.Context, serverName gomatrixserverlib.ServerName, eventID string) (allowed bool, err error)
|
||||||
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
|
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
|
||||||
QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error
|
QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error
|
||||||
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
||||||
|
|
|
@ -86,6 +86,9 @@ type QueryStateAfterEventsResponse struct {
|
||||||
|
|
||||||
// QueryEventsByIDRequest is a request to QueryEventsByID
|
// QueryEventsByIDRequest is a request to QueryEventsByID
|
||||||
type QueryEventsByIDRequest struct {
|
type QueryEventsByIDRequest struct {
|
||||||
|
// The roomID to query events for. If this is empty, we first try to fetch the roomID from the database
|
||||||
|
// as this is needed for further processing/parsing events.
|
||||||
|
RoomID string `json:"room_id"`
|
||||||
// The event IDs to look up.
|
// The event IDs to look up.
|
||||||
EventIDs []string `json:"event_ids"`
|
EventIDs []string `json:"event_ids"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,9 +108,10 @@ func SendInputRoomEvents(
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEvent returns the event or nil, even on errors.
|
// GetEvent returns the event or nil, even on errors.
|
||||||
func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, eventID string) *gomatrixserverlib.HeaderedEvent {
|
func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *gomatrixserverlib.HeaderedEvent {
|
||||||
var res QueryEventsByIDResponse
|
var res QueryEventsByIDResponse
|
||||||
err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{
|
err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{
|
||||||
|
RoomID: roomID,
|
||||||
EventIDs: []string{eventID},
|
EventIDs: []string{eventID},
|
||||||
}, &res)
|
}, &res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -67,7 +67,7 @@ func CheckForSoftFail(
|
||||||
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
|
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
|
||||||
|
|
||||||
// Load the actual auth events from the database.
|
// Load the actual auth events from the database.
|
||||||
authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries)
|
authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, fmt.Errorf("loadAuthEvents: %w", err)
|
return true, fmt.Errorf("loadAuthEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -85,7 +85,7 @@ func CheckForSoftFail(
|
||||||
func CheckAuthEvents(
|
func CheckAuthEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db storage.RoomDatabase,
|
db storage.RoomDatabase,
|
||||||
roomNID types.RoomNID,
|
roomInfo *types.RoomInfo,
|
||||||
event *gomatrixserverlib.HeaderedEvent,
|
event *gomatrixserverlib.HeaderedEvent,
|
||||||
authEventIDs []string,
|
authEventIDs []string,
|
||||||
) ([]types.EventNID, error) {
|
) ([]types.EventNID, error) {
|
||||||
|
@ -100,7 +100,7 @@ func CheckAuthEvents(
|
||||||
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
|
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
|
||||||
|
|
||||||
// Load the actual auth events from the database.
|
// Load the actual auth events from the database.
|
||||||
authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries)
|
authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loadAuthEvents: %w", err)
|
return nil, fmt.Errorf("loadAuthEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -193,7 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
|
||||||
func loadAuthEvents(
|
func loadAuthEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db state.StateResolutionStorage,
|
db state.StateResolutionStorage,
|
||||||
roomNID types.RoomNID,
|
roomInfo *types.RoomInfo,
|
||||||
needed gomatrixserverlib.StateNeeded,
|
needed gomatrixserverlib.StateNeeded,
|
||||||
state []types.StateEntry,
|
state []types.StateEntry,
|
||||||
) (result authEvents, err error) {
|
) (result authEvents, err error) {
|
||||||
|
@ -216,7 +216,7 @@ func loadAuthEvents(
|
||||||
eventNIDs = append(eventNIDs, eventNID)
|
eventNIDs = append(eventNIDs, eventNID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil {
|
if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
roomID := ""
|
roomID := ""
|
||||||
|
|
|
@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
events, err := db.Events(ctx, info.RoomNID, eventNIDs)
|
events, err := db.Events(ctx, info, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -157,7 +157,7 @@ func IsInvitePending(
|
||||||
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
||||||
// Returns an error if there was an issue fetching the events.
|
// Returns an error if there was an issue fetching the events.
|
||||||
func GetMembershipsAtState(
|
func GetMembershipsAtState(
|
||||||
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool,
|
ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, joinedOnly bool,
|
||||||
) ([]types.Event, error) {
|
) ([]types.Event, error) {
|
||||||
|
|
||||||
var eventNIDs types.EventNIDs
|
var eventNIDs types.EventNIDs
|
||||||
|
@ -177,7 +177,7 @@ func GetMembershipsAtState(
|
||||||
util.Unique(eventNIDs)
|
util.Unique(eventNIDs)
|
||||||
|
|
||||||
// Get all of the events in this state
|
// Get all of the events in this state
|
||||||
stateEvents, err := db.Events(ctx, roomNID, eventNIDs)
|
stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -227,9 +227,9 @@ func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadEvents(
|
func LoadEvents(
|
||||||
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID,
|
ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID,
|
||||||
) ([]*gomatrixserverlib.Event, error) {
|
) ([]*gomatrixserverlib.Event, error) {
|
||||||
stateEvents, err := db.Events(ctx, roomNID, eventNIDs)
|
stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -242,13 +242,13 @@ func LoadEvents(
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadStateEvents(
|
func LoadStateEvents(
|
||||||
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry,
|
ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
|
||||||
) ([]*gomatrixserverlib.Event, error) {
|
) ([]*gomatrixserverlib.Event, error) {
|
||||||
eventNIDs := make([]types.EventNID, len(stateEntries))
|
eventNIDs := make([]types.EventNID, len(stateEntries))
|
||||||
for i := range stateEntries {
|
for i := range stateEntries {
|
||||||
eventNIDs[i] = stateEntries[i].EventNID
|
eventNIDs[i] = stateEntries[i].EventNID
|
||||||
}
|
}
|
||||||
return LoadEvents(ctx, db, roomNID, eventNIDs)
|
return LoadEvents(ctx, db, roomInfo, eventNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckServerAllowedToSeeEvent(
|
func CheckServerAllowedToSeeEvent(
|
||||||
|
@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState(
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries)
|
return LoadStateEvents(ctx, db, info, filteredEntries)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Remove this when we have tests to assert correctness of this function
|
// TODO: Remove this when we have tests to assert correctness of this function
|
||||||
|
@ -366,7 +366,7 @@ BFSLoop:
|
||||||
next = make([]string, 0)
|
next = make([]string, 0)
|
||||||
}
|
}
|
||||||
// Retrieve the events to process from the database.
|
// Retrieve the events to process from the database.
|
||||||
events, err = db.EventsFromIDs(ctx, info.RoomNID, front)
|
events, err = db.EventsFromIDs(ctx, info, front)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resultNIDs, redactEventIDs, err
|
return resultNIDs, redactEventIDs, err
|
||||||
}
|
}
|
||||||
|
@ -467,7 +467,7 @@ func QueryLatestEventsAndState(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries)
|
stateEvents, err := LoadStateEvents(ctx, db, roomInfo, stateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,9 +4,10 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
@ -38,9 +39,9 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
|
||||||
var authNIDs []types.EventNID
|
var authNIDs []types.EventNID
|
||||||
for _, x := range room.Events() {
|
for _, x := range room.Events() {
|
||||||
|
|
||||||
roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap())
|
roomInfo, err := db.GetOrCreateRoomInfo(context.Background(), x.Unwrap())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Greater(t, roomNID, types.RoomNID(0))
|
assert.NotNil(t, roomInfo)
|
||||||
|
|
||||||
eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type())
|
eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -49,7 +50,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
|
||||||
eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey())
|
eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false)
|
evNID, _, err := db.StoreEvent(context.Background(), x.Event, roomInfo, eventTypeNID, eventStateKeyNID, authNIDs, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
authNIDs = append(authNIDs, evNID)
|
authNIDs = append(authNIDs, evNID)
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,9 +24,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/opentracing/opentracing-go"
|
"github.com/opentracing/opentracing-go"
|
||||||
|
@ -274,8 +275,10 @@ func (r *Inputer) processRoomEvent(
|
||||||
|
|
||||||
// Check if the event is allowed by its auth events. If it isn't then
|
// Check if the event is allowed by its auth events. If it isn't then
|
||||||
// we consider the event to be "rejected" — it will still be persisted.
|
// we consider the event to be "rejected" — it will still be persisted.
|
||||||
|
redactAllowed := true
|
||||||
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
|
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
|
||||||
isRejected = true
|
isRejected = true
|
||||||
|
redactAllowed = false
|
||||||
rejectionErr = err
|
rejectionErr = err
|
||||||
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
|
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
|
||||||
}
|
}
|
||||||
|
@ -323,7 +326,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
// burning CPU time.
|
// burning CPU time.
|
||||||
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
|
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
|
||||||
if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent {
|
if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent {
|
||||||
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo.RoomNID, input, missingPrev)
|
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo, input, missingPrev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.processStateBefore: %w", err)
|
return fmt.Errorf("r.processStateBefore: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -332,9 +335,11 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
roomNID, err := r.DB.GetOrCreateRoomNID(ctx, event)
|
if roomInfo == nil {
|
||||||
if err != nil {
|
roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, event)
|
||||||
return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err)
|
if err != nil {
|
||||||
|
return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type())
|
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type())
|
||||||
|
@ -348,15 +353,24 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the event.
|
// Store the event.
|
||||||
_, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
|
eventNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("updater.StoreEvent: %w", err)
|
return fmt.Errorf("updater.StoreEvent: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if storing this event results in it being redacted then do so.
|
// if storing this event results in it being redacted then do so.
|
||||||
if !isRejected && redactedEventID == event.EventID() {
|
var (
|
||||||
if err = eventutil.RedactEvent(redactionEvent, event); err != nil {
|
redactedEventID string
|
||||||
return fmt.Errorf("eventutil.RedactEvent: %w", rerr)
|
redactionEvent *gomatrixserverlib.Event
|
||||||
|
redactedEvent *gomatrixserverlib.Event
|
||||||
|
)
|
||||||
|
if !isRejected && !isCreateEvent {
|
||||||
|
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, redactAllowed)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if redactedEvent != nil {
|
||||||
|
redactedEventID = redactedEvent.EventID()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -489,7 +503,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse
|
||||||
// nolint:nakedret
|
// nolint:nakedret
|
||||||
func (r *Inputer) processStateBefore(
|
func (r *Inputer) processStateBefore(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID,
|
roomInfo *types.RoomInfo,
|
||||||
input *api.InputRoomEvent,
|
input *api.InputRoomEvent,
|
||||||
missingPrev bool,
|
missingPrev bool,
|
||||||
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
|
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
|
||||||
|
@ -505,7 +519,7 @@ func (r *Inputer) processStateBefore(
|
||||||
case input.HasState:
|
case input.HasState:
|
||||||
// If we're overriding the state then we need to go and retrieve
|
// If we're overriding the state then we need to go and retrieve
|
||||||
// them from the database. It's a hard error if they are missing.
|
// them from the database. It's a hard error if they are missing.
|
||||||
stateEvents, err := r.DB.EventsFromIDs(ctx, roomNID, input.StateEventIDs)
|
stateEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, input.StateEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err)
|
return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -604,7 +618,7 @@ func (r *Inputer) fetchAuthEvents(
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, authEventID := range authEventIDs {
|
for _, authEventID := range authEventIDs {
|
||||||
authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo.RoomNID, []string{authEventID})
|
authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, []string{authEventID})
|
||||||
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
|
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
|
||||||
unknown[authEventID] = struct{}{}
|
unknown[authEventID] = struct{}{}
|
||||||
continue
|
continue
|
||||||
|
@ -690,9 +704,11 @@ nextAuthEvent:
|
||||||
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
|
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
|
||||||
}
|
}
|
||||||
|
|
||||||
roomNID, err := r.DB.GetOrCreateRoomNID(ctx, authEvent)
|
if roomInfo == nil {
|
||||||
if err != nil {
|
roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, authEvent)
|
||||||
return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err)
|
if err != nil {
|
||||||
|
return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type())
|
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type())
|
||||||
|
@ -706,7 +722,7 @@ nextAuthEvent:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, store the event in the database.
|
// Finally, store the event in the database.
|
||||||
eventNID, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
|
eventNID, _, err := r.DB.StoreEvent(ctx, authEvent, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("updater.StoreEvent: %w", err)
|
return fmt.Errorf("updater.StoreEvent: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -782,7 +798,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, membershipNIDs)
|
memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,7 +53,7 @@ func (r *Inputer) updateMemberships(
|
||||||
// Load the event JSON so we can look up the "membership" key.
|
// Load the event JSON so we can look up the "membership" key.
|
||||||
// TODO: Maybe add a membership key to the events table so we can load that
|
// TODO: Maybe add a membership key to the events table so we can load that
|
||||||
// key without having to load the entire event JSON?
|
// key without having to load the entire event JSON?
|
||||||
events, err := updater.Events(ctx, 0, eventNIDs)
|
events, err := updater.Events(ctx, nil, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
|
||||||
for _, entry := range stateEntries {
|
for _, entry := range stateEntries {
|
||||||
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
||||||
}
|
}
|
||||||
stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomNID, stateEventNIDs)
|
stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.log.WithError(err).Warnf("failed to load state events locally")
|
t.log.WithError(err).Warnf("failed to load state events locally")
|
||||||
return nil
|
return nil
|
||||||
|
@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
|
||||||
missingEventList = append(missingEventList, evID)
|
missingEventList = append(missingEventList, evID)
|
||||||
}
|
}
|
||||||
t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events")
|
t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events")
|
||||||
events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList)
|
events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
|
||||||
}
|
}
|
||||||
t.haveEventsMutex.Unlock()
|
t.haveEventsMutex.Unlock()
|
||||||
|
|
||||||
events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList)
|
events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err)
|
return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -844,7 +844,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
|
||||||
|
|
||||||
if localFirst {
|
if localFirst {
|
||||||
// fetch from the roomserver
|
// fetch from the roomserver
|
||||||
events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, []string{missingEventID})
|
events, err := t.db.EventsFromIDs(ctx, t.roomInfo, []string{missingEventID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
|
t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
|
||||||
} else if len(events) == 1 {
|
} else if len(events) == 1 {
|
||||||
|
|
|
@ -70,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, memberNIDs)
|
memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.PerformError{
|
res.Error = &api.PerformError{
|
||||||
Code: api.PerformErrorBadRequest,
|
Code: api.PerformErrorBadRequest,
|
||||||
|
|
|
@ -23,7 +23,6 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/auth"
|
"github.com/matrix-org/dendrite/roomserver/auth"
|
||||||
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
|
@ -86,7 +85,7 @@ func (r *Backfiller) PerformBackfill(
|
||||||
// Retrieve events from the list that was filled previously. If we fail to get
|
// Retrieve events from the list that was filled previously. If we fail to get
|
||||||
// events from the database then attempt once to get them from federation instead.
|
// events from the database then attempt once to get them from federation instead.
|
||||||
var loadedEvents []*gomatrixserverlib.Event
|
var loadedEvents []*gomatrixserverlib.Event
|
||||||
loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
|
loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(types.MissingEventError); ok {
|
if _, ok := err.(types.MissingEventError); ok {
|
||||||
return r.backfillViaFederation(ctx, request, response)
|
return r.backfillViaFederation(ctx, request, response)
|
||||||
|
@ -473,7 +472,7 @@ FindSuccessor:
|
||||||
// Retrieve all "m.room.member" state events of "join" membership, which
|
// Retrieve all "m.room.member" state events of "join" membership, which
|
||||||
// contains the list of users in the room before the event, therefore all
|
// contains the list of users in the room before the event, therefore all
|
||||||
// the servers in it at that moment.
|
// the servers in it at that moment.
|
||||||
memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info.RoomNID, stateEntries, true)
|
memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info, stateEntries, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
|
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
|
||||||
return nil
|
return nil
|
||||||
|
@ -532,7 +531,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
|
||||||
roomNID = nid.RoomNID
|
roomNID = nid.RoomNID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
eventsWithNids, err := b.db.Events(ctx, roomNID, eventNIDs)
|
eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
|
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -562,7 +561,7 @@ func joinEventsFromHistoryVisibility(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all of the events in this state
|
// Get all of the events in this state
|
||||||
stateEvents, err := db.Events(ctx, roomInfo.RoomNID, eventNIDs)
|
stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// even though the default should be shared, restricting the visibility to joined
|
// even though the default should be shared, restricting the visibility to joined
|
||||||
// feels more secure here.
|
// feels more secure here.
|
||||||
|
@ -585,7 +584,7 @@ func joinEventsFromHistoryVisibility(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, visibility, err
|
return nil, visibility, err
|
||||||
}
|
}
|
||||||
evs, err := db.Events(ctx, roomInfo.RoomNID, joinEventNIDs)
|
evs, err := db.Events(ctx, roomInfo, joinEventNIDs)
|
||||||
return evs, visibility, err
|
return evs, visibility, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -606,7 +605,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
roomNID, err = db.GetOrCreateRoomNID(ctx, ev.Unwrap())
|
roomInfo, err := db.GetOrCreateRoomInfo(ctx, ev.Unwrap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("failed to get or create roomNID")
|
logrus.WithError(err).Error("failed to get or create roomNID")
|
||||||
continue
|
continue
|
||||||
|
@ -624,23 +623,22 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var redactedEventID string
|
eventNID, _, err = db.StoreEvent(ctx, ev.Unwrap(), roomInfo, eventTypeNID, eventStateKeyNID, authNids, false)
|
||||||
var redactionEvent *gomatrixserverlib.Event
|
|
||||||
eventNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), roomNID, eventTypeNID, eventStateKeyNID, authNids, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Unwrap(), true)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
|
||||||
|
continue
|
||||||
|
}
|
||||||
// If storing this event results in it being redacted, then do so.
|
// If storing this event results in it being redacted, then do so.
|
||||||
// It's also possible for this event to be a redaction which results in another event being
|
// It's also possible for this event to be a redaction which results in another event being
|
||||||
// redacted, which we don't care about since we aren't returning it in this backfill.
|
// redacted, which we don't care about since we aren't returning it in this backfill.
|
||||||
if redactedEventID == ev.EventID() {
|
if redactedEvent != nil && redactedEvent.EventID() == ev.EventID() {
|
||||||
eventToRedact := ev.Unwrap()
|
ev = redactedEvent.Headered(ev.RoomVersion)
|
||||||
if err := eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil {
|
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ev = eventToRedact.Headered(ev.RoomVersion)
|
|
||||||
events[j] = ev
|
events[j] = ev
|
||||||
}
|
}
|
||||||
backfilledEventMap[ev.EventID()] = types.Event{
|
backfilledEventMap[ev.EventID()] = types.Event{
|
||||||
|
|
|
@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
latestEvents, err := r.DB.EventsFromIDs(ctx, info.RoomNID, []string{latestEventRefs[0].EventID})
|
latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0].EventID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
|
stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -100,7 +100,7 @@ func (r *InboundPeeker) PerformInboundPeek(
|
||||||
}
|
}
|
||||||
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
|
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
|
||||||
|
|
||||||
authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
|
authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite(
|
||||||
// try and see if the user is allowed to make this invite. We can't do
|
// try and see if the user is allowed to make this invite. We can't do
|
||||||
// this for invites coming in over federation - we have to take those on
|
// this for invites coming in over federation - we have to take those on
|
||||||
// trust.
|
// trust.
|
||||||
_, err = helpers.CheckAuthEvents(ctx, r.DB, info.RoomNID, event, event.AuthEventIDs())
|
_, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
||||||
"processInviteEvent.checkAuthEvents failed for event",
|
"processInviteEvent.checkAuthEvents failed for event",
|
||||||
|
@ -291,7 +291,7 @@ func buildInviteStrippedState(
|
||||||
for _, stateNID := range stateEntries {
|
for _, stateNID := range stateEntries {
|
||||||
stateNIDs = append(stateNIDs, stateNID.EventNID)
|
stateNIDs = append(stateNIDs, stateNID.EventNID)
|
||||||
}
|
}
|
||||||
stateEvents, err := db.Events(ctx, info.RoomNID, stateNIDs)
|
stateEvents, err := db.Events(ctx, info, stateNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,11 +21,12 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/roomserver/acls"
|
"github.com/matrix-org/dendrite/roomserver/acls"
|
||||||
|
@ -102,7 +103,7 @@ func (r *Queryer) QueryStateAfterEvents(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
|
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -114,7 +115,7 @@ func (r *Queryer) QueryStateAfterEvents(
|
||||||
}
|
}
|
||||||
authEventIDs = util.UniqueStrings(authEventIDs)
|
authEventIDs = util.UniqueStrings(authEventIDs)
|
||||||
|
|
||||||
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
|
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("getAuthChain: %w", err)
|
return fmt.Errorf("getAuthChain: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -132,24 +133,46 @@ func (r *Queryer) QueryStateAfterEvents(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryEventsByID implements api.RoomserverInternalAPI
|
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
|
||||||
|
// which room to use by querying the first events roomID.
|
||||||
func (r *Queryer) QueryEventsByID(
|
func (r *Queryer) QueryEventsByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryEventsByIDRequest,
|
request *api.QueryEventsByIDRequest,
|
||||||
response *api.QueryEventsByIDResponse,
|
response *api.QueryEventsByIDResponse,
|
||||||
) error {
|
) error {
|
||||||
events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs)
|
if len(request.EventIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
// We didn't receive a room ID, we need to fetch it first before we can continue.
|
||||||
|
// This happens for e.g. ` /_matrix/federation/v1/event/{eventId}`
|
||||||
|
var roomInfo *types.RoomInfo
|
||||||
|
if request.RoomID == "" {
|
||||||
|
var eventNIDs map[string]types.EventMetadata
|
||||||
|
eventNIDs, err = r.DB.EventNIDs(ctx, []string{request.EventIDs[0]})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(eventNIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
roomInfo, err = r.DB.RoomInfoByNID(ctx, eventNIDs[request.EventIDs[0]].RoomNID)
|
||||||
|
} else {
|
||||||
|
roomInfo, err = r.DB.RoomInfo(ctx, request.RoomID)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if roomInfo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
events, err := r.DB.EventsFromIDs(ctx, roomInfo, request.EventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
roomVersion, verr := r.roomVersion(event.RoomID())
|
response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion))
|
||||||
if verr != nil {
|
|
||||||
return verr
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Events = append(response.Events, event.Headered(roomVersion))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -186,7 +209,7 @@ func (r *Queryer) QueryMembershipForUser(
|
||||||
response.IsInRoom = stillInRoom
|
response.IsInRoom = stillInRoom
|
||||||
response.HasBeenInRoom = true
|
response.HasBeenInRoom = true
|
||||||
|
|
||||||
evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID})
|
evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -268,10 +291,10 @@ func (r *Queryer) QueryMembershipAtEvent(
|
||||||
// once. If we have more than one membership event, we need to get the state for each state entry.
|
// once. If we have more than one membership event, we need to get the state for each state entry.
|
||||||
if canShortCircuit {
|
if canShortCircuit {
|
||||||
if len(memberships) == 0 {
|
if len(memberships) == 0 {
|
||||||
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
|
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
|
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to get memberships at state: %w", err)
|
return fmt.Errorf("unable to get memberships at state: %w", err)
|
||||||
|
@ -318,7 +341,7 @@ func (r *Queryer) QueryMembershipsForRoom(
|
||||||
}
|
}
|
||||||
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
|
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
|
||||||
}
|
}
|
||||||
events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
|
events, err = r.DB.Events(ctx, info, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.DB.Events: %w", err)
|
return fmt.Errorf("r.DB.Events: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -357,14 +380,14 @@ func (r *Queryer) QueryMembershipsForRoom(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
|
events, err = r.DB.Events(ctx, info, eventNIDs)
|
||||||
} else {
|
} else {
|
||||||
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
|
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
|
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly)
|
events, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntries, request.JoinedOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -412,39 +435,39 @@ func (r *Queryer) QueryServerJoinedToRoom(
|
||||||
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
|
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
|
||||||
func (r *Queryer) QueryServerAllowedToSeeEvent(
|
func (r *Queryer) QueryServerAllowedToSeeEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryServerAllowedToSeeEventRequest,
|
serverName gomatrixserverlib.ServerName,
|
||||||
response *api.QueryServerAllowedToSeeEventResponse,
|
eventID string,
|
||||||
) (err error) {
|
) (allowed bool, err error) {
|
||||||
events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID})
|
events, err := r.DB.EventNIDs(ctx, []string{eventID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(events) == 0 {
|
if len(events) == 0 {
|
||||||
response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see
|
return allowed, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
roomID := events[0].RoomID()
|
info, err := r.DB.RoomInfoByNID(ctx, events[eventID].RoomNID)
|
||||||
|
|
||||||
inRoomReq := &api.QueryServerJoinedToRoomRequest{
|
|
||||||
RoomID: roomID,
|
|
||||||
ServerName: request.ServerName,
|
|
||||||
}
|
|
||||||
inRoomRes := &api.QueryServerJoinedToRoomResponse{}
|
|
||||||
if err = r.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil {
|
|
||||||
return fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
info, err := r.DB.RoomInfo(ctx, roomID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return allowed, err
|
||||||
}
|
}
|
||||||
if info == nil || info.IsStub() {
|
if info == nil || info.IsStub() {
|
||||||
return nil
|
return allowed, nil
|
||||||
}
|
}
|
||||||
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
|
var isInRoom bool
|
||||||
ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom,
|
if r.IsLocalServerName(serverName) || serverName == "" {
|
||||||
|
isInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID)
|
||||||
|
if err != nil {
|
||||||
|
return allowed, fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
isInRoom, err = r.DB.GetServerInRoom(ctx, info.RoomNID, serverName)
|
||||||
|
if err != nil {
|
||||||
|
return allowed, fmt.Errorf("r.DB.GetServerInRoom: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return helpers.CheckServerAllowedToSeeEvent(
|
||||||
|
ctx, r.DB, info, eventID, serverName, isInRoom,
|
||||||
)
|
)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryMissingEvents implements api.RoomserverInternalAPI
|
// QueryMissingEvents implements api.RoomserverInternalAPI
|
||||||
|
@ -466,19 +489,22 @@ func (r *Queryer) QueryMissingEvents(
|
||||||
eventsToFilter[id] = true
|
eventsToFilter[id] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
events, err := r.DB.EventsFromIDs(ctx, 0, front)
|
if len(front) == 0 {
|
||||||
|
return nil // no events to query, give up.
|
||||||
|
}
|
||||||
|
events, err := r.DB.EventNIDs(ctx, []string{front[0]})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(events) == 0 {
|
if len(events) == 0 {
|
||||||
return nil // we are missing the events being asked to search from, give up.
|
return nil // we are missing the events being asked to search from, give up.
|
||||||
}
|
}
|
||||||
info, err := r.DB.RoomInfo(ctx, events[0].RoomID())
|
info, err := r.DB.RoomInfoByNID(ctx, events[front[0]].RoomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if info == nil || info.IsStub() {
|
if info == nil || info.IsStub() {
|
||||||
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
|
return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
|
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
|
||||||
|
@ -486,7 +512,7 @@ func (r *Queryer) QueryMissingEvents(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
|
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -529,7 +555,7 @@ func (r *Queryer) QueryStateAndAuthChain(
|
||||||
// TODO: this probably means it should be a different query operation...
|
// TODO: this probably means it should be a different query operation...
|
||||||
if request.OnlyFetchAuthChain {
|
if request.OnlyFetchAuthChain {
|
||||||
var authEvents []*gomatrixserverlib.Event
|
var authEvents []*gomatrixserverlib.Event
|
||||||
authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, request.AuthEventIDs)
|
authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -556,7 +582,7 @@ func (r *Queryer) QueryStateAndAuthChain(
|
||||||
}
|
}
|
||||||
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
|
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
|
||||||
|
|
||||||
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
|
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -611,18 +637,18 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
|
||||||
return nil, rejected, false, err
|
return nil, rejected, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries)
|
events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo, stateEntries)
|
||||||
return events, rejected, false, err
|
return events, rejected, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error)
|
type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Event, error)
|
||||||
|
|
||||||
// GetAuthChain fetches the auth chain for the given auth events. An auth chain
|
// GetAuthChain fetches the auth chain for the given auth events. An auth chain
|
||||||
// is the list of all events that are referenced in the auth_events section, and
|
// is the list of all events that are referenced in the auth_events section, and
|
||||||
// all their auth_events, recursively. The returned set of events contain the
|
// all their auth_events, recursively. The returned set of events contain the
|
||||||
// given events. Will *not* error if we don't have all auth events.
|
// given events. Will *not* error if we don't have all auth events.
|
||||||
func GetAuthChain(
|
func GetAuthChain(
|
||||||
ctx context.Context, fn eventsFromIDs, authEventIDs []string,
|
ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string,
|
||||||
) ([]*gomatrixserverlib.Event, error) {
|
) ([]*gomatrixserverlib.Event, error) {
|
||||||
// List of event IDs to fetch. On each pass, these events will be requested
|
// List of event IDs to fetch. On each pass, these events will be requested
|
||||||
// from the database and the `eventsToFetch` will be updated with any new
|
// from the database and the `eventsToFetch` will be updated with any new
|
||||||
|
@ -633,7 +659,7 @@ func GetAuthChain(
|
||||||
|
|
||||||
for len(eventsToFetch) > 0 {
|
for len(eventsToFetch) > 0 {
|
||||||
// Try to retrieve the events from the database.
|
// Try to retrieve the events from the database.
|
||||||
events, err := fn(ctx, 0, eventsToFetch)
|
events, err := fn(ctx, roomInfo, eventsToFetch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -852,7 +878,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error {
|
func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error {
|
||||||
chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs)
|
chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, nil, req.EventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -971,7 +997,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
|
||||||
// For each of the joined users, let's see if we can get a valid
|
// For each of the joined users, let's see if we can get a valid
|
||||||
// membership event.
|
// membership event.
|
||||||
for _, joinNID := range joinNIDs {
|
for _, joinNID := range joinNIDs {
|
||||||
events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID})
|
events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID})
|
||||||
if err != nil || len(events) != 1 {
|
if err != nil || len(events) != 1 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventsFromIDs implements RoomserverInternalAPIEventDB
|
// EventsFromIDs implements RoomserverInternalAPIEventDB
|
||||||
func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) {
|
func (db *getEventDB) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) (res []types.Event, err error) {
|
||||||
for _, evID := range eventIDs {
|
for _, evID := range eventIDs {
|
||||||
res = append(res, types.Event{
|
res = append(res, types.Event{
|
||||||
EventNID: 0,
|
EventNID: 0,
|
||||||
|
@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) {
|
||||||
t.Fatalf("Failed to add events to db: %v", err)
|
t.Fatalf("Failed to add events to db: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"})
|
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getAuthChain failed: %v", err)
|
t.Fatalf("getAuthChain failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) {
|
||||||
t.Fatalf("Failed to add events to db: %v", err)
|
t.Fatalf("Failed to add events to db: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"})
|
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e", "f"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getAuthChain failed: %v", err)
|
t.Fatalf("getAuthChain failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -278,6 +278,16 @@ func TestPurgeRoom(t *testing.T) {
|
||||||
if roomInfo == nil {
|
if roomInfo == nil {
|
||||||
t.Fatalf("room does not exist")
|
t.Fatalf("room does not exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
roomInfo2, err := db.RoomInfoByNID(ctx, roomInfo.RoomNID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(roomInfo, roomInfo2) {
|
||||||
|
t.Fatalf("expected roomInfos to be the same, but they aren't")
|
||||||
|
}
|
||||||
|
|
||||||
// remember the roomInfo before purging
|
// remember the roomInfo before purging
|
||||||
existingRoomInfo := roomInfo
|
existingRoomInfo := roomInfo
|
||||||
|
|
||||||
|
@ -333,6 +343,10 @@ func TestPurgeRoom(t *testing.T) {
|
||||||
if roomInfo != nil {
|
if roomInfo != nil {
|
||||||
t.Fatalf("room should not exist after purging: %+v", roomInfo)
|
t.Fatalf("room should not exist after purging: %+v", roomInfo)
|
||||||
}
|
}
|
||||||
|
roomInfo2, err = db.RoomInfoByNID(ctx, existingRoomInfo.RoomNID)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected room to not exist, but it does: %#v", roomInfo2)
|
||||||
|
}
|
||||||
|
|
||||||
// validation below
|
// validation below
|
||||||
|
|
||||||
|
|
|
@ -41,8 +41,8 @@ type StateResolutionStorage interface {
|
||||||
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
||||||
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
||||||
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
||||||
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
|
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateResolution struct {
|
type StateResolution struct {
|
||||||
|
@ -975,7 +975,7 @@ func (v *StateResolution) resolveConflictsV2(
|
||||||
|
|
||||||
// Store the newly found auth events in the auth set for this event.
|
// Store the newly found auth events in the auth set for this event.
|
||||||
var authEventMap map[string]types.StateEntry
|
var authEventMap map[string]types.StateEntry
|
||||||
authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo.RoomNID, conflictedEvent, knownAuthEvents)
|
authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo, conflictedEvent, knownAuthEvents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1091,7 +1091,7 @@ func (v *StateResolution) loadStateEvents(
|
||||||
eventNIDs = append(eventNIDs, entry.EventNID)
|
eventNIDs = append(eventNIDs, entry.EventNID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
events, err := v.db.Events(ctx, v.roomInfo.RoomNID, eventNIDs)
|
events, err := v.db.Events(ctx, v.roomInfo, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -1120,7 +1120,7 @@ type authEventLoader struct {
|
||||||
// loadAuthEvents loads all of the auth events for a given event recursively,
|
// loadAuthEvents loads all of the auth events for a given event recursively,
|
||||||
// along with a map that contains state entries for all of the auth events.
|
// along with a map that contains state entries for all of the auth events.
|
||||||
func (l *authEventLoader) loadAuthEvents(
|
func (l *authEventLoader) loadAuthEvents(
|
||||||
ctx context.Context, roomNID types.RoomNID, event *gomatrixserverlib.Event, eventMap map[string]types.Event,
|
ctx context.Context, roomInfo *types.RoomInfo, event *gomatrixserverlib.Event, eventMap map[string]types.Event,
|
||||||
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
|
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
|
||||||
l.Lock()
|
l.Lock()
|
||||||
defer l.Unlock()
|
defer l.Unlock()
|
||||||
|
@ -1155,7 +1155,7 @@ func (l *authEventLoader) loadAuthEvents(
|
||||||
// If we need to get events from the database, go and fetch
|
// If we need to get events from the database, go and fetch
|
||||||
// those now.
|
// those now.
|
||||||
if len(l.lookupFromDB) > 0 {
|
if len(l.lookupFromDB) > 0 {
|
||||||
eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomNID, l.lookupFromDB)
|
eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomInfo, l.lookupFromDB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
|
return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ type Database interface {
|
||||||
SupportsConcurrentRoomInputs() bool
|
SupportsConcurrentRoomInputs() bool
|
||||||
// RoomInfo returns room information for the given room ID, or nil if there is no room.
|
// RoomInfo returns room information for the given room ID, or nil if there is no room.
|
||||||
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||||
|
RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
|
||||||
// Store the room state at an event in the database
|
// Store the room state at an event in the database
|
||||||
AddState(
|
AddState(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
@ -69,12 +70,12 @@ type Database interface {
|
||||||
) ([]types.StateEntryList, error)
|
) ([]types.StateEntryList, error)
|
||||||
// Look up the Events for a list of numeric event IDs.
|
// Look up the Events for a list of numeric event IDs.
|
||||||
// Returns a sorted list of events.
|
// Returns a sorted list of events.
|
||||||
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
// Look up snapshot NID for an event ID string
|
// Look up snapshot NID for an event ID string
|
||||||
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||||
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
|
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
|
||||||
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
|
// Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error.
|
||||||
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
|
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
|
||||||
// Look up the state entries for a list of string event IDs
|
// Look up the state entries for a list of string event IDs
|
||||||
// Returns an error if the there is an error talking to the database
|
// Returns an error if the there is an error talking to the database
|
||||||
// Returns a types.MissingEventError if the event IDs aren't in the database.
|
// Returns a types.MissingEventError if the event IDs aren't in the database.
|
||||||
|
@ -135,7 +136,7 @@ type Database interface {
|
||||||
// EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was
|
// EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was
|
||||||
// not found.
|
// not found.
|
||||||
// Returns an error if the retrieval went wrong.
|
// Returns an error if the retrieval went wrong.
|
||||||
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
|
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
|
||||||
// Publish or unpublish a room from the room directory.
|
// Publish or unpublish a room from the room directory.
|
||||||
PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
|
PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
|
||||||
// Returns a list of room IDs for rooms which are published.
|
// Returns a list of room IDs for rooms which are published.
|
||||||
|
@ -179,36 +180,53 @@ type Database interface {
|
||||||
GetMembershipForHistoryVisibility(
|
GetMembershipForHistoryVisibility(
|
||||||
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
|
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
|
||||||
) (map[string]*gomatrixserverlib.HeaderedEvent, error)
|
) (map[string]*gomatrixserverlib.HeaderedEvent, error)
|
||||||
GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error)
|
GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error)
|
||||||
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
||||||
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
||||||
|
MaybeRedactEvent(
|
||||||
|
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
|
||||||
|
) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type RoomDatabase interface {
|
type RoomDatabase interface {
|
||||||
|
EventDatabase
|
||||||
// RoomInfo returns room information for the given room ID, or nil if there is no room.
|
// RoomInfo returns room information for the given room ID, or nil if there is no room.
|
||||||
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||||
|
RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
|
||||||
// IsEventRejected returns true if the event is known and rejected.
|
// IsEventRejected returns true if the event is known and rejected.
|
||||||
IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error)
|
IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error)
|
||||||
MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error)
|
MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error)
|
||||||
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
|
|
||||||
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
|
|
||||||
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
|
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
|
||||||
GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
|
GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
|
||||||
StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
|
|
||||||
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
|
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
|
||||||
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||||
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
||||||
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
|
||||||
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
|
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
|
||||||
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
||||||
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
|
||||||
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
||||||
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
|
|
||||||
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
|
|
||||||
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
|
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
|
||||||
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
|
GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error)
|
||||||
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
|
||||||
GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error)
|
|
||||||
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
||||||
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
||||||
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventDatabase interface {
|
||||||
|
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||||
|
EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
||||||
|
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||||
|
StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
|
||||||
|
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error)
|
||||||
|
SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
|
||||||
|
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
||||||
|
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||||
|
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||||
|
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
|
||||||
|
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
|
// MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error
|
||||||
|
// (nil if there was nothing to do)
|
||||||
|
MaybeRedactEvent(
|
||||||
|
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
|
||||||
|
) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error)
|
||||||
|
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,23 +194,28 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
Cache: cache,
|
EventDatabase: shared.EventDatabase{
|
||||||
Writer: writer,
|
DB: db,
|
||||||
EventTypesTable: eventTypes,
|
Cache: cache,
|
||||||
EventStateKeysTable: eventStateKeys,
|
Writer: writer,
|
||||||
EventJSONTable: eventJSON,
|
EventsTable: events,
|
||||||
EventsTable: events,
|
EventJSONTable: eventJSON,
|
||||||
RoomsTable: rooms,
|
EventTypesTable: eventTypes,
|
||||||
StateBlockTable: stateBlock,
|
EventStateKeysTable: eventStateKeys,
|
||||||
StateSnapshotTable: stateSnapshot,
|
PrevEventsTable: prevEvents,
|
||||||
PrevEventsTable: prevEvents,
|
RedactionsTable: redactions,
|
||||||
RoomAliasesTable: roomAliases,
|
},
|
||||||
InvitesTable: invites,
|
Cache: cache,
|
||||||
MembershipTable: membership,
|
Writer: writer,
|
||||||
PublishedTable: published,
|
RoomsTable: rooms,
|
||||||
RedactionsTable: redactions,
|
StateBlockTable: stateBlock,
|
||||||
Purge: purge,
|
StateSnapshotTable: stateSnapshot,
|
||||||
|
RoomAliasesTable: roomAliases,
|
||||||
|
InvitesTable: invites,
|
||||||
|
MembershipTable: membership,
|
||||||
|
PublishedTable: published,
|
||||||
|
Purge: purge,
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,8 +116,8 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RoomUpdater) Events(ctx context.Context, _ types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) {
|
func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
|
||||||
return u.d.events(ctx, u.txn, u.roomInfo.RoomNID, eventNIDs)
|
return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RoomUpdater) SnapshotNIDFromEventID(
|
func (u *RoomUpdater) SnapshotNIDFromEventID(
|
||||||
|
@ -195,8 +195,8 @@ func (u *RoomUpdater) StateAtEventIDs(
|
||||||
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
|
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) {
|
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) {
|
||||||
return u.d.eventsFromIDs(ctx, u.txn, roomNID, eventIDs, NoFilter)
|
return u.d.eventsFromIDs(ctx, u.txn, u.roomInfo, eventIDs, NoFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsReferenced implements types.RoomRecentEventsUpdater
|
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
@ -28,6 +27,23 @@ import (
|
||||||
const redactionsArePermanent = true
|
const redactionsArePermanent = true
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
|
DB *sql.DB
|
||||||
|
EventDatabase
|
||||||
|
Cache caching.RoomServerCaches
|
||||||
|
Writer sqlutil.Writer
|
||||||
|
RoomsTable tables.Rooms
|
||||||
|
StateSnapshotTable tables.StateSnapshot
|
||||||
|
StateBlockTable tables.StateBlock
|
||||||
|
RoomAliasesTable tables.RoomAliases
|
||||||
|
InvitesTable tables.Invites
|
||||||
|
MembershipTable tables.Membership
|
||||||
|
PublishedTable tables.Published
|
||||||
|
Purge tables.Purge
|
||||||
|
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventDatabase contains all tables needed to work with events
|
||||||
|
type EventDatabase struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
Cache caching.RoomServerCaches
|
Cache caching.RoomServerCaches
|
||||||
Writer sqlutil.Writer
|
Writer sqlutil.Writer
|
||||||
|
@ -35,17 +51,8 @@ type Database struct {
|
||||||
EventJSONTable tables.EventJSON
|
EventJSONTable tables.EventJSON
|
||||||
EventTypesTable tables.EventTypes
|
EventTypesTable tables.EventTypes
|
||||||
EventStateKeysTable tables.EventStateKeys
|
EventStateKeysTable tables.EventStateKeys
|
||||||
RoomsTable tables.Rooms
|
|
||||||
StateSnapshotTable tables.StateSnapshot
|
|
||||||
StateBlockTable tables.StateBlock
|
|
||||||
RoomAliasesTable tables.RoomAliases
|
|
||||||
PrevEventsTable tables.PreviousEvents
|
PrevEventsTable tables.PreviousEvents
|
||||||
InvitesTable tables.Invites
|
|
||||||
MembershipTable tables.Membership
|
|
||||||
PublishedTable tables.Published
|
|
||||||
RedactionsTable tables.Redactions
|
RedactionsTable tables.Redactions
|
||||||
Purge tables.Purge
|
|
||||||
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SupportsConcurrentRoomInputs() bool {
|
func (d *Database) SupportsConcurrentRoomInputs() bool {
|
||||||
|
@ -58,13 +65,13 @@ func (d *Database) GetMembershipForHistoryVisibility(
|
||||||
return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...)
|
return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) EventTypeNIDs(
|
func (d *EventDatabase) EventTypeNIDs(
|
||||||
ctx context.Context, eventTypes []string,
|
ctx context.Context, eventTypes []string,
|
||||||
) (map[string]types.EventTypeNID, error) {
|
) (map[string]types.EventTypeNID, error) {
|
||||||
return d.eventTypeNIDs(ctx, nil, eventTypes)
|
return d.eventTypeNIDs(ctx, nil, eventTypes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) eventTypeNIDs(
|
func (d *EventDatabase) eventTypeNIDs(
|
||||||
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||||
) (map[string]types.EventTypeNID, error) {
|
) (map[string]types.EventTypeNID, error) {
|
||||||
result := make(map[string]types.EventTypeNID)
|
result := make(map[string]types.EventTypeNID)
|
||||||
|
@ -91,7 +98,7 @@ func (d *Database) eventTypeNIDs(
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) EventStateKeys(
|
func (d *EventDatabase) EventStateKeys(
|
||||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||||
) (map[types.EventStateKeyNID]string, error) {
|
) (map[types.EventStateKeyNID]string, error) {
|
||||||
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
|
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
|
||||||
|
@ -116,13 +123,13 @@ func (d *Database) EventStateKeys(
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) EventStateKeyNIDs(
|
func (d *EventDatabase) EventStateKeyNIDs(
|
||||||
ctx context.Context, eventStateKeys []string,
|
ctx context.Context, eventStateKeys []string,
|
||||||
) (map[string]types.EventStateKeyNID, error) {
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
|
return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) eventStateKeyNIDs(
|
func (d *EventDatabase) eventStateKeyNIDs(
|
||||||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||||
) (map[string]types.EventStateKeyNID, error) {
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
result := make(map[string]types.EventStateKeyNID)
|
result := make(map[string]types.EventStateKeyNID)
|
||||||
|
@ -174,7 +181,7 @@ func (d *Database) eventStateKeyNIDs(
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StateEntriesForEventIDs(
|
func (d *EventDatabase) StateEntriesForEventIDs(
|
||||||
ctx context.Context, eventIDs []string, excludeRejected bool,
|
ctx context.Context, eventIDs []string, excludeRejected bool,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected)
|
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected)
|
||||||
|
@ -213,6 +220,17 @@ func (d *Database) stateEntriesForTuples(
|
||||||
return lists, nil
|
return lists, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) {
|
||||||
|
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{roomNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(roomIDs) == 0 {
|
||||||
|
return nil, fmt.Errorf("room does not exist")
|
||||||
|
}
|
||||||
|
return d.roomInfo(ctx, nil, roomIDs[0])
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||||
return d.roomInfo(ctx, nil, roomID)
|
return d.roomInfo(ctx, nil, roomID)
|
||||||
}
|
}
|
||||||
|
@ -292,7 +310,7 @@ func (d *Database) addState(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) EventNIDs(
|
func (d *EventDatabase) EventNIDs(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, eventIDs []string,
|
||||||
) (map[string]types.EventMetadata, error) {
|
) (map[string]types.EventMetadata, error) {
|
||||||
return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
|
return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
|
||||||
|
@ -305,7 +323,7 @@ const (
|
||||||
FilterUnsentOnly UnsentFilter = true
|
FilterUnsentOnly UnsentFilter = true
|
||||||
)
|
)
|
||||||
|
|
||||||
func (d *Database) eventNIDs(
|
func (d *EventDatabase) eventNIDs(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
|
ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
|
||||||
) (map[string]types.EventMetadata, error) {
|
) (map[string]types.EventMetadata, error) {
|
||||||
switch filter {
|
switch filter {
|
||||||
|
@ -318,7 +336,7 @@ func (d *Database) eventNIDs(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SetState(
|
func (d *EventDatabase) SetState(
|
||||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
@ -326,19 +344,19 @@ func (d *Database) SetState(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StateAtEventIDs(
|
func (d *EventDatabase) StateAtEventIDs(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, eventIDs []string,
|
||||||
) ([]types.StateAtEvent, error) {
|
) ([]types.StateAtEvent, error) {
|
||||||
return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
|
return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SnapshotNIDFromEventID(
|
func (d *EventDatabase) SnapshotNIDFromEventID(
|
||||||
ctx context.Context, eventID string,
|
ctx context.Context, eventID string,
|
||||||
) (types.StateSnapshotNID, error) {
|
) (types.StateSnapshotNID, error) {
|
||||||
return d.snapshotNIDFromEventID(ctx, nil, eventID)
|
return d.snapshotNIDFromEventID(ctx, nil, eventID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) snapshotNIDFromEventID(
|
func (d *EventDatabase) snapshotNIDFromEventID(
|
||||||
ctx context.Context, txn *sql.Tx, eventID string,
|
ctx context.Context, txn *sql.Tx, eventID string,
|
||||||
) (types.StateSnapshotNID, error) {
|
) (types.StateSnapshotNID, error) {
|
||||||
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
|
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
|
||||||
|
@ -351,17 +369,17 @@ func (d *Database) snapshotNIDFromEventID(
|
||||||
return stateNID, err
|
return stateNID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) EventIDs(
|
func (d *EventDatabase) EventIDs(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
) (map[types.EventNID]string, error) {
|
) (map[types.EventNID]string, error) {
|
||||||
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) {
|
func (d *EventDatabase) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) {
|
||||||
return d.eventsFromIDs(ctx, nil, roomNID, eventIDs, NoFilter)
|
return d.eventsFromIDs(ctx, nil, roomInfo, eventIDs, NoFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
|
func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
|
||||||
nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
|
nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -370,15 +388,9 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types
|
||||||
var nids []types.EventNID
|
var nids []types.EventNID
|
||||||
for _, nid := range nidMap {
|
for _, nid := range nidMap {
|
||||||
nids = append(nids, nid.EventNID)
|
nids = append(nids, nid.EventNID)
|
||||||
if roomNID != 0 && roomNID != nid.RoomNID {
|
|
||||||
logrus.Errorf("expected events from room %d, but also found %d", roomNID, nid.RoomNID)
|
|
||||||
}
|
|
||||||
if roomNID == 0 {
|
|
||||||
roomNID = nid.RoomNID
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.events(ctx, txn, roomNID, nids)
|
return d.events(ctx, txn, roomInfo, nids)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) LatestEventIDs(
|
func (d *Database) LatestEventIDs(
|
||||||
|
@ -517,19 +529,17 @@ func (d *Database) GetInvitesForUser(
|
||||||
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
|
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) Events(
|
func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
|
||||||
ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID,
|
return d.events(ctx, nil, roomInfo, eventNIDs)
|
||||||
) ([]types.Event, error) {
|
|
||||||
return d.events(ctx, nil, roomNID, eventNIDs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) events(
|
func (d *EventDatabase) events(
|
||||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, inputEventNIDs types.EventNIDs,
|
ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs,
|
||||||
) ([]types.Event, error) {
|
) ([]types.Event, error) {
|
||||||
if roomNID == 0 {
|
if roomInfo == nil { // this should never happen
|
||||||
// No need to go further, as we won't find any events for this room.
|
return nil, fmt.Errorf("unable to parse events without roomInfo")
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Sort(inputEventNIDs)
|
sort.Sort(inputEventNIDs)
|
||||||
events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs))
|
events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs))
|
||||||
eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
|
eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
|
||||||
|
@ -566,31 +576,9 @@ func (d *Database) events(
|
||||||
eventIDs = map[types.EventNID]string{}
|
eventIDs = map[types.EventNID]string{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var roomVersion gomatrixserverlib.RoomVersion
|
|
||||||
var fetchRoomVersion bool
|
|
||||||
var ok bool
|
|
||||||
var roomID string
|
|
||||||
if roomID, ok = d.Cache.GetRoomServerRoomID(roomNID); ok {
|
|
||||||
roomVersion, ok = d.Cache.GetRoomVersion(roomID)
|
|
||||||
if !ok {
|
|
||||||
fetchRoomVersion = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if roomVersion == "" || fetchRoomVersion {
|
|
||||||
var dbRoomVersions map[types.RoomNID]gomatrixserverlib.RoomVersion
|
|
||||||
dbRoomVersions, err = d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, []types.RoomNID{roomNID})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if roomVersion, ok = dbRoomVersions[roomNID]; !ok {
|
|
||||||
return nil, fmt.Errorf("unable to find roomversion for room %d", roomNID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, eventJSON := range eventJSONs {
|
for _, eventJSON := range eventJSONs {
|
||||||
events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
|
events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
|
||||||
eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion,
|
eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomInfo.RoomVersion,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -660,8 +648,8 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e
|
||||||
return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID)
|
return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrCreateRoomNID gets or creates a new roomNID for the given event
|
// GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID.
|
||||||
func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (roomNID types.RoomNID, err error) {
|
func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (roomInfo *types.RoomInfo, err error) {
|
||||||
// Get the default room version. If the client doesn't supply a room_version
|
// Get the default room version. If the client doesn't supply a room_version
|
||||||
// then we will use our configured default to create the room.
|
// then we will use our configured default to create the room.
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
|
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
|
||||||
|
@ -670,8 +658,9 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver
|
||||||
// room.
|
// room.
|
||||||
var roomVersion gomatrixserverlib.RoomVersion
|
var roomVersion gomatrixserverlib.RoomVersion
|
||||||
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
|
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
|
||||||
return 0, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
|
return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
|
||||||
}
|
}
|
||||||
|
var roomNID types.RoomNID
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion)
|
roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -679,7 +668,10 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
return roomNID, err
|
return &types.RoomInfo{
|
||||||
|
RoomVersion: roomVersion,
|
||||||
|
RoomNID: roomNID,
|
||||||
|
}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
|
func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
|
||||||
|
@ -710,25 +702,22 @@ func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKe
|
||||||
return eventStateKeyNID, nil
|
return eventStateKeyNID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StoreEvent(
|
func (d *EventDatabase) StoreEvent(
|
||||||
ctx context.Context, event *gomatrixserverlib.Event,
|
ctx context.Context, event *gomatrixserverlib.Event,
|
||||||
roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
|
roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
|
||||||
authEventNIDs []types.EventNID, isRejected bool,
|
authEventNIDs []types.EventNID, isRejected bool,
|
||||||
) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
|
) (types.EventNID, types.StateAtEvent, error) {
|
||||||
var (
|
var (
|
||||||
eventNID types.EventNID
|
eventNID types.EventNID
|
||||||
stateNID types.StateSnapshotNID
|
stateNID types.StateSnapshotNID
|
||||||
redactionEvent *gomatrixserverlib.Event
|
err error
|
||||||
redactedEventID string
|
|
||||||
err error
|
|
||||||
)
|
)
|
||||||
// Second writer is using the database-provided transaction, probably from the
|
|
||||||
// room updater, for easy roll-back if required.
|
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
if eventNID, stateNID, err = d.EventsTable.InsertEvent(
|
if eventNID, stateNID, err = d.EventsTable.InsertEvent(
|
||||||
ctx,
|
ctx,
|
||||||
txn,
|
txn,
|
||||||
roomNID,
|
roomInfo.RoomNID,
|
||||||
eventTypeNID,
|
eventTypeNID,
|
||||||
eventStateKeyNID,
|
eventStateKeyNID,
|
||||||
event.EventID(),
|
event.EventID(),
|
||||||
|
@ -751,16 +740,26 @@ func (d *Database) StoreEvent(
|
||||||
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
|
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
|
||||||
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
|
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
|
||||||
}
|
}
|
||||||
if !isRejected { // ignore rejected redaction events
|
|
||||||
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, roomNID, eventNID, event)
|
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
|
||||||
if err != nil {
|
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
|
||||||
return fmt.Errorf("d.handleRedactions: %w", err)
|
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
|
||||||
|
// function only does SELECTs though so the created txn (at this point) is just a read txn like
|
||||||
|
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
|
||||||
|
// to do writes however then this will need to go inside `Writer.Do`.
|
||||||
|
|
||||||
|
// The following is a copy of RoomUpdater.StorePreviousEvents
|
||||||
|
for _, ref := range prevEvents {
|
||||||
|
if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||||
|
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err)
|
return 0, types.StateAtEvent{}, fmt.Errorf("d.Writer.Do: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We should attempt to update the previous events table with any
|
// We should attempt to update the previous events table with any
|
||||||
|
@ -768,33 +767,6 @@ func (d *Database) StoreEvent(
|
||||||
// events updater because it somewhat works as a mutex, ensuring
|
// events updater because it somewhat works as a mutex, ensuring
|
||||||
// that there's a row-level lock on the latest room events (well,
|
// that there's a row-level lock on the latest room events (well,
|
||||||
// on Postgres at least).
|
// on Postgres at least).
|
||||||
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
|
|
||||||
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
|
|
||||||
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
|
|
||||||
// function only does SELECTs though so the created txn (at this point) is just a read txn like
|
|
||||||
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
|
|
||||||
// to do writes however then this will need to go inside `Writer.Do`.
|
|
||||||
succeeded := false
|
|
||||||
var roomInfo *types.RoomInfo
|
|
||||||
roomInfo, err = d.roomInfo(ctx, nil, event.RoomID())
|
|
||||||
if err != nil {
|
|
||||||
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
|
|
||||||
}
|
|
||||||
if roomInfo == nil && len(prevEvents) > 0 {
|
|
||||||
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
|
|
||||||
}
|
|
||||||
var updater *RoomUpdater
|
|
||||||
updater, err = d.GetRoomUpdater(ctx, roomInfo)
|
|
||||||
if err != nil {
|
|
||||||
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
|
|
||||||
}
|
|
||||||
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
|
|
||||||
|
|
||||||
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
|
|
||||||
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
|
|
||||||
}
|
|
||||||
succeeded = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return eventNID, types.StateAtEvent{
|
return eventNID, types.StateAtEvent{
|
||||||
BeforeStateSnapshotNID: stateNID,
|
BeforeStateSnapshotNID: stateNID,
|
||||||
|
@ -805,7 +777,7 @@ func (d *Database) StoreEvent(
|
||||||
},
|
},
|
||||||
EventNID: eventNID,
|
EventNID: eventNID,
|
||||||
},
|
},
|
||||||
}, redactionEvent, redactedEventID, err
|
}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error {
|
func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error {
|
||||||
|
@ -893,7 +865,7 @@ func (d *Database) assignEventTypeNID(
|
||||||
return eventTypeNID, nil
|
return eventTypeNID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) assignStateKeyNID(
|
func (d *EventDatabase) assignStateKeyNID(
|
||||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||||
) (types.EventStateKeyNID, error) {
|
) (types.EventStateKeyNID, error) {
|
||||||
eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey)
|
eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey)
|
||||||
|
@ -937,7 +909,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
|
||||||
return roomVersion, err
|
return roomVersion, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRedactions manages the redacted status of events. There's two cases to consider in order to comply with the spec:
|
// MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec:
|
||||||
// "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid."
|
// "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid."
|
||||||
// https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
|
// https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
|
||||||
// These cases are:
|
// These cases are:
|
||||||
|
@ -952,95 +924,95 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
|
||||||
// when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need
|
// when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need
|
||||||
// to cross-reference with other tables when loading.
|
// to cross-reference with other tables when loading.
|
||||||
//
|
//
|
||||||
// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction.
|
// Returns the redaction event and the redacted event if this call resulted in a redaction.
|
||||||
func (d *Database) handleRedactions(
|
func (d *EventDatabase) MaybeRedactEvent(
|
||||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event,
|
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
|
||||||
) (*gomatrixserverlib.Event, string, error) {
|
) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) {
|
||||||
var err error
|
var (
|
||||||
isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil
|
redactionEvent, redactedEvent *types.Event
|
||||||
if isRedactionEvent {
|
err error
|
||||||
// an event which redacts itself should be ignored
|
validated bool
|
||||||
if event.EventID() == event.Redacts() {
|
ignoreRedaction bool
|
||||||
return nil, "", nil
|
)
|
||||||
|
|
||||||
|
wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil
|
||||||
|
if isRedactionEvent {
|
||||||
|
// an event which redacts itself should be ignored
|
||||||
|
if event.EventID() == event.Redacts() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{
|
||||||
|
Validated: false,
|
||||||
|
RedactionEventID: event.EventID(),
|
||||||
|
RedactsEventID: event.Redacts(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{
|
redactionEvent, redactedEvent, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event)
|
||||||
Validated: false,
|
switch {
|
||||||
RedactionEventID: event.EventID(),
|
case err != nil:
|
||||||
RedactsEventID: event.Redacts(),
|
return fmt.Errorf("d.loadRedactionPair: %w", err)
|
||||||
})
|
case validated || redactedEvent == nil || redactionEvent == nil:
|
||||||
|
// we've seen this redaction before or there is nothing to redact
|
||||||
|
return nil
|
||||||
|
case redactedEvent.RoomID() != redactionEvent.RoomID():
|
||||||
|
// redactions across rooms aren't allowed
|
||||||
|
ignoreRedaction = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. The power level of the redaction event’s sender is greater than or equal to the redact level. (redactAllowed)
|
||||||
|
// 2. The domain of the redaction event’s sender matches that of the original event’s sender.
|
||||||
|
_, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender())
|
||||||
|
_, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender())
|
||||||
|
if !redactAllowed || sender1 != sender2 {
|
||||||
|
ignoreRedaction = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mark the event as redacted
|
||||||
|
if redactionsArePermanent {
|
||||||
|
redactedEvent.Redact()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err)
|
return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
|
||||||
|
}
|
||||||
|
// NOTSPEC: sytest relies on this unspecced field existing :(
|
||||||
|
err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
|
||||||
|
}
|
||||||
|
// overwrite the eventJSON table
|
||||||
|
err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, roomNID, eventNID, event)
|
err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err)
|
return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if wErr != nil {
|
||||||
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if validated || redactedEvent == nil || redactionEvent == nil {
|
if ignoreRedaction || redactionEvent == nil || redactedEvent == nil {
|
||||||
// we've seen this redaction before or there is nothing to redact
|
return nil, nil, nil
|
||||||
return nil, "", nil
|
|
||||||
}
|
}
|
||||||
if redactedEvent.RoomID() != redactionEvent.RoomID() {
|
return redactionEvent.Event, redactedEvent.Event, nil
|
||||||
// redactions across rooms aren't allowed
|
|
||||||
return nil, "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the power level from the database, so we can verify the user is allowed to redact the event
|
|
||||||
powerLevels, err := d.GetStateEvent(ctx, event.RoomID(), gomatrixserverlib.MRoomPowerLevels, "")
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", fmt.Errorf("d.GetStateEvent: %w", err)
|
|
||||||
}
|
|
||||||
if powerLevels == nil {
|
|
||||||
return nil, "", fmt.Errorf("unable to fetch m.room.power_levels event from database for room %s", event.RoomID())
|
|
||||||
}
|
|
||||||
pl, err := powerLevels.PowerLevels()
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", fmt.Errorf("unable to get powerlevels for room: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
redactUser := pl.UserLevel(redactionEvent.Sender())
|
|
||||||
switch {
|
|
||||||
case redactUser >= pl.Redact:
|
|
||||||
// The power level of the redaction event’s sender is greater than or equal to the redact level.
|
|
||||||
case redactedEvent.Sender() == redactionEvent.Sender():
|
|
||||||
// The domain of the redaction event’s sender matches that of the original event’s sender.
|
|
||||||
default:
|
|
||||||
return nil, "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// mark the event as redacted
|
|
||||||
if redactionsArePermanent {
|
|
||||||
redactedEvent.Redact()
|
|
||||||
}
|
|
||||||
|
|
||||||
err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
|
|
||||||
}
|
|
||||||
// NOTSPEC: sytest relies on this unspecced field existing :(
|
|
||||||
err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID())
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
|
|
||||||
}
|
|
||||||
// overwrite the eventJSON table
|
|
||||||
err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON())
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return redactionEvent.Event, redactedEvent.EventID(), err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadRedactionPair returns both the redaction event and the redacted event, else nil.
|
// loadRedactionPair returns both the redaction event and the redacted event, else nil.
|
||||||
func (d *Database) loadRedactionPair(
|
func (d *EventDatabase) loadRedactionPair(
|
||||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event,
|
ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event,
|
||||||
) (*types.Event, *types.Event, bool, error) {
|
) (*types.Event, *types.Event, bool, error) {
|
||||||
var redactionEvent, redactedEvent *types.Event
|
var redactionEvent, redactedEvent *types.Event
|
||||||
var info *tables.RedactionInfo
|
var info *tables.RedactionInfo
|
||||||
|
@ -1072,16 +1044,16 @@ func (d *Database) loadRedactionPair(
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRedactionEvent {
|
if isRedactionEvent {
|
||||||
redactedEvent = d.loadEvent(ctx, roomNID, info.RedactsEventID)
|
redactedEvent = d.loadEvent(ctx, roomInfo, info.RedactsEventID)
|
||||||
} else {
|
} else {
|
||||||
redactionEvent = d.loadEvent(ctx, roomNID, info.RedactionEventID)
|
redactionEvent = d.loadEvent(ctx, roomInfo, info.RedactionEventID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return redactionEvent, redactedEvent, info.Validated, nil
|
return redactionEvent, redactedEvent, info.Validated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyRedactions will redact events that have an `unsigned.redacted_because` field.
|
// applyRedactions will redact events that have an `unsigned.redacted_because` field.
|
||||||
func (d *Database) applyRedactions(events []types.Event) {
|
func (d *EventDatabase) applyRedactions(events []types.Event) {
|
||||||
for i := range events {
|
for i := range events {
|
||||||
if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() {
|
if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() {
|
||||||
events[i].Redact()
|
events[i].Redact()
|
||||||
|
@ -1090,7 +1062,7 @@ func (d *Database) applyRedactions(events []types.Event) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadEvent loads a single event or returns nil on any problems/missing event
|
// loadEvent loads a single event or returns nil on any problems/missing event
|
||||||
func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID string) *types.Event {
|
func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo, eventID string) *types.Event {
|
||||||
nids, err := d.EventNIDs(ctx, []string{eventID})
|
nids, err := d.EventNIDs(ctx, []string{eventID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
|
@ -1098,7 +1070,7 @@ func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID
|
||||||
if len(nids) == 0 {
|
if len(nids) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
evs, err := d.Events(ctx, roomNID, []types.EventNID{nids[eventID].EventNID})
|
evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1144,7 +1116,7 @@ func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *type
|
||||||
// If no event could be found, returns nil
|
// If no event could be found, returns nil
|
||||||
// If there was an issue during the retrieval, returns an error
|
// If there was an issue during the retrieval, returns an error
|
||||||
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
|
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
roomInfo, err := d.RoomInfo(ctx, roomID)
|
roomInfo, err := d.roomInfo(ctx, nil, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1209,7 +1181,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
|
||||||
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error
|
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error
|
||||||
// if there are no events with this event type.
|
// if there are no events with this event type.
|
||||||
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
roomInfo, err := d.RoomInfo(ctx, roomID)
|
roomInfo, err := d.roomInfo(ctx, nil, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1340,7 +1312,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
||||||
eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion)
|
eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion)
|
||||||
// TODO: This feels like this is going to be really slow...
|
// TODO: This feels like this is going to be really slow...
|
||||||
for _, roomID := range roomIDs {
|
for _, roomID := range roomIDs {
|
||||||
roomInfo, err2 := d.RoomInfo(ctx, roomID)
|
roomInfo, err2 := d.roomInfo(ctx, nil, roomID)
|
||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2)
|
return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2)
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,12 +52,14 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat
|
||||||
|
|
||||||
cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
|
cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
|
||||||
|
|
||||||
|
evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache}
|
||||||
|
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
EventStateKeysTable: stateKeyTable,
|
EventDatabase: evDb,
|
||||||
MembershipTable: membershipTable,
|
MembershipTable: membershipTable,
|
||||||
Writer: sqlutil.NewExclusiveWriter(),
|
Writer: sqlutil.NewExclusiveWriter(),
|
||||||
Cache: cache,
|
Cache: cache,
|
||||||
}, func() {
|
}, func() {
|
||||||
err := base.Close()
|
err := base.Close()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -203,24 +203,29 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
||||||
}
|
}
|
||||||
|
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
Cache: cache,
|
EventDatabase: shared.EventDatabase{
|
||||||
Writer: writer,
|
DB: db,
|
||||||
EventsTable: events,
|
Cache: cache,
|
||||||
EventTypesTable: eventTypes,
|
Writer: writer,
|
||||||
EventStateKeysTable: eventStateKeys,
|
EventsTable: events,
|
||||||
EventJSONTable: eventJSON,
|
EventTypesTable: eventTypes,
|
||||||
RoomsTable: rooms,
|
EventStateKeysTable: eventStateKeys,
|
||||||
StateBlockTable: stateBlock,
|
EventJSONTable: eventJSON,
|
||||||
StateSnapshotTable: stateSnapshot,
|
PrevEventsTable: prevEvents,
|
||||||
PrevEventsTable: prevEvents,
|
RedactionsTable: redactions,
|
||||||
RoomAliasesTable: roomAliases,
|
},
|
||||||
InvitesTable: invites,
|
Cache: cache,
|
||||||
MembershipTable: membership,
|
Writer: writer,
|
||||||
PublishedTable: published,
|
RoomsTable: rooms,
|
||||||
RedactionsTable: redactions,
|
StateBlockTable: stateBlock,
|
||||||
GetRoomUpdaterFn: d.GetRoomUpdater,
|
StateSnapshotTable: stateSnapshot,
|
||||||
Purge: purge,
|
RoomAliasesTable: roomAliases,
|
||||||
|
InvitesTable: invites,
|
||||||
|
MembershipTable: membership,
|
||||||
|
PublishedTable: published,
|
||||||
|
GetRoomUpdaterFn: d.GetRoomUpdater,
|
||||||
|
Purge: purge,
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,9 +20,11 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
|
@ -85,8 +87,6 @@ type BaseDendrite struct {
|
||||||
startupLock sync.Mutex
|
startupLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
const NoListener = ""
|
|
||||||
|
|
||||||
const HTTPServerTimeout = time.Minute * 5
|
const HTTPServerTimeout = time.Minute * 5
|
||||||
|
|
||||||
type BaseDendriteOptions int
|
type BaseDendriteOptions int
|
||||||
|
@ -345,18 +345,17 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() {
|
||||||
// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs
|
// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs
|
||||||
// and adds a prometheus handler under /_dendrite/metrics.
|
// and adds a prometheus handler under /_dendrite/metrics.
|
||||||
func (b *BaseDendrite) SetupAndServeHTTP(
|
func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
externalHTTPAddr config.HTTPAddress,
|
externalHTTPAddr config.ServerAddress,
|
||||||
certFile, keyFile *string,
|
certFile, keyFile *string,
|
||||||
) {
|
) {
|
||||||
// Manually unlocked right before actually serving requests,
|
// Manually unlocked right before actually serving requests,
|
||||||
// as we don't return from this method (defer doesn't work).
|
// as we don't return from this method (defer doesn't work).
|
||||||
b.startupLock.Lock()
|
b.startupLock.Lock()
|
||||||
externalAddr, _ := externalHTTPAddr.Address()
|
|
||||||
|
|
||||||
externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
|
externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
|
||||||
|
|
||||||
externalServ := &http.Server{
|
externalServ := &http.Server{
|
||||||
Addr: string(externalAddr),
|
Addr: externalHTTPAddr.Address,
|
||||||
WriteTimeout: HTTPServerTimeout,
|
WriteTimeout: HTTPServerTimeout,
|
||||||
Handler: externalRouter,
|
Handler: externalRouter,
|
||||||
BaseContext: func(_ net.Listener) context.Context {
|
BaseContext: func(_ net.Listener) context.Context {
|
||||||
|
@ -419,7 +418,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
|
|
||||||
b.startupLock.Unlock()
|
b.startupLock.Unlock()
|
||||||
|
|
||||||
if externalAddr != NoListener {
|
if externalHTTPAddr.Enabled() {
|
||||||
go func() {
|
go func() {
|
||||||
var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
|
var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
|
||||||
logrus.Infof("Starting external listener on %s", externalServ.Addr)
|
logrus.Infof("Starting external listener on %s", externalServ.Addr)
|
||||||
|
@ -437,9 +436,30 @@ func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := externalServ.ListenAndServe(); err != nil {
|
if externalHTTPAddr.IsUnixSocket() {
|
||||||
if err != http.ErrServerClosed {
|
err := os.Remove(externalHTTPAddr.Address)
|
||||||
logrus.WithError(err).Fatal("failed to serve HTTP")
|
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
logrus.WithError(err).Fatal("failed to remove existing unix socket")
|
||||||
|
}
|
||||||
|
listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Fatal("failed to serve unix socket")
|
||||||
|
}
|
||||||
|
err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Fatal("failed to set unix socket permissions")
|
||||||
|
}
|
||||||
|
if err := externalServ.Serve(listener); err != nil {
|
||||||
|
if err != http.ErrServerClosed {
|
||||||
|
logrus.WithError(err).Fatal("failed to serve unix socket")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
if err := externalServ.ListenAndServe(); err != nil {
|
||||||
|
if err != http.ErrServerClosed {
|
||||||
|
logrus.WithError(err).Fatal("failed to serve HTTP")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,10 +2,13 @@ package base_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -18,7 +21,7 @@ import (
|
||||||
//go:embed static/*.gotmpl
|
//go:embed static/*.gotmpl
|
||||||
var staticContent embed.FS
|
var staticContent embed.FS
|
||||||
|
|
||||||
func TestLandingPage(t *testing.T) {
|
func TestLandingPage_Tcp(t *testing.T) {
|
||||||
// generate the expected result
|
// generate the expected result
|
||||||
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
|
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
|
||||||
expectedRes := &bytes.Buffer{}
|
expectedRes := &bytes.Buffer{}
|
||||||
|
@ -35,7 +38,9 @@ func TestLandingPage(t *testing.T) {
|
||||||
s.Close()
|
s.Close()
|
||||||
|
|
||||||
// start base with the listener and wait for it to be started
|
// start base with the listener and wait for it to be started
|
||||||
go b.SetupAndServeHTTP(config.HTTPAddress(s.URL), nil, nil)
|
address, err := config.HTTPAddress(s.URL)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
go b.SetupAndServeHTTP(address, nil, nil)
|
||||||
time.Sleep(time.Millisecond * 10)
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
|
||||||
// When hitting /, we should be redirected to /_matrix/static, which should contain the landing page
|
// When hitting /, we should be redirected to /_matrix/static, which should contain the landing page
|
||||||
|
@ -55,3 +60,43 @@ func TestLandingPage(t *testing.T) {
|
||||||
// Using .String() for user friendly output
|
// Using .String() for user friendly output
|
||||||
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
|
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLandingPage_UnixSocket(t *testing.T) {
|
||||||
|
// generate the expected result
|
||||||
|
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
|
||||||
|
expectedRes := &bytes.Buffer{}
|
||||||
|
err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{
|
||||||
|
"Version": internal.VersionString(),
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
b, _, _ := testrig.Base(nil)
|
||||||
|
defer b.Close()
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
socket := path.Join(tempDir, "socket")
|
||||||
|
// start base with the listener and wait for it to be started
|
||||||
|
address := config.UnixSocketAddress(socket, 0755)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
go b.SetupAndServeHTTP(address, nil, nil)
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||||
|
return net.Dial("unix", socket)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := client.Get("http://unix/")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
// read the response
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
_, err = buf.ReadFrom(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Using .String() for user friendly output
|
||||||
|
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
|
||||||
|
}
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -131,20 +130,6 @@ func (d DataSource) IsPostgres() bool {
|
||||||
// A Topic in kafka.
|
// A Topic in kafka.
|
||||||
type Topic string
|
type Topic string
|
||||||
|
|
||||||
// An Address to listen on.
|
|
||||||
type Address string
|
|
||||||
|
|
||||||
// An HTTPAddress to listen on, starting with either http:// or https://.
|
|
||||||
type HTTPAddress string
|
|
||||||
|
|
||||||
func (h HTTPAddress) Address() (Address, error) {
|
|
||||||
url, err := url.Parse(string(h))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return Address(url.Host), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileSizeBytes is a file size in bytes
|
// FileSizeBytes is a file size in bytes
|
||||||
type FileSizeBytes int64
|
type FileSizeBytes int64
|
||||||
|
|
||||||
|
|
45
setup/config/config_address.go
Normal file
45
setup/config/config_address.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NetworkTCP = "tcp"
|
||||||
|
NetworkUnix = "unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServerAddress struct {
|
||||||
|
Address string
|
||||||
|
Scheme string
|
||||||
|
UnixSocketPermission fs.FileMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ServerAddress) Enabled() bool {
|
||||||
|
return s.Address != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ServerAddress) IsUnixSocket() bool {
|
||||||
|
return s.Scheme == NetworkUnix
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ServerAddress) Network() string {
|
||||||
|
if s.Scheme == NetworkUnix {
|
||||||
|
return NetworkUnix
|
||||||
|
} else {
|
||||||
|
return NetworkTCP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress {
|
||||||
|
return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HTTPAddress(urlAddress string) (ServerAddress, error) {
|
||||||
|
parsedUrl, err := url.Parse(urlAddress)
|
||||||
|
if err != nil {
|
||||||
|
return ServerAddress{}, err
|
||||||
|
}
|
||||||
|
return ServerAddress{parsedUrl.Host, parsedUrl.Scheme, 0}, nil
|
||||||
|
}
|
25
setup/config/config_address_test.go
Normal file
25
setup/config/config_address_test.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHttpAddress_ParseGood(t *testing.T) {
|
||||||
|
address, err := HTTPAddress("http://localhost:123")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "localhost:123", address.Address)
|
||||||
|
assert.Equal(t, "tcp", address.Network())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHttpAddress_ParseBad(t *testing.T) {
|
||||||
|
_, err := HTTPAddress(":")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnixSocketAddress_Network(t *testing.T) {
|
||||||
|
address := UnixSocketAddress("/tmp", fs.FileMode(0755))
|
||||||
|
assert.Equal(t, "unix", address.Network())
|
||||||
|
}
|
|
@ -253,7 +253,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo
|
||||||
var res MSC2836EventRelationshipsResponse
|
var res MSC2836EventRelationshipsResponse
|
||||||
var returnEvents []*gomatrixserverlib.HeaderedEvent
|
var returnEvents []*gomatrixserverlib.HeaderedEvent
|
||||||
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
|
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
|
||||||
event := rc.getLocalEvent(rc.req.EventID)
|
event := rc.getLocalEvent(rc.req.RoomID, rc.req.EventID)
|
||||||
if event == nil {
|
if event == nil {
|
||||||
event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID)
|
event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID)
|
||||||
}
|
}
|
||||||
|
@ -592,7 +592,7 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelation
|
||||||
// lookForEvent returns the event for the event ID given, by trying to query remote servers
|
// lookForEvent returns the event for the event ID given, by trying to query remote servers
|
||||||
// if the event ID is unknown via /event_relationships.
|
// if the event ID is unknown via /event_relationships.
|
||||||
func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
|
func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
|
||||||
event := rc.getLocalEvent(eventID)
|
event := rc.getLocalEvent(rc.req.RoomID, eventID)
|
||||||
if event == nil {
|
if event == nil {
|
||||||
queryRes := rc.remoteEventRelationships(eventID)
|
queryRes := rc.remoteEventRelationships(eventID)
|
||||||
if queryRes != nil {
|
if queryRes != nil {
|
||||||
|
@ -622,9 +622,10 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
|
func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.HeaderedEvent {
|
||||||
var queryEventsRes roomserver.QueryEventsByIDResponse
|
var queryEventsRes roomserver.QueryEventsByIDResponse
|
||||||
err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
|
err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
|
||||||
|
RoomID: roomID,
|
||||||
EventIDs: []string{eventID},
|
EventIDs: []string{eventID},
|
||||||
}, &queryEventsRes)
|
}, &queryEventsRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -212,6 +212,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
|
||||||
// Finally, work out if there are any more events missing.
|
// Finally, work out if there are any more events missing.
|
||||||
if len(missingEventIDs) > 0 {
|
if len(missingEventIDs) > 0 {
|
||||||
eventsReq := &api.QueryEventsByIDRequest{
|
eventsReq := &api.QueryEventsByIDRequest{
|
||||||
|
RoomID: ev.RoomID(),
|
||||||
EventIDs: missingEventIDs,
|
EventIDs: missingEventIDs,
|
||||||
}
|
}
|
||||||
eventsRes := &api.QueryEventsByIDResponse{}
|
eventsRes := &api.QueryEventsByIDResponse{}
|
||||||
|
|
|
@ -109,7 +109,7 @@ func GetMemberships(
|
||||||
}
|
}
|
||||||
|
|
||||||
qryRes := &api.QueryEventsByIDResponse{}
|
qryRes := &api.QueryEventsByIDResponse{}
|
||||||
if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs}, qryRes); err != nil {
|
if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs, RoomID: roomID}, qryRes); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
|
@ -187,8 +187,8 @@ func Test_UserStatistics(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Users not active for one/two month", func(t *testing.T) {
|
t.Run("Users not active for one/two month", func(t *testing.T) {
|
||||||
mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, -2, 0))
|
mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, 0, -60))
|
||||||
mustUpdateDeviceLastSeen(t, ctx, db, "user2", time.Now().AddDate(0, -1, 0))
|
mustUpdateDeviceLastSeen(t, ctx, db, "user2", time.Now().AddDate(0, 0, -30))
|
||||||
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
|
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
@ -224,9 +224,9 @@ func Test_UserStatistics(t *testing.T) {
|
||||||
- Where account creation and last_seen are > 30 days apart
|
- Where account creation and last_seen are > 30 days apart
|
||||||
*/
|
*/
|
||||||
t.Run("R30Users tests", func(t *testing.T) {
|
t.Run("R30Users tests", func(t *testing.T) {
|
||||||
mustUserUpdateRegistered(t, ctx, db, "user1", time.Now().AddDate(0, -2, 0))
|
mustUserUpdateRegistered(t, ctx, db, "user1", time.Now().AddDate(0, 0, -60))
|
||||||
mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now())
|
mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now())
|
||||||
mustUserUpdateRegistered(t, ctx, db, "user4", time.Now().AddDate(0, -2, 0))
|
mustUserUpdateRegistered(t, ctx, db, "user4", time.Now().AddDate(0, 0, -60))
|
||||||
mustUpdateDeviceLastSeen(t, ctx, db, "user4", time.Now())
|
mustUpdateDeviceLastSeen(t, ctx, db, "user4", time.Now())
|
||||||
startTime := time.Now().AddDate(0, 0, -2)
|
startTime := time.Now().AddDate(0, 0, -2)
|
||||||
err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24))
|
err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24))
|
||||||
|
|
Loading…
Reference in a new issue