Merge branch 'main' into neilalexander/cfg

This commit is contained in:
Neil Alexander 2022-07-25 15:03:31 +01:00 committed by GitHub
commit 506206eee4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
111 changed files with 1654 additions and 1306 deletions

View file

@ -223,6 +223,31 @@ jobs:
- name: Test upgrade - name: Test upgrade
run: ./dendrite-upgrade-tests --head . run: ./dendrite-upgrade-tests --head .
# run database upgrade tests, skipping over one version
upgrade_test_direct:
name: Upgrade tests from HEAD-2
timeout-minutes: 20
needs: initial-tests-done
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v2
with:
go-version: "1.18"
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-upgrade
- name: Build upgrade-tests
run: go build ./cmd/dendrite-upgrade-tests
- name: Test upgrade
run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head .
# run Sytest in different variations # run Sytest in different variations
sytest: sytest:
timeout-minutes: 20 timeout-minutes: 20
@ -359,7 +384,7 @@ jobs:
integration-tests-done: integration-tests-done:
name: Integration tests passed name: Integration tests passed
needs: [initial-tests-done, upgrade_test, sytest, complement] needs: [initial-tests-done, upgrade_test, upgrade_test_direct, sytest, complement]
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
steps: steps:

View file

@ -8,7 +8,6 @@ COPY . /build
RUN mkdir -p bin RUN mkdir -p bin
RUN go build -trimpath -o bin/ ./cmd/dendrite-monolith-server RUN go build -trimpath -o bin/ ./cmd/dendrite-monolith-server
RUN go build -trimpath -o bin/ ./cmd/goose
RUN go build -trimpath -o bin/ ./cmd/create-account RUN go build -trimpath -o bin/ ./cmd/create-account
RUN go build -trimpath -o bin/ ./cmd/generate-keys RUN go build -trimpath -o bin/ ./cmd/generate-keys

View file

@ -8,7 +8,6 @@ COPY . /build
RUN mkdir -p bin RUN mkdir -p bin
RUN go build -trimpath -o bin/ ./cmd/dendrite-polylith-multi RUN go build -trimpath -o bin/ ./cmd/dendrite-polylith-multi
RUN go build -trimpath -o bin/ ./cmd/goose
RUN go build -trimpath -o bin/ ./cmd/create-account RUN go build -trimpath -o bin/ ./cmd/create-account
RUN go build -trimpath -o bin/ ./cmd/generate-keys RUN go build -trimpath -o bin/ ./cmd/generate-keys

View file

@ -59,6 +59,7 @@ func AddPublicRoutes(
routing.Setup( routing.Setup(
base.PublicClientAPIMux, base.PublicClientAPIMux,
base.PublicWellKnownAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux, base.DendriteAdminMux,
cfg, rsAPI, asAPI, cfg, rsAPI, asAPI,

View file

@ -48,7 +48,7 @@ import (
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup( func Setup(
publicAPIMux, synapseAdminRouter, dendriteAdminRouter *mux.Router, publicAPIMux, wkMux, synapseAdminRouter, dendriteAdminRouter *mux.Router,
cfg *config.ClientAPI, cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI, asAPI appserviceAPI.AppServiceInternalAPI,
@ -74,6 +74,26 @@ func Setup(
unstableFeatures["org.matrix."+msc] = true unstableFeatures["org.matrix."+msc] = true
} }
if cfg.Matrix.WellKnownClientName != "" {
logrus.Infof("Setting m.homeserver base_url as %s at /.well-known/matrix/client", cfg.Matrix.WellKnownClientName)
wkMux.Handle("/client", httputil.MakeExternalAPI("wellknown", func(r *http.Request) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct {
HomeserverName struct {
BaseUrl string `json:"base_url"`
} `json:"m.homeserver"`
}{
HomeserverName: struct {
BaseUrl string `json:"base_url"`
}{
BaseUrl: cfg.Matrix.WellKnownClientName,
},
},
}
})).Methods(http.MethodGet, http.MethodOptions)
}
publicAPIMux.Handle("/versions", publicAPIMux.Handle("/versions",
httputil.MakeExternalAPI("versions", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("versions", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -37,6 +37,7 @@ var (
flagBuildConcurrency = flag.Int("build-concurrency", runtime.NumCPU(), "The amount of build concurrency when building images") flagBuildConcurrency = flag.Int("build-concurrency", runtime.NumCPU(), "The amount of build concurrency when building images")
flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github") flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github")
flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.") flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.")
flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done")
alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+") alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+")
) )
@ -229,7 +230,7 @@ func getAndSortVersionsFromGithub(httpClient *http.Client) (semVers []*semver.Ve
return semVers, nil return semVers, nil
} }
func calculateVersions(cli *http.Client, from, to string) []string { func calculateVersions(cli *http.Client, from, to string, direct bool) []string {
semvers, err := getAndSortVersionsFromGithub(cli) semvers, err := getAndSortVersionsFromGithub(cli)
if err != nil { if err != nil {
log.Fatalf("failed to collect semvers from github: %s", err) log.Fatalf("failed to collect semvers from github: %s", err)
@ -284,6 +285,9 @@ func calculateVersions(cli *http.Client, from, to string) []string {
if to == HEAD { if to == HEAD {
versions = append(versions, HEAD) versions = append(versions, HEAD)
} }
if direct {
versions = []string{versions[0], versions[len(versions)-1]}
}
return versions return versions
} }
@ -461,7 +465,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
cleanup(dockerClient) cleanup(dockerClient)
versions := calculateVersions(httpClient, *flagFrom, *flagTo) versions := calculateVersions(httpClient, *flagFrom, *flagTo, *flagDirect)
log.Printf("Testing dendrite versions: %v\n", versions) log.Printf("Testing dendrite versions: %v\n", versions)
branchToImageID := buildDendriteImages(httpClient, dockerClient, *flagTempDir, *flagBuildConcurrency, versions) branchToImageID := buildDendriteImages(httpClient, dockerClient, *flagTempDir, *flagBuildConcurrency, versions)

View file

@ -1,109 +0,0 @@
## Database migrations
We use [goose](https://github.com/pressly/goose) to handle database migrations. This allows us to execute
both SQL deltas (e.g `ALTER TABLE ...`) as well as manipulate data in the database in Go using Go functions.
To run a migration, the `goose` binary in this directory needs to be built:
```
$ go build ./cmd/goose
```
This binary allows Dendrite databases to be upgraded and downgraded. Sample usage for upgrading the roomserver database:
```
# for sqlite
$ ./goose -dir roomserver/storage/sqlite3/deltas sqlite3 ./roomserver.db up
# for postgres
$ ./goose -dir roomserver/storage/postgres/deltas postgres "user=dendrite dbname=dendrite sslmode=disable" up
```
For a full list of options, including rollbacks, see https://github.com/pressly/goose or use `goose` with no args.
### Rationale
Dendrite creates tables on startup using `CREATE TABLE IF NOT EXISTS`, so you might think that we should also
apply version upgrades on startup as well. This is convenient and doesn't involve an additional binary to run
which complicates upgrades. However, combining the upgrade mechanism and the server binary makes it difficult
to handle rollbacks. Firstly, how do you specify you wish to rollback? We would have to add additional flags
to the main server binary to say "rollback to version X". Secondly, if you roll back the server binary from
version 5 to version 4, the version 4 binary doesn't know how to rollback the database from version 5 to
version 4! For these reasons, we prefer to have a separate "upgrade" binary which is run for database upgrades.
Rather than roll-our-own migration tool, we decided to use [goose](https://github.com/pressly/goose) as it supports
complex migrations in Go code in addition to just executing SQL deltas. Other alternatives like
`github.com/golang-migrate/migrate` [do not support](https://github.com/golang-migrate/migrate/issues/15) these
kinds of complex migrations.
### Adding new deltas
You can add `.sql` or `.go` files manually or you can use goose to create them for you.
If you only want to add a SQL delta then run:
```
$ ./goose -dir serverkeyapi/storage/sqlite3/deltas sqlite3 ./foo.db create new_col sql
2020/09/09 14:37:43 Created new file: serverkeyapi/storage/sqlite3/deltas/20200909143743_new_col.sql
```
In this case, the version number is `20200909143743`. The important thing is that it is always increasing.
Then add up/downgrade SQL commands to the created file which looks like:
```sql
-- +goose Up
-- +goose StatementBegin
SELECT 'up SQL query';
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
SELECT 'down SQL query';
-- +goose StatementEnd
```
You __must__ keep the `+goose` annotations. You'll need to repeat this process for Postgres.
For complex Go migrations:
```
$ ./goose -dir serverkeyapi/storage/sqlite3/deltas sqlite3 ./foo.db create complex_update go
2020/09/09 14:40:38 Created new file: serverkeyapi/storage/sqlite3/deltas/20200909144038_complex_update.go
```
Then modify the created `.go` file which looks like:
```go
package migrations
import (
"database/sql"
"fmt"
"github.com/pressly/goose"
)
func init() {
goose.AddMigration(upComplexUpdate, downComplexUpdate)
}
func upComplexUpdate(tx *sql.Tx) error {
// This code is executed when the migration is applied.
return nil
}
func downComplexUpdate(tx *sql.Tx) error {
// This code is executed when the migration is rolled back.
return nil
}
```
You __must__ import the package in `/cmd/goose/main.go` so `func init()` gets called.
#### Database limitations
- SQLite3 does NOT support `ALTER TABLE table_name DROP COLUMN` - you would have to rename the column or drop the table
entirely and recreate it. ([example](https://github.com/matrix-org/dendrite/blob/master/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.sql))
More information: [sqlite.org](https://www.sqlite.org/lang_altertable.html)

View file

@ -1,154 +0,0 @@
// This is custom goose binary
package main
import (
"flag"
"fmt"
"log"
"os"
"github.com/pressly/goose"
pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
const (
AppService = "appservice"
FederationSender = "federationapi"
KeyServer = "keyserver"
MediaAPI = "mediaapi"
RoomServer = "roomserver"
SigningKeyServer = "signingkeyserver"
SyncAPI = "syncapi"
UserAPI = "userapi"
)
var (
dir = flags.String("dir", "", "directory with migration files")
flags = flag.NewFlagSet("goose", flag.ExitOnError)
component = flags.String("component", "", "dendrite component name")
knownDBs = []string{
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
}
)
// nolint: gocyclo
func main() {
err := flags.Parse(os.Args[1:])
if err != nil {
panic(err.Error())
}
args := flags.Args()
if len(args) < 3 {
fmt.Println(
`Usage: goose [OPTIONS] DRIVER DBSTRING COMMAND
Drivers:
postgres
sqlite3
Examples:
goose -component roomserver sqlite3 ./roomserver.db status
goose -component roomserver sqlite3 ./roomserver.db up
goose -component roomserver postgres "user=dendrite dbname=dendrite sslmode=disable" status
Options:
-component string
Dendrite component name e.g roomserver, signingkeyserver, clientapi, syncapi
-table string
migrations table name (default "goose_db_version")
-h print help
-v enable verbose mode
-dir string
directory with migration files, only relevant when creating new migrations.
-version
print version
Commands:
up Migrate the DB to the most recent version available
up-by-one Migrate the DB up by 1
up-to VERSION Migrate the DB to a specific VERSION
down Roll back the version by 1
down-to VERSION Roll back to a specific VERSION
redo Re-run the latest migration
reset Roll back all migrations
status Dump the migration status for the current DB
version Print the current version of the database
create NAME [sql|go] Creates new migration file with the current timestamp
fix Apply sequential ordering to migrations`,
)
return
}
engine := args[0]
if engine != "sqlite3" && engine != "postgres" {
fmt.Println("engine must be one of 'sqlite3' or 'postgres'")
return
}
knownComponent := false
for _, c := range knownDBs {
if c == *component {
knownComponent = true
break
}
}
if !knownComponent {
fmt.Printf("component must be one of %v\n", knownDBs)
return
}
if engine == "sqlite3" {
loadSQLiteDeltas(*component)
} else {
loadPostgresDeltas(*component)
}
dbstring, command := args[1], args[2]
db, err := goose.OpenDBWithDriver(engine, dbstring)
if err != nil {
log.Fatalf("goose: failed to open DB: %v\n", err)
}
defer func() {
if err := db.Close(); err != nil {
log.Fatalf("goose: failed to close DB: %v\n", err)
}
}()
arguments := []string{}
if len(args) > 3 {
arguments = append(arguments, args[3:]...)
}
// goose demands a directory even though we don't use it for upgrades
d := *dir
if d == "" {
d = os.TempDir()
}
if err := goose.Run(command, db, d, arguments...); err != nil {
log.Fatalf("goose %v: %v", command, err)
}
}
func loadSQLiteDeltas(component string) {
switch component {
case UserAPI:
slusers.LoadFromGoose()
}
}
func loadPostgresDeltas(component string) {
switch component {
case UserAPI:
pgusers.LoadFromGoose()
}
}

View file

@ -64,6 +64,10 @@ global:
# e.g. localhost:443 # e.g. localhost:443
well_known_server_name: "" well_known_server_name: ""
# The server name to delegate client-server communications to, with optional port
# e.g. localhost:443
well_known_client_name: ""
# Lists of domains that the server will trust as identity servers to verify third # Lists of domains that the server will trust as identity servers to verify third
# party identifiers such as phone numbers and email addresses. # party identifiers such as phone numbers and email addresses.
trusted_third_party_id_servers: trusted_third_party_id_servers:

View file

@ -54,6 +54,10 @@ global:
# e.g. localhost:443 # e.g. localhost:443
well_known_server_name: "" well_known_server_name: ""
# The server name to delegate client-server communications to, with optional port
# e.g. localhost:443
well_known_client_name: ""
# Lists of domains that the server will trust as identity servers to verify third # Lists of domains that the server will trust as identity servers to verify third
# party identifiers such as phone numbers and email addresses. # party identifiers such as phone numbers and email addresses.
trusted_third_party_id_servers: trusted_third_party_id_servers:
@ -125,7 +129,7 @@ app_service_api:
# Database configuration for this component. # Database configuration for this component.
database: database:
connection_string: postgresql://username@password:hostname/dendrite_appservice?sslmode=disable connection_string: postgresql://username:password@hostname/dendrite_appservice?sslmode=disable
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -199,7 +203,7 @@ federation_api:
external_api: external_api:
listen: http://[::]:8072 listen: http://[::]:8072
database: database:
connection_string: postgresql://username@password:hostname/dendrite_federationapi?sslmode=disable connection_string: postgresql://username:password@hostname/dendrite_federationapi?sslmode=disable
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -236,7 +240,7 @@ key_server:
listen: http://[::]:7779 # The listen address for incoming API requests listen: http://[::]:7779 # The listen address for incoming API requests
connect: http://key_server:7779 # The connect address for other components to use connect: http://key_server:7779 # The connect address for other components to use
database: database:
connection_string: postgresql://username@password:hostname/dendrite_keyserver?sslmode=disable connection_string: postgresql://username:password@hostname/dendrite_keyserver?sslmode=disable
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -249,7 +253,7 @@ media_api:
external_api: external_api:
listen: http://[::]:8074 listen: http://[::]:8074
database: database:
connection_string: postgresql://username@password:hostname/dendrite_mediaapi?sslmode=disable connection_string: postgresql://username:password@hostname/dendrite_mediaapi?sslmode=disable
max_open_conns: 5 max_open_conns: 5
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -286,7 +290,7 @@ mscs:
# - msc2836 # (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) # - msc2836 # (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836)
# - msc2946 # (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946) # - msc2946 # (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946)
database: database:
connection_string: postgresql://username@password:hostname/dendrite_mscs?sslmode=disable connection_string: postgresql://username:password@hostname/dendrite_mscs?sslmode=disable
max_open_conns: 5 max_open_conns: 5
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -297,7 +301,7 @@ room_server:
listen: http://[::]:7770 # The listen address for incoming API requests listen: http://[::]:7770 # The listen address for incoming API requests
connect: http://room_server:7770 # The connect address for other components to use connect: http://room_server:7770 # The connect address for other components to use
database: database:
connection_string: postgresql://username@password:hostname/dendrite_roomserver?sslmode=disable connection_string: postgresql://username:password@hostname/dendrite_roomserver?sslmode=disable
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -310,7 +314,7 @@ sync_api:
external_api: external_api:
listen: http://[::]:8073 listen: http://[::]:8073
database: database:
connection_string: postgresql://username@password:hostname/dendrite_syncapi?sslmode=disable connection_string: postgresql://username:password@hostname/dendrite_syncapi?sslmode=disable
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -326,7 +330,7 @@ user_api:
listen: http://[::]:7781 # The listen address for incoming API requests listen: http://[::]:7781 # The listen address for incoming API requests
connect: http://user_api:7781 # The connect address for other components to use connect: http://user_api:7781 # The connect address for other components to use
account_database: account_database:
connection_string: postgresql://username@password:hostname/dendrite_userapi?sslmode=disable connection_string: postgresql://username:password@hostname/dendrite_userapi?sslmode=disable
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -24,7 +24,7 @@ Unfortunately we can't accept contributions without it.
## Getting up and running ## Getting up and running
See the [Installation](INSTALL.md) section for information on how to build an See the [Installation](installation) section for information on how to build an
instance of Dendrite. You will likely need this in order to test your changes. instance of Dendrite. You will likely need this in order to test your changes.
## Code style ## Code style

View file

@ -86,9 +86,12 @@ would be a huge help too, as that will help us to understand where the memory us
You may need to revisit the connection limit of your PostgreSQL server and/or make changes to the `max_connections` lines in your Dendrite configuration. Be aware that each Dendrite component opens its own database connections and has its own connection limit, even in monolith mode! You may need to revisit the connection limit of your PostgreSQL server and/or make changes to the `max_connections` lines in your Dendrite configuration. Be aware that each Dendrite component opens its own database connections and has its own connection limit, even in monolith mode!
## What is being reported when enabling anonymous stats? ## What is being reported when enabling phone-home statistics?
If anonymous stats reporting is enabled, the following data is send to the defined endpoint. Phone-home statistics contain your server's domain name, some configuration information about
your deployment and aggregated information about active users on your deployment. They are sent
to the endpoint URL configured in your Dendrite configuration file only. The following is an
example of the data that is sent:
```json ```json
{ {
@ -106,7 +109,7 @@ If anonymous stats reporting is enabled, the following data is send to the defin
"go_arch": "amd64", "go_arch": "amd64",
"go_os": "linux", "go_os": "linux",
"go_version": "go1.16.13", "go_version": "go1.16.13",
"homeserver": "localhost:8800", "homeserver": "my.domain.com",
"log_level": "trace", "log_level": "trace",
"memory_rss": 93452, "memory_rss": 93452,
"monolith": true, "monolith": true,

View file

@ -233,6 +233,8 @@ GEM
multipart-post (2.1.1) multipart-post (2.1.1)
nokogiri (1.13.6-arm64-darwin) nokogiri (1.13.6-arm64-darwin)
racc (~> 1.4) racc (~> 1.4)
nokogiri (1.13.6-x86_64-linux)
racc (~> 1.4)
octokit (4.22.0) octokit (4.22.0)
faraday (>= 0.9) faraday (>= 0.9)
sawyer (~> 0.8.0, >= 0.5.3) sawyer (~> 0.8.0, >= 0.5.3)
@ -263,7 +265,7 @@ GEM
thread_safe (0.3.6) thread_safe (0.3.6)
typhoeus (1.4.0) typhoeus (1.4.0)
ethon (>= 0.9.0) ethon (>= 0.9.0)
tzinfo (1.2.9) tzinfo (1.2.10)
thread_safe (~> 0.1) thread_safe (~> 0.1)
unf (0.1.4) unf (0.1.4)
unf_ext unf_ext
@ -273,11 +275,11 @@ GEM
PLATFORMS PLATFORMS
arm64-darwin-21 arm64-darwin-21
x86_64-linux
DEPENDENCIES DEPENDENCIES
github-pages (~> 226) github-pages (~> 226)
jekyll-feed (~> 0.15.1) jekyll-feed (~> 0.15.1)
minima (~> 2.5.1)
BUNDLED WITH BUNDLED WITH
2.3.7 2.3.7

View file

@ -32,6 +32,15 @@ To create a new **admin account**, add the `-admin` flag:
./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -admin ./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -admin
``` ```
An example of using `create-account` when running in **Docker**, having found the `CONTAINERNAME` from `docker ps`:
```bash
docker exec -it CONTAINERNAME /usr/bin/create-account -config /path/to/dendrite.yaml -username USERNAME
```
```bash
docker exec -it CONTAINERNAME /usr/bin/create-account -config /path/to/dendrite.yaml -username USERNAME -admin
```
## Using shared secret registration ## Using shared secret registration
Dendrite supports the Synapse-compatible shared secret registration endpoint. Dendrite supports the Synapse-compatible shared secret registration endpoint.

View file

@ -1,68 +0,0 @@
{
# debug
admin off
email example@example.com
default_sni example.com
# Debug endpoint
# acme_ca https://acme-staging-v02.api.letsencrypt.org/directory
}
#######################################################################
# Snippets
#______________________________________________________________________
(handle_errors_maintenance) {
handle_errors {
@maintenance expression {http.error.status_code} == 502
rewrite @maintenance maintenance.html
root * "/path/to/service/pages"
file_server
}
}
(matrix-well-known-header) {
# Headers
header Access-Control-Allow-Origin "*"
header Access-Control-Allow-Methods "GET, POST, PUT, DELETE, OPTIONS"
header Access-Control-Allow-Headers "Origin, X-Requested-With, Content-Type, Accept, Authorization"
header Content-Type "application/json"
}
#######################################################################
example.com {
# ...
handle /.well-known/matrix/server {
import matrix-well-known-header
respond `{ "m.server": "matrix.example.com:443" }` 200
}
handle /.well-known/matrix/client {
import matrix-well-known-header
respond `{ "m.homeserver": { "base_url": "https://matrix.example.com" } }` 200
}
import handle_errors_maintenance
}
example.com:8448 {
# server<->server HTTPS traffic
reverse_proxy http://dendrite-host:8008
}
matrix.example.com {
handle /_matrix/* {
# client<->server HTTPS traffic
reverse_proxy http://dendrite-host:8008
}
handle_path /* {
# Client webapp (Element SPA or ...)
file_server {
root /path/to/www/example.com/matrix-web-client/
}
}
}

View file

@ -0,0 +1,57 @@
# Sample Caddyfile for using Caddy in front of Dendrite.
#
# Customize email address and domain names.
# Optional settings commented out.
#
# BE SURE YOUR DOMAINS ARE POINTED AT YOUR SERVER FIRST.
# Documentation: https://caddyserver.com/docs/
#
# Bonus tip: If your IP address changes, use Caddy's
# dynamic DNS plugin to update your DNS records to
# point to your new IP automatically:
# https://github.com/mholt/caddy-dynamicdns
#
# Global options block
{
# In case there is a problem with your certificates.
# email example@example.com
# Turn off the admin endpoint if you don't need graceful config
# changes and/or are running untrusted code on your machine.
# admin off
# Enable this if your clients don't send ServerName in TLS handshakes.
# default_sni example.com
# Enable debug mode for verbose logging.
# debug
# Use Let's Encrypt's staging endpoint for testing.
# acme_ca https://acme-staging-v02.api.letsencrypt.org/directory
# If you're port-forwarding HTTP/HTTPS ports from 80/443 to something
# else, enable these and put the alternate port numbers here.
# http_port 8080
# https_port 8443
}
# The server name of your matrix homeserver. This example shows
# "well-known delegation" from the registered domain to a subdomain,
# which is only needed if your server_name doesn't match your Matrix
# homeserver URL (i.e. you can show users a vanity domain that looks
# nice and is easy to remember but still have your Matrix server on
# its own subdomain or hosted service).
example.com {
header /.well-known/matrix/* Content-Type application/json
header /.well-known/matrix/* Access-Control-Allow-Origin *
respond /.well-known/matrix/server `{"m.server": "matrix.example.com:443"}`
respond /.well-known/matrix/client `{"m.homeserver": {"base_url": "https://matrix.example.com"}}`
}
# The actual domain name whereby your Matrix server is accessed.
matrix.example.com {
# Set localhost:8008 to the address of your Dendrite server, if different
reverse_proxy /_matrix/* localhost:8008
}

View file

@ -0,0 +1,66 @@
# Sample Caddyfile for using Caddy in front of Dendrite.
#
# Customize email address and domain names.
# Optional settings commented out.
#
# BE SURE YOUR DOMAINS ARE POINTED AT YOUR SERVER FIRST.
# Documentation: https://caddyserver.com/docs/
#
# Bonus tip: If your IP address changes, use Caddy's
# dynamic DNS plugin to update your DNS records to
# point to your new IP automatically:
# https://github.com/mholt/caddy-dynamicdns
#
# Global options block
{
# In case there is a problem with your certificates.
# email example@example.com
# Turn off the admin endpoint if you don't need graceful config
# changes and/or are running untrusted code on your machine.
# admin off
# Enable this if your clients don't send ServerName in TLS handshakes.
# default_sni example.com
# Enable debug mode for verbose logging.
# debug
# Use Let's Encrypt's staging endpoint for testing.
# acme_ca https://acme-staging-v02.api.letsencrypt.org/directory
# If you're port-forwarding HTTP/HTTPS ports from 80/443 to something
# else, enable these and put the alternate port numbers here.
# http_port 8080
# https_port 8443
}
# The server name of your matrix homeserver. This example shows
# "well-known delegation" from the registered domain to a subdomain,
# which is only needed if your server_name doesn't match your Matrix
# homeserver URL (i.e. you can show users a vanity domain that looks
# nice and is easy to remember but still have your Matrix server on
# its own subdomain or hosted service).
example.com {
header /.well-known/matrix/* Content-Type application/json
header /.well-known/matrix/* Access-Control-Allow-Origin *
respond /.well-known/matrix/server `{"m.server": "matrix.example.com:443"}`
respond /.well-known/matrix/client `{"m.homeserver": {"base_url": "https://matrix.example.com"}}`
}
# The actual domain name whereby your Matrix server is accessed.
matrix.example.com {
# Change the end of each reverse_proxy line to the correct
# address for your various services.
@sync_api {
path_regexp /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/messages)$
}
reverse_proxy @sync_api sync_api:8073
reverse_proxy /_matrix/client* client_api:8071
reverse_proxy /_matrix/federation* federation_api:8071
reverse_proxy /_matrix/key* federation_api:8071
reverse_proxy /_matrix/media* media_api:8071
}

View file

@ -2,7 +2,7 @@
title: Starting the polylith title: Starting the polylith
parent: Installation parent: Installation
has_toc: true has_toc: true
nav_order: 9 nav_order: 10
permalink: /installation/start/polylith permalink: /installation/start/polylith
--- ---

View file

@ -2,7 +2,7 @@
title: Optimise your installation title: Optimise your installation
parent: Installation parent: Installation
has_toc: true has_toc: true
nav_order: 10 nav_order: 11
permalink: /installation/start/optimisation permalink: /installation/start/optimisation
--- ---

View file

@ -95,12 +95,13 @@ enabled.
To do so, follow the [NATS Server installation instructions](https://docs.nats.io/running-a-nats-service/introduction/installation) and then [start your NATS deployment](https://docs.nats.io/running-a-nats-service/introduction/running). JetStream must be enabled, either by passing the `-js` flag to `nats-server`, To do so, follow the [NATS Server installation instructions](https://docs.nats.io/running-a-nats-service/introduction/installation) and then [start your NATS deployment](https://docs.nats.io/running-a-nats-service/introduction/running). JetStream must be enabled, either by passing the `-js` flag to `nats-server`,
or by specifying the `store_dir` option in the the `jetstream` configuration. or by specifying the `store_dir` option in the the `jetstream` configuration.
### Reverse proxy (polylith deployments) ### Reverse proxy
Polylith deployments require a reverse proxy, such as [NGINX](https://www.nginx.com) or A reverse proxy such as [Caddy](https://caddyserver.com), [NGINX](https://www.nginx.com) or
[HAProxy](http://www.haproxy.org). Configuring those is not covered in this documentation, [HAProxy](http://www.haproxy.org) is required for polylith deployments and is useful for monolith
although a [sample configuration for NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/polylith-sample.conf) deployments. Configuring those is not covered in this documentation, although sample configurations
is provided. for [Caddy](https://github.com/matrix-org/dendrite/blob/main/docs/caddy) and
[NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx) are provided.
### Windows ### Windows

View file

@ -14,27 +14,38 @@ that take the format `@user:example.com`.
For federation to work, the server name must be resolvable by other homeservers on the internet For federation to work, the server name must be resolvable by other homeservers on the internet
— that is, the domain must be registered and properly configured with the relevant DNS records. — that is, the domain must be registered and properly configured with the relevant DNS records.
Matrix servers discover each other when federating using the following methods: Matrix servers usually discover each other when federating using the following methods:
1. If a well-known delegation exists on `example.com`, use the path server from the 1. If a well-known delegation exists on `example.com`, use the domain and port from the
well-known file to connect to the remote homeserver; well-known file to connect to the remote homeserver;
2. If a DNS SRV delegation exists on `example.com`, use the hostname and port from the DNS SRV 2. If a DNS SRV delegation exists on `example.com`, use the IP address and port from the DNS SRV
record to connect to the remote homeserver; record to connect to the remote homeserver;
3. If neither well-known or DNS SRV delegation are configured, attempt to connect to the remote 3. If neither well-known or DNS SRV delegation are configured, attempt to connect to the remote
homeserver by connecting to `example.com` port TCP/8448 using HTTPS. homeserver by connecting to `example.com` port TCP/8448 using HTTPS.
The exact details of how server name resolution works can be found in
[the spec](https://spec.matrix.org/v1.3/server-server-api/#resolving-server-names).
## TLS certificates ## TLS certificates
Matrix federation requires that valid TLS certificates are present on the domain. You must Matrix federation requires that valid TLS certificates are present on the domain. You must
obtain certificates from a publicly accepted Certificate Authority (CA). [LetsEncrypt](https://letsencrypt.org) obtain certificates from a publicly-trusted certificate authority (CA). [Let's Encrypt](https://letsencrypt.org)
is an example of such a CA that can be used. Self-signed certificates are not suitable for is a popular choice of CA because the certificates are publicly-trusted, free, and automated
federation and will typically not be accepted by other homeservers. via the ACME protocol. (Self-signed certificates are not suitable for federation and will typically
not be accepted by other homeservers.)
A common practice to help ease the management of certificates is to install a reverse proxy in Automating the renewal of TLS certificates is best practice. There are many tools for this,
front of Dendrite which manages the TLS certificates and HTTPS proxying itself. Software such as but the simplest way to achieve TLS automation is to have your reverse proxy do it for you.
[NGINX](https://www.nginx.com) and [HAProxy](http://www.haproxy.org) can be used for the task. [Caddy](https://caddyserver.com) is recommended as a production-grade reverse proxy with
Although the finer details of configuring these are not described here, you must reverse proxy automatic TLS which is commonly used in front of Dendrite. It obtains and renews TLS certificates
all `/_matrix` paths to your Dendrite server. automatically and by default as long as your domain name is pointed at your server first.
Although the finer details of [configuring Caddy](https://caddyserver.com/docs/) is not described
here, in general, you must reverse proxy all `/_matrix` paths to your Dendrite server. For example,
with Caddy:
```
reverse_proxy /_matrix/* localhost:8008
```
It is possible for the reverse proxy to listen on the standard HTTPS port TCP/443 so long as your It is possible for the reverse proxy to listen on the standard HTTPS port TCP/443 so long as your
domain delegation is configured to point to port TCP/443. domain delegation is configured to point to port TCP/443.
@ -51,17 +62,12 @@ you will be able to delegate from `example.com` to `matrix.example.com` so that
Delegation can be performed in one of two ways: Delegation can be performed in one of two ways:
* **Well-known delegation**: A well-known text file is served over HTTPS on the domain name * **Well-known delegation (preferred)**: A well-known text file is served over HTTPS on the domain
that you want to use, pointing to your server on `matrix.example.com` port 8448; name that you want to use, pointing to your server on `matrix.example.com` port 8448;
* **DNS SRV delegation**: A DNS SRV record is created on the domain name that you want to * **DNS SRV delegation (not recommended)**: See the SRV delegation section below for details.
use, pointing to your server on `matrix.example.com` port TCP/8448.
If you are using a reverse proxy to forward `/_matrix` to Dendrite, your well-known or DNS SRV If you are using a reverse proxy to forward `/_matrix` to Dendrite, your well-known or delegation
delegation must refer to the hostname and port that the reverse proxy is listening on instead. must refer to the hostname and port that the reverse proxy is listening on instead.
Well-known delegation is typically easier to set up and usually preferred. However, you can use
either or both methods to delegate. If you configure both methods of delegation, it is important
that they both agree and refer to the same hostname and port.
## Well-known delegation ## Well-known delegation
@ -74,20 +80,46 @@ and contain the following JSON document:
```json ```json
{ {
"m.server": "https://matrix.example.com:8448" "m.server": "matrix.example.com:8448"
} }
``` ```
For example, this can be done with the following Caddy config:
```
handle /.well-known/matrix/client {
header Content-Type application/json
header Access-Control-Allow-Origin *
respond `{"m.homeserver": {"base_url": "https://matrix.example.com:8448"}}`
}
```
You can also serve `.well-known` with Dendrite itself by setting the `well_known_server_name` config
option to the value you want for `m.server`. This is primarily useful if Dendrite is exposed on
`example.com:443` and you don't want to set up a separate webserver just for serving the `.well-known`
file.
```yaml
global:
...
well_known_server_name: "example.com:443"
```
## DNS SRV delegation ## DNS SRV delegation
Using DNS SRV delegation requires creating DNS SRV records on the `example.com` zone which This method is not recommended, as the behavior of SRV records in Matrix is rather unintuitive:
refer to your Dendrite installation. SRV records will only change the IP address and port that other servers connect to, they won't
affect the domain name. In technical terms, the `Host` header and TLS SNI of federation requests
will still be `example.com` even if the SRV record points at `matrix.example.com`.
Assuming that your Dendrite installation is listening for HTTPS connections at `matrix.example.com` In practice, this means that the server must be configured with valid TLS certificates for
port 8448, the DNS SRV record must have the following fields: `example.com`, rather than `matrix.example.com` as one might intuitively expect. If there's a
reverse proxy in between, the proxy configuration must be written as if it's `example.com`, as the
proxy will never see the name `matrix.example.com` in incoming requests.
* Name: `@` (or whichever term your DNS provider uses to signal the root) This behavior also means that if `example.com` and `matrix.example.com` point at the same IP
* Service: `_matrix` address, there is no reason to have a SRV record pointing at `matrix.example.com`. It can still
* Protocol: `_tcp` be used to change the port number, but it won't do anything else.
* Port: `8448`
* Target: `matrix.example.com` If you understand how SRV records work and still want to use them, the service name is `_matrix` and
the protocol is `_tcp`.

View file

@ -0,0 +1,38 @@
---
title: Building Dendrite
parent: Installation
has_toc: true
nav_order: 3
permalink: /installation/build
---
# Build all Dendrite commands
Dendrite has numerous utility commands in addition to the actual server binaries.
Build them all from the root of the source repo with `build.sh` (Linux/Mac):
```sh
./build.sh
```
or `build.cmd` (Windows):
```powershell
build.cmd
```
The resulting binaries will be placed in the `bin` subfolder.
# Installing as a monolith
You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`:
```sh
go install ./cmd/dendrite-monolith-server
```
Alternatively, you can specify a custom path for the binary to be written to using `go build`:
```sh
go build -o /usr/local/bin/ ./cmd/dendrite-monolith-server
```

View file

@ -17,7 +17,9 @@ filenames in the Dendrite configuration file and start Dendrite. The databases w
and populated automatically. and populated automatically.
Note that Dendrite **cannot share a single SQLite database across multiple components**. Each Note that Dendrite **cannot share a single SQLite database across multiple components**. Each
component must be configured with its own SQLite database filename. component must be configured with its own SQLite database filename. You will have to remove
the `global.database` section from your Dendrite config and add it to each individual section
instead in order to use SQLite.
### Connection strings ### Connection strings

View file

@ -29,5 +29,6 @@ Polylith deployments require a reverse proxy in order to ensure that requests ar
sent to the correct endpoint. You must ensure that a suitable reverse proxy is installed sent to the correct endpoint. You must ensure that a suitable reverse proxy is installed
and configured. and configured.
A [sample configuration file](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/polylith-sample.conf) Sample configurations are provided
is provided for [NGINX](https://www.nginx.com). for [Caddy](https://github.com/matrix-org/dendrite/blob/main/docs/caddy/polylith/Caddyfile)
and [NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/polylith-sample.conf).

View file

@ -1,13 +1,13 @@
--- ---
title: Populate the configuration title: Configuring Dendrite
parent: Installation parent: Installation
nav_order: 7 nav_order: 7
permalink: /installation/configuration permalink: /installation/configuration
--- ---
# Populate the configuration # Configuring Dendrite
The configuration file is used to configure Dendrite. Sample configuration files are A YAML configuration file is used to configure Dendrite. Sample configuration files are
present in the top level of the Dendrite repository: present in the top level of the Dendrite repository:
* [`dendrite-sample.monolith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.monolith.yaml) * [`dendrite-sample.monolith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.monolith.yaml)

View file

@ -1,7 +1,7 @@
--- ---
title: Generating signing keys title: Generating signing keys
parent: Installation parent: Installation
nav_order: 4 nav_order: 8
permalink: /installation/signingkeys permalink: /installation/signingkeys
--- ---

View file

@ -15,8 +15,9 @@ you can start your Dendrite monolith deployment by starting the `dendrite-monoli
./dendrite-monolith-server -config /path/to/dendrite.yaml ./dendrite-monolith-server -config /path/to/dendrite.yaml
``` ```
If you want to change the addresses or ports that Dendrite listens on, you By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses
can use the `-http-bind-address` and `-https-bind-address` command line arguments: or ports that Dendrite listens on, you can use the `-http-bind-address` and
`-https-bind-address` command line arguments:
```bash ```bash
./dendrite-monolith-server -config /path/to/dendrite.yaml \ ./dendrite-monolith-server -config /path/to/dendrite.yaml \

View file

@ -26,7 +26,6 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// InviteV2 implements /_matrix/federation/v2/invite/{roomID}/{eventID} // InviteV2 implements /_matrix/federation/v2/invite/{roomID}/{eventID}
@ -144,7 +143,6 @@ func processInvite(
// Check that the event is signed by the server sending the request. // Check that the event is signed by the server sending the request.
redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version())
if err != nil { if err != nil {
logrus.WithError(err).Errorf("XXX: invite.go")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The event JSON could not be redacted"), JSON: jsonerror.BadJSON("The event JSON could not be redacted"),

View file

@ -15,23 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpRemoveRoomsTable(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable) _, err := tx.ExecContext(ctx, `
}
func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func UpRemoveRoomsTable(tx *sql.Tx) error {
_, err := tx.Exec(`
DROP TABLE IF EXISTS federationsender_rooms; DROP TABLE IF EXISTS federationsender_rooms;
`) `)
if err != nil { if err != nil {

View file

@ -82,9 +82,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrator(d.db)
deltas.LoadRemoveRoomsTable(m) m.AddMigrations(sqlutil.Migration{
if err = m.RunDeltas(d.db, dbProperties); err != nil { Version: "federationsender: drop federationsender_rooms",
Up: deltas.UpRemoveRoomsTable,
})
err = m.Up(base.Context())
if err != nil {
return nil, err return nil, err
} }
d.Database = shared.Database{ d.Database = shared.Database{

View file

@ -15,23 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpRemoveRoomsTable(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable) _, err := tx.ExecContext(ctx, `
}
func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func UpRemoveRoomsTable(tx *sql.Tx) error {
_, err := tx.Exec(`
DROP TABLE IF EXISTS federationsender_rooms; DROP TABLE IF EXISTS federationsender_rooms;
`) `)
if err != nil { if err != nil {

View file

@ -81,9 +81,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrator(d.db)
deltas.LoadRemoveRoomsTable(m) m.AddMigrations(sqlutil.Migration{
if err = m.RunDeltas(d.db, dbProperties); err != nil { Version: "federationsender: drop federationsender_rooms",
Up: deltas.UpRemoveRoomsTable,
})
err = m.Up(base.Context())
if err != nil {
return nil, err return nil, err
} }
d.Database = shared.Database{ d.Database = shared.Database{

3
go.mod
View file

@ -25,7 +25,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220713083127-fc2ea1e62e46 github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771
github.com/matrix-org/pinecone v0.0.0-20220708135211-1ce778fcde6a github.com/matrix-org/pinecone v0.0.0-20220708135211-1ce778fcde6a
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.13 github.com/mattn/go-sqlite3 v1.14.13
@ -37,7 +37,6 @@ require (
github.com/opentracing/opentracing-go v1.2.0 github.com/opentracing/opentracing-go v1.2.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pressly/goose v2.7.0+incompatible
github.com/prometheus/client_golang v1.12.2 github.com/prometheus/client_golang v1.12.2
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.7.1 github.com/stretchr/testify v1.7.1

6
go.sum
View file

@ -341,8 +341,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220713083127-fc2ea1e62e46 h1:5X/kXY3nwqKOwwrE9tnMKrjbsi3PHigQYvrvDBSntO8= github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771 h1:ZIPHFIPNDS9dmEbPEiJbNmyCGJtn9exfpLC7JOcn/bE=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220713083127-fc2ea1e62e46/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
github.com/matrix-org/pinecone v0.0.0-20220708135211-1ce778fcde6a h1:DdG8vXMlZ65EAtc4V+3t7zHZ2Gqs24pSnyXS+4BRHUs= github.com/matrix-org/pinecone v0.0.0-20220708135211-1ce778fcde6a h1:DdG8vXMlZ65EAtc4V+3t7zHZ2Gqs24pSnyXS+4BRHUs=
github.com/matrix-org/pinecone v0.0.0-20220708135211-1ce778fcde6a/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/pinecone v0.0.0-20220708135211-1ce778fcde6a/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
@ -432,8 +432,6 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pressly/goose v2.7.0+incompatible h1:PWejVEv07LCerQEzMMeAtjuyCKbyprZ/LBa6K5P0OCQ=
github.com/pressly/goose v2.7.0+incompatible/go.mod h1:m+QHWCqxR3k8D9l7qfzuC/djtlfzxr34mozWDYEu1z8=
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=

View file

@ -0,0 +1,18 @@
package caching
import "github.com/matrix-org/dendrite/roomserver/types"
// EventStateKeyCache contains the subset of functions needed for
// a room event state key cache.
type EventStateKeyCache interface {
GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool)
StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string)
}
func (c Caches) GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool) {
return c.RoomServerStateKeys.Get(eventStateKeyNID)
}
func (c Caches) StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string) {
c.RoomServerStateKeys.Set(eventStateKeyNID, eventStateKey)
}

View file

@ -9,6 +9,7 @@ type RoomServerCaches interface {
RoomVersionCache RoomVersionCache
RoomInfoCache RoomInfoCache
RoomServerEventsCache RoomServerEventsCache
EventStateKeyCache
} }
// RoomServerNIDsCache contains the subset of functions needed for // RoomServerNIDsCache contains the subset of functions needed for
@ -19,9 +20,9 @@ type RoomServerNIDsCache interface {
} }
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
return c.RoomServerRoomIDs.Get(int64(roomNID)) return c.RoomServerRoomIDs.Get(roomNID)
} }
func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) { func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) {
c.RoomServerRoomIDs.Set(int64(roomNID), roomID) c.RoomServerRoomIDs.Set(roomNID, roomID)
} }

View file

@ -23,16 +23,17 @@ import (
// different implementations as long as they satisfy the Cache // different implementations as long as they satisfy the Cache
// interface. // interface.
type Caches struct { type Caches struct {
RoomVersions Cache[string, gomatrixserverlib.RoomVersion] // room ID -> room version RoomVersions Cache[string, gomatrixserverlib.RoomVersion] // room ID -> room version
ServerKeys Cache[string, gomatrixserverlib.PublicKeyLookupResult] // server name -> server keys ServerKeys Cache[string, gomatrixserverlib.PublicKeyLookupResult] // server name -> server keys
RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID
RoomServerRoomIDs Cache[int64, string] // room NID -> room ID RoomServerRoomIDs Cache[types.RoomNID, string] // room NID -> room ID
RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event
RoomInfos Cache[string, *types.RoomInfo] // room ID -> room info RoomServerStateKeys Cache[types.EventStateKeyNID, string] // event NID -> event state key
FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU RoomInfos Cache[string, *types.RoomInfo] // room ID -> room info
FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU
SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU
LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response
LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID
} }
// Cache is the interface that an implementation must satisfy. // Cache is the interface that an implementation must satisfy.
@ -44,7 +45,7 @@ type Cache[K keyable, T any] interface {
type keyable interface { type keyable interface {
// from https://github.com/dgraph-io/ristretto/blob/8e850b710d6df0383c375ec6a7beae4ce48fc8d5/z/z.go#L34 // from https://github.com/dgraph-io/ristretto/blob/8e850b710d6df0383c375ec6a7beae4ce48fc8d5/z/z.go#L34
uint64 | string | []byte | byte | int | int32 | uint32 | int64 | lazyLoadingCacheKey ~uint64 | ~string | []byte | byte | ~int | ~int32 | ~uint32 | ~int64 | lazyLoadingCacheKey
} }
type costable interface { type costable interface {

View file

@ -40,13 +40,14 @@ const (
federationEDUsCache federationEDUsCache
spaceSummaryRoomsCache spaceSummaryRoomsCache
lazyLoadingCache lazyLoadingCache
eventStateKeyCache
) )
func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enablePrometheus bool) *Caches { func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enablePrometheus bool) *Caches {
cache, err := ristretto.NewCache(&ristretto.Config{ cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: 1e5, // 10x number of expected cache items, affects bloom filter size, gives us room for 10,000 currently NumCounters: int64((maxCost / 1024) * 10), // 10 counters per 1KB data, affects bloom filter size
BufferItems: 64, // recommended by the ristretto godocs as a sane buffer size value BufferItems: 64, // recommended by the ristretto godocs as a sane buffer size value
MaxCost: int64(maxCost), MaxCost: int64(maxCost), // max cost is in bytes, as per the Dendrite config
Metrics: true, Metrics: true,
KeyToHash: func(key interface{}) (uint64, uint64) { KeyToHash: func(key interface{}) (uint64, uint64) {
return z.KeyToHash(key) return z.KeyToHash(key)
@ -88,7 +89,7 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm
Prefix: roomNIDsCache, Prefix: roomNIDsCache,
MaxAge: maxAge, MaxAge: maxAge,
}, },
RoomServerRoomIDs: &RistrettoCachePartition[int64, string]{ // room NID -> room ID RoomServerRoomIDs: &RistrettoCachePartition[types.RoomNID, string]{ // room NID -> room ID
cache: cache, cache: cache,
Prefix: roomIDsCache, Prefix: roomIDsCache,
MaxAge: maxAge, MaxAge: maxAge,
@ -100,6 +101,11 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm
MaxAge: maxAge, MaxAge: maxAge,
}, },
}, },
RoomServerStateKeys: &RistrettoCachePartition[types.EventStateKeyNID, string]{ // event NID -> event state key
cache: cache,
Prefix: eventStateKeyCache,
MaxAge: maxAge,
},
RoomInfos: &RistrettoCachePartition[string, *types.RoomInfo]{ // room ID -> room info RoomInfos: &RistrettoCachePartition[string, *types.RoomInfo]{ // room ID -> room info
cache: cache, cache: cache,
Prefix: roomInfosCache, Prefix: roomInfosCache,

View file

@ -1,130 +1,142 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlutil package sqlutil
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"runtime" "sync"
"sort" "time"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/internal"
"github.com/pressly/goose" "github.com/sirupsen/logrus"
) )
type Migrations struct { const createDBMigrationsSQL = "" +
registeredGoMigrations map[int64]*goose.Migration "CREATE TABLE IF NOT EXISTS db_migrations (" +
" version TEXT PRIMARY KEY NOT NULL," +
" time TEXT NOT NULL," +
" dendrite_version TEXT NOT NULL" +
");"
const insertVersionSQL = "" +
"INSERT INTO db_migrations (version, time, dendrite_version)" +
" VALUES ($1, $2, $3)"
const selectDBMigrationsSQL = "SELECT version FROM db_migrations"
// Migration defines a migration to be run.
type Migration struct {
// Version is a simple description/name of this migration.
Version string
// Up defines the function to execute for an upgrade.
Up func(ctx context.Context, txn *sql.Tx) error
// Down defines the function to execute for a downgrade (not implemented yet).
Down func(ctx context.Context, txn *sql.Tx) error
} }
func NewMigrations() *Migrations { // Migrator
return &Migrations{ type Migrator struct {
registeredGoMigrations: make(map[int64]*goose.Migration), db *sql.DB
migrations []Migration
knownMigrations map[string]struct{}
mutex *sync.Mutex
}
// NewMigrator creates a new DB migrator.
func NewMigrator(db *sql.DB) *Migrator {
return &Migrator{
db: db,
migrations: []Migration{},
knownMigrations: make(map[string]struct{}),
mutex: &sync.Mutex{},
} }
} }
// Copy-pasted from goose directly to store migrations into a map we control // AddMigrations appends migrations to the list of migrations. Migrations are executed
// in the order they are added to the list. De-duplicates migrations using their Version field.
// AddMigration adds a migration. func (m *Migrator) AddMigrations(migrations ...Migration) {
func (m *Migrations) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { m.mutex.Lock()
_, filename, _, _ := runtime.Caller(1) defer m.mutex.Unlock()
m.AddNamedMigration(filename, up, down) for _, mig := range migrations {
} if _, ok := m.knownMigrations[mig.Version]; !ok {
m.migrations = append(m.migrations, mig)
// AddNamedMigration : Add a named migration. m.knownMigrations[mig.Version] = struct{}{}
func (m *Migrations) AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) { }
v, _ := goose.NumericComponent(filename)
migration := &goose.Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename}
if existing, ok := m.registeredGoMigrations[v]; ok {
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
} }
m.registeredGoMigrations[v] = migration
} }
// RunDeltas up to the latest version. // Up executes all migrations in order they were added.
func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error { func (m *Migrator) Up(ctx context.Context) error {
maxVer := goose.MaxVersion var (
minVer := int64(0) err error
migrations, err := m.collect(minVer, maxVer) dendriteVersion = internal.VersionString()
)
// ensure there is a table for known migrations
executedMigrations, err := m.ExecutedMigrations(ctx)
if err != nil { if err != nil {
return fmt.Errorf("runDeltas: Failed to collect migrations: %w", err) return fmt.Errorf("unable to create/get migrations: %w", err)
} }
if props.ConnectionString.IsPostgres() {
if err = goose.SetDialect("postgres"); err != nil {
return err
}
} else if props.ConnectionString.IsSQLite() {
if err = goose.SetDialect("sqlite3"); err != nil {
return err
}
} else {
return fmt.Errorf("unknown connection string: %s", props.ConnectionString)
}
for {
current, err := goose.EnsureDBVersion(db)
if err != nil {
return fmt.Errorf("runDeltas: Failed to EnsureDBVersion: %w", err)
}
next, err := migrations.Next(current) return WithTransaction(m.db, func(txn *sql.Tx) error {
if err != nil { for i := range m.migrations {
if err == goose.ErrNoNextVersion { now := time.Now().UTC().Format(time.RFC3339)
return nil migration := m.migrations[i]
logrus.Debugf("Executing database migration '%s'", migration.Version)
// Skip migration if it was already executed
if _, ok := executedMigrations[migration.Version]; ok {
continue
}
err = migration.Up(ctx, txn)
if err != nil {
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
}
_, err = txn.ExecContext(ctx, insertVersionSQL,
migration.Version,
now,
dendriteVersion,
)
if err != nil {
return fmt.Errorf("unable to insert executed migrations: %w", err)
} }
return fmt.Errorf("runDeltas: Failed to load next migration to %+v : %w", next, err)
} }
return nil
if err = next.Up(db); err != nil { })
return fmt.Errorf("runDeltas: Failed run migration: %w", err)
}
}
} }
func (m *Migrations) collect(current, target int64) (goose.Migrations, error) { // ExecutedMigrations returns a map with already executed migrations in addition to creating the
var migrations goose.Migrations // migrations table, if it doesn't exist.
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, error) {
// Go migrations registered via goose.AddMigration(). result := make(map[string]struct{})
for _, migration := range m.registeredGoMigrations { _, err := m.db.ExecContext(ctx, createDBMigrationsSQL)
v, err := goose.NumericComponent(migration.Source) if err != nil {
if err != nil { return nil, fmt.Errorf("unable to create db_migrations: %w", err)
return nil, err }
} rows, err := m.db.QueryContext(ctx, selectDBMigrationsSQL)
if versionFilter(v, current, target) { if err != nil {
migrations = append(migrations, migration) return nil, fmt.Errorf("unable to query db_migrations: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "ExecutedMigrations: rows.close() failed")
var version string
for rows.Next() {
if err = rows.Scan(&version); err != nil {
return nil, fmt.Errorf("unable to scan version: %w", err)
} }
result[version] = struct{}{}
} }
migrations = sortAndConnectMigrations(migrations) return result, rows.Err()
return migrations, nil
}
func sortAndConnectMigrations(migrations goose.Migrations) goose.Migrations {
sort.Sort(migrations)
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range migrations {
prev := int64(-1)
if i > 0 {
prev = migrations[i-1].Version
migrations[i-1].Next = m.Version
}
migrations[i].Previous = prev
}
return migrations
}
func versionFilter(v, current, target int64) bool {
if target > current {
return v > current && v <= target
}
if target < current {
return v <= current && v > target
}
return false
} }

View file

@ -0,0 +1,112 @@
package sqlutil_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/test"
_ "github.com/mattn/go-sqlite3"
)
var dummyMigrations = []sqlutil.Migration{
{
Version: "init",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS dummy ( test TEXT );")
return err
},
},
{
Version: "v2",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
return err
},
},
{
Version: "v2", // duplicate, this migration will be skipped
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
return err
},
},
{
Version: "multiple execs",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test3 TEXT;")
if err != nil {
return err
}
_, err = txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test4 TEXT;")
return err
},
},
}
var failMigration = sqlutil.Migration{
Version: "iFail",
Up: func(ctx context.Context, txn *sql.Tx) error {
return fmt.Errorf("iFail")
},
Down: nil,
}
func Test_migrations_Up(t *testing.T) {
withFail := append(dummyMigrations, failMigration)
tests := []struct {
name string
migrations []sqlutil.Migration
wantResult map[string]struct{}
wantErr bool
}{
{
name: "dummy migration",
migrations: dummyMigrations,
wantResult: map[string]struct{}{
"init": {},
"v2": {},
"multiple execs": {},
},
},
{
name: "with fail",
migrations: withFail,
wantErr: true,
},
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
defer close()
driverName := "sqlite3"
if dbType == test.DBTypePostgres {
driverName = "postgres"
}
db, err := sql.Open(driverName, conStr)
if err != nil {
t.Errorf("unable to open database: %v", err)
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(tt.migrations...)
if err = m.Up(ctx); (err != nil) != tt.wantErr {
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
}
result, err := m.ExecutedMigrations(ctx)
if err != nil {
t.Errorf("unable to get executed migrations: %v", err)
}
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
}
})
})
}
}

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -66,6 +67,16 @@ func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "keyserver: cross signing signature indexes",
Up: deltas.UpFixCrossSigningSignatureIndexes,
})
if err = m.Up(context.Background()); err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},

View file

@ -15,37 +15,27 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
}
func LoadRefactorKeyChanges(m *sqlutil.Migrations) {
m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
}
func UpRefactorKeyChanges(tx *sql.Tx) error {
// start counting from the last max offset, else 0. We need to do a count(*) first to see if there // start counting from the last max offset, else 0. We need to do a count(*) first to see if there
// even are entries in this table to know if we can query for log_offset. Without the count then // even are entries in this table to know if we can query for log_offset. Without the count then
// the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
// exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/ // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
var count int var count int
_ = tx.QueryRow(`SELECT count(*) FROM keyserver_key_changes`).Scan(&count) _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
if count > 0 { if count > 0 {
var maxOffset int64 var maxOffset int64
_ = tx.QueryRow(`SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset) _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil { if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err) return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
} }
} }
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- make the new table -- make the new table
DROP TABLE IF EXISTS keyserver_key_changes; DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes ( CREATE TABLE IF NOT EXISTS keyserver_key_changes (
@ -60,8 +50,8 @@ func UpRefactorKeyChanges(tx *sql.Tx) error {
return nil return nil
} }
func DownRefactorKeyChanges(tx *sql.Tx) error { func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
DROP SEQUENCE IF EXISTS keyserver_key_changes_seq; DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
DROP TABLE IF EXISTS keyserver_key_changes; DROP TABLE IF EXISTS keyserver_key_changes;

View file

@ -15,18 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) { func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes) _, err := tx.ExecContext(ctx, `
}
func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id); ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
@ -38,8 +33,8 @@ func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
return nil return nil
} }
func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error { func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id); ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);

View file

@ -19,6 +19,8 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -55,7 +57,23 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
db: db, db: db,
} }
_, err := db.Exec(keyChangesSchema) _, err := db.Exec(keyChangesSchema)
return s, err if err != nil {
return s, err
}
// TODO: Remove when we are sure we are not having goose artefacts in the db
// This forces an error, which indicates the migration is already applied, since the
// column partition was removed from the table
err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan()
if err == nil {
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "keyserver: refactor key changes",
Up: deltas.UpRefactorKeyChanges,
})
return s, m.Up(context.Background())
}
return s, nil
} }
func (s *keyChangesStatements) Prepare() (err error) { func (s *keyChangesStatements) Prepare() (err error) {

View file

@ -16,7 +16,6 @@ package postgres
import ( import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/keyserver/storage/shared"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -53,12 +52,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadRefactorKeyChanges(m)
deltas.LoadFixCrossSigningSignatureIndexes(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = kc.Prepare(); err != nil { if err = kc.Prepare(); err != nil {
return nil, err return nil, err
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -65,6 +66,15 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "keyserver: cross signing signature indexes",
Up: deltas.UpFixCrossSigningSignatureIndexes,
})
if err = m.Up(context.Background()); err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},

View file

@ -15,28 +15,18 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
}
func LoadRefactorKeyChanges(m *sqlutil.Migrations) {
m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
}
func UpRefactorKeyChanges(tx *sql.Tx) error {
// start counting from the last max offset, else 0. // start counting from the last max offset, else 0.
var maxOffset int64 var maxOffset int64
var userID string var userID string
_ = tx.QueryRow(`SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset) _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- make the new table -- make the new table
DROP TABLE IF EXISTS keyserver_key_changes; DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes ( CREATE TABLE IF NOT EXISTS keyserver_key_changes (
@ -51,14 +41,14 @@ func UpRefactorKeyChanges(tx *sql.Tx) error {
} }
// to start counting from maxOffset, insert a row with that value // to start counting from maxOffset, insert a row with that value
if userID != "" { if userID != "" {
_, err = tx.Exec(`INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID) _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
return err return err
} }
return nil return nil
} }
func DownRefactorKeyChanges(tx *sql.Tx) error { func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
DROP TABLE IF EXISTS keyserver_key_changes; DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes ( CREATE TABLE IF NOT EXISTS keyserver_key_changes (

View file

@ -15,18 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) { func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes) _, err := tx.ExecContext(ctx, `
}
func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
origin_user_id TEXT NOT NULL, origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL, origin_key_id TEXT NOT NULL,
@ -50,8 +45,8 @@ func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
return nil return nil
} }
func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error { func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
origin_user_id TEXT NOT NULL, origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL, origin_key_id TEXT NOT NULL,

View file

@ -19,6 +19,8 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -53,7 +55,23 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
db: db, db: db,
} }
_, err := db.Exec(keyChangesSchema) _, err := db.Exec(keyChangesSchema)
return s, err if err != nil {
return s, err
}
// TODO: Remove when we are sure we are not having goose artefacts in the db
// This forces an error, which indicates the migration is already applied, since the
// column partition was removed from the table
err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan()
if err == nil {
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "keyserver: refactor key changes",
Up: deltas.UpRefactorKeyChanges,
})
return s, m.Up(context.Background())
}
return s, nil
} }
func (s *keyChangesStatements) Prepare() (err error) { func (s *keyChangesStatements) Prepare() (err error) {

View file

@ -17,7 +17,6 @@ package sqlite3
import ( import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/keyserver/storage/shared"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
@ -52,12 +51,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadRefactorKeyChanges(m)
deltas.LoadFixCrossSigningSignatureIndexes(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = kc.Prepare(); err != nil { if err = kc.Prepare(); err != nil {
return nil, err return nil, err
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -21,14 +22,14 @@ import (
// Move these to a more sensible place. // Move these to a more sensible place.
func UpdateToInviteMembership( func UpdateToInviteMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, mu *shared.MembershipUpdater, add *types.Event, updates []api.OutputEvent,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// We may have already sent the invite to the user, either because we are // We may have already sent the invite to the user, either because we are
// reprocessing this event, or because the we received this invite from a // reprocessing this event, or because the we received this invite from a
// remote server via the federation invite API. In those cases we don't need // remote server via the federation invite API. In those cases we don't need
// to send the event. // to send the event.
needsSending, err := mu.SetToInvite(add) needsSending, retired, err := mu.Update(tables.MembershipStateInvite, add)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -38,13 +39,23 @@ func UpdateToInviteMembership(
// room event stream. This ensures that the consumers only have to // room event stream. This ensures that the consumers only have to
// consider a single stream of events when determining whether a user // consider a single stream of events when determining whether a user
// is invited, rather than having to combine multiple streams themselves. // is invited, rather than having to combine multiple streams themselves.
onie := api.OutputNewInviteEvent{
Event: add.Headered(roomVersion),
RoomVersion: roomVersion,
}
updates = append(updates, api.OutputEvent{ updates = append(updates, api.OutputEvent{
Type: api.OutputTypeNewInviteEvent, Type: api.OutputTypeNewInviteEvent,
NewInviteEvent: &onie, NewInviteEvent: &api.OutputNewInviteEvent{
Event: add.Headered(roomVersion),
RoomVersion: roomVersion,
},
})
}
for _, eventID := range retired {
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID,
Membership: gomatrixserverlib.Join,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
},
}) })
} }
return updates, nil return updates, nil

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
@ -60,20 +61,14 @@ func (r *Inputer) updateMemberships(
var updates []api.OutputEvent var updates []api.OutputEvent
for _, change := range changes { for _, change := range changes {
var ae *gomatrixserverlib.Event var ae *types.Event
var re *gomatrixserverlib.Event var re *types.Event
targetUserNID := change.EventStateKeyNID targetUserNID := change.EventStateKeyNID
if change.removedEventNID != 0 { if change.removedEventNID != 0 {
ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID) re, _ = helpers.EventMap(events).Lookup(change.removedEventNID)
if ev != nil {
re = ev.Event
}
} }
if change.addedEventNID != 0 { if change.addedEventNID != 0 {
ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID) ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID)
if ev != nil {
ae = ev.Event
}
} }
if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
return nil, err return nil, err
@ -85,30 +80,27 @@ func (r *Inputer) updateMemberships(
func (r *Inputer) updateMembership( func (r *Inputer) updateMembership(
updater *shared.RoomUpdater, updater *shared.RoomUpdater,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event, remove, add *types.Event,
updates []api.OutputEvent, updates []api.OutputEvent,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
var err error var err error
// Default the membership to Leave if no event was added or removed. // Default the membership to Leave if no event was added or removed.
oldMembership := gomatrixserverlib.Leave
newMembership := gomatrixserverlib.Leave newMembership := gomatrixserverlib.Leave
if remove != nil {
oldMembership, err = remove.Membership()
if err != nil {
return nil, err
}
}
if add != nil { if add != nil {
newMembership, err = add.Membership() newMembership, err = add.Membership()
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if oldMembership == newMembership && newMembership != gomatrixserverlib.Join {
// If the membership is the same then nothing changed and we can return var targetLocal bool
// immediately, unless it's a Join update (e.g. profile update). if add != nil {
return updates, nil targetLocal = r.isLocalTarget(add)
}
mu, err := updater.MembershipUpdater(targetUserNID, targetLocal)
if err != nil {
return nil, err
} }
// In an ideal world, we shouldn't ever have "add" be nil and "remove" be // In an ideal world, we shouldn't ever have "add" be nil and "remove" be
@ -120,17 +112,10 @@ func (r *Inputer) updateMembership(
// after a state reset, often thinking that the user was still joined to // after a state reset, often thinking that the user was still joined to
// the room even though the room state said otherwise, and this would prevent // the room even though the room state said otherwise, and this would prevent
// the user from being able to attempt to rejoin the room without modifying // the user from being able to attempt to rejoin the room without modifying
// the database. So instead what we'll do is we'll just update the membership // the database. So instead we're going to remove the membership from the
// table to say that the user is "leave" and we'll use the old event to // database altogether, so that it doesn't create future problems.
// avoid nil pointer exceptions on the code path that follows. if add == nil && remove != nil {
if add == nil { return nil, mu.Delete()
add = remove
newMembership = gomatrixserverlib.Leave
}
mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add))
if err != nil {
return nil, err
} }
switch newMembership { switch newMembership {
@ -149,7 +134,7 @@ func (r *Inputer) updateMembership(
} }
} }
func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool { func (r *Inputer) isLocalTarget(event *types.Event) bool {
isTargetLocalUser := false isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil { if statekey := event.StateKey(); statekey != nil {
_, domain, _ := gomatrixserverlib.SplitID('@', *statekey) _, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
@ -159,81 +144,61 @@ func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool {
} }
func updateToJoinMembership( func updateToJoinMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, mu *shared.MembershipUpdater, add *types.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// If the user is already marked as being joined, we call SetToJoin to update
// the event ID then we can return immediately. Retired is ignored as there
// is no invite event to retire.
if mu.IsJoin() {
_, err := mu.SetToJoin(add.Sender(), add.EventID(), true)
if err != nil {
return nil, err
}
return updates, nil
}
// When we mark a user as being joined we will invalidate any invites that // When we mark a user as being joined we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have // are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this // been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream. // by studying the state changes in the room event stream.
retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false) _, retired, err := mu.Update(tables.MembershipStateJoin, add)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, eventID := range retired { for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{
EventID: eventID,
Membership: gomatrixserverlib.Join,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
}
updates = append(updates, api.OutputEvent{ updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &orie, RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID,
Membership: gomatrixserverlib.Join,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
},
}) })
} }
return updates, nil return updates, nil
} }
func updateToLeaveMembership( func updateToLeaveMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, mu *shared.MembershipUpdater, add *types.Event,
newMembership string, updates []api.OutputEvent, newMembership string, updates []api.OutputEvent,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// If the user is already neither joined, nor invited to the room then we
// can return immediately.
if mu.IsLeave() {
return updates, nil
}
// When we mark a user as having left we will invalidate any invites that // When we mark a user as having left we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have // are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this // been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream. // by studying the state changes in the room event stream.
retired, err := mu.SetToLeave(add.Sender(), add.EventID()) _, retired, err := mu.Update(tables.MembershipStateLeaveOrBan, add)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, eventID := range retired { for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{
EventID: eventID,
Membership: newMembership,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
}
updates = append(updates, api.OutputEvent{ updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &orie, RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID,
Membership: newMembership,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
},
}) })
} }
return updates, nil return updates, nil
} }
func updateToKnockMembership( func updateToKnockMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, mu *shared.MembershipUpdater, add *types.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
if mu.IsLeave() { if _, _, err := mu.Update(tables.MembershipStateKnock, add); err != nil {
_, err := mu.SetToKnock(add) return nil, err
if err != nil {
return nil, err
}
} }
return updates, nil return updates, nil
} }

View file

@ -39,11 +39,13 @@ type Inviter struct {
Inputer *input.Inputer Inputer *input.Inputer
} }
// nolint:gocyclo
func (r *Inviter) PerformInvite( func (r *Inviter) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
res *api.PerformInviteResponse, res *api.PerformInviteResponse,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
var outputUpdates []api.OutputEvent
event := req.Event event := req.Event
if event.StateKey() == nil { if event.StateKey() == nil {
return nil, fmt.Errorf("invite must be a state event") return nil, fmt.Errorf("invite must be a state event")
@ -66,6 +68,13 @@ func (r *Inviter) PerformInvite(
} }
isTargetLocal := domain == r.Cfg.Matrix.ServerName isTargetLocal := domain == r.Cfg.Matrix.ServerName
isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName
if !isOriginLocal && !isTargetLocal {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: "The invite must be either from or to a local user",
}
return nil, nil
}
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
"inviter": event.Sender(), "inviter": event.Sender(),
@ -97,6 +106,34 @@ func (r *Inviter) PerformInvite(
} }
} }
updateMembershipTableManually := func() ([]api.OutputEvent, error) {
var updater *shared.MembershipUpdater
if updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion); err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
}
outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{
EventNID: 0,
Event: event.Unwrap(),
}, outputUpdates, req.Event.RoomVersion)
if err != nil {
return nil, fmt.Errorf("updateToInviteMembership: %w", err)
}
if err = updater.Commit(); err != nil {
return nil, fmt.Errorf("updater.Commit: %w", err)
}
logger.Debugf("updated membership to invite and sending invite OutputEvent")
return outputUpdates, nil
}
if (info == nil || info.IsStub) && !isOriginLocal && isTargetLocal {
// The invite came in over federation for a room that we don't know about
// yet. We need to handle this a bit differently to most invites because
// we don't know the room state, therefore the roomserver can't process
// an input event. Instead we will update the membership table with the
// new invite and generate an output event.
return updateMembershipTableManually()
}
var isAlreadyJoined bool var isAlreadyJoined bool
if info != nil { if info != nil {
_, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
@ -140,31 +177,13 @@ func (r *Inviter) PerformInvite(
return nil, nil return nil, nil
} }
// If the invite originated remotely then we can't send an
// InputRoomEvent for the invite as it will never pass auth checks
// due to lacking room state, but we still need to tell the client
// about the invite so we can accept it, hence we return an output
// event to send to the Sync API.
if !isOriginLocal { if !isOriginLocal {
// The invite originated over federation. Process the membership return updateMembershipTableManually()
// update, which will notify the sync API etc about the incoming
// invite. We do NOT send an InputRoomEvent for the invite as it
// will never pass auth checks due to lacking room state, but we
// still need to tell the client about the invite so we can accept
// it, hence we return an output event to send to the sync api.
var updater *shared.MembershipUpdater
updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
if err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
}
unwrapped := event.Unwrap()
var outputUpdates []api.OutputEvent
outputUpdates, err = helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion)
if err != nil {
return nil, fmt.Errorf("updateToInviteMembership: %w", err)
}
if err = updater.Commit(); err != nil {
return nil, fmt.Errorf("updater.Commit: %w", err)
}
logger.Debugf("updated membership to invite and sending invite OutputEvent")
return outputUpdates, nil
} }
// The invite originated locally. Therefore we have a responsibility to // The invite originated locally. Therefore we have a responsibility to
@ -229,12 +248,11 @@ func (r *Inviter) PerformInvite(
Code: api.PerformErrorNotAllowed, Code: api.PerformErrorNotAllowed,
} }
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed")
return nil, nil
} }
// Don't notify the sync api of this event in the same way as a federated invite so the invitee // Don't notify the sync api of this event in the same way as a federated invite so the invitee
// gets the invite, as the roomserver will do this when it processes the m.room.member invite. // gets the invite, as the roomserver will do this when it processes the m.room.member invite.
return nil, nil return outputUpdates, nil
} }
func buildInviteStrippedState( func buildInviteStrippedState(

View file

@ -268,21 +268,19 @@ func (r *Joiner) performJoinRoomByID(
case nil: case nil:
// The room join is local. Send the new join event into the // The room join is local. Send the new join event into the
// roomserver. First of all check that the user isn't already // roomserver. First of all check that the user isn't already
// a member of the room. // a member of the room. This is best-effort (as in we won't
alreadyJoined := false // fail if we can't find the existing membership) because there
for _, se := range buildRes.StateEvents { // is really no harm in just sending another membership event.
if !se.StateKeyEquals(userID) { membershipReq := &api.QueryMembershipForUserRequest{
continue RoomID: req.RoomIDOrAlias,
} UserID: userID,
if membership, merr := se.Membership(); merr == nil {
alreadyJoined = (membership == gomatrixserverlib.Join)
break
}
} }
membershipRes := &api.QueryMembershipForUserResponse{}
_ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes)
// If we haven't already joined the room then send an event // If we haven't already joined the room then send an event
// into the room changing our membership status. // into the room changing our membership status.
if !alreadyJoined { if !membershipRes.RoomExists || !membershipRes.IsInRoom {
inputReq := rsAPI.InputRoomEventsRequest{ inputReq := rsAPI.InputRoomEventsRequest{
InputRoomEvents: []rsAPI.InputRoomEvent{ InputRoomEvents: []rsAPI.InputRoomEvent{
{ {

View file

@ -228,14 +228,14 @@ func (r *Leaver) performFederatedRejectInvite(
util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event")
} }
if updater != nil { if updater != nil {
if _, err = updater.SetToLeave(req.UserID, eventID); err != nil { if err = updater.Delete(); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("failed to set membership to leave, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to delete membership, still retiring invite event")
if err = updater.Rollback(); err != nil { if err = updater.Rollback(); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("failed to rollback membership leave, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to rollback deleting membership, still retiring invite event")
} }
} else { } else {
if err = updater.Commit(); err != nil { if err = updater.Commit(); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("failed to commit membership update, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to commit deleting membership, still retiring invite event")
} }
} }
} }

View file

@ -16,6 +16,7 @@ package query
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -225,6 +226,9 @@ func (r *Queryer) QueryMembershipsForRoom(
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly) eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return nil
}
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
} }
events, err = r.DB.Events(ctx, eventNIDs) events, err = r.DB.Events(ctx, eventNIDs)
@ -260,6 +264,9 @@ func (r *Queryer) QueryMembershipsForRoom(
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, false) eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, false)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return nil
}
return err return err
} }

View file

@ -15,32 +15,21 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_membership ADD COLUMN IF NOT EXISTS forgotten BOOLEAN NOT NULL DEFAULT false;`)
goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
}
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func UpAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE roomserver_membership ADD COLUMN IF NOT EXISTS forgotten BOOLEAN NOT NULL DEFAULT false;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
return nil return nil
} }
func DownAddForgottenColumn(tx *sql.Tx) error { func DownAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE roomserver_membership DROP COLUMN IF EXISTS forgotten;`) _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_membership DROP COLUMN IF EXISTS forgotten;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -15,11 +15,11 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -36,48 +36,44 @@ type stateBlockData struct {
EventNIDs types.EventNIDs EventNIDs types.EventNIDs
} }
func LoadStateBlocksRefactor(m *sqlutil.Migrations) {
m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
}
// nolint:gocyclo // nolint:gocyclo
func UpStateBlocksRefactor(tx *sql.Tx) error { func UpStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!") logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!")
defer logrus.Warn("State storage upgrade complete") defer logrus.Warn("State storage upgrade complete")
var snapshotcount int var snapshotcount int
var maxsnapshotid int var maxsnapshotid int
var maxblockid int var maxblockid int
if err := tx.QueryRow(`SELECT COUNT(DISTINCT state_snapshot_nid) FROM roomserver_state_snapshots;`).Scan(&snapshotcount); err != nil { if err := tx.QueryRowContext(ctx, `SELECT COUNT(DISTINCT state_snapshot_nid) FROM roomserver_state_snapshots;`).Scan(&snapshotcount); err != nil {
return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
} }
if err := tx.QueryRow(`SELECT COALESCE(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil { if err := tx.QueryRowContext(ctx, `SELECT COALESCE(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil {
return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
} }
if err := tx.QueryRow(`SELECT COALESCE(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil { if err := tx.QueryRowContext(ctx, `SELECT COALESCE(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil {
return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
} }
maxsnapshotid++ maxsnapshotid++
maxblockid++ maxblockid++
if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil { if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
// We create new sequences starting with the maximum state snapshot and block NIDs. // We create new sequences starting with the maximum state snapshot and block NIDs.
// This means that all newly created snapshots and blocks by the migration will have // This means that all newly created snapshots and blocks by the migration will have
// NIDs higher than these values, so that when we come to update the references to // NIDs higher than these values, so that when we come to update the references to
// these NIDs using UPDATE statements, we can guarantee we are only ever updating old // these NIDs using UPDATE statements, we can guarantee we are only ever updating old
// values and not accidentally overwriting new ones. // values and not accidentally overwriting new ones.
if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE roomserver_state_block_nid_sequence START WITH %d;`, maxblockid)); err != nil { if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE roomserver_state_block_nid_sequence START WITH %d;`, maxblockid)); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE roomserver_state_snapshot_nid_sequence START WITH %d;`, maxsnapshotid)); err != nil { if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE roomserver_state_snapshot_nid_sequence START WITH %d;`, maxsnapshotid)); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_block ( CREATE TABLE IF NOT EXISTS roomserver_state_block (
state_block_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_block_nid_sequence'), state_block_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_block_nid_sequence'),
state_block_hash BYTEA UNIQUE, state_block_hash BYTEA UNIQUE,
@ -87,7 +83,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
if err != nil { if err != nil {
return fmt.Errorf("tx.Exec (create blocks table): %w", err) return fmt.Errorf("tx.Exec (create blocks table): %w", err)
} }
_, err = tx.Exec(` _, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_sequence'), state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_sequence'),
state_snapshot_hash BYTEA UNIQUE, state_snapshot_hash BYTEA UNIQUE,
@ -104,7 +100,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// in question a state snapshot NID of 0 to indicate 'no snapshot'. // in question a state snapshot NID of 0 to indicate 'no snapshot'.
// If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget // If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget
// any snapshots. // any snapshots.
_, err = tx.Exec( _, err = tx.ExecContext(ctx,
`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2`, `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2`,
types.MRoomCreateNID, types.EmptyStateKeyNID, types.MRoomCreateNID, types.EmptyStateKeyNID,
) )
@ -115,7 +111,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
batchsize := 100 batchsize := 100
for batchoffset := 0; batchoffset < snapshotcount; batchoffset += batchsize { for batchoffset := 0; batchoffset < snapshotcount; batchoffset += batchsize {
var snapshotrows *sql.Rows var snapshotrows *sql.Rows
snapshotrows, err = tx.Query(` snapshotrows, err = tx.QueryContext(ctx, `
SELECT SELECT
state_snapshot_nid, state_snapshot_nid,
room_nid, room_nid,
@ -146,7 +142,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
state_block_nid; state_block_nid;
`, batchsize, batchoffset) `, batchsize, batchoffset)
if err != nil { if err != nil {
return fmt.Errorf("tx.Query: %w", err) return fmt.Errorf("tx.QueryContext: %w", err)
} }
logrus.Warnf("Rewriting snapshots %d-%d of %d...", batchoffset, batchoffset+batchsize, snapshotcount) logrus.Warnf("Rewriting snapshots %d-%d of %d...", batchoffset, batchoffset+batchsize, snapshotcount)
@ -183,7 +179,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// fill in bad create snapshots // fill in bad create snapshots
for _, s := range badCreateSnapshots { for _, s := range badCreateSnapshots {
var createEventNID types.EventNID var createEventNID types.EventNID
err = tx.QueryRow( err = tx.QueryRowContext(ctx,
`SELECT event_nid FROM roomserver_events WHERE state_snapshot_nid = $1 AND event_type_nid = 1`, s.StateSnapshotNID, `SELECT event_nid FROM roomserver_events WHERE state_snapshot_nid = $1 AND event_type_nid = 1`, s.StateSnapshotNID,
).Scan(&createEventNID) ).Scan(&createEventNID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -208,7 +204,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
var blocknid types.StateBlockNID var blocknid types.StateBlockNID
err = tx.QueryRow(` err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_block (state_block_hash, event_nids) INSERT INTO roomserver_state_block (state_block_hash, event_nids)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2 ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2
@ -227,7 +223,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
var newNID types.StateSnapshotNID var newNID types.StateSnapshotNID
err = tx.QueryRow(` err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids) INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2 ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2
@ -237,12 +233,12 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err) return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err)
} }
if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil { if _, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil {
return fmt.Errorf("tx.Exec (update events): %w", err) return fmt.Errorf("tx.ExecContext (update events): %w", err)
} }
if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil { if _, err = tx.ExecContext(ctx, `UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil {
return fmt.Errorf("tx.Exec (update rooms): %w", err) return fmt.Errorf("tx.ExecContext (update rooms): %w", err)
} }
} }
} }
@ -252,13 +248,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// in roomserver_state_snapshots // in roomserver_state_snapshots
var count int64 var count int64
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil { if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err) return fmt.Errorf("assertion query failed: %s", err)
} }
if count > 0 { if count > 0 {
var res sql.Result var res sql.Result
var c int64 var c int64
res, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid) res, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to reset invalid state snapshots: %w", err) return fmt.Errorf("failed to reset invalid state snapshots: %w", err)
} }
@ -268,13 +264,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c) return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c)
} }
} }
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil { if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err) return fmt.Errorf("assertion query failed: %s", err)
} }
if count > 0 { if count > 0 {
var debugRoomID string var debugRoomID string
var debugSnapNID, debugLastEventNID int64 var debugSnapNID, debugLastEventNID int64
err = tx.QueryRow( err = tx.QueryRowContext(ctx,
`SELECT room_id, state_snapshot_nid, last_event_sent_nid FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid, `SELECT room_id, state_snapshot_nid, last_event_sent_nid FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid,
).Scan(&debugRoomID, &debugSnapNID, &debugLastEventNID) ).Scan(&debugRoomID, &debugSnapNID, &debugLastEventNID)
if err != nil { if err != nil {
@ -291,13 +287,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count)
} }
if _, err = tx.Exec(` if _, err = tx.ExecContext(ctx, `
DROP TABLE _roomserver_state_snapshots; DROP TABLE _roomserver_state_snapshots;
DROP SEQUENCE roomserver_state_snapshot_nid_seq; DROP SEQUENCE roomserver_state_snapshot_nid_seq;
`); err != nil { `); err != nil {
return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err) return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err)
} }
if _, err = tx.Exec(` if _, err = tx.ExecContext(ctx, `
DROP TABLE _roomserver_state_block; DROP TABLE _roomserver_state_block;
DROP SEQUENCE roomserver_state_block_nid_seq; DROP SEQUENCE roomserver_state_block_nid_seq;
`); err != nil { `); err != nil {
@ -307,6 +303,6 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return nil return nil
} }
func DownStateBlocksRefactor(tx *sql.Tx) error { func DownStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
panic("Downgrading state storage is not supported") panic("Downgrading state storage is not supported")
} }

View file

@ -23,6 +23,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -86,24 +87,24 @@ const insertMembershipSQL = "" +
const selectMembershipFromRoomAndTargetSQL = "" + const selectMembershipFromRoomAndTargetSQL = "" +
"SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND event_nid != 0 AND target_nid = $2"
const selectMembershipsFromRoomAndMembershipSQL = "" + const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" " WHERE room_nid = $1 AND event_nid != 0 AND membership_nid = $2 and forgotten = false"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" + const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" + " WHERE room_nid = $1 AND event_nid != 0 AND membership_nid = $2" +
" AND target_local = true and forgotten = false" " AND target_local = true and forgotten = false"
const selectMembershipsFromRoomSQL = "" + const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 and forgotten = false" " WHERE room_nid = $1 AND event_nid != 0 and forgotten = false"
const selectLocalMembershipsFromRoomSQL = "" + const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" + " WHERE room_nid = $1 AND event_nid != 0" +
" AND target_local = true and forgotten = false" " AND target_local = true and forgotten = false"
const selectMembershipForUpdateSQL = "" + const selectMembershipForUpdateSQL = "" +
@ -118,6 +119,9 @@ const updateMembershipForgetRoom = "" +
"UPDATE roomserver_membership SET forgotten = $3" + "UPDATE roomserver_membership SET forgotten = $3" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
const deleteMembershipSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
const selectRoomsWithMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
@ -165,11 +169,20 @@ type membershipStatements struct {
updateMembershipForgetRoomStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt
selectLocalServerInRoomStmt *sql.Stmt selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
} }
func CreateMembershipTable(db *sql.DB) error { func CreateMembershipTable(db *sql.DB) error {
_, err := db.Exec(membershipSchema) _, err := db.Exec(membershipSchema)
return err if err != nil {
return err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "roomserver: add forgotten column",
Up: deltas.UpAddForgottenColumn,
})
return m.Up(context.Background())
} }
func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
@ -191,6 +204,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL},
{&s.deleteMembershipStmt, deleteMembershipSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -412,3 +426,13 @@ func (s *membershipStatements) SelectServerInRoom(
} }
return roomNID == nid, nil return roomNID == nid, nil
} }
func (s *membershipStatements) DeleteMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) error {
_, err := sqlutil.TxStmt(txn, s.deleteMembershipStmt).ExecContext(
ctx, roomNID, targetUserNID,
)
return err
}

View file

@ -21,7 +21,6 @@ import (
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
@ -45,17 +44,25 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c
} }
// Create the tables. // Create the tables.
if err := d.create(db); err != nil { if err = d.create(db); err != nil {
return nil, err return nil, err
} }
// Then execute the migrations. By this point the tables are created with the latest // Special case, since this migration uses several tables, so it needs to
// schemas. // be sure that all tables are created first.
m := sqlutil.NewMigrations() // TODO: Remove when we are sure we are not having goose artefacts in the db
deltas.LoadAddForgottenColumn(m) // This forces an error, which indicates the migration is already applied, since the
deltas.LoadStateBlocksRefactor(m) // column event_nid was removed from the table
if err := m.RunDeltas(db, dbProperties); err != nil { err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan()
return nil, err if err == nil {
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "roomserver: state blocks refactor",
Up: deltas.UpStateBlocksRefactor,
})
if err := m.Up(base.Context()); err != nil {
return nil, err
}
} }
// Then prepare the statements. Now that the migrations have run, any columns referred // Then prepare the statements. Now that the migrations have run, any columns referred

View file

@ -15,7 +15,7 @@ type MembershipUpdater struct {
d *Database d *Database
roomNID types.RoomNID roomNID types.RoomNID
targetUserNID types.EventStateKeyNID targetUserNID types.EventStateKeyNID
membership tables.MembershipState oldMembership tables.MembershipState
} }
func NewMembershipUpdater( func NewMembershipUpdater(
@ -30,7 +30,6 @@ func NewMembershipUpdater(
if err != nil { if err != nil {
return err return err
} }
targetUserNID, err = d.assignStateKeyNID(ctx, targetUserID) targetUserNID, err = d.assignStateKeyNID(ctx, targetUserID)
if err != nil { if err != nil {
return err return err
@ -73,139 +72,62 @@ func (d *Database) membershipUpdaterTxn(
// IsInvite implements types.MembershipUpdater // IsInvite implements types.MembershipUpdater
func (u *MembershipUpdater) IsInvite() bool { func (u *MembershipUpdater) IsInvite() bool {
return u.membership == tables.MembershipStateInvite return u.oldMembership == tables.MembershipStateInvite
} }
// IsJoin implements types.MembershipUpdater // IsJoin implements types.MembershipUpdater
func (u *MembershipUpdater) IsJoin() bool { func (u *MembershipUpdater) IsJoin() bool {
return u.membership == tables.MembershipStateJoin return u.oldMembership == tables.MembershipStateJoin
} }
// IsLeave implements types.MembershipUpdater // IsLeave implements types.MembershipUpdater
func (u *MembershipUpdater) IsLeave() bool { func (u *MembershipUpdater) IsLeave() bool {
return u.membership == tables.MembershipStateLeaveOrBan return u.oldMembership == tables.MembershipStateLeaveOrBan
} }
// IsKnock implements types.MembershipUpdater // IsKnock implements types.MembershipUpdater
func (u *MembershipUpdater) IsKnock() bool { func (u *MembershipUpdater) IsKnock() bool {
return u.membership == tables.MembershipStateKnock return u.oldMembership == tables.MembershipStateKnock
} }
// SetToInvite implements types.MembershipUpdater func (u *MembershipUpdater) Delete() error {
func (u *MembershipUpdater) SetToInvite(event *gomatrixserverlib.Event) (bool, error) { if _, err := u.d.InvitesTable.UpdateInviteRetired(u.ctx, u.txn, u.roomNID, u.targetUserNID); err != nil {
var inserted bool return err
err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { }
return u.d.MembershipTable.DeleteMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID)
}
func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event *types.Event) (bool, []string, error) {
var inserted bool // Did the query result in a membership change?
var retired []string // Did we retire any updates in the process?
return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, event.Sender()) senderUserNID, err := u.d.assignStateKeyNID(u.ctx, event.Sender())
if err != nil { if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
} }
inserted, err = u.d.InvitesTable.InsertInviteEvent( inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, newMembership, event.EventNID, false)
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
)
if err != nil { if err != nil {
return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
} }
if u.membership != tables.MembershipStateInvite { if !inserted {
if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil { return nil
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) }
switch {
case u.oldMembership != tables.MembershipStateInvite && newMembership == tables.MembershipStateInvite:
inserted, err = u.d.InvitesTable.InsertInviteEvent(
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
)
if err != nil {
return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
} }
} case u.oldMembership == tables.MembershipStateInvite && newMembership != tables.MembershipStateInvite:
return nil retired, err = u.d.InvitesTable.UpdateInviteRetired(
})
return inserted, err
}
// SetToJoin implements types.MembershipUpdater
func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
var inviteEventIDs []string
err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, senderUserID)
if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
// If this is a join event update, there is no invite to update
if !isUpdate {
inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID, u.ctx, u.txn, u.roomNID, u.targetUserNID,
) )
if err != nil { if err != nil {
return fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) return fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err)
} }
} }
// Look up the NID of the new join event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
if u.membership != tables.MembershipStateJoin || isUpdate {
if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil return nil
}) })
return inviteEventIDs, err
}
// SetToLeave implements types.MembershipUpdater
func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
var inviteEventIDs []string
err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, senderUserID)
if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err)
}
// Look up the NID of the new leave event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
if u.membership != tables.MembershipStateLeaveOrBan {
if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil
})
return inviteEventIDs, err
}
// SetToKnock implements types.MembershipUpdater
func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, error) {
var inserted bool
err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, event.Sender())
if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
if u.membership != tables.MembershipStateKnock {
// Look up the NID of the new knock event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil
})
return inserted, err
} }

View file

@ -72,7 +72,24 @@ func (d *Database) eventTypeNIDs(
func (d *Database) EventStateKeys( func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) { ) (map[types.EventStateKeyNID]string, error) {
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs) result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
fetch := make([]types.EventStateKeyNID, 0, len(eventStateKeyNIDs))
for _, nid := range eventStateKeyNIDs {
if key, ok := d.Cache.GetEventStateKey(nid); ok {
result[nid] = key
} else {
fetch = append(fetch, nid)
}
}
fromDB, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, fetch)
if err != nil {
return nil, err
}
for nid, key := range fromDB {
result[nid] = key
d.Cache.StoreEventStateKey(nid, key)
}
return result, nil
} }
func (d *Database) EventStateKeyNIDs( func (d *Database) EventStateKeyNIDs(

View file

@ -15,24 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) _, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
}
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func UpAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership ( CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL, target_nid INTEGER NOT NULL,
@ -57,8 +46,8 @@ DROP TABLE roomserver_membership_tmp;`)
return nil return nil
} }
func DownAddForgottenColumn(tx *sql.Tx) error { func DownAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp; _, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership ( CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL, target_nid INTEGER NOT NULL,

View file

@ -21,40 +21,35 @@ import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func LoadStateBlocksRefactor(m *sqlutil.Migrations) {
m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
}
// nolint:gocyclo // nolint:gocyclo
func UpStateBlocksRefactor(tx *sql.Tx) error { func UpStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!") logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!")
defer logrus.Warn("State storage upgrade complete") defer logrus.Warn("State storage upgrade complete")
var maxsnapshotid int var maxsnapshotid int
var maxblockid int var maxblockid int
if err := tx.QueryRow(`SELECT IFNULL(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil { if err := tx.QueryRowContext(ctx, `SELECT IFNULL(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil {
return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
} }
if err := tx.QueryRow(`SELECT IFNULL(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil { if err := tx.QueryRowContext(ctx, `SELECT IFNULL(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil {
return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
} }
maxsnapshotid++ maxsnapshotid++
maxblockid++ maxblockid++
oldMaxSnapshotID := maxsnapshotid oldMaxSnapshotID := maxsnapshotid
if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil { if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_block ( CREATE TABLE IF NOT EXISTS roomserver_state_block (
state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT,
state_block_hash BLOB UNIQUE, state_block_hash BLOB UNIQUE,
@ -62,9 +57,9 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
); );
`) `)
if err != nil { if err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
_, err = tx.Exec(` _, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
state_snapshot_hash BLOB UNIQUE, state_snapshot_hash BLOB UNIQUE,
@ -73,11 +68,11 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
); );
`) `)
if err != nil { if err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
snapshotrows, err := tx.Query(`SELECT state_snapshot_nid, room_nid, state_block_nids FROM _roomserver_state_snapshots;`) snapshotrows, err := tx.QueryContext(ctx, `SELECT state_snapshot_nid, room_nid, state_block_nids FROM _roomserver_state_snapshots;`)
if err != nil { if err != nil {
return fmt.Errorf("tx.Query: %w", err) return fmt.Errorf("tx.QueryContext: %w", err)
} }
defer internal.CloseAndLogIfError(context.TODO(), snapshotrows, "rows.close() failed") defer internal.CloseAndLogIfError(context.TODO(), snapshotrows, "rows.close() failed")
for snapshotrows.Next() { for snapshotrows.Next() {
@ -99,7 +94,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// in question a state snapshot NID of 0 to indicate 'no snapshot'. // in question a state snapshot NID of 0 to indicate 'no snapshot'.
// If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget // If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget
// any snapshots. // any snapshots.
_, err = tx.Exec( _, err = tx.ExecContext(ctx,
`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2 AND state_snapshot_nid = $3`, `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2 AND state_snapshot_nid = $3`,
types.MRoomCreateNID, types.EmptyStateKeyNID, snapshot, types.MRoomCreateNID, types.EmptyStateKeyNID, snapshot,
) )
@ -109,9 +104,9 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
for _, block := range blocks { for _, block := range blocks {
if err = func() error { if err = func() error {
blockrows, berr := tx.Query(`SELECT event_nid FROM _roomserver_state_block WHERE state_block_nid = $1`, block) blockrows, berr := tx.QueryContext(ctx, `SELECT event_nid FROM _roomserver_state_block WHERE state_block_nid = $1`, block)
if berr != nil { if berr != nil {
return fmt.Errorf("tx.Query (event nids from old block): %w", berr) return fmt.Errorf("tx.QueryContext (event nids from old block): %w", berr)
} }
defer internal.CloseAndLogIfError(context.TODO(), blockrows, "rows.close() failed") defer internal.CloseAndLogIfError(context.TODO(), blockrows, "rows.close() failed")
events := types.EventNIDs{} events := types.EventNIDs{}
@ -129,14 +124,14 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
var blocknid types.StateBlockNID var blocknid types.StateBlockNID
err = tx.QueryRow(` err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_block (state_block_nid, state_block_hash, event_nids) INSERT INTO roomserver_state_block (state_block_nid, state_block_hash, event_nids)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$3 ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$3
RETURNING state_block_nid RETURNING state_block_nid
`, maxblockid, events.Hash(), eventjson).Scan(&blocknid) `, maxblockid, events.Hash(), eventjson).Scan(&blocknid)
if err != nil { if err != nil {
return fmt.Errorf("tx.QueryRow.Scan (insert new block): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (insert new block): %w", err)
} }
maxblockid++ maxblockid++
newblocks = append(newblocks, blocknid) newblocks = append(newblocks, blocknid)
@ -151,22 +146,22 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
var newsnapshot types.StateSnapshotNID var newsnapshot types.StateSnapshotNID
err = tx.QueryRow(` err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_snapshots (state_snapshot_nid, state_snapshot_hash, room_nid, state_block_nids) INSERT INTO roomserver_state_snapshots (state_snapshot_nid, state_snapshot_hash, room_nid, state_block_nids)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$3 ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$3
RETURNING state_snapshot_nid RETURNING state_snapshot_nid
`, maxsnapshotid, newblocks.Hash(), room, newblocksjson).Scan(&newsnapshot) `, maxsnapshotid, newblocks.Hash(), room, newblocksjson).Scan(&newsnapshot)
if err != nil { if err != nil {
return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (insert new snapshot): %w", err)
} }
maxsnapshotid++ maxsnapshotid++
_, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid) _, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid)
if err != nil { if err != nil {
return fmt.Errorf("tx.Exec (update events): %w", err) return fmt.Errorf("tx.ExecContext (update events): %w", err)
} }
if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil { if _, err = tx.ExecContext(ctx, `UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil {
return fmt.Errorf("tx.Exec (update rooms): %w", err) return fmt.Errorf("tx.ExecContext (update rooms): %w", err)
} }
} }
} }
@ -175,13 +170,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// If we do, this is a problem if Dendrite tries to load the snapshot as it will not exist // If we do, this is a problem if Dendrite tries to load the snapshot as it will not exist
// in roomserver_state_snapshots // in roomserver_state_snapshots
var count int64 var count int64
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil { if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err) return fmt.Errorf("assertion query failed: %s", err)
} }
if count > 0 { if count > 0 {
var res sql.Result var res sql.Result
var c int64 var c int64
res, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID) res, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to reset invalid state snapshots: %w", err) return fmt.Errorf("failed to reset invalid state snapshots: %w", err)
} }
@ -191,23 +186,23 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c) return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c)
} }
} }
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil { if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err) return fmt.Errorf("assertion query failed: %s", err)
} }
if count > 0 { if count > 0 {
return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count)
} }
if _, err = tx.Exec(`DROP TABLE _roomserver_state_snapshots;`); err != nil { if _, err = tx.ExecContext(ctx, `DROP TABLE _roomserver_state_snapshots;`); err != nil {
return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err) return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err)
} }
if _, err = tx.Exec(`DROP TABLE _roomserver_state_block;`); err != nil { if _, err = tx.ExecContext(ctx, `DROP TABLE _roomserver_state_block;`); err != nil {
return fmt.Errorf("tx.Exec (delete old block table): %w", err) return fmt.Errorf("tx.Exec (delete old block table): %w", err)
} }
return nil return nil
} }
func DownStateBlocksRefactor(tx *sql.Tx) error { func DownStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
panic("Downgrading state storage is not supported") panic("Downgrading state storage is not supported")
} }

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -62,24 +63,24 @@ const insertMembershipSQL = "" +
const selectMembershipFromRoomAndTargetSQL = "" + const selectMembershipFromRoomAndTargetSQL = "" +
"SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND event_nid != 0 AND target_nid = $2"
const selectMembershipsFromRoomAndMembershipSQL = "" + const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" " WHERE room_nid = $1 AND event_nid != 0 AND membership_nid = $2 and forgotten = false"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" + const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" + " WHERE room_nid = $1 AND event_nid != 0 AND membership_nid = $2" +
" AND target_local = true and forgotten = false" " AND target_local = true and forgotten = false"
const selectMembershipsFromRoomSQL = "" + const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 and forgotten = false" " WHERE room_nid = $1 AND event_nid != 0 and forgotten = false"
const selectLocalMembershipsFromRoomSQL = "" + const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" + " WHERE room_nid = $1 AND event_nid != 0" +
" AND target_local = true and forgotten = false" " AND target_local = true and forgotten = false"
const selectMembershipForUpdateSQL = "" + const selectMembershipForUpdateSQL = "" +
@ -125,6 +126,9 @@ const selectServerInRoomSQL = "" +
" JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
const deleteMembershipSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
type membershipStatements struct { type membershipStatements struct {
db *sql.DB db *sql.DB
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
@ -140,11 +144,20 @@ type membershipStatements struct {
updateMembershipForgetRoomStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt
selectLocalServerInRoomStmt *sql.Stmt selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
} }
func CreateMembershipTable(db *sql.DB) error { func CreateMembershipTable(db *sql.DB) error {
_, err := db.Exec(membershipSchema) _, err := db.Exec(membershipSchema)
return err if err != nil {
return err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "roomserver: add forgotten column",
Up: deltas.UpAddForgottenColumn,
})
return m.Up(context.Background())
} }
func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
@ -166,6 +179,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL},
{&s.deleteMembershipStmt, deleteMembershipSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -383,3 +397,13 @@ func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.
} }
return roomNID == nid, nil return roomNID == nid, nil
} }
func (s *membershipStatements) DeleteMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) error {
_, err := sqlutil.TxStmt(txn, s.deleteMembershipStmt).ExecContext(
ctx, roomNID, targetUserNID,
)
return err
}

View file

@ -54,17 +54,25 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c
// db.SetMaxOpenConns(20) // db.SetMaxOpenConns(20)
// Create the tables. // Create the tables.
if err := d.create(db); err != nil { if err = d.create(db); err != nil {
return nil, err return nil, err
} }
// Then execute the migrations. By this point the tables are created with the latest // Special case, since this migration uses several tables, so it needs to
// schemas. // be sure that all tables are created first.
m := sqlutil.NewMigrations() // TODO: Remove when we are sure we are not having goose artefacts in the db
deltas.LoadAddForgottenColumn(m) // This forces an error, which indicates the migration is already applied, since the
deltas.LoadStateBlocksRefactor(m) // column event_nid was removed from the table
if err := m.RunDeltas(db, dbProperties); err != nil { err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan()
return nil, err if err == nil {
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "roomserver: state blocks refactor",
Up: deltas.UpStateBlocksRefactor,
})
if err := m.Up(base.Context()); err != nil {
return nil, err
}
} }
// Then prepare the statements. Now that the migrations have run, any columns referred // Then prepare the statements. Now that the migrations have run, any columns referred

View file

@ -133,6 +133,7 @@ type Membership interface {
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error
} }
type Published interface { type Published interface {

View file

@ -60,6 +60,9 @@ func TestMembershipTable(t *testing.T) {
// This inserts a left user to the room // This inserts a left user to the room
err = tab.InsertMembership(ctx, nil, 1, stateKeyNID, true) err = tab.InsertMembership(ctx, nil, 1, stateKeyNID, true)
assert.NoError(t, err) assert.NoError(t, err)
// We must update the membership with a non-zero event NID or it will get filtered out in later queries
_, err = tab.UpdateMembership(ctx, nil, 1, stateKeyNID, userNIDs[0], tables.MembershipStateLeaveOrBan, 1, false)
assert.NoError(t, err)
} }
// ... so this should be false // ... so this should be false

View file

@ -46,6 +46,9 @@ type Global struct {
// The server name to delegate server-server communications to, with optional port // The server name to delegate server-server communications to, with optional port
WellKnownServerName string `yaml:"well_known_server_name"` WellKnownServerName string `yaml:"well_known_server_name"`
// The server name to delegate client-server communications to, with optional port
WellKnownClientName string `yaml:"well_known_client_name"`
// Disables federation. Dendrite will not be able to make any outbound HTTP requests // Disables federation. Dendrite will not be able to make any outbound HTTP requests
// to other servers and the federation API will not be exposed. // to other servers and the federation API will not be exposed.
DisableFederation bool `yaml:"disable_federation"` DisableFederation bool `yaml:"disable_federation"`
@ -73,7 +76,7 @@ type Global struct {
// ServerNotices configuration used for sending server notices // ServerNotices configuration used for sending server notices
ServerNotices ServerNotices `yaml:"server_notices"` ServerNotices ServerNotices `yaml:"server_notices"`
// ReportStats configures opt-in anonymous stats reporting. // ReportStats configures opt-in phone-home statistics reporting.
ReportStats ReportStats `yaml:"report_stats"` ReportStats ReportStats `yaml:"report_stats"`
// Configuration for the caches. // Configuration for the caches.
@ -195,9 +198,9 @@ func (c *Cache) Verify(errors *ConfigErrors, isMonolith bool) {
checkPositive(errors, "max_size_estimated", int64(c.EstimatedMaxSize)) checkPositive(errors, "max_size_estimated", int64(c.EstimatedMaxSize))
} }
// ReportStats configures opt-in anonymous stats reporting. // ReportStats configures opt-in phone-home statistics reporting.
type ReportStats struct { type ReportStats struct {
// Enabled configures anonymous usage stats of the server // Enabled configures phone-home statistics of the server
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
// Endpoint the endpoint to report stats to // Endpoint the endpoint to report stats to

View file

@ -42,6 +42,7 @@ global:
key_id: ed25519:auto key_id: ed25519:auto
key_validity_period: 168h0m0s key_validity_period: 168h0m0s
well_known_server_name: "localhost:443" well_known_server_name: "localhost:443"
well_known_client_name: "localhost:443"
trusted_third_party_id_servers: trusted_third_party_id_servers:
- matrix.org - matrix.org
- vector.im - vector.im

View file

@ -708,7 +708,6 @@ func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEve
StateKey: *ev.StateKey(), StateKey: *ev.StateKey(),
Content: ev.Content(), Content: ev.Content(),
Sender: ev.Sender(), Sender: ev.Sender(),
RoomID: ev.RoomID(),
OriginServerTS: ev.OriginServerTS(), OriginServerTS: ev.OriginServerTS(),
} }
} }

View file

@ -240,6 +240,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
msg.RemovesStateEventIDs, msg.RemovesStateEventIDs,
msg.TransactionID, msg.TransactionID,
false, false,
msg.HistoryVisibility,
) )
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
@ -289,7 +290,8 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
[]string{}, // adds no state []string{}, // adds no state
[]string{}, // removes no state []string{}, // removes no state
nil, // no transaction nil, // no transaction
ev.StateKey() != nil, // exclude from sync? ev.StateKey() != nil, // exclude from sync?,
msg.HistoryVisibility,
) )
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
@ -363,7 +365,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
"event": string(msg.Event.JSON()), "event": string(msg.Event.JSON()),
"pdupos": pduPos, "pdupos": pduPos,
log.ErrorKey: err, log.ErrorKey: err,
}).Panicf("roomserver output log: write invite failure") }).Errorf("roomserver output log: write invite failure")
return return
} }
@ -383,7 +385,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": msg.EventID, "event_id": msg.EventID,
log.ErrorKey: err, log.ErrorKey: err,
}).Panicf("roomserver output log: remove invite failure") }).Errorf("roomserver output log: remove invite failure")
return return
} }
@ -401,7 +403,7 @@ func (s *OutputRoomEventConsumer) onNewPeek(
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
log.ErrorKey: err, log.ErrorKey: err,
}).Panicf("roomserver output log: write peek failure") }).Errorf("roomserver output log: write peek failure")
return return
} }
@ -420,7 +422,7 @@ func (s *OutputRoomEventConsumer) onRetirePeek(
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
log.ErrorKey: err, log.ErrorKey: err,
}).Panicf("roomserver output log: write peek failure") }).Errorf("roomserver output log: write peek failure")
return return
} }

View file

@ -21,6 +21,7 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
keytypes "github.com/matrix-org/dendrite/keyserver/types" keytypes "github.com/matrix-org/dendrite/keyserver/types"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -46,7 +47,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information. // be already filled in with join/leave information.
func DeviceListCatchup( func DeviceListCatchup(
ctx context.Context, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
userID string, res *types.Response, from, to types.StreamPosition, userID string, res *types.Response, from, to types.StreamPosition,
) (newPos types.StreamPosition, hasNew bool, err error) { ) (newPos types.StreamPosition, hasNew bool, err error) {
@ -93,7 +94,7 @@ func DeviceListCatchup(
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...) queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs) queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
var sharedUsersMap map[string]int var sharedUsersMap map[string]int
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, db, userID, queryRes.UserIDs)
util.GetLogger(ctx).Debugf( util.GetLogger(ctx).Debugf(
"QueryKeyChanges request off=%d,to=%d response off=%d uids=%v", "QueryKeyChanges request off=%d,to=%d response off=%d uids=%v",
offset, toOffset, queryRes.Offset, queryRes.UserIDs, offset, toOffset, queryRes.Offset, queryRes.UserIDs,
@ -215,30 +216,28 @@ func TrackChangedUsers(
return changed, left, nil return changed, left, nil
} }
// filterSharedUsers takes a list of remote users whose keys have changed and filters
// it down to include only users who the requesting user shares a room with.
func filterSharedUsers( func filterSharedUsers(
ctx context.Context, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, usersWithChangedKeys []string, ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string,
) (map[string]int, []string) { ) (map[string]int, []string) {
var result []string sharedUsersMap := make(map[string]int, len(usersWithChangedKeys))
var sharedUsersRes roomserverAPI.QuerySharedUsersResponse for _, userID := range usersWithChangedKeys {
err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ sharedUsersMap[userID] = 0
UserID: userID, }
OtherUserIDs: usersWithChangedKeys, sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys)
}, &sharedUsersRes)
if err != nil { if err != nil {
// default to all users so we do needless queries rather than miss some important device update // default to all users so we do needless queries rather than miss some important device update
return nil, usersWithChangedKeys return nil, usersWithChangedKeys
} }
for _, userID := range sharedUsers {
sharedUsersMap[userID]++
}
// We forcibly put ourselves in this list because we should be notified about our own device updates // We forcibly put ourselves in this list because we should be notified about our own device updates
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't // and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
// be notified about key changes. // be notified about key changes.
sharedUsersRes.UserIDsToCount[userID] = 1 sharedUsersMap[userID] = 1
return sharedUsersMap, sharedUsers
for _, uid := range usersWithChangedKeys {
if sharedUsersRes.UserIDsToCount[uid] > 0 {
result = append(result, uid)
}
}
return sharedUsersRes.UserIDsToCount, result
} }
func joinedRooms(res *types.Response, userID string) []string { func joinedRooms(res *types.Response, userID string) []string {

View file

@ -11,6 +11,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
var ( var (
@ -105,6 +106,22 @@ func (s *mockRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.Query
return nil return nil
} }
// This is actually a database function, but seeing as we track the state inside the
// *mockRoomserverAPI, we'll just comply with the interface here instead.
func (s *mockRoomserverAPI) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
commonUsers := []string{}
for _, members := range s.roomIDToJoinedMembers {
for _, member := range members {
for _, userID := range otherUserIDs {
if member == userID {
commonUsers = append(commonUsers, userID)
}
}
}
}
return util.UniqueStrings(commonUsers), nil
}
type wantCatchup struct { type wantCatchup struct {
hasNew bool hasNew bool
changed []string changed []string
@ -178,7 +195,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -201,7 +218,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -224,7 +241,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("Catchup returned an error: %s", err)
} }
@ -246,7 +263,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -305,7 +322,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
roomID: {syncingUser, existingUser}, roomID: {syncingUser, existingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -333,7 +350,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("Catchup returned an error: %s", err)
} }
@ -419,7 +436,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
}, },
} }
_, hasNew, err := DeviceListCatchup( _, hasNew, err := DeviceListCatchup(
context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken, context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken,
) )
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)

View file

@ -594,6 +594,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
[]string{}, []string{},
[]string{}, []string{},
nil, true, nil, true,
gomatrixserverlib.HistoryVisibilityShared,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -27,6 +27,8 @@ import (
type Database interface { type Database interface {
Presence Presence
SharedUsers
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
@ -67,7 +69,9 @@ type Database interface {
// when generating the sync stream position for this event. Returns the sync stream position for the inserted event. // when generating the sync stream position for this event. Returns the sync stream position for the inserted event.
// Returns an error if there was a problem inserting this event. // Returns an error if there was a problem inserting this event.
WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []*gomatrixserverlib.HeaderedEvent, WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []*gomatrixserverlib.HeaderedEvent,
addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error) addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool,
historyVisibility gomatrixserverlib.HistoryVisibility,
) (types.StreamPosition, error)
// PurgeRoomState completely purges room state from the sync API. This is done when // PurgeRoomState completely purges room state from the sync API. This is done when
// receiving an output event that completely resets the state. // receiving an output event that completely resets the state.
PurgeRoomState(ctx context.Context, roomID string) error PurgeRoomState(ctx context.Context, roomID string) error
@ -165,3 +169,8 @@ type Presence interface {
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
} }
type SharedUsers interface {
// SharedUsers returns a subset of otherUserIDs that share a room with userID.
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
}

View file

@ -23,6 +23,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -51,6 +52,7 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
-- The serial ID of the output_room_events table when this event became -- The serial ID of the output_room_events table when this event became
-- part of the current state of the room. -- part of the current state of the room.
added_at BIGINT, added_at BIGINT,
history_visibility SMALLINT NOT NULL DEFAULT 2,
-- Clobber based on 3-uple of room_id, type and state_key -- Clobber based on 3-uple of room_id, type and state_key
CONSTRAINT syncapi_room_state_unique UNIQUE (room_id, type, state_key) CONSTRAINT syncapi_room_state_unique UNIQUE (room_id, type, state_key)
); );
@ -63,8 +65,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON sync
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +
"INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" + "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at, history_visibility)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" +
" ON CONFLICT ON CONSTRAINT syncapi_room_state_unique" + " ON CONFLICT ON CONSTRAINT syncapi_room_state_unique" +
" DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9" " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9"
@ -100,13 +102,18 @@ const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
const selectEventsWithEventIDsSQL = "" + const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because otherwise // TODO: The session_id and transaction_id blanks are here because
// the rowsToStreamEvents expects there to be exactly six columns. We need to // the rowsToStreamEvents expects there to be exactly seven columns. We need to
// figure out if these really need to be in the DB, and if so, we need a // figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020 // better permanent fix for this. - neilalexander, 2 Jan 2020
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id, history_visibility" +
" FROM syncapi_current_room_state WHERE event_id = ANY($1)" " FROM syncapi_current_room_state WHERE event_id = ANY($1)"
const selectSharedUsersSQL = "" +
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
") AND state_key = ANY($2) AND membership='join';"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
@ -118,6 +125,7 @@ type currentRoomStateStatements struct {
selectJoinedUsersInRoomStmt *sql.Stmt selectJoinedUsersInRoomStmt *sql.Stmt
selectEventsWithEventIDsStmt *sql.Stmt selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
selectSharedUsersStmt *sql.Stmt
} }
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -126,6 +134,17 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (current_room_state)",
Up: deltas.UpAddHistoryVisibilityColumnCurrentRoomState,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
return nil, err return nil, err
} }
@ -156,6 +175,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err return nil, err
} }
if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -327,6 +349,7 @@ func (s *currentRoomStateStatements) UpsertRoomState(
headeredJSON, headeredJSON,
membership, membership,
addedAt, addedAt,
event.Visibility,
) )
return err return err
} }
@ -379,3 +402,24 @@ func (s *currentRoomStateStatements) SelectStateEvent(
} }
return &ev, err return &ev, err
} }
func (s *currentRoomStateStatements) SelectSharedUsers(
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, otherUserIDs)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
var stateKey string
result := make([]string, 0, len(otherUserIDs))
for rows.Next() {
if err := rows.Scan(&stateKey); err != nil {
return nil, err
}
result = append(result, stateKey)
}
return result, rows.Err()
}

View file

@ -15,24 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpFixSequences(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpFixSequences, DownFixSequences) _, err := tx.ExecContext(ctx, `
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes -- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to -- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence. -- reuse existing stream IDs from a different sequence.
@ -49,8 +38,8 @@ func UpFixSequences(tx *sql.Tx) error {
return nil return nil
} }
func DownFixSequences(tx *sql.Tx) error { func DownFixSequences(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- We need to delete all of the existing receipts because the indexes -- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to -- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence. -- reuse existing stream IDs from a different sequence.

View file

@ -15,18 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) { func UpRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn) _, err := tx.ExecContext(ctx, `
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device ALTER TABLE syncapi_send_to_device
DROP COLUMN IF EXISTS sent_by_token; DROP COLUMN IF EXISTS sent_by_token;
`) `)
@ -36,8 +31,8 @@ func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
return nil return nil
} }
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error { func DownRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_send_to_device ALTER TABLE syncapi_send_to_device
ADD COLUMN IF NOT EXISTS sent_by_token TEXT; ADD COLUMN IF NOT EXISTS sent_by_token TEXT;
`) `)

View file

@ -0,0 +1,54 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_output_room_events ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2;
UPDATE syncapi_output_room_events SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_current_room_state ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2;
UPDATE syncapi_current_room_state SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_output_room_events DROP COLUMN IF EXISTS history_visibility;
ALTER TABLE syncapi_current_room_state DROP COLUMN IF EXISTS history_visibility;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -67,7 +68,9 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
-- events retrieved through backfilling that have a position in the stream -- events retrieved through backfilling that have a position in the stream
-- that relates to the moment these were retrieved rather than the moment these -- that relates to the moment these were retrieved rather than the moment these
-- were emitted. -- were emitted.
exclude_from_sync BOOL DEFAULT FALSE exclude_from_sync BOOL DEFAULT FALSE,
-- The history visibility before this event (1 - world_readable; 2 - shared; 3 - invited; 4 - joined)
history_visibility SMALLINT NOT NULL DEFAULT 2
); );
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_type_idx ON syncapi_output_room_events (type); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_type_idx ON syncapi_output_room_events (type);
@ -78,16 +81,16 @@ CREATE INDEX IF NOT EXISTS syncapi_output_room_events_exclude_from_sync_idx ON s
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" + "INSERT INTO syncapi_output_room_events (" +
"room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + "room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync, history_visibility" +
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) " + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
"ON CONFLICT ON CONSTRAINT syncapi_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " + "ON CONFLICT ON CONSTRAINT syncapi_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " +
"RETURNING id" "RETURNING id"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectEventsWithFilterSQL = "" + const selectEventsWithFilterSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" + " AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" + " AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
@ -96,7 +99,7 @@ const selectEventsWithFilterSQL = "" +
" LIMIT $7" " LIMIT $7"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" + " AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
@ -105,7 +108,7 @@ const selectRecentEventsSQL = "" +
" ORDER BY id DESC LIMIT $8" " ORDER BY id DESC LIMIT $8"
const selectRecentEventsForSyncSQL = "" + const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" + " AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
@ -114,7 +117,7 @@ const selectRecentEventsForSyncSQL = "" +
" ORDER BY id DESC LIMIT $8" " ORDER BY id DESC LIMIT $8"
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" + " AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
@ -130,7 +133,7 @@ const updateEventJSONSQL = "" +
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" + const selectStateInRangeSQL = "" +
"SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" +
" FROM syncapi_output_room_events" + " FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" AND room_id = ANY($3)" + " AND room_id = ANY($3)" +
@ -146,10 +149,10 @@ const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1" "DELETE FROM syncapi_output_room_events WHERE room_id = $1"
const selectContextEventSQL = "" + const selectContextEventSQL = "" +
"SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND event_id = $2" "SELECT id, headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND event_id = $2"
const selectContextBeforeEventSQL = "" + const selectContextBeforeEventSQL = "" +
"SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2" + "SELECT headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" + " AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
@ -157,7 +160,7 @@ const selectContextBeforeEventSQL = "" +
" ORDER BY id DESC LIMIT $3" " ORDER BY id DESC LIMIT $3"
const selectContextAfterEventSQL = "" + const selectContextAfterEventSQL = "" +
"SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2" + "SELECT id, headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" + " AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
@ -186,6 +189,17 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL}, {&s.selectEventsStmt, selectEventsSQL},
@ -246,14 +260,15 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
for rows.Next() { for rows.Next() {
var ( var (
eventID string eventID string
streamPos types.StreamPosition streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool excludeFromSync bool
addIDs pq.StringArray addIDs pq.StringArray
delIDs pq.StringArray delIDs pq.StringArray
historyVisibility gomatrixserverlib.HistoryVisibility
) )
if err := rows.Scan(&eventID, &streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs); err != nil { if err := rows.Scan(&eventID, &streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs, &historyVisibility); err != nil {
return nil, nil, err return nil, nil, err
} }
// Sanity check for deleted state and whine if we see it. We don't need to do anything // Sanity check for deleted state and whine if we see it. We don't need to do anything
@ -283,6 +298,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
needSet[id] = true needSet[id] = true
} }
stateNeeded[ev.RoomID()] = needSet stateNeeded[ev.RoomID()] = needSet
ev.Visibility = historyVisibility
eventIDToEvent[eventID] = types.StreamEvent{ eventIDToEvent[eventID] = types.StreamEvent{
HeaderedEvent: &ev, HeaderedEvent: &ev,
@ -314,7 +330,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
func (s *outputRoomEventsStatements) InsertEvent( func (s *outputRoomEventsStatements) InsertEvent(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
transactionID *api.TransactionID, excludeFromSync bool, transactionID *api.TransactionID, excludeFromSync bool, historyVisibility gomatrixserverlib.HistoryVisibility,
) (streamPos types.StreamPosition, err error) { ) (streamPos types.StreamPosition, err error) {
var txnID *string var txnID *string
var sessionID *int64 var sessionID *int64
@ -351,6 +367,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
sessionID, sessionID,
txnID, txnID,
excludeFromSync, excludeFromSync,
historyVisibility,
).Scan(&streamPos) ).Scan(&streamPos)
return return
} }
@ -504,13 +521,15 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
row := sqlutil.TxStmt(txn, s.selectContextEventStmt).QueryRowContext(ctx, roomID, eventID) row := sqlutil.TxStmt(txn, s.selectContextEventStmt).QueryRowContext(ctx, roomID, eventID)
var eventAsString string var eventAsString string
if err = row.Scan(&id, &eventAsString); err != nil { var historyVisibility gomatrixserverlib.HistoryVisibility
if err = row.Scan(&id, &eventAsString, &historyVisibility); err != nil {
return 0, evt, err return 0, evt, err
} }
if err = json.Unmarshal([]byte(eventAsString), &evt); err != nil { if err = json.Unmarshal([]byte(eventAsString), &evt); err != nil {
return 0, evt, err return 0, evt, err
} }
evt.Visibility = historyVisibility
return id, evt, nil return id, evt, nil
} }
@ -532,15 +551,17 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
for rows.Next() { for rows.Next() {
var ( var (
eventBytes []byte eventBytes []byte
evt *gomatrixserverlib.HeaderedEvent evt *gomatrixserverlib.HeaderedEvent
historyVisibility gomatrixserverlib.HistoryVisibility
) )
if err = rows.Scan(&eventBytes); err != nil { if err = rows.Scan(&eventBytes, &historyVisibility); err != nil {
return evts, err return evts, err
} }
if err = json.Unmarshal(eventBytes, &evt); err != nil { if err = json.Unmarshal(eventBytes, &evt); err != nil {
return evts, err return evts, err
} }
evt.Visibility = historyVisibility
evts = append(evts, evt) evts = append(evts, evt)
} }
@ -565,15 +586,17 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
for rows.Next() { for rows.Next() {
var ( var (
eventBytes []byte eventBytes []byte
evt *gomatrixserverlib.HeaderedEvent evt *gomatrixserverlib.HeaderedEvent
historyVisibility gomatrixserverlib.HistoryVisibility
) )
if err = rows.Scan(&lastID, &eventBytes); err != nil { if err = rows.Scan(&lastID, &eventBytes, &historyVisibility); err != nil {
return 0, evts, err return 0, evts, err
} }
if err = json.Unmarshal(eventBytes, &evt); err != nil { if err = json.Unmarshal(eventBytes, &evt); err != nil {
return 0, evts, err return 0, evts, err
} }
evt.Visibility = historyVisibility
evts = append(evts, evt) evts = append(evts, evt)
} }
@ -584,15 +607,16 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent var result []types.StreamEvent
for rows.Next() { for rows.Next() {
var ( var (
eventID string eventID string
streamPos types.StreamPosition streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool excludeFromSync bool
sessionID *int64 sessionID *int64
txnID *string txnID *string
transactionID *api.TransactionID transactionID *api.TransactionID
historyVisibility gomatrixserverlib.HistoryVisibility
) )
if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID, &historyVisibility); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
@ -607,7 +631,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
TransactionID: *txnID, TransactionID: *txnID,
} }
} }
ev.Visibility = historyVisibility
result = append(result, types.StreamEvent{ result = append(result, types.StreamEvent{
HeaderedEvent: &ev, HeaderedEvent: &ev,
StreamPosition: streamPos, StreamPosition: streamPos,

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -73,6 +74,15 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: fix sequences",
Up: deltas.UpFixSequences,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
r := &receiptStatements{ r := &receiptStatements{
db: db, db: db,
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -76,6 +77,15 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: drop sent_by_token",
Up: deltas.UpRemoveSendToDeviceSentColumn,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/shared"
) )
@ -98,12 +97,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,

View file

@ -176,6 +176,10 @@ func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]t
return d.Peeks.SelectPeekingDevices(ctx) return d.Peeks.SelectPeekingDevices(ctx)
} }
func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
return d.CurrentRoomState.SelectSharedUsers(ctx, nil, userID, otherUserIDs)
}
func (d *Database) GetStateEvent( func (d *Database) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
@ -364,11 +368,12 @@ func (d *Database) WriteEvent(
addStateEvents []*gomatrixserverlib.HeaderedEvent, addStateEvents []*gomatrixserverlib.HeaderedEvent,
addStateEventIDs, removeStateEventIDs []string, addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, excludeFromSync bool, transactionID *api.TransactionID, excludeFromSync bool,
historyVisibility gomatrixserverlib.HistoryVisibility,
) (pduPosition types.StreamPosition, returnErr error) { ) (pduPosition types.StreamPosition, returnErr error) {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error var err error
pos, err := d.OutputEvents.InsertEvent( pos, err := d.OutputEvents.InsertEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility,
) )
if err != nil { if err != nil {
return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err) return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
@ -387,7 +392,9 @@ func (d *Database) WriteEvent(
// Nothing to do, the event may have just been a message event. // Nothing to do, the event may have just been a message event.
return nil return nil
} }
for i := range addStateEvents {
addStateEvents[i].Visibility = historyVisibility
}
return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition) return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition)
}) })

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -41,6 +42,7 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
headered_event_json TEXT NOT NULL, headered_event_json TEXT NOT NULL,
membership TEXT, membership TEXT,
added_at BIGINT, added_at BIGINT,
history_visibility SMALLINT NOT NULL DEFAULT 2, -- The history visibility before this event (1 - world_readable; 2 - shared; 3 - invited; 4 - joined)
UNIQUE (room_id, type, state_key) UNIQUE (room_id, type, state_key)
); );
-- for event deletion -- for event deletion
@ -52,8 +54,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON sync
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +
"INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" + "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at, history_visibility)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" +
" ON CONFLICT (room_id, type, state_key)" + " ON CONFLICT (room_id, type, state_key)" +
" DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9" " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9"
@ -84,13 +86,18 @@ const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
const selectEventsWithEventIDsSQL = "" + const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because otherwise // TODO: The session_id and transaction_id blanks are here because
// the rowsToStreamEvents expects there to be exactly six columns. We need to // the rowsToStreamEvents expects there to be exactly seven columns. We need to
// figure out if these really need to be in the DB, and if so, we need a // figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020 // better permanent fix for this. - neilalexander, 2 Jan 2020
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id, history_visibility" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)" " FROM syncapi_current_room_state WHERE event_id IN ($1)"
const selectSharedUsersSQL = "" +
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
") AND state_key IN ($2) AND membership='join';"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *StreamIDStatements streamIDStatements *StreamIDStatements
@ -100,8 +107,9 @@ type currentRoomStateStatements struct {
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectRoomIDsWithAnyMembershipStmt *sql.Stmt selectRoomIDsWithAnyMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
@ -113,6 +121,17 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (current_room_state)",
Up: deltas.UpAddHistoryVisibilityColumnCurrentRoomState,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
return nil, err return nil, err
} }
@ -322,6 +341,7 @@ func (s *currentRoomStateStatements) UpsertRoomState(
headeredJSON, headeredJSON,
membership, membership,
addedAt, addedAt,
event.Visibility,
) )
return err return err
} }
@ -396,3 +416,29 @@ func (s *currentRoomStateStatements) SelectStateEvent(
} }
return &ev, err return &ev, err
} }
func (s *currentRoomStateStatements) SelectSharedUsers(
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
) ([]string, error) {
query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
stmt, err := s.db.Prepare(query)
if err != nil {
return nil, fmt.Errorf("SelectSharedUsers s.db.Prepare: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectSharedUsers: stmt.close() failed")
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, userID, otherUserIDs)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
var stateKey string
result := make([]string, 0, len(otherUserIDs))
for rows.Next() {
if err := rows.Scan(&stateKey); err != nil {
return nil, err
}
result = append(result, stateKey)
}
return result, rows.Err()
}

View file

@ -15,24 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpFixSequences(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpFixSequences, DownFixSequences) _, err := tx.ExecContext(ctx, `
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes -- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to -- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence. -- reuse existing stream IDs from a different sequence.
@ -45,8 +34,8 @@ func UpFixSequences(tx *sql.Tx) error {
return nil return nil
} }
func DownFixSequences(tx *sql.Tx) error { func DownFixSequences(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- We need to delete all of the existing receipts because the indexes -- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to -- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence. -- reuse existing stream IDs from a different sequence.

View file

@ -15,18 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) { func UpRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn) _, err := tx.ExecContext(ctx, `
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content); CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device; INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device; DROP TABLE syncapi_send_to_device;
@ -45,8 +40,8 @@ func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
return nil return nil
} }
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error { func DownRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content); CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device; INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device; DROP TABLE syncapi_send_to_device;

View file

@ -0,0 +1,82 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists.
// Required for unit tests, as otherwise a duplicate column error will show up.
_, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1")
if err == nil {
return nil
}
_, err = tx.ExecContext(ctx, `
ALTER TABLE syncapi_output_room_events ADD COLUMN history_visibility SMALLINT NOT NULL DEFAULT 2;
UPDATE syncapi_output_room_events SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists.
// Required for unit tests, as otherwise a duplicate column error will show up.
_, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_current_room_state LIMIT 1")
if err == nil {
return nil
}
_, err = tx.ExecContext(ctx, `
ALTER TABLE syncapi_current_room_state ADD COLUMN history_visibility SMALLINT NOT NULL DEFAULT 2;
UPDATE syncapi_current_room_state SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists.
_, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1")
if err != nil {
// The column probably doesn't exist
return nil
}
_, err = tx.ExecContext(ctx, `
ALTER TABLE syncapi_output_room_events DROP COLUMN history_visibility;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_current_room_state LIMIT 1")
if err != nil {
// The column probably doesn't exist
return nil
}
_, err = tx.ExecContext(ctx, `
ALTER TABLE syncapi_current_room_state DROP COLUMN history_visibility;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -47,7 +48,8 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
remove_state_ids TEXT, -- JSON encoded string array remove_state_ids TEXT, -- JSON encoded string array
session_id BIGINT, session_id BIGINT,
transaction_id TEXT, transaction_id TEXT,
exclude_from_sync BOOL NOT NULL DEFAULT FALSE exclude_from_sync BOOL NOT NULL DEFAULT FALSE,
history_visibility SMALLINT NOT NULL DEFAULT 2 -- The history visibility before this event (1 - world_readable; 2 - shared; 3 - invited; 4 - joined)
); );
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_type_idx ON syncapi_output_room_events (type); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_type_idx ON syncapi_output_room_events (type);
@ -58,27 +60,27 @@ CREATE INDEX IF NOT EXISTS syncapi_output_room_events_exclude_from_sync_idx ON s
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" + "INSERT INTO syncapi_output_room_events (" +
"id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + "id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync, history_visibility" +
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) " +
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)" "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $14)"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events WHERE event_id IN ($1)"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" " WHERE room_id = $1 AND id > $2 AND id <= $3"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectRecentEventsForSyncSQL = "" + const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" " WHERE room_id = $1 AND id > $2 AND id <= $3"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
@ -90,7 +92,7 @@ const updateEventJSONSQL = "" +
"UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
const selectStateInRangeSQL = "" + const selectStateInRangeSQL = "" +
"SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" +
" FROM syncapi_output_room_events" + " FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2)" + " WHERE (id > $1 AND id <= $2)" +
" AND room_id IN ($3)" + " AND room_id IN ($3)" +
@ -102,15 +104,15 @@ const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1" "DELETE FROM syncapi_output_room_events WHERE room_id = $1"
const selectContextEventSQL = "" + const selectContextEventSQL = "" +
"SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND event_id = $2" "SELECT id, headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND event_id = $2"
const selectContextBeforeEventSQL = "" + const selectContextBeforeEventSQL = "" +
"SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2" "SELECT headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectContextAfterEventSQL = "" + const selectContextAfterEventSQL = "" +
"SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2" "SELECT id, headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
@ -135,6 +137,17 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
@ -196,14 +209,15 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
for rows.Next() { for rows.Next() {
var ( var (
eventID string eventID string
streamPos types.StreamPosition streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool excludeFromSync bool
addIDsJSON string addIDsJSON string
delIDsJSON string delIDsJSON string
historyVisibility gomatrixserverlib.HistoryVisibility
) )
if err := rows.Scan(&eventID, &streamPos, &eventBytes, &excludeFromSync, &addIDsJSON, &delIDsJSON); err != nil { if err := rows.Scan(&eventID, &streamPos, &eventBytes, &excludeFromSync, &addIDsJSON, &delIDsJSON, &historyVisibility); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -239,6 +253,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
needSet[id] = true needSet[id] = true
} }
stateNeeded[ev.RoomID()] = needSet stateNeeded[ev.RoomID()] = needSet
ev.Visibility = historyVisibility
eventIDToEvent[eventID] = types.StreamEvent{ eventIDToEvent[eventID] = types.StreamEvent{
HeaderedEvent: &ev, HeaderedEvent: &ev,
@ -270,7 +285,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
func (s *outputRoomEventsStatements) InsertEvent( func (s *outputRoomEventsStatements) InsertEvent(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
transactionID *api.TransactionID, excludeFromSync bool, transactionID *api.TransactionID, excludeFromSync bool, historyVisibility gomatrixserverlib.HistoryVisibility,
) (types.StreamPosition, error) { ) (types.StreamPosition, error) {
var txnID *string var txnID *string
var sessionID *int64 var sessionID *int64
@ -326,6 +341,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
sessionID, sessionID,
txnID, txnID,
excludeFromSync, excludeFromSync,
historyVisibility,
excludeFromSync, excludeFromSync,
) )
return streamPos, err return streamPos, err
@ -481,15 +497,16 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent var result []types.StreamEvent
for rows.Next() { for rows.Next() {
var ( var (
eventID string eventID string
streamPos types.StreamPosition streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool excludeFromSync bool
sessionID *int64 sessionID *int64
txnID *string txnID *string
transactionID *api.TransactionID transactionID *api.TransactionID
historyVisibility gomatrixserverlib.HistoryVisibility
) )
if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID, &historyVisibility); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
@ -505,6 +522,8 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
} }
} }
ev.Visibility = historyVisibility
result = append(result, types.StreamEvent{ result = append(result, types.StreamEvent{
HeaderedEvent: &ev, HeaderedEvent: &ev,
StreamPosition: streamPos, StreamPosition: streamPos,
@ -519,13 +538,15 @@ func (s *outputRoomEventsStatements) SelectContextEvent(
) (id int, evt gomatrixserverlib.HeaderedEvent, err error) { ) (id int, evt gomatrixserverlib.HeaderedEvent, err error) {
row := sqlutil.TxStmt(txn, s.selectContextEventStmt).QueryRowContext(ctx, roomID, eventID) row := sqlutil.TxStmt(txn, s.selectContextEventStmt).QueryRowContext(ctx, roomID, eventID)
var eventAsString string var eventAsString string
if err = row.Scan(&id, &eventAsString); err != nil { var historyVisibility gomatrixserverlib.HistoryVisibility
if err = row.Scan(&id, &eventAsString, &historyVisibility); err != nil {
return 0, evt, err return 0, evt, err
} }
if err = json.Unmarshal([]byte(eventAsString), &evt); err != nil { if err = json.Unmarshal([]byte(eventAsString), &evt); err != nil {
return 0, evt, err return 0, evt, err
} }
evt.Visibility = historyVisibility
return id, evt, nil return id, evt, nil
} }
@ -550,15 +571,17 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
for rows.Next() { for rows.Next() {
var ( var (
eventBytes []byte eventBytes []byte
evt *gomatrixserverlib.HeaderedEvent evt *gomatrixserverlib.HeaderedEvent
historyVisibility gomatrixserverlib.HistoryVisibility
) )
if err = rows.Scan(&eventBytes); err != nil { if err = rows.Scan(&eventBytes, &historyVisibility); err != nil {
return evts, err return evts, err
} }
if err = json.Unmarshal(eventBytes, &evt); err != nil { if err = json.Unmarshal(eventBytes, &evt); err != nil {
return evts, err return evts, err
} }
evt.Visibility = historyVisibility
evts = append(evts, evt) evts = append(evts, evt)
} }
@ -586,15 +609,17 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
for rows.Next() { for rows.Next() {
var ( var (
eventBytes []byte eventBytes []byte
evt *gomatrixserverlib.HeaderedEvent evt *gomatrixserverlib.HeaderedEvent
historyVisibility gomatrixserverlib.HistoryVisibility
) )
if err = rows.Scan(&lastID, &eventBytes); err != nil { if err = rows.Scan(&lastID, &eventBytes, &historyVisibility); err != nil {
return 0, evts, err return 0, evts, err
} }
if err = json.Unmarshal(eventBytes, &evt); err != nil { if err = json.Unmarshal(eventBytes, &evt); err != nil {
return 0, evts, err return 0, evts, err
} }
evt.Visibility = historyVisibility
evts = append(evts, evt) evts = append(evts, evt)
} }
return lastID, evts, rows.Err() return lastID, evts, rows.Err()

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -70,6 +71,15 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: fix sequences",
Up: deltas.UpFixSequences,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
r := &receiptStatements{ r := &receiptStatements{
db: db, db: db,
streamIDStatements: streamID, streamIDStatements: streamID,

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -76,6 +77,15 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: drop sent_by_token",
Up: deltas.UpRemoveSendToDeviceSentColumn,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }

View file

@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
) )
// SyncServerDatasource represents a sync server datasource which manages // SyncServerDatasource represents a sync server datasource which manages
@ -42,13 +41,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err return nil, err
} }
if err = d.prepare(dbProperties); err != nil { if err = d.prepare(); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil
} }
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) { func (d *SyncServerDatasource) prepare() (err error) {
if err = d.streamID.Prepare(d.db); err != nil { if err = d.streamID.Prepare(d.db); err != nil {
return err return err
} }
@ -108,12 +107,6 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
if err != nil { if err != nil {
return err return err
} }
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,

View file

@ -37,7 +37,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
addStateEvents = append(addStateEvents, ev) addStateEvents = append(addStateEvents, ev)
addStateEventIDs = append(addStateEventIDs, ev.EventID()) addStateEventIDs = append(addStateEventIDs, ev.EventID())
} }
pos, err := db.WriteEvent(ctx, ev, addStateEvents, addStateEventIDs, removeStateEventIDs, nil, false) pos, err := db.WriteEvent(ctx, ev, addStateEvents, addStateEventIDs, removeStateEventIDs, nil, false, gomatrixserverlib.HistoryVisibilityShared)
if err != nil { if err != nil {
t.Fatalf("WriteEvent failed: %s", err) t.Fatalf("WriteEvent failed: %s", err)
} }

View file

@ -52,7 +52,14 @@ type Peeks interface {
type Events interface { type Events interface {
SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string) (map[string]map[string]bool, map[string]types.StreamEvent, error) SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string) (map[string]map[string]bool, map[string]types.StreamEvent, error)
SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error)
InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error) InsertEvent(
ctx context.Context, txn *sql.Tx,
event *gomatrixserverlib.HeaderedEvent,
addState, removeState []string,
transactionID *api.TransactionID,
excludeFromSync bool,
historyVisibility gomatrixserverlib.HistoryVisibility,
) (streamPos types.StreamPosition, err error)
// SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync.
// Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`. // Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`.
@ -104,6 +111,8 @@ type CurrentRoomState interface {
SelectJoinedUsers(ctx context.Context) (map[string][]string, error) SelectJoinedUsers(ctx context.Context) (map[string][]string, error)
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
} }
// BackwardsExtremities keeps track of backwards extremities for a room. // BackwardsExtremities keeps track of backwards extremities for a room.

View file

@ -53,7 +53,7 @@ func TestOutputRoomEventsTable(t *testing.T) {
events := room.Events() events := room.Events()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
for _, ev := range events { for _, ev := range events {
_, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false) _, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false, gomatrixserverlib.HistoryVisibilityShared)
if err != nil { if err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err) return fmt.Errorf("failed to InsertEvent: %s", err)
} }
@ -79,7 +79,7 @@ func TestOutputRoomEventsTable(t *testing.T) {
"body": "test.txt", "body": "test.txt",
"url": "mxc://test.txt", "url": "mxc://test.txt",
}) })
if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil { if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false, gomatrixserverlib.HistoryVisibilityShared); err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err) return fmt.Errorf("failed to InsertEvent: %s", err)
} }
wantEventID := []string{urlEv.EventID()} wantEventID := []string{urlEv.EventID()}

View file

@ -28,7 +28,7 @@ func (p *DeviceListStreamProvider) IncrementalSync(
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
var err error var err error
to, _, err = internal.DeviceListCatchup(context.Background(), p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) to, _, err = internal.DeviceListCatchup(context.Background(), p.DB, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
if err != nil { if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed") req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from return from

View file

@ -429,7 +429,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
} }
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition) rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition)
_, _, err = internal.DeviceListCatchup( _, _, err = internal.DeviceListCatchup(
req.Context(), rp.keyAPI, rp.rsAPI, syncReq.Device.UserID, req.Context(), rp.db, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition, syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
) )
if err != nil { if err != nil {

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -85,6 +86,23 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations([]sqlutil.Migration{
{
Version: "userapi: add is active",
Up: deltas.UpIsActive,
Down: deltas.DownIsActive,
},
{
Version: "userapi: add account type",
Up: deltas.UpAddAccountType,
Down: deltas.DownAddAccountType,
},
}...)
err = m.Up(context.Background())
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL}, {&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL}, {&s.updatePasswordStmt, updatePasswordSQL},

View file

@ -1,33 +1,21 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadFromGoose() { func UpIsActive(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpIsActive, DownIsActive) _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
goose.AddMigration(UpAddAccountType, DownAddAccountType)
}
func LoadIsActive(m *sqlutil.Migrations) {
m.AddMigration(UpIsActive, DownIsActive)
}
func UpIsActive(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
return nil return nil
} }
func DownIsActive(tx *sql.Tx) error { func DownIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN is_deactivated;") _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN is_deactivated;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

Some files were not shown because too many files have changed in this diff Show more