Merge branch 'master' of github.com:matrix-org/dendrite into sendJoin_response

This commit is contained in:
Andrew Morgan 2019-09-30 16:50:08 +01:00
commit 2aa42788ee
96 changed files with 2587 additions and 853 deletions

View file

@ -1,49 +0,0 @@
steps:
- command:
# https://github.com/golangci/golangci-lint#memory-usage-of-golangci-lint
- "GOGC=20 ./scripts/find-lint.sh"
label: "\U0001F9F9 Lint / :go: 1.12"
agents:
# Use a larger instance as linting takes a looot of memory
queue: "medium"
plugins:
- docker#v3.0.1:
image: "golang:1.12"
- wait
- command:
- "go build ./cmd/..."
label: "\U0001F528 Build / :go: 1.11"
plugins:
- docker#v3.0.1:
image: "golang:1.11"
retry:
automatic:
- exit_status: 128
limit: 3
- command:
- "go build ./cmd/..."
label: "\U0001F528 Build / :go: 1.12"
plugins:
- docker#v3.0.1:
image: "golang:1.12"
retry:
automatic:
- exit_status: 128
limit: 3
- command:
- "go test ./..."
label: "\U0001F9EA Unit tests / :go: 1.11"
plugins:
- docker#v3.0.1:
image: "golang:1.11"
- command:
- "go test ./..."
label: "\U0001F9EA Unit tests / :go: 1.12"
plugins:
- docker#v3.0.1:
image: "golang:1.12"

View file

@ -1,32 +0,0 @@
version: 2
jobs:
dendrite:
docker:
- image: matrixdotorg/sytest-dendrite
working_directory: /src
steps:
- checkout
# Set up dendrite
- run:
name: Build Dendrite
command: ./build.sh
- run:
name: Copy dummy keys to root
command: |
mv .circleci/matrix_key.pem .
mv .circleci/server.key .
- run:
name: Run sytest with whitelisted tests
command: /dendrite_sytest.sh
- store_artifacts:
path: /logs
destination: logs
- store_test_results:
path: /logs
workflows:
version: 2
build:
jobs:
- dendrite

View file

@ -1,5 +0,0 @@
-----BEGIN MATRIX PRIVATE KEY-----
Key-ID: ed25519:zXtB
jDyHsx0EXbAfvM32yBEKQfIy1FHrmwtB1uMAbm5INBg=
-----END MATRIX PRIVATE KEY-----

View file

@ -1,52 +0,0 @@
-----BEGIN PRIVATE KEY-----
MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCanRCqP11MLIQh
nC26+A1oyBsFfH7auZ3pqE/WFDrCCIoc7ek7cF3fZU7q8OYI+Q9L5V8fobuLb6FB
iXD5zZ6pBAI0VNjAS8yi8VluXIv6pJKsVY3k2hGiU7xRoEhkzckZBaEiruspQbcX
ziNoWoueVBB1a4Eproqzy225cTcoprHsJIPXj0HpW/jKcmahmlM/OrqRAxTwxpb/
moI6MWIeN4n7h55N6dU1ScVvBS7gZpZQ28d8akuvG3m8kE8q1OPFYGvrNeowD4sp
qDPFijhbygwpzDQlAWriPcqV9KhuGRnYRGTGvuluOttmpgNhNFVxVAlwZJuMVAMU
Jhek66ntKsxWkF5LsO8ls20hmHyyAsL7+rb2ZjuRtEwE8SwOstU2AIIXoSTtqXjX
zC8Ew0VB9MCInJoJC/+iKTLoDqXRZeDKGFx1A2F3Y+Er+Z41HcwgqKRsPqZ066yR
6iKAb5rzJutnEARtbSrNipy9nHE5hIgKJzgOnggcegypcAj3nqbfFFCZA2CFNXoG
XFkmBHEpz38pPLI5z6HpeZRRySoIyahk9IfSwM3aB1aUi//8CcpAodGvYGNQkQ3W
HkrZmM4MtC25I5RyMpYJQWKFpx1cOVPf2ASqaJ+IX1JJTv9dSdYHY/rxsxaiXiry
+uI7UITRvUKgAOrExfSAXco73bgUFwIDAQABAoICAQCP9QX7PhxEPH6aPKxnlWYG
1aozJYOHa6QYVlpfXV6IIyNVZD7w1OLSiaU9IydL23nelKZI8XGJllpyhuHl9Qlx
HQZga0+VW/4hCM7X7tt2d50JUG9ZUaFxnr2M0swU73X6Ej/B51OVilZLl+dn1kaB
GIxqh7ovcRA774EuVLei5fJriGQpZH1eJgAznujoNqSkDq5/Lntk48LcIqR2Qly0
/ck/pTpEGSAnCZUGlbDbxyjWCIxozx/A3rguVb8ghi+9KtXQntZ6AT71fmMV3mgz
LqC8miFDA1rdY+MoVDAusrhZoPSkCEWYGL0HijNDYlLbvf874rDhq6diL0V8jOAd
PGOx5BY6VUWbSQAUtKpMuNSL6tidkOACGPwbuH7OIaG+yGZ0/Oiy3fureiAEg5VU
piyp6F7p1g0vgQEnj4CHiCQlX48bjC/mm8758DeaH8H5T++A8MOgRhgFVb9f01R+
NMzszMziuVNDYe01cwdY1TXUx5b0o+opsbPm6sNp/7afL9Hou1epP9zQC0I8ulfP
fgrKTddMwlNjoBuDMQ8GqoK275YU4wtyhUMfjr3xQ0JwP46cZbhhc4nh6qcRSNTf
yVuKv/pT/bJcSmg5JOCS8qdK0BQhAvUin9HvgSAV9QmZVpxzT/xhqwuRlLDKW+VR
XyPt996f3L4CTXI9h88AQQKCAQEAycBChu3/ZKl8a90anOlv9PwmaaXfLBKH9Rkw
aeZrMilxTJAb+LEsmtj35rF5KPeBP6ARpX5gmvKJVzCDHT9YgNs+6C3E+l2f1/3a
TcjZKPTukT2gJdCgejhEgTzAwEse322GSptuyidtNpY7NgbAxP4VdDMOmPYbzufb
5BqxmfiGsfXgdvQkj8/MzHuGhhft4SU6ED/Ax+EPUWVV7kBr2995kGDF5z5CuJkb
SJjmVxAJZP/kC2Z/iPnP51G0hiCxHp7+gPY4mvvkHvhJGnGH/vutjRjoe28BENlP
MgB68S1/U3NGSUzWv86pT1OdHd+qynWj/NzF7Gp/T/ju8VZBXwKCAQEAxDAMSOfF
dizsU7cJbf6vxi6XJHjhwWUWD2vMznKz1D4mkByeY8aSOc8kQZsE5nd4ZgwkYTaZ
gItjGjM5y5dpKurfKdqQ+dA6PS03h3p+tp1lZp9/dI9X/DfkTO/LUdrfkVVcbQhE
zqc6C35qO98rhJdsRwhOF28mOc/4bbs0XjC5dEoBGyFt7Fbn2mYoCo4FSHl7WIq6
TZR9pLAvxjqEZ6Dwrzpp9wtdLIQYPga+KVKcDT/DStThXDTCNt5PyDE9c8eImFww
u0T87Er5hSEQgodURxDOZh+9ktIfXzMtxiAJ3iDCEPc3NNnLCWfKMhwGsVTCCXj6
WuHTOe79tOaQSQKCAQEAqBN52PsRl4TzWNEcyLhZQxmFzuIXKJpPlctkX/VMPL/1
2bj89JR1+pLjA9e6fnyjuqPZz6uXQ77m2DJcKNOLId6Fa9wljAbPkZu0cLTw5YQX
8/wJHTfPWcLin2BDnG94yt5t0F3pUJTEEYPa1EmP8w1SRjn64Ue3JwpWUJREfWdk
n4GdfLwscXrGvVvzWGc7ECR5WOwj6OEAZ+kqS5BzyvtERRm6BcoCv9Mdvb9Tthhw
Gypri2vat/yWTbnt0QgPRtliYYG+6q8K/xoNnPAUQkLd9PxZQevaUXUY2yk3QxGK
T7VrSsmu5qB+wM2ByU9686xJ7/DlGu4mHjPerEQVtQKCAQBcM3iSitpyP4qRjWQR
HbDeIudFbMosaaWEedU28REynkLhV5HYsmnmYUNY0dHrvhoHW419YnuhveBFX+25
kN8MHHXk5aNcxE+akLWYJimHCVGueScdUIC5OEtDHS8guQx48PUPCOPNeyn8XNzw
ZmG9Xqy0dWK+AK6mXOcUKvbhjWSbEmySo5NVj0JHkdsfmr9A4Fbntcr4yuCBlYve
TYIMccark3hZci3HzgzWmbSlFv3f/Cd787A19VWRE8nK+9k1oIDBmhIM8M8s/c9m
kbOApLkm7O8Tb7dYWQgFZbgNdOEuU5bhAk4fuHuDYBPWmPVMQdkvOnvuWlM61ubF
LdaBAoIBACDpbb5AQIYsWWOnoXuuGh+YY4kmnaBFpsbgEYkZSy92AaLr4Ibf49WN
oqNDX73YaJlURaGPYMC9J2Huq7TZcewH3SwkVA3N5UmDoijkM4juRfADAfVIMxB5
+9paWeEfnYC/o377FTJIJ9hHJWIaWSoiJZLYDBmoYdxmk8DSHAJCeWsjYDzPybsH
7RyMPIa1u7lVdgOPEOBi1OIg7ASLxGKiHQtrYHq99GcaVvU/UxoNRMcSnPfY3G8R
pGah+EndSCb2F20ouDyvlKfOylAltH2BeNc3B4PeP7ZhlVr7bfyOAfC2Z7FNDm3J
+yaBExKfroZjsksctNAcAbgpuvhLLG8=
-----END PRIVATE KEY-----

View file

@ -20,6 +20,41 @@ should pick up any unit test and run it). There are also [scripts](scripts) for
[linting](scripts/find-lint.sh) and doing a [build/test/lint [linting](scripts/find-lint.sh) and doing a [build/test/lint
run](scripts/build-test-lint.sh). run](scripts/build-test-lint.sh).
## Continuous Integration
When a Pull Request is submitted, continuous integration jobs are run
automatically to ensure the code builds and is relatively well-written. Checks
are run on [Buildkite](https://buildkite.com/matrix-dot-org/dendrite/) and
[CircleCI](https://circleci.com/gh/matrix-org/dendrite/). The Buildkite
pipeline can be found in Matrix.org's [pipelines
repository](https://github.com/matrix-org/pipelines).
If a job fails, click the "details" button and you should be taken to the job's
logs.
![Click the details button on the failing build step](docs/images/details-button-location.jpg)
Scroll down to the failing step and you should see some log output. Scan
the logs until you find what it's complaining about, fix it, submit a new
commit, then rinse and repeat until CI passes.
### Running CI Tests Locally
To save waiting for CI to finish after every commit, it is ideal to run the
checks locally before pushing, fixing errors first. This also saves other
people time as only so many PRs can be tested at a given time.
To execute what Buildkite tests, simply run `./scripts/build-test-lint.sh`.
This script will build the code, lint it, and run `go test ./...` with race
condition checking enabled. If something needs to be changed, fix it and then
run the script again until it no longer complains. Be warned that the linting
can take a significant amount of CPU and RAM.
CircleCI simply runs [Sytest](https://github.com/matrix-org/sytest) with a test
whitelist. See
[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md#using-a-sytest-docker-image)
for instructions on setting it up to run locally.
## Picking Things To Do ## Picking Things To Do

View file

@ -35,10 +35,10 @@ cd dendrite
If using Kafka, install and start it (c.f. [scripts/install-local-kafka.sh](scripts/install-local-kafka.sh)): If using Kafka, install and start it (c.f. [scripts/install-local-kafka.sh](scripts/install-local-kafka.sh)):
```bash ```bash
MIRROR=http://apache.mirror.anlx.net/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz KAFKA_URL=http://archive.apache.org/dist/kafka/2.1.0/kafka_2.11-2.1.0.tgz
# Only download the kafka if it isn't already downloaded. # Only download the kafka if it isn't already downloaded.
test -f kafka.tgz || wget $MIRROR -O kafka.tgz test -f kafka.tgz || wget $KAFKA_URL -O kafka.tgz
# Unpack the kafka over the top of any existing installation # Unpack the kafka over the top of any existing installation
mkdir -p kafka && tar xzf kafka.tgz -C kafka --strip-components 1 mkdir -p kafka && tar xzf kafka.tgz -C kafka --strip-components 1

View file

@ -1,4 +1,4 @@
# Dendrite [![Build Status](https://badge.buildkite.com/4be40938ab19f2bbc4a6c6724517353ee3ec1422e279faf374.svg)](https://buildkite.com/matrix-dot-org/dendrite) [![CircleCI](https://circleci.com/gh/matrix-org/dendrite.svg?style=svg)](https://circleci.com/gh/matrix-org/dendrite) [![Dendrite Dev on Matrix](https://img.shields.io/matrix/dendrite-dev:matrix.org.svg?label=%23dendrite-dev%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite-dev:matrix.org) [![Dendrite on Matrix](https://img.shields.io/matrix/dendrite:matrix.org.svg?label=%23dendrite%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite:matrix.org) # Dendrite [![Build Status](https://badge.buildkite.com/4be40938ab19f2bbc4a6c6724517353ee3ec1422e279faf374.svg?branch=master)](https://buildkite.com/matrix-dot-org/dendrite) [![CircleCI](https://circleci.com/gh/matrix-org/dendrite.svg?style=svg)](https://circleci.com/gh/matrix-org/dendrite) [![Dendrite Dev on Matrix](https://img.shields.io/matrix/dendrite-dev:matrix.org.svg?label=%23dendrite-dev%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite-dev:matrix.org) [![Dendrite on Matrix](https://img.shields.io/matrix/dendrite:matrix.org.svg?label=%23dendrite%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite:matrix.org)
Dendrite will be a matrix homeserver written in go. Dendrite will be a matrix homeserver written in go.

View file

@ -20,13 +20,13 @@ package api
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/common"
commonHTTP "github.com/matrix-org/dendrite/common/http" commonHTTP "github.com/matrix-org/dendrite/common/http"
opentracing "github.com/opentracing/opentracing-go" opentracing "github.com/opentracing/opentracing-go"
) )
@ -134,9 +134,9 @@ func (h *httpAppServiceQueryAPI) UserIDExists(
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// RetreiveUserProfile is a wrapper that queries both the local database and // RetrieveUserProfile is a wrapper that queries both the local database and
// application services for a given user's profile // application services for a given user's profile
func RetreiveUserProfile( func RetrieveUserProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
asAPI AppServiceQueryAPI, asAPI AppServiceQueryAPI,
@ -164,7 +164,7 @@ func RetreiveUserProfile(
// If no user exists, return // If no user exists, return
if !userResp.UserIDExists { if !userResp.UserIDExists {
return nil, errors.New("no known profile for given user ID") return nil, common.ErrProfileNoExists
} }
// Try to query the user from the local database again // Try to query the user from the local database again

View file

@ -21,5 +21,9 @@ type Device struct {
// The access_token granted to this device. // The access_token granted to this device.
// This uniquely identifies the device from all other devices and clients. // This uniquely identifies the device from all other devices and clients.
AccessToken string AccessToken string
// The unique ID of the session identified by the access token.
// Can be used as a secure substitution in places where data needs to be
// associated with access tokens.
SessionID int64
// TODO: display name, last used timestamp, keys, etc // TODO: display name, last used timestamp, keys, etc
} }

View file

@ -14,7 +14,7 @@
package authtypes package authtypes
// Profile represents the profile for a Matrix account on this home server. // Profile represents the profile for a Matrix account.
type Profile struct { type Profile struct {
Localpart string Localpart string
DisplayName string DisplayName string

View file

@ -17,6 +17,7 @@ package accounts
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -71,25 +72,44 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
func (s *filterStatements) selectFilter( func (s *filterStatements) selectFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, localpart string, filterID string,
) (filter []byte, err error) { ) (*gomatrixserverlib.Filter, error) {
err = s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filter) // Retrieve filter from database (stored as canonical JSON)
return var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil {
return nil, err
}
// Unmarshal JSON into Filter struct
var filter gomatrixserverlib.Filter
if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err
}
return &filter, nil
} }
func (s *filterStatements) insertFilter( func (s *filterStatements) insertFilter(
ctx context.Context, filter []byte, localpart string, ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) { ) (filterID string, err error) {
var existingFilterID string var existingFilterID string
// This can result in a race condition when two clients try to insert the // Serialise json
// same filter and localpart at the same time, however this is not a filterJSON, err := json.Marshal(filter)
// problem as both calls will result in the same filterID if err != nil {
filterJSON, err := gomatrixserverlib.CanonicalJSON(filter) return "", err
}
// Remove whitespaces and sort JSON data
// needed to prevent from inserting the same filter multiple times
filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
if err != nil { if err != nil {
return "", err return "", err
} }
// Check if filter already exists in the database // Check if filter already exists in the database using its localpart and content
//
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID) localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {

View file

@ -230,7 +230,7 @@ func (d *Database) newMembership(
} }
// Only "join" membership events can be considered as new memberships // Only "join" membership events can be considered as new memberships
if membership == "join" { if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil { if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err return err
} }
@ -344,11 +344,11 @@ func (d *Database) GetThreePIDsForLocalpart(
} }
// GetFilter looks up the filter associated with a given local user and filter ID. // GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter represented as a byte slice. Otherwise returns an error if // Returns a filter structure. Otherwise returns an error if no such filter exists
// no such filter exists or if there was an error talking to the database. // or if there was an error talking to the database.
func (d *Database) GetFilter( func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, localpart string, filterID string,
) ([]byte, error) { ) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID) return d.filter.selectFilter(ctx, localpart, filterID)
} }
@ -356,7 +356,7 @@ func (d *Database) GetFilter(
// Returns the filterID as a string. Otherwise returns an error if something // Returns the filterID as a string. Otherwise returns an error if something
// goes wrong. // goes wrong.
func (d *Database) PutFilter( func (d *Database) PutFilter(
ctx context.Context, localpart string, filter []byte, ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) { ) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart) return d.filter.insertFilter(ctx, filter, localpart)
} }

View file

@ -27,11 +27,19 @@ import (
) )
const devicesSchema = ` const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices. -- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices ( CREATE TABLE IF NOT EXISTS device_devices (
-- The access token granted to this device. This has to be the primary key -- The access token granted to this device. This has to be the primary key
-- so we can distinguish which device is making a given request. -- so we can distinguish which device is making a given request.
access_token TEXT NOT NULL PRIMARY KEY, access_token TEXT NOT NULL PRIMARY KEY,
-- The auto-allocated unique ID of the session identified by the access token.
-- This can be used as a secure substitution of the access token in situations
-- where data is associated with access tokens (e.g. transaction storage),
-- so we don't have to store users' access tokens everywhere.
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'),
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally. -- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
-- access_tokens will be clobbered based on the device ID for a user. -- access_tokens will be clobbered based on the device ID for a user.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
@ -51,10 +59,11 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca
` `
const insertDeviceSQL = "" + const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" + const selectDeviceByTokenSQL = "" +
"SELECT device_id, localpart FROM device_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
@ -120,14 +129,16 @@ func (s *devicesStatements) insertDevice(
displayName *string, displayName *string,
) (*authtypes.Device, error) { ) (*authtypes.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := common.TxStmt(txn, s.insertDeviceStmt) stmt := common.TxStmt(txn, s.insertDeviceStmt)
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil { if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
return nil, err return nil, err
} }
return &authtypes.Device{ return &authtypes.Device{
ID: id, ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken, AccessToken: accessToken,
SessionID: sessionID,
}, nil }, nil
} }
@ -161,7 +172,7 @@ func (s *devicesStatements) selectDeviceByToken(
var dev authtypes.Device var dev authtypes.Device
var localpart string var localpart string
stmt := s.selectDeviceByTokenStmt stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart) err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
if err == nil { if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.AccessToken = accessToken dev.AccessToken = accessToken
@ -169,6 +180,8 @@ func (s *devicesStatements) selectDeviceByToken(
return &dev, err return &dev, err
} }
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) { ) (*authtypes.Device, error) {

View file

@ -84,7 +84,7 @@ func (d *Database) CreateDevice(
if deviceID != nil { if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
// Revoke existing token for this device // Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err return err
} }

View file

@ -33,13 +33,6 @@ func SaveAccountData(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != http.MethodPut {
return util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad method"),
}
}
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,

View file

@ -0,0 +1,210 @@
// Copyright 2019 Parminder Singh <parmsingh129@gmail.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"html/template"
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/util"
)
// recaptchaTemplate is an HTML webpage template for recaptcha auth
const recaptchaTemplate = `
<html>
<head>
<title>Authentication</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<script src="https://www.google.com/recaptcha/api.js"
async defer></script>
<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
<script>
function captchaDone() {
$('#registrationForm').submit();
}
</script>
</head>
<body>
<form id="registrationForm" method="post" action="{{.myUrl}}">
<div>
<p>
Hello! We need to prevent computer programs and other automated
things from creating accounts on this server.
</p>
<p>
Please verify that you're not a robot.
</p>
<input type="hidden" name="session" value="{{.session}}" />
<div class="g-recaptcha"
data-sitekey="{{.siteKey}}"
data-callback="captchaDone">
</div>
<noscript>
<input type="submit" value="All Done" />
</noscript>
</div>
</div>
</form>
</body>
</html>
`
// successTemplate is an HTML template presented to the user after successful
// recaptcha completion
const successTemplate = `
<html>
<head>
<title>Success!</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<script>
if (window.onAuthDone) {
window.onAuthDone();
} else if (window.opener && window.opener.postMessage) {
window.opener.postMessage("authDone", "*");
}
</script>
</head>
<body>
<div>
<p>Thank you!</p>
<p>You may now close this window and return to the application.</p>
</div>
</body>
</html>
`
// serveTemplate fills template data and serves it using http.ResponseWriter
func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]string) {
t := template.Must(template.New("response").Parse(templateHTML))
if err := t.Execute(w, data); err != nil {
panic(err)
}
}
// AuthFallback implements GET and POST /auth/{authType}/fallback/web?session={sessionID}
func AuthFallback(
w http.ResponseWriter, req *http.Request, authType string,
cfg config.Dendrite,
) *util.JSONResponse {
sessionID := req.URL.Query().Get("session")
if sessionID == "" {
return writeHTTPMessage(w, req,
"Session ID not provided",
http.StatusBadRequest,
)
}
serveRecaptcha := func() {
data := map[string]string{
"myUrl": req.URL.String(),
"session": sessionID,
"siteKey": cfg.Matrix.RecaptchaPublicKey,
}
serveTemplate(w, recaptchaTemplate, data)
}
serveSuccess := func() {
data := map[string]string{}
serveTemplate(w, successTemplate, data)
}
if req.Method == http.MethodGet {
// Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha {
if err := checkRecaptchaEnabled(&cfg, w, req); err != nil {
return err
}
serveRecaptcha()
return nil
}
return &util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Unknown auth stage type"),
}
} else if req.Method == http.MethodPost {
// Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha {
if err := checkRecaptchaEnabled(&cfg, w, req); err != nil {
return err
}
clientIP := req.RemoteAddr
err := req.ParseForm()
if err != nil {
res := httputil.LogThenError(req, err)
return &res
}
response := req.Form.Get("g-recaptcha-response")
if err := validateRecaptcha(&cfg, response, clientIP); err != nil {
util.GetLogger(req.Context()).Error(err)
return err
}
// Success. Add recaptcha as a completed login flow
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
serveSuccess()
return nil
}
return &util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Unknown auth stage type"),
}
}
return &util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad method"),
}
}
// checkRecaptchaEnabled creates an error response if recaptcha is not usable on homeserver.
func checkRecaptchaEnabled(
cfg *config.Dendrite,
w http.ResponseWriter,
req *http.Request,
) *util.JSONResponse {
if !cfg.Matrix.RecaptchaEnabled {
return writeHTTPMessage(w, req,
"Recaptcha login is disabled on this Homeserver",
http.StatusBadRequest,
)
}
return nil
}
// writeHTTPMessage writes the given header and message to the HTTP response writer.
// Returns an error JSONResponse obtained through httputil.LogThenError if the writing failed, otherwise nil.
func writeHTTPMessage(
w http.ResponseWriter, req *http.Request,
message string, header int,
) *util.JSONResponse {
w.WriteHeader(header)
_, err := w.Write([]byte(message))
if err != nil {
res := httputil.LogThenError(req, err)
return &res
}
return nil
}

View file

@ -15,6 +15,7 @@
package routing package routing
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@ -54,10 +55,6 @@ const (
presetPublicChat = "public_chat" presetPublicChat = "public_chat"
) )
const (
joinRulePublic = "public"
joinRuleInvite = "invite"
)
const ( const (
historyVisibilityShared = "shared" historyVisibilityShared = "shared"
// TODO: These should be implemented once history visibility is implemented // TODO: These should be implemented once history visibility is implemented
@ -97,6 +94,27 @@ func (r createRoomRequest) Validate() *util.JSONResponse {
} }
} }
// Validate creation_content fields defined in the spec by marshalling the
// creation_content map into bytes and then unmarshalling the bytes into
// common.CreateContent.
creationContentBytes, err := json.Marshal(r.CreationContent)
if err != nil {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("malformed creation_content"),
}
}
var CreationContent gomatrixserverlib.CreateContent
err = json.Unmarshal(creationContentBytes, &CreationContent)
if err != nil {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("malformed creation_content"),
}
}
return nil return nil
} }
@ -154,7 +172,17 @@ func createRoom(
JSON: jsonerror.InvalidArgumentValue(err.Error()), JSON: jsonerror.InvalidArgumentValue(err.Error()),
} }
} }
// TODO: visibility/presets/raw initial state/creation content
// Clobber keys: creator, room_version
if r.CreationContent == nil {
r.CreationContent = make(map[string]interface{}, 2)
}
r.CreationContent["creator"] = userID
r.CreationContent["room_version"] = "1" // TODO: We set this to 1 before we support Room versioning
// TODO: visibility/presets/raw initial state
// TODO: Create room alias association // TODO: Create room alias association
// Make sure this doesn't fall into an application service's namespace though! // Make sure this doesn't fall into an application service's namespace though!
@ -163,13 +191,13 @@ func createRoom(
"roomID": roomID, "roomID": roomID,
}).Info("Creating new room") }).Info("Creating new room")
profile, err := appserviceAPI.RetreiveUserProfile(req.Context(), userID, asAPI, accountDB) profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
membershipContent := common.MemberContent{ membershipContent := gomatrixserverlib.MemberContent{
Membership: "join", Membership: gomatrixserverlib.Join,
DisplayName: profile.DisplayName, DisplayName: profile.DisplayName,
AvatarURL: profile.AvatarURL, AvatarURL: profile.AvatarURL,
} }
@ -177,19 +205,19 @@ func createRoom(
var joinRules, historyVisibility string var joinRules, historyVisibility string
switch r.Preset { switch r.Preset {
case presetPrivateChat: case presetPrivateChat:
joinRules = joinRuleInvite joinRules = gomatrixserverlib.Invite
historyVisibility = historyVisibilityShared historyVisibility = historyVisibilityShared
case presetTrustedPrivateChat: case presetTrustedPrivateChat:
joinRules = joinRuleInvite joinRules = gomatrixserverlib.Invite
historyVisibility = historyVisibilityShared historyVisibility = historyVisibilityShared
// TODO If trusted_private_chat, all invitees are given the same power level as the room creator. // TODO If trusted_private_chat, all invitees are given the same power level as the room creator.
case presetPublicChat: case presetPublicChat:
joinRules = joinRulePublic joinRules = gomatrixserverlib.Public
historyVisibility = historyVisibilityShared historyVisibility = historyVisibilityShared
default: default:
// Default room rules, r.Preset was previously checked for valid values so // Default room rules, r.Preset was previously checked for valid values so
// only a request with no preset should end up here. // only a request with no preset should end up here.
joinRules = joinRuleInvite joinRules = gomatrixserverlib.Invite
historyVisibility = historyVisibilityShared historyVisibility = historyVisibilityShared
} }
@ -214,11 +242,11 @@ func createRoom(
// harder to reason about, hence sticking to a strict static ordering. // harder to reason about, hence sticking to a strict static ordering.
// TODO: Synapse has txn/token ID on each event. Do we need to do this here? // TODO: Synapse has txn/token ID on each event. Do we need to do this here?
eventsToMake := []fledglingEvent{ eventsToMake := []fledglingEvent{
{"m.room.create", "", common.CreateContent{Creator: userID}}, {"m.room.create", "", r.CreationContent},
{"m.room.member", userID, membershipContent}, {"m.room.member", userID, membershipContent},
{"m.room.power_levels", "", common.InitialPowerLevelsContent(userID)}, {"m.room.power_levels", "", common.InitialPowerLevelsContent(userID)},
// TODO: m.room.canonical_alias // TODO: m.room.canonical_alias
{"m.room.join_rules", "", common.JoinRulesContent{JoinRule: joinRules}}, {"m.room.join_rules", "", gomatrixserverlib.JoinRuleContent{JoinRule: joinRules}},
{"m.room.history_visibility", "", common.HistoryVisibilityContent{HistoryVisibility: historyVisibility}}, {"m.room.history_visibility", "", common.HistoryVisibilityContent{HistoryVisibility: historyVisibility}},
} }
if r.GuestCanJoin { if r.GuestCanJoin {

View file

@ -106,13 +106,6 @@ func UpdateDeviceByID(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
deviceID string, deviceID string,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != http.MethodPut {
return util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad Method"),
}
}
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)

View file

@ -117,6 +117,9 @@ func SetLocalAlias(
// 1. The new method for checking for things matching an AS's namespace // 1. The new method for checking for things matching an AS's namespace
// 2. Using an overall Regex object for all AS's just like we did for usernames // 2. Using an overall Regex object for all AS's just like we did for usernames
for _, appservice := range cfg.Derived.ApplicationServices { for _, appservice := range cfg.Derived.ApplicationServices {
// Don't prevent AS from creating aliases in its own namespace
// Note that Dendrite uses SenderLocalpart as UserID for AS users
if device.UserID != appservice.SenderLocalpart {
if aliasNamespaces, ok := appservice.NamespaceMap["aliases"]; ok { if aliasNamespaces, ok := appservice.NamespaceMap["aliases"]; ok {
for _, namespace := range aliasNamespaces { for _, namespace := range aliasNamespaces {
if namespace.Exclusive && namespace.RegexpObject.MatchString(alias) { if namespace.Exclusive && namespace.RegexpObject.MatchString(alias) {
@ -128,6 +131,7 @@ func SetLocalAlias(
} }
} }
} }
}
var r struct { var r struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
@ -160,13 +164,36 @@ func SetLocalAlias(
} }
// RemoveLocalAlias implements DELETE /directory/room/{roomAlias} // RemoveLocalAlias implements DELETE /directory/room/{roomAlias}
// TODO: Check if the user has the power level to remove an alias
func RemoveLocalAlias( func RemoveLocalAlias(
req *http.Request, req *http.Request,
device *authtypes.Device, device *authtypes.Device,
alias string, alias string,
aliasAPI roomserverAPI.RoomserverAliasAPI, aliasAPI roomserverAPI.RoomserverAliasAPI,
) util.JSONResponse { ) util.JSONResponse {
creatorQueryReq := roomserverAPI.GetCreatorIDForAliasRequest{
Alias: alias,
}
var creatorQueryRes roomserverAPI.GetCreatorIDForAliasResponse
if err := aliasAPI.GetCreatorIDForAlias(req.Context(), &creatorQueryReq, &creatorQueryRes); err != nil {
return httputil.LogThenError(req, err)
}
if creatorQueryRes.UserID == "" {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Alias does not exist"),
}
}
if creatorQueryRes.UserID != device.UserID {
// TODO: Still allow deletion if user is admin
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("You do not have permission to delete this alias"),
}
}
queryReq := roomserverAPI.RemoveRoomAliasRequest{ queryReq := roomserverAPI.RemoveRoomAliasRequest{
Alias: alias, Alias: alias,
UserID: device.UserID, UserID: device.UserID,

View file

@ -17,13 +17,10 @@ package routing
import ( import (
"net/http" "net/http"
"encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -32,12 +29,6 @@ import (
func GetFilter( func GetFilter(
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string, req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != http.MethodGet {
return util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad method"),
}
}
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -49,7 +40,7 @@ func GetFilter(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
res, err := accountDB.GetFilter(req.Context(), localpart, filterID) filter, err := accountDB.GetFilter(req.Context(), localpart, filterID)
if err != nil { if err != nil {
//TODO better error handling. This error message is *probably* right, //TODO better error handling. This error message is *probably* right,
// but if there are obscure db errors, this will also be returned, // but if there are obscure db errors, this will also be returned,
@ -59,11 +50,6 @@ func GetFilter(
JSON: jsonerror.NotFound("No such filter"), JSON: jsonerror.NotFound("No such filter"),
} }
} }
filter := gomatrix.Filter{}
err = json.Unmarshal(res, &filter)
if err != nil {
return httputil.LogThenError(req, err)
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -79,12 +65,6 @@ type filterResponse struct {
func PutFilter( func PutFilter(
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != http.MethodPost {
return util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad method"),
}
}
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -97,21 +77,21 @@ func PutFilter(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
var filter gomatrix.Filter var filter gomatrixserverlib.Filter
if reqErr := httputil.UnmarshalJSONRequest(req, &filter); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &filter); reqErr != nil {
return *reqErr return *reqErr
} }
filterArray, err := json.Marshal(filter) // Validate generates a user-friendly error
if err != nil { if err = filter.Validate(); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Filter is malformed"), JSON: jsonerror.BadJSON("Invalid filter: " + err.Error()),
} }
} }
filterID, err := accountDB.PutFilter(req.Context(), localpart, filterArray) filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -0,0 +1,127 @@
// Copyright 2019 Alex Chen
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type getEventRequest struct {
req *http.Request
device *authtypes.Device
roomID string
eventID string
cfg config.Dendrite
federation *gomatrixserverlib.FederationClient
keyRing gomatrixserverlib.KeyRing
requestedEvent gomatrixserverlib.Event
}
// GetEvent implements GET /_matrix/client/r0/rooms/{roomId}/event/{eventId}
// https://matrix.org/docs/spec/client_server/r0.4.0.html#get-matrix-client-r0-rooms-roomid-event-eventid
func GetEvent(
req *http.Request,
device *authtypes.Device,
roomID string,
eventID string,
cfg config.Dendrite,
queryAPI api.RoomserverQueryAPI,
federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.KeyRing,
) util.JSONResponse {
eventsReq := api.QueryEventsByIDRequest{
EventIDs: []string{eventID},
}
var eventsResp api.QueryEventsByIDResponse
err := queryAPI.QueryEventsByID(req.Context(), &eventsReq, &eventsResp)
if err != nil {
return httputil.LogThenError(req, err)
}
if len(eventsResp.Events) == 0 {
// Event not found locally
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"),
}
}
requestedEvent := eventsResp.Events[0]
r := getEventRequest{
req: req,
device: device,
roomID: roomID,
eventID: eventID,
cfg: cfg,
federation: federation,
keyRing: keyRing,
requestedEvent: requestedEvent,
}
stateReq := api.QueryStateAfterEventsRequest{
RoomID: r.requestedEvent.RoomID(),
PrevEventIDs: r.requestedEvent.PrevEventIDs(),
StateToFetch: []gomatrixserverlib.StateKeyTuple{{
EventType: gomatrixserverlib.MRoomMember,
StateKey: device.UserID,
}},
}
var stateResp api.QueryStateAfterEventsResponse
if err := queryAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil {
return httputil.LogThenError(req, err)
}
if !stateResp.RoomExists {
util.GetLogger(req.Context()).Errorf("Expected to find room for event %s but failed", r.requestedEvent.EventID())
return jsonerror.InternalServerError()
}
if !stateResp.PrevEventsExist {
// Missing some events locally; stateResp.StateEvents unavailable.
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"),
}
}
for _, stateEvent := range stateResp.StateEvents {
if stateEvent.StateKeyEquals(r.device.UserID) {
membership, err := stateEvent.Membership()
if err != nil {
return httputil.LogThenError(req, err)
}
if membership == gomatrixserverlib.Join {
return util.JSONResponse{
Code: http.StatusOK,
JSON: gomatrixserverlib.ToClientEvent(r.requestedEvent, gomatrixserverlib.FormatAll),
}
}
}
}
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"),
}
}

View file

@ -70,7 +70,7 @@ func JoinRoomByIDOrAlias(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
content["membership"] = "join" content["membership"] = gomatrixserverlib.Join
content["displayname"] = profile.DisplayName content["displayname"] = profile.DisplayName
content["avatar_url"] = profile.AvatarURL content["avatar_url"] = profile.AvatarURL

View file

@ -18,7 +18,6 @@ import (
"net/http" "net/http"
"context" "context"
"database/sql"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -44,8 +43,10 @@ type flow struct {
type passwordRequest struct { type passwordRequest struct {
User string `json:"user"` User string `json:"user"`
Password string `json:"password"` Password string `json:"password"`
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
// Thus a pointer is needed to differentiate between the two
InitialDisplayName *string `json:"initial_device_display_name"` InitialDisplayName *string `json:"initial_device_display_name"`
DeviceID string `json:"device_id"` DeviceID *string `json:"device_id"`
} }
type loginResponse struct { type loginResponse struct {
@ -110,7 +111,7 @@ func Login(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
dev, err := getDevice(req.Context(), r, deviceDB, acc, localpart, token) dev, err := getDevice(req.Context(), r, deviceDB, acc, token)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -134,20 +135,16 @@ func Login(
} }
} }
// check if device exists else create one // getDevice returns a new or existing device
func getDevice( func getDevice(
ctx context.Context, ctx context.Context,
r passwordRequest, r passwordRequest,
deviceDB *devices.Database, deviceDB *devices.Database,
acc *authtypes.Account, acc *authtypes.Account,
localpart, token string, token string,
) (dev *authtypes.Device, err error) { ) (dev *authtypes.Device, err error) {
dev, err = deviceDB.GetDeviceByID(ctx, localpart, r.DeviceID)
if err == sql.ErrNoRows {
// device doesn't exist, create one
dev, err = deviceDB.CreateDevice( dev, err = deviceDB.CreateDevice(
ctx, acc.Localpart, nil, token, r.InitialDisplayName, ctx, acc.Localpart, r.DeviceID, token, r.InitialDisplayName,
) )
}
return return
} }

View file

@ -20,7 +20,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -29,13 +28,6 @@ import (
func Logout( func Logout(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != http.MethodPost {
return util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad method"),
}
}
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)

View file

@ -58,27 +58,12 @@ func SendMembership(
} }
} }
inviteStored, err := threepid.CheckAndProcessInvite( inviteStored, jsonErrResp := checkAndProcessThreepid(
req.Context(), device, &body, cfg, queryAPI, accountDB, producer, req, device, &body, cfg, queryAPI, accountDB, producer,
membership, roomID, evTime, membership, roomID, evTime,
) )
if err == threepid.ErrMissingParameter { if jsonErrResp != nil {
return util.JSONResponse{ return *jsonErrResp
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
} else if err == threepid.ErrNotTrusted {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.NotTrusted(body.IDServer),
}
} else if err == common.ErrRoomNoExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(err.Error()),
}
} else if err != nil {
return httputil.LogThenError(req, err)
} }
// If an invite has been stored on an identity server, it means that a // If an invite has been stored on an identity server, it means that a
@ -114,9 +99,18 @@ func SendMembership(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
var returnData interface{} = struct{}{}
// The join membership requires the room id to be sent in the response
if membership == gomatrixserverlib.Join {
returnData = struct {
RoomID string `json:"room_id"`
}{roomID}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: struct{}{}, JSON: returnData,
} }
} }
@ -147,10 +141,10 @@ func buildMembershipEvent(
// "unban" or "kick" isn't a valid membership value, change it to "leave" // "unban" or "kick" isn't a valid membership value, change it to "leave"
if membership == "unban" || membership == "kick" { if membership == "unban" || membership == "kick" {
membership = "leave" membership = gomatrixserverlib.Leave
} }
content := common.MemberContent{ content := gomatrixserverlib.MemberContent{
Membership: membership, Membership: membership,
DisplayName: profile.DisplayName, DisplayName: profile.DisplayName,
AvatarURL: profile.AvatarURL, AvatarURL: profile.AvatarURL,
@ -182,7 +176,7 @@ func loadProfile(
var profile *authtypes.Profile var profile *authtypes.Profile
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
profile, err = appserviceAPI.RetreiveUserProfile(ctx, userID, asAPI, accountDB) profile, err = appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB)
} else { } else {
profile = &authtypes.Profile{} profile = &authtypes.Profile{}
} }
@ -198,7 +192,7 @@ func loadProfile(
func getMembershipStateKey( func getMembershipStateKey(
body threepid.MembershipRequest, device *authtypes.Device, membership string, body threepid.MembershipRequest, device *authtypes.Device, membership string,
) (stateKey string, reason string, err error) { ) (stateKey string, reason string, err error) {
if membership == "ban" || membership == "unban" || membership == "kick" || membership == "invite" { if membership == gomatrixserverlib.Ban || membership == "unban" || membership == "kick" || membership == gomatrixserverlib.Invite {
// If we're in this case, the state key is contained in the request body, // If we're in this case, the state key is contained in the request body,
// possibly along with a reason (for "kick" and "ban") so we need to parse // possibly along with a reason (for "kick" and "ban") so we need to parse
// it // it
@ -215,3 +209,41 @@ func getMembershipStateKey(
return return
} }
func checkAndProcessThreepid(
req *http.Request,
device *authtypes.Device,
body *threepid.MembershipRequest,
cfg config.Dendrite,
queryAPI roomserverAPI.RoomserverQueryAPI,
accountDB *accounts.Database,
producer *producers.RoomserverProducer,
membership, roomID string,
evTime time.Time,
) (inviteStored bool, errRes *util.JSONResponse) {
inviteStored, err := threepid.CheckAndProcessInvite(
req.Context(), device, body, cfg, queryAPI, accountDB, producer,
membership, roomID, evTime,
)
if err == threepid.ErrMissingParameter {
return inviteStored, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
} else if err == threepid.ErrNotTrusted {
return inviteStored, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.NotTrusted(body.IDServer),
}
} else if err == common.ErrRoomNoExists {
return inviteStored, &util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(err.Error()),
}
} else if err != nil {
er := httputil.LogThenError(req, err)
return inviteStored, &er
}
return
}

View file

@ -30,49 +30,61 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// GetProfile implements GET /profile/{userID} // GetProfile implements GET /profile/{userID}
func GetProfile( func GetProfile(
req *http.Request, accountDB *accounts.Database, userID string, asAPI appserviceAPI.AppServiceQueryAPI, req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != http.MethodGet { profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation)
return util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad method"),
}
}
profile, err := appserviceAPI.RetreiveUserProfile(req.Context(), userID, asAPI, accountDB)
if err != nil { if err != nil {
if err == common.ErrProfileNoExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("The user does not exist or does not have a profile"),
}
}
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
res := common.ProfileResponse{
AvatarURL: profile.AvatarURL,
DisplayName: profile.DisplayName,
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: res, JSON: common.ProfileResponse{
AvatarURL: profile.AvatarURL,
DisplayName: profile.DisplayName,
},
} }
} }
// GetAvatarURL implements GET /profile/{userID}/avatar_url // GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL( func GetAvatarURL(
req *http.Request, accountDB *accounts.Database, userID string, asAPI appserviceAPI.AppServiceQueryAPI, req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
profile, err := appserviceAPI.RetreiveUserProfile(req.Context(), userID, asAPI, accountDB) profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation)
if err != nil { if err != nil {
if err == common.ErrProfileNoExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("The user does not exist or does not have a profile"),
}
}
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
res := common.AvatarURL{
AvatarURL: profile.AvatarURL,
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: res, JSON: common.AvatarURL{
AvatarURL: profile.AvatarURL,
},
} }
} }
@ -158,18 +170,27 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname // GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName( func GetDisplayName(
req *http.Request, accountDB *accounts.Database, userID string, asAPI appserviceAPI.AppServiceQueryAPI, req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
profile, err := appserviceAPI.RetreiveUserProfile(req.Context(), userID, asAPI, accountDB) profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation)
if err != nil { if err != nil {
if err == common.ErrProfileNoExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("The user does not exist or does not have a profile"),
}
}
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
res := common.DisplayName{
DisplayName: profile.DisplayName,
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: res, JSON: common.DisplayName{
DisplayName: profile.DisplayName,
},
} }
} }
@ -253,6 +274,48 @@ func SetDisplayName(
} }
} }
// getProfile gets the full profile of a user by querying the database or a
// remote homeserver.
// Returns an error when something goes wrong or specifically
// common.ErrProfileNoExists when the profile doesn't exist.
func getProfile(
ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) (*authtypes.Profile, error) {
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return nil, err
}
if domain != cfg.Matrix.ServerName {
profile, fedErr := federation.LookupProfile(ctx, domain, userID, "")
if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound {
return nil, common.ErrProfileNoExists
}
}
return nil, fedErr
}
return &authtypes.Profile{
Localpart: localpart,
DisplayName: profile.DisplayName,
AvatarURL: profile.AvatarURL,
}, nil
}
profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB)
if err != nil {
return nil, err
}
return profile, nil
}
func buildMembershipEvents( func buildMembershipEvents(
ctx context.Context, ctx context.Context,
memberships []authtypes.Membership, memberships []authtypes.Membership,
@ -269,8 +332,8 @@ func buildMembershipEvents(
StateKey: &userID, StateKey: &userID,
} }
content := common.MemberContent{ content := gomatrixserverlib.MemberContent{
Membership: "join", Membership: gomatrixserverlib.Join,
} }
content.DisplayName = newProfile.DisplayName content.DisplayName = newProfile.DisplayName

View file

@ -29,6 +29,7 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
@ -70,12 +71,17 @@ func init() {
} }
// sessionsDict keeps track of completed auth stages for each session. // sessionsDict keeps track of completed auth stages for each session.
// It shouldn't be passed by value because it contains a mutex.
type sessionsDict struct { type sessionsDict struct {
sync.Mutex
sessions map[string][]authtypes.LoginType sessions map[string][]authtypes.LoginType
} }
// GetCompletedStages returns the completed stages for a session. // GetCompletedStages returns the completed stages for a session.
func (d sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType {
d.Lock()
defer d.Unlock()
if completedStages, ok := d.sessions[sessionID]; ok { if completedStages, ok := d.sessions[sessionID]; ok {
return completedStages return completedStages
} }
@ -83,17 +89,25 @@ func (d sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType
return make([]authtypes.LoginType, 0) return make([]authtypes.LoginType, 0)
} }
// AddCompletedStage records that a session has completed an auth stage.
func (d *sessionsDict) AddCompletedStage(sessionID string, stage authtypes.LoginType) {
d.sessions[sessionID] = append(d.GetCompletedStages(sessionID), stage)
}
func newSessionsDict() *sessionsDict { func newSessionsDict() *sessionsDict {
return &sessionsDict{ return &sessionsDict{
sessions: make(map[string][]authtypes.LoginType), sessions: make(map[string][]authtypes.LoginType),
} }
} }
// AddCompletedSessionStage records that a session has completed an auth stage.
func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
sessions.Lock()
defer sessions.Unlock()
for _, completedStage := range sessions.sessions[sessionID] {
if completedStage == stage {
return
}
}
sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
}
var ( var (
// TODO: Remove old sessions. Need to do so on a session-specific timeout. // TODO: Remove old sessions. Need to do so on a session-specific timeout.
// sessions stores the completed flow stages for all sessions. Referenced using their sessionID. // sessions stores the completed flow stages for all sessions. Referenced using their sessionID.
@ -115,7 +129,10 @@ type registerRequest struct {
// user-interactive auth params // user-interactive auth params
Auth authDict `json:"auth"` Auth authDict `json:"auth"`
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
// Thus a pointer is needed to differentiate between the two
InitialDisplayName *string `json:"initial_device_display_name"` InitialDisplayName *string `json:"initial_device_display_name"`
DeviceID *string `json:"device_id"`
// Prevent this user from logging in // Prevent this user from logging in
InhibitLogin common.WeakBoolean `json:"inhibit_login"` InhibitLogin common.WeakBoolean `json:"inhibit_login"`
@ -521,7 +538,7 @@ func handleRegistrationFlow(
} }
// Add Recaptcha to the list of completed registration stages // Add Recaptcha to the list of completed registration stages
sessions.AddCompletedStage(sessionID, authtypes.LoginTypeRecaptcha) AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
case authtypes.LoginTypeSharedSecret: case authtypes.LoginTypeSharedSecret:
// Check shared secret against config // Check shared secret against config
@ -534,7 +551,7 @@ func handleRegistrationFlow(
} }
// Add SharedSecret to the list of completed registration stages // Add SharedSecret to the list of completed registration stages
sessions.AddCompletedStage(sessionID, authtypes.LoginTypeSharedSecret) AddCompletedSessionStage(sessionID, authtypes.LoginTypeSharedSecret)
case "": case "":
// Extract the access token from the request, if there's one to extract // Extract the access token from the request, if there's one to extract
@ -564,7 +581,7 @@ func handleRegistrationFlow(
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
// Add Dummy to the list of completed registration stages // Add Dummy to the list of completed registration stages
sessions.AddCompletedStage(sessionID, authtypes.LoginTypeDummy) AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
default: default:
return util.JSONResponse{ return util.JSONResponse{
@ -620,7 +637,7 @@ func handleApplicationServiceRegistration(
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), accountDB, deviceDB, r.Username, "", appserviceID, req.Context(), accountDB, deviceDB, r.Username, "", appserviceID,
r.InhibitLogin, r.InitialDisplayName, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -640,7 +657,7 @@ func checkAndCompleteFlow(
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), accountDB, deviceDB, r.Username, r.Password, "", req.Context(), accountDB, deviceDB, r.Username, r.Password, "",
r.InhibitLogin, r.InitialDisplayName, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -691,10 +708,10 @@ func LegacyRegister(
return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") return util.MessageResponse(http.StatusForbidden, "HMAC incorrect")
} }
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
default: default:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotImplemented, Code: http.StatusNotImplemented,
@ -732,13 +749,19 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
return nil return nil
} }
// completeRegistration runs some rudimentary checks against the submitted
// input, then if successful creates an account and a newly associated device
// We pass in each individual part of the request here instead of just passing a
// registerRequest, as this function serves requests encoded as both
// registerRequests and legacyRegisterRequests, which share some attributes but
// not all
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
accountDB *accounts.Database, accountDB *accounts.Database,
deviceDB *devices.Database, deviceDB *devices.Database,
username, password, appserviceID string, username, password, appserviceID string,
inhibitLogin common.WeakBoolean, inhibitLogin common.WeakBoolean,
displayName *string, displayName, deviceID *string,
) util.JSONResponse { ) util.JSONResponse {
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
@ -767,6 +790,9 @@ func completeRegistration(
} }
} }
// Increment prometheus counter for created users
amtRegUsers.Inc()
// Check whether inhibit_login option is set. If so, don't create an access // Check whether inhibit_login option is set. If so, don't create an access
// token or a device for this user // token or a device for this user
if inhibitLogin { if inhibitLogin {
@ -787,8 +813,7 @@ func completeRegistration(
} }
} }
// TODO: Use the device ID in the request. dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName)
dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -796,9 +821,6 @@ func completeRegistration(
} }
} }
// Increment prometheus counter for created users
amtRegUsers.Inc()
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: registerResponse{ JSON: registerResponse{

View file

@ -0,0 +1,234 @@
// Copyright 2019 Sumukha PK
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"encoding/json"
"net/http"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// newTag creates and returns a new gomatrix.TagContent
func newTag() gomatrix.TagContent {
return gomatrix.TagContent{
Tags: make(map[string]gomatrix.TagProperties),
}
}
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
func GetTags(
req *http.Request,
accountDB *accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
syncProducer *producers.SyncAPIProducer,
) util.JSONResponse {
if device.UserID != userID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Cannot retrieve another user's tags"),
}
}
_, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil {
return httputil.LogThenError(req, err)
}
if len(data) == 0 {
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: data[0].Content,
}
}
// PutTag implements PUT /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags/{tag}
// Put functionality works by getting existing data from the DB (if any), adding
// the tag to the "map" and saving the new "map" to the DB
func PutTag(
req *http.Request,
accountDB *accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
tag string,
syncProducer *producers.SyncAPIProducer,
) util.JSONResponse {
if device.UserID != userID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Cannot modify another user's tags"),
}
}
var properties gomatrix.TagProperties
if reqErr := httputil.UnmarshalJSONRequest(req, &properties); reqErr != nil {
return *reqErr
}
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil {
return httputil.LogThenError(req, err)
}
var tagContent gomatrix.TagContent
if len(data) > 0 {
if err = json.Unmarshal(data[0].Content, &tagContent); err != nil {
return httputil.LogThenError(req, err)
}
} else {
tagContent = newTag()
}
tagContent.Tags[tag] = properties
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
return httputil.LogThenError(req, err)
}
// Send data to syncProducer in order to inform clients of changes
// Run in a goroutine in order to prevent blocking the tag request response
go func() {
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
}()
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
// DeleteTag implements DELETE /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags/{tag}
// Delete functionality works by obtaining the saved tags, removing the intended tag from
// the "map" and then saving the new "map" in the DB
func DeleteTag(
req *http.Request,
accountDB *accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
tag string,
syncProducer *producers.SyncAPIProducer,
) util.JSONResponse {
if device.UserID != userID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Cannot modify another user's tags"),
}
}
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil {
return httputil.LogThenError(req, err)
}
// If there are no tags in the database, exit
if len(data) == 0 {
// Spec only defines 200 responses for this endpoint so we don't return anything else.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
var tagContent gomatrix.TagContent
err = json.Unmarshal(data[0].Content, &tagContent)
if err != nil {
return httputil.LogThenError(req, err)
}
// Check whether the tag to be deleted exists
if _, ok := tagContent.Tags[tag]; ok {
delete(tagContent.Tags, tag)
} else {
// Spec only defines 200 responses for this endpoint so we don't return anything else.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
return httputil.LogThenError(req, err)
}
// Send data to syncProducer in order to inform clients of changes
// Run in a goroutine in order to prevent blocking the tag request response
go func() {
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
}()
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
// obtainSavedTags gets all tags scoped to a userID and roomID
// from the database
func obtainSavedTags(
req *http.Request,
userID string,
roomID string,
accountDB *accounts.Database,
) (string, []gomatrixserverlib.ClientEvent, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return "", nil, err
}
data, err := accountDB.GetAccountDataByType(
req.Context(), localpart, roomID, "m.tag",
)
return localpart, data, err
}
// saveTagData saves the provided tag data into the database
func saveTagData(
req *http.Request,
localpart string,
roomID string,
accountDB *accounts.Database,
Tag gomatrix.TagContent,
) error {
newTagData, err := json.Marshal(Tag)
if err != nil {
return err
}
return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData))
}

View file

@ -93,7 +93,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/join/{roomIDOrAlias}", r0mux.Handle("/join/{roomIDOrAlias}",
common.MakeAuthAPI("join", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { common.MakeAuthAPI(gomatrixserverlib.Join, authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req)) vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -132,6 +132,15 @@ func Setup(
nil, cfg, queryAPI, producer, transactionsCache) nil, cfg, queryAPI, producer, transactionsCache)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/event/{eventID}",
common.MakeAuthAPI("rooms_get_event", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, queryAPI, federation, keyRing)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req)) vars, err := common.URLDecodeMapValues(mux.Vars(req))
@ -236,6 +245,13 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/auth/{authType}/fallback/web",
common.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
vars := mux.Vars(req)
return AuthFallback(w, req, vars["authType"], cfg)
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/pushrules/", r0mux.Handle("/pushrules/",
common.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { common.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
// TODO: Implement push rules API // TODO: Implement push rules API
@ -283,7 +299,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetProfile(req, accountDB, vars["userID"], asAPI) return GetProfile(req, accountDB, &cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -293,7 +309,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetAvatarURL(req, accountDB, vars["userID"], asAPI) return GetAvatarURL(req, accountDB, &cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -315,7 +331,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetDisplayName(req, accountDB, vars["userID"], asAPI) return GetDisplayName(req, accountDB, &cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -483,4 +499,34 @@ func Setup(
}} }}
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags",
common.MakeAuthAPI("get_tags", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
common.MakeAuthAPI("put_tag", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
common.MakeAuthAPI("delete_tag", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}),
).Methods(http.MethodDelete, http.MethodOptions)
} }

View file

@ -50,7 +50,7 @@ func SendEvent(
) util.JSONResponse { ) util.JSONResponse {
if txnID != nil { if txnID != nil {
// Try to fetch response from transactionsCache // Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(*txnID); ok { if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
return *res return *res
} }
} }
@ -60,18 +60,18 @@ func SendEvent(
return *resErr return *resErr
} }
var txnAndDeviceID *api.TransactionID var txnAndSessionID *api.TransactionID
if txnID != nil { if txnID != nil {
txnAndDeviceID = &api.TransactionID{ txnAndSessionID = &api.TransactionID{
TransactionID: *txnID, TransactionID: *txnID,
DeviceID: device.ID, SessionID: device.SessionID,
} }
} }
// pass the new event to the roomserver and receive the correct event ID // pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded // event ID in case of duplicate transaction is discarded
eventID, err := producer.SendEvents( eventID, err := producer.SendEvents(
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndDeviceID, req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndSessionID,
) )
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
@ -83,7 +83,7 @@ func SendEvent(
} }
// Add response to transactionsCache // Add response to transactionsCache
if txnID != nil { if txnID != nil {
txnCache.AddTransaction(*txnID, &res) txnCache.AddTransaction(device.AccessToken, *txnID, &res)
} }
return res return res

View file

@ -59,7 +59,7 @@ type idServerStoreInviteResponse struct {
PublicKey string `json:"public_key"` PublicKey string `json:"public_key"`
Token string `json:"token"` Token string `json:"token"`
DisplayName string `json:"display_name"` DisplayName string `json:"display_name"`
PublicKeys []common.PublicKey `json:"public_keys"` PublicKeys []gomatrixserverlib.PublicKey `json:"public_keys"`
} }
var ( var (
@ -91,7 +91,7 @@ func CheckAndProcessInvite(
producer *producers.RoomserverProducer, membership string, roomID string, producer *producers.RoomserverProducer, membership string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStoredOnIDServer bool, err error) { ) (inviteStoredOnIDServer bool, err error) {
if membership != "invite" || (body.Address == "" && body.IDServer == "" && body.Medium == "") { if membership != gomatrixserverlib.Invite || (body.Address == "" && body.IDServer == "" && body.Medium == "") {
// If none of the 3PID-specific fields are supplied, it's a standard invite // If none of the 3PID-specific fields are supplied, it's a standard invite
// so return nil for it to be processed as such // so return nil for it to be processed as such
return return
@ -342,7 +342,7 @@ func emit3PIDInviteEvent(
} }
validityURL := fmt.Sprintf("https://%s/_matrix/identity/api/v1/pubkey/isvalid", body.IDServer) validityURL := fmt.Sprintf("https://%s/_matrix/identity/api/v1/pubkey/isvalid", body.IDServer)
content := common.ThirdPartyInviteContent{ content := gomatrixserverlib.ThirdPartyInviteContent{
DisplayName: res.DisplayName, DisplayName: res.DisplayName,
KeyValidityURL: validityURL, KeyValidityURL: validityURL,
PublicKey: res.PublicKey, PublicKey: res.PublicKey,

View file

@ -86,7 +86,7 @@ func main() {
// Build a m.room.member event. // Build a m.room.member event.
b.Type = "m.room.member" b.Type = "m.room.member"
b.StateKey = userID b.StateKey = userID
b.SetContent(map[string]string{"membership": "join"}) // nolint: errcheck b.SetContent(map[string]string{"membership": gomatrixserverlib.Join}) // nolint: errcheck
b.AuthEvents = []gomatrixserverlib.EventReference{create} b.AuthEvents = []gomatrixserverlib.EventReference{create}
member := buildAndOutput() member := buildAndOutput()

View file

@ -498,6 +498,11 @@ func (config *Dendrite) checkMatrix(configErrs *configErrors) {
checkNotEmpty(configErrs, "matrix.server_name", string(config.Matrix.ServerName)) checkNotEmpty(configErrs, "matrix.server_name", string(config.Matrix.ServerName))
checkNotEmpty(configErrs, "matrix.private_key", string(config.Matrix.PrivateKeyPath)) checkNotEmpty(configErrs, "matrix.private_key", string(config.Matrix.PrivateKeyPath))
checkNotZero(configErrs, "matrix.federation_certificates", int64(len(config.Matrix.FederationCertificatePaths))) checkNotZero(configErrs, "matrix.federation_certificates", int64(len(config.Matrix.FederationCertificatePaths)))
if config.Matrix.RecaptchaEnabled {
checkNotEmpty(configErrs, "matrix.recaptcha_public_key", string(config.Matrix.RecaptchaPublicKey))
checkNotEmpty(configErrs, "matrix.recaptcha_private_key", string(config.Matrix.RecaptchaPrivateKey))
checkNotEmpty(configErrs, "matrix.recaptcha_siteverify_api", string(config.Matrix.RecaptchaSiteVerifyAPI))
}
} }
// checkMedia verifies the parameters media.* are valid. // checkMedia verifies the parameters media.* are valid.

View file

@ -54,12 +54,14 @@ database:
server_key: "postgresql:///server_keys" server_key: "postgresql:///server_keys"
sync_api: "postgresql:///syn_api" sync_api: "postgresql:///syn_api"
room_server: "postgresql:///room_server" room_server: "postgresql:///room_server"
appservice: "postgresql:///appservice"
listen: listen:
room_server: "localhost:7770" room_server: "localhost:7770"
client_api: "localhost:7771" client_api: "localhost:7771"
federation_api: "localhost:7772" federation_api: "localhost:7772"
sync_api: "localhost:7773" sync_api: "localhost:7773"
media_api: "localhost:7774" media_api: "localhost:7774"
appservice_api: "localhost:7777"
typing_server: "localhost:7778" typing_server: "localhost:7778"
logging: logging:
- type: "file" - type: "file"

View file

@ -14,47 +14,7 @@
package common package common
// CreateContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-create import "github.com/matrix-org/gomatrixserverlib"
type CreateContent struct {
Creator string `json:"creator"`
Federate *bool `json:"m.federate,omitempty"`
}
// MemberContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-member
type MemberContent struct {
Membership string `json:"membership"`
DisplayName string `json:"displayname,omitempty"`
AvatarURL string `json:"avatar_url,omitempty"`
Reason string `json:"reason,omitempty"`
ThirdPartyInvite *TPInvite `json:"third_party_invite,omitempty"`
}
// TPInvite is the "Invite" structure defined at http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-member
type TPInvite struct {
DisplayName string `json:"display_name"`
Signed TPInviteSigned `json:"signed"`
}
// TPInviteSigned is the "signed" structure defined at http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-member
type TPInviteSigned struct {
MXID string `json:"mxid"`
Signatures map[string]map[string]string `json:"signatures"`
Token string `json:"token"`
}
// ThirdPartyInviteContent is the content event for https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-third-party-invite
type ThirdPartyInviteContent struct {
DisplayName string `json:"display_name"`
KeyValidityURL string `json:"key_validity_url"`
PublicKey string `json:"public_key"`
PublicKeys []PublicKey `json:"public_keys"`
}
// PublicKey is the PublicKeys structure in https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-third-party-invite
type PublicKey struct {
KeyValidityURL string `json:"key_validity_url"`
PublicKey string `json:"public_key"`
}
// NameContent is the event content for https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-name // NameContent is the event content for https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-name
type NameContent struct { type NameContent struct {
@ -71,51 +31,26 @@ type GuestAccessContent struct {
GuestAccess string `json:"guest_access"` GuestAccess string `json:"guest_access"`
} }
// JoinRulesContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-join-rules
type JoinRulesContent struct {
JoinRule string `json:"join_rule"`
}
// HistoryVisibilityContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-history-visibility // HistoryVisibilityContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-history-visibility
type HistoryVisibilityContent struct { type HistoryVisibilityContent struct {
HistoryVisibility string `json:"history_visibility"` HistoryVisibility string `json:"history_visibility"`
} }
// PowerLevelContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-power-levels
type PowerLevelContent struct {
EventsDefault int `json:"events_default"`
Invite int `json:"invite"`
StateDefault int `json:"state_default"`
Redact int `json:"redact"`
Ban int `json:"ban"`
UsersDefault int `json:"users_default"`
Events map[string]int `json:"events"`
Kick int `json:"kick"`
Users map[string]int `json:"users"`
}
// InitialPowerLevelsContent returns the initial values for m.room.power_levels on room creation // InitialPowerLevelsContent returns the initial values for m.room.power_levels on room creation
// if they have not been specified. // if they have not been specified.
// http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-power-levels // http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-power-levels
// https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/handlers/room.py#L294 // https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/handlers/room.py#L294
func InitialPowerLevelsContent(roomCreator string) PowerLevelContent { func InitialPowerLevelsContent(roomCreator string) (c gomatrixserverlib.PowerLevelContent) {
return PowerLevelContent{ c.Defaults()
EventsDefault: 0, c.Events = map[string]int64{
Invite: 0,
StateDefault: 50,
Redact: 50,
Ban: 50,
UsersDefault: 0,
Events: map[string]int{
"m.room.name": 50, "m.room.name": 50,
"m.room.power_levels": 100, "m.room.power_levels": 100,
"m.room.history_visibility": 100, "m.room.history_visibility": 100,
"m.room.canonical_alias": 50, "m.room.canonical_alias": 50,
"m.room.avatar": 50, "m.room.avatar": 50,
},
Kick: 50,
Users: map[string]int{roomCreator: 100},
} }
c.Users = map[string]int64{roomCreator: 100}
return c
} }
// AliasesContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-aliases // AliasesContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-aliases

View file

@ -10,6 +10,7 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go" opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext" "github.com/opentracing/opentracing-go/ext"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
) )
@ -43,6 +44,24 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
return http.HandlerFunc(withSpan) return http.HandlerFunc(withSpan)
} }
// MakeHTMLAPI adds Span metrics to the HTML Handler function
// This is used to serve HTML alongside JSON error messages
func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
withSpan := func(w http.ResponseWriter, req *http.Request) {
span := opentracing.StartSpan(metricsName)
defer span.Finish()
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
if err := f(w, req); err != nil {
h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
return *err
}))
h.ServeHTTP(w, req)
}
}
return prometheus.InstrumentHandler(metricsName, http.HandlerFunc(withSpan))
}
// MakeInternalAPI turns a util.JSONRequestHandler function into an http.Handler. // MakeInternalAPI turns a util.JSONRequestHandler function into an http.Handler.
// This is used for APIs that are internal to dendrite. // This is used for APIs that are internal to dendrite.
// If we are passed a tracing context in the request headers then we use that // If we are passed a tracing context in the request headers then we use that

View file

@ -15,9 +15,12 @@
package common package common
import ( import (
"fmt"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"runtime"
"strings"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dugong" "github.com/matrix-org/dugong"
@ -54,15 +57,35 @@ func (h *logLevelHook) Levels() []logrus.Level {
return levels return levels
} }
// callerPrettyfier is a function that given a runtime.Frame object, will
// extract the calling function's name and file, and return them in a nicely
// formatted way
func callerPrettyfier(f *runtime.Frame) (string, string) {
// Retrieve just the function name
s := strings.Split(f.Function, ".")
funcname := s[len(s)-1]
// Append a newline + tab to it to move the actual log content to its own line
funcname += "\n\t"
// Surround the filepath in brackets and append line number so IDEs can quickly
// navigate
filename := fmt.Sprintf(" [%s:%d]", f.File, f.Line)
return funcname, filename
}
// SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded. // SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded.
func SetupStdLogging() { func SetupStdLogging() {
logrus.SetReportCaller(true)
logrus.SetFormatter(&utcFormatter{ logrus.SetFormatter(&utcFormatter{
&logrus.TextFormatter{ &logrus.TextFormatter{
TimestampFormat: "2006-01-02T15:04:05.000000000Z07:00", TimestampFormat: "2006-01-02T15:04:05.000000000Z07:00",
FullTimestamp: true, FullTimestamp: true,
DisableColors: false, DisableColors: false,
DisableTimestamp: false, DisableTimestamp: false,
DisableSorting: false, QuoteEmptyFields: true,
CallerPrettyfier: callerPrettyfier,
}, },
}) })
} }
@ -71,8 +94,8 @@ func SetupStdLogging() {
// If something fails here it means that the logging was improperly configured, // If something fails here it means that the logging was improperly configured,
// so we just exit with the error // so we just exit with the error
func SetupHookLogging(hooks []config.LogrusHook, componentName string) { func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
logrus.SetReportCaller(true)
for _, hook := range hooks { for _, hook := range hooks {
// Check we received a proper logging level // Check we received a proper logging level
level, err := logrus.ParseLevel(hook.Level) level, err := logrus.ParseLevel(hook.Level)
if err != nil { if err != nil {
@ -126,6 +149,7 @@ func setupFileHook(hook config.LogrusHook, level logrus.Level, componentName str
DisableColors: true, DisableColors: true,
DisableTimestamp: false, DisableTimestamp: false,
DisableSorting: false, DisableSorting: false,
QuoteEmptyFields: true,
}, },
}, },
&dugong.DailyRotationSchedule{GZip: true}, &dugong.DailyRotationSchedule{GZip: true},

View file

@ -22,7 +22,14 @@ import (
// DefaultCleanupPeriod represents the default time duration after which cacheCleanService runs. // DefaultCleanupPeriod represents the default time duration after which cacheCleanService runs.
const DefaultCleanupPeriod time.Duration = 30 * time.Minute const DefaultCleanupPeriod time.Duration = 30 * time.Minute
type txnsMap map[string]*util.JSONResponse type txnsMap map[CacheKey]*util.JSONResponse
// CacheKey is the type for the key in a transactions cache.
// This is needed because the spec requires transaction IDs to have a per-access token scope.
type CacheKey struct {
AccessToken string
TxnID string
}
// Cache represents a temporary store for response entries. // Cache represents a temporary store for response entries.
// Entries are evicted after a certain period, defined by cleanupPeriod. // Entries are evicted after a certain period, defined by cleanupPeriod.
@ -50,14 +57,14 @@ func NewWithCleanupPeriod(cleanupPeriod time.Duration) *Cache {
return &t return &t
} }
// FetchTransaction looks up an entry for txnID in Cache. // FetchTransaction looks up an entry for the (accessToken, txnID) tuple in Cache.
// Looks in both the txnMaps. // Looks in both the txnMaps.
// Returns (JSON response, true) if txnID is found, else the returned bool is false. // Returns (JSON response, true) if txnID is found, else the returned bool is false.
func (t *Cache) FetchTransaction(txnID string) (*util.JSONResponse, bool) { func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, bool) {
t.RLock() t.RLock()
defer t.RUnlock() defer t.RUnlock()
for _, txns := range t.txnsMaps { for _, txns := range t.txnsMaps {
res, ok := txns[txnID] res, ok := txns[CacheKey{accessToken, txnID}]
if ok { if ok {
return res, true return res, true
} }
@ -65,13 +72,13 @@ func (t *Cache) FetchTransaction(txnID string) (*util.JSONResponse, bool) {
return nil, false return nil, false
} }
// AddTransaction adds an entry for txnID in Cache for later access. // AddTransaction adds an entry for the (accessToken, txnID) tuple in Cache.
// Adds to the front txnMap. // Adds to the front txnMap.
func (t *Cache) AddTransaction(txnID string, res *util.JSONResponse) { func (t *Cache) AddTransaction(accessToken, txnID string, res *util.JSONResponse) {
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
t.txnsMaps[0][txnID] = res t.txnsMaps[0][CacheKey{accessToken, txnID}] = res
} }
// cacheCleanService is responsible for cleaning up entries after cleanupPeriod. // cacheCleanService is responsible for cleaning up entries after cleanupPeriod.

View file

@ -24,27 +24,54 @@ type fakeType struct {
} }
var ( var (
fakeAccessToken = "aRandomAccessToken"
fakeAccessToken2 = "anotherRandomAccessToken"
fakeTxnID = "aRandomTxnID" fakeTxnID = "aRandomTxnID"
fakeResponse = &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: "0"}} fakeResponse = &util.JSONResponse{
Code: http.StatusOK, JSON: fakeType{ID: "0"},
}
fakeResponse2 = &util.JSONResponse{
Code: http.StatusOK, JSON: fakeType{ID: "1"},
}
) )
// TestCache creates a New Cache and tests AddTransaction & FetchTransaction // TestCache creates a New Cache and tests AddTransaction & FetchTransaction
func TestCache(t *testing.T) { func TestCache(t *testing.T) {
fakeTxnCache := New() fakeTxnCache := New()
fakeTxnCache.AddTransaction(fakeTxnID, fakeResponse) fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse)
// Add entries for noise. // Add entries for noise.
for i := 1; i <= 100; i++ { for i := 1; i <= 100; i++ {
fakeTxnCache.AddTransaction( fakeTxnCache.AddTransaction(
fakeAccessToken,
fakeTxnID+string(i), fakeTxnID+string(i),
&util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: string(i)}}, &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: string(i)}},
) )
} }
testResponse, ok := fakeTxnCache.FetchTransaction(fakeTxnID) testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID)
if !ok { if !ok {
t.Error("Failed to retrieve entry for txnID: ", fakeTxnID) t.Error("Failed to retrieve entry for txnID: ", fakeTxnID)
} else if testResponse.JSON != fakeResponse.JSON { } else if testResponse.JSON != fakeResponse.JSON {
t.Error("Fetched response incorrect. Expected: ", fakeResponse.JSON, " got: ", testResponse.JSON) t.Error("Fetched response incorrect. Expected: ", fakeResponse.JSON, " got: ", testResponse.JSON)
} }
} }
// TestCacheScope ensures transactions with the same transaction ID are not shared
// across multiple access tokens.
func TestCacheScope(t *testing.T) {
cache := New()
cache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse)
cache.AddTransaction(fakeAccessToken2, fakeTxnID, fakeResponse2)
if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID); !ok {
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
} else if res.JSON != fakeResponse.JSON {
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse.JSON, res.JSON)
}
if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID); !ok {
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
} else if res.JSON != fakeResponse2.JSON {
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON)
}
}

View file

@ -15,9 +15,14 @@
package common package common
import ( import (
"errors"
"strconv" "strconv"
) )
// ErrProfileNoExists is returned when trying to lookup a user's profile that
// doesn't exist locally.
var ErrProfileNoExists = errors.New("no known profile for given user ID")
// AccountData represents account data sent from the client API server to the // AccountData represents account data sent from the client API server to the
// sync API server // sync API server
type AccountData struct { type AccountData struct {

View file

@ -58,7 +58,7 @@ docker-compose up kafka zookeeper postgres
and the following dendrite components and the following dendrite components
``` ```
docker-compose up client_api media_api sync_api room_server public_rooms_api docker-compose up client_api media_api sync_api room_server public_rooms_api typing_server
docker-compose up client_api_proxy docker-compose up client_api_proxy
``` ```

View file

@ -114,6 +114,7 @@ listen:
media_api: "media_api:7774" media_api: "media_api:7774"
public_rooms_api: "public_rooms_api:7775" public_rooms_api: "public_rooms_api:7775"
federation_sender: "federation_sender:7776" federation_sender: "federation_sender:7776"
typing_server: "typing_server:7777"
# The configuration for tracing the dendrite components. # The configuration for tracing the dendrite components.
tracing: tracing:

View file

@ -95,6 +95,16 @@ services:
networks: networks:
- internal - internal
typing_server:
container_name: dendrite_typing_server
hostname: typing_server
entrypoint: ["bash", "./docker/services/typing-server.sh"]
build: ./
volumes:
- ..:/build
networks:
- internal
federation_api_proxy: federation_api_proxy:
container_name: dendrite_federation_api_proxy container_name: dendrite_federation_api_proxy
hostname: federation_api_proxy hostname: federation_api_proxy

View file

@ -0,0 +1,5 @@
#!/bin/bash
bash ./docker/build.sh
./bin/dendrite-typing-server --config=dendrite.yaml

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

View file

@ -58,7 +58,7 @@ func MakeJoin(
Type: "m.room.member", Type: "m.room.member",
StateKey: &userID, StateKey: &userID,
} }
err = builder.SetContent(map[string]interface{}{"membership": "join"}) err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Join})
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) return httputil.LogThenError(httpReq, err)
} }

View file

@ -56,7 +56,7 @@ func MakeLeave(
Type: "m.room.member", Type: "m.room.member",
StateKey: &userID, StateKey: &userID,
} }
err = builder.SetContent(map[string]interface{}{"membership": "leave"}) err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Leave})
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) return httputil.LogThenError(httpReq, err)
} }
@ -153,7 +153,7 @@ func SendLeave(
mem, err := event.Membership() mem, err := event.Membership()
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) return httputil.LogThenError(httpReq, err)
} else if mem != "leave" { } else if mem != gomatrixserverlib.Leave {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The membership in the event content must be set to leave"), JSON: jsonerror.BadJSON("The membership in the event content must be set to leave"),

View file

@ -53,7 +53,7 @@ func GetProfile(
return httputil.LogThenError(httpReq, err) return httputil.LogThenError(httpReq, err)
} }
profile, err := appserviceAPI.RetreiveUserProfile(httpReq.Context(), userID, asAPI, accountDB) profile, err := appserviceAPI.RetrieveUserProfile(httpReq.Context(), userID, asAPI, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) return httputil.LogThenError(httpReq, err)
} }

View file

@ -64,8 +64,9 @@ func Setup(
// {keyID} argument and always return a response containing all of the keys. // {keyID} argument and always return a response containing all of the keys.
v2keysmux.Handle("/server/{keyID}", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server/{keyID}", localKeys).Methods(http.MethodGet)
v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet)
v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet)
v1fedmux.Handle("/send/{txnID}/", common.MakeFedAPI( v1fedmux.Handle("/send/{txnID}", common.MakeFedAPI(
"federation_send", cfg.Matrix.ServerName, keys, "federation_send", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) vars, err := common.URLDecodeMapValues(mux.Vars(httpReq))
@ -260,7 +261,7 @@ func Setup(
}, },
)).Methods(http.MethodPost) )).Methods(http.MethodPost)
v1fedmux.Handle("/backfill/{roomID}/", common.MakeFedAPI( v1fedmux.Handle("/backfill/{roomID}", common.MakeFedAPI(
"federation_backfill", cfg.Matrix.ServerName, keys, "federation_backfill", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) vars, err := common.URLDecodeMapValues(mux.Vars(httpReq))

View file

@ -27,7 +27,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
@ -42,7 +41,7 @@ type invite struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
Sender string `json:"sender"` Sender string `json:"sender"`
Token string `json:"token"` Token string `json:"token"`
Signed common.TPInviteSigned `json:"signed"` Signed gomatrixserverlib.MemberThirdPartyInviteSigned `json:"signed"`
} }
type invites struct { type invites struct {
@ -194,16 +193,16 @@ func createInviteFrom3PIDInvite(
StateKey: &inv.MXID, StateKey: &inv.MXID,
} }
profile, err := appserviceAPI.RetreiveUserProfile(ctx, inv.MXID, asAPI, accountDB) profile, err := appserviceAPI.RetrieveUserProfile(ctx, inv.MXID, asAPI, accountDB)
if err != nil { if err != nil {
return nil, err return nil, err
} }
content := common.MemberContent{ content := gomatrixserverlib.MemberContent{
AvatarURL: profile.AvatarURL, AvatarURL: profile.AvatarURL,
DisplayName: profile.DisplayName, DisplayName: profile.DisplayName,
Membership: "invite", Membership: gomatrixserverlib.Invite,
ThirdPartyInvite: &common.TPInvite{ ThirdPartyInvite: &gomatrixserverlib.MemberThirdPartyInvite{
Signed: inv.Signed, Signed: inv.Signed,
}, },
} }
@ -330,7 +329,7 @@ func sendToRemoteServer(
func fillDisplayName( func fillDisplayName(
builder *gomatrixserverlib.EventBuilder, authEvents gomatrixserverlib.AuthEvents, builder *gomatrixserverlib.EventBuilder, authEvents gomatrixserverlib.AuthEvents,
) error { ) error {
var content common.MemberContent var content gomatrixserverlib.MemberContent
if err := json.Unmarshal(builder.Content, &content); err != nil { if err := json.Unmarshal(builder.Content, &content); err != nil {
return err return err
} }
@ -343,7 +342,7 @@ func fillDisplayName(
return nil return nil
} }
var thirdPartyInviteContent common.ThirdPartyInviteContent var thirdPartyInviteContent gomatrixserverlib.ThirdPartyInviteContent
if err := json.Unmarshal(thirdPartyInviteEvent.Content(), &thirdPartyInviteContent); err != nil { if err := json.Unmarshal(thirdPartyInviteEvent.Content(), &thirdPartyInviteContent); err != nil {
return err return err
} }

View file

@ -0,0 +1,98 @@
package api
import (
"context"
"net/http"
commonHTTP "github.com/matrix-org/dendrite/common/http"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/opentracing/opentracing-go"
)
// QueryJoinedHostsInRoomRequest is a request to QueryJoinedHostsInRoom
type QueryJoinedHostsInRoomRequest struct {
RoomID string `json:"room_id"`
}
// QueryJoinedHostsInRoomResponse is a response to QueryJoinedHostsInRoom
type QueryJoinedHostsInRoomResponse struct {
JoinedHosts []types.JoinedHost `json:"joined_hosts"`
}
// QueryJoinedHostServerNamesRequest is a request to QueryJoinedHostServerNames
type QueryJoinedHostServerNamesInRoomRequest struct {
RoomID string `json:"room_id"`
}
// QueryJoinedHostServerNamesResponse is a response to QueryJoinedHostServerNames
type QueryJoinedHostServerNamesInRoomResponse struct {
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
}
// FederationSenderQueryAPI is used to query information from the federation sender.
type FederationSenderQueryAPI interface {
// Query the joined hosts and the membership events accounting for their participation in a room.
// Note that if a server has multiple users in the room, it will have multiple entries in the returned slice.
// See `QueryJoinedHostServerNamesInRoom` for a de-duplicated version.
QueryJoinedHostsInRoom(
ctx context.Context,
request *QueryJoinedHostsInRoomRequest,
response *QueryJoinedHostsInRoomResponse,
) error
// Query the server names of the joined hosts in a room.
// Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice
// containing only the server names (without information for membership events).
QueryJoinedHostServerNamesInRoom(
ctx context.Context,
request *QueryJoinedHostServerNamesInRoomRequest,
response *QueryJoinedHostServerNamesInRoomResponse,
) error
}
// FederationSenderQueryJoinedHostsInRoomPath is the HTTP path for the QueryJoinedHostsInRoom API.
const FederationSenderQueryJoinedHostsInRoomPath = "/api/federationsender/queryJoinedHostsInRoom"
// FederationSenderQueryJoinedHostServerNamesInRoomPath is the HTTP path for the QueryJoinedHostServerNamesInRoom API.
const FederationSenderQueryJoinedHostServerNamesInRoomPath = "/api/federationsender/queryJoinedHostServerNamesInRoom"
// NewFederationSenderQueryAPIHTTP creates a FederationSenderQueryAPI implemented by talking to a HTTP POST API.
// If httpClient is nil then it uses the http.DefaultClient
func NewFederationSenderQueryAPIHTTP(federationSenderURL string, httpClient *http.Client) FederationSenderQueryAPI {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &httpFederationSenderQueryAPI{federationSenderURL, httpClient}
}
type httpFederationSenderQueryAPI struct {
federationSenderURL string
httpClient *http.Client
}
// QueryJoinedHostsInRoom implements FederationSenderQueryAPI
func (h *httpFederationSenderQueryAPI) QueryJoinedHostsInRoom(
ctx context.Context,
request *QueryJoinedHostsInRoomRequest,
response *QueryJoinedHostsInRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostsInRoom")
defer span.Finish()
apiURL := h.federationSenderURL + FederationSenderQueryJoinedHostsInRoomPath
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// QueryJoinedHostServerNamesInRoom implements FederationSenderQueryAPI
func (h *httpFederationSenderQueryAPI) QueryJoinedHostServerNamesInRoom(
ctx context.Context,
request *QueryJoinedHostServerNamesInRoomRequest,
response *QueryJoinedHostServerNamesInRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom")
defer span.Finish()
apiURL := h.federationSenderURL + FederationSenderQueryJoinedHostServerNamesInRoomPath
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}

View file

@ -233,7 +233,7 @@ func joinedHostsFromEvents(evs []gomatrixserverlib.Event) ([]types.JoinedHost, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
if membership != "join" { if membership != gomatrixserverlib.Join {
continue continue
} }
_, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) _, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())

View file

@ -15,11 +15,15 @@
package federationsender package federationsender
import ( import (
"net/http"
"github.com/matrix-org/dendrite/common/basecomponent" "github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/federationsender/consumers" "github.com/matrix-org/dendrite/federationsender/consumers"
"github.com/matrix-org/dendrite/federationsender/query"
"github.com/matrix-org/dendrite/federationsender/queue" "github.com/matrix-org/dendrite/federationsender/queue"
"github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -29,8 +33,8 @@ import (
func SetupFederationSenderComponent( func SetupFederationSenderComponent(
base *basecomponent.BaseDendrite, base *basecomponent.BaseDendrite,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
queryAPI api.RoomserverQueryAPI, rsQueryAPI roomserverAPI.RoomserverQueryAPI,
) { ) api.FederationSenderQueryAPI {
federationSenderDB, err := storage.NewDatabase(string(base.Cfg.Database.FederationSender)) federationSenderDB, err := storage.NewDatabase(string(base.Cfg.Database.FederationSender))
if err != nil { if err != nil {
logrus.WithError(err).Panic("failed to connect to federation sender db") logrus.WithError(err).Panic("failed to connect to federation sender db")
@ -40,7 +44,7 @@ func SetupFederationSenderComponent(
rsConsumer := consumers.NewOutputRoomEventConsumer( rsConsumer := consumers.NewOutputRoomEventConsumer(
base.Cfg, base.KafkaConsumer, queues, base.Cfg, base.KafkaConsumer, queues,
federationSenderDB, queryAPI, federationSenderDB, rsQueryAPI,
) )
if err = rsConsumer.Start(); err != nil { if err = rsConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start room server consumer") logrus.WithError(err).Panic("failed to start room server consumer")
@ -52,4 +56,11 @@ func SetupFederationSenderComponent(
if err := tsConsumer.Start(); err != nil { if err := tsConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start typing server consumer") logrus.WithError(err).Panic("failed to start typing server consumer")
} }
queryAPI := query.FederationSenderQueryAPI{
DB: federationSenderDB,
}
queryAPI.SetupHTTP(http.DefaultServeMux)
return &queryAPI
} }

View file

@ -0,0 +1,91 @@
package query
import (
"context"
"encoding/json"
"net/http"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// FederationSenderQueryDatabase has the APIs needed to implement the query API.
type FederationSenderQueryDatabase interface {
GetJoinedHosts(
ctx context.Context, roomID string,
) ([]types.JoinedHost, error)
}
// FederationSenderQueryAPI is an implementation of api.FederationSenderQueryAPI
type FederationSenderQueryAPI struct {
DB FederationSenderQueryDatabase
}
// QueryJoinedHostsInRoom implements api.FederationSenderQueryAPI
func (f *FederationSenderQueryAPI) QueryJoinedHostsInRoom(
ctx context.Context,
request *api.QueryJoinedHostsInRoomRequest,
response *api.QueryJoinedHostsInRoomResponse,
) (err error) {
response.JoinedHosts, err = f.DB.GetJoinedHosts(ctx, request.RoomID)
return
}
// QueryJoinedHostServerNamesInRoom implements api.FederationSenderQueryAPI
func (f *FederationSenderQueryAPI) QueryJoinedHostServerNamesInRoom(
ctx context.Context,
request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse,
) (err error) {
joinedHosts, err := f.DB.GetJoinedHosts(ctx, request.RoomID)
if err != nil {
return
}
serverNamesSet := make(map[gomatrixserverlib.ServerName]bool, len(joinedHosts))
for _, host := range joinedHosts {
serverNamesSet[host.ServerName] = true
}
response.ServerNames = make([]gomatrixserverlib.ServerName, 0, len(serverNamesSet))
for name := range serverNamesSet {
response.ServerNames = append(response.ServerNames, name)
}
return
}
// SetupHTTP adds the FederationSenderQueryAPI handlers to the http.ServeMux.
func (f *FederationSenderQueryAPI) SetupHTTP(servMux *http.ServeMux) {
servMux.Handle(
api.FederationSenderQueryJoinedHostsInRoomPath,
common.MakeInternalAPI("QueryJoinedHostsInRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryJoinedHostsInRoomRequest
var response api.QueryJoinedHostsInRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := f.QueryJoinedHostsInRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.FederationSenderQueryJoinedHostServerNamesInRoomPath,
common.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryJoinedHostServerNamesInRoomRequest
var response api.QueryJoinedHostServerNamesInRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := f.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

12
go.mod
View file

@ -20,10 +20,11 @@ require (
github.com/jaegertracing/jaeger-client-go v0.0.0-20170921145708-3ad49a1d839b github.com/jaegertracing/jaeger-client-go v0.0.0-20170921145708-3ad49a1d839b
github.com/jaegertracing/jaeger-lib v0.0.0-20170920222118-21a3da6d66fe github.com/jaegertracing/jaeger-lib v0.0.0-20170920222118-21a3da6d66fe
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/lib/pq v0.0.0-20170918175043-23da1db4f16d github.com/lib/pq v0.0.0-20170918175043-23da1db4f16d
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5
github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2 github.com/matrix-org/gomatrixserverlib v0.0.0-20190814163046-d6285a18401f
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5
github.com/matttproud/golang_protobuf_extensions v1.0.1 github.com/matttproud/golang_protobuf_extensions v1.0.1
@ -40,8 +41,9 @@ require (
github.com/prometheus/common v0.0.0-20170108231212-dd2f054febf4 github.com/prometheus/common v0.0.0-20170108231212-dd2f054febf4
github.com/prometheus/procfs v0.0.0-20170128160123-1878d9fbb537 github.com/prometheus/procfs v0.0.0-20170128160123-1878d9fbb537
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5
github.com/sirupsen/logrus v1.3.0 github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.2.2 github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/testify v1.3.0
github.com/tidwall/gjson v1.1.5 github.com/tidwall/gjson v1.1.5
github.com/tidwall/match v1.0.1 github.com/tidwall/match v1.0.1
github.com/tidwall/sjson v1.0.3 github.com/tidwall/sjson v1.0.3
@ -54,7 +56,7 @@ require (
go.uber.org/zap v1.7.1 go.uber.org/zap v1.7.1
golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613 golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613
golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95 golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7
gopkg.in/Shopify/sarama.v1 v1.11.0 gopkg.in/Shopify/sarama.v1 v1.11.0
gopkg.in/airbrake/gobrake.v2 v2.0.9 gopkg.in/airbrake/gobrake.v2 v2.0.9
gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20170727041045-23bcc3c4eae3 gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20170727041045-23bcc3c4eae3

18
go.sum
View file

@ -36,6 +36,7 @@ github.com/jaegertracing/jaeger-lib v0.0.0-20170920222118-21a3da6d66fe/go.mod h1
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 h1:KAZ1BW2TCmT6PRihDPpocIy1QTtsAsrx6TneU/4+CMg= github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 h1:KAZ1BW2TCmT6PRihDPpocIy1QTtsAsrx6TneU/4+CMg=
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg= github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
@ -47,10 +48,18 @@ github.com/matrix-org/gomatrix v0.0.0-20171003113848-a7fc80c8060c h1:aZap604NyBG
github.com/matrix-org/gomatrix v0.0.0-20171003113848-a7fc80c8060c/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20171003113848-a7fc80c8060c/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af h1:piaIBNQGIHnni27xRB7VKkEwoWCgAmeuYf8pxAyG0bI= github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af h1:piaIBNQGIHnni27xRB7VKkEwoWCgAmeuYf8pxAyG0bI=
github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4=
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20181109104322-1c2cbc0872f0 h1:3UzhmERBbis4ZaB3imEbZwtDjGz/oVRC2cLLEajCzJA= github.com/matrix-org/gomatrixserverlib v0.0.0-20181109104322-1c2cbc0872f0 h1:3UzhmERBbis4ZaB3imEbZwtDjGz/oVRC2cLLEajCzJA=
github.com/matrix-org/gomatrixserverlib v0.0.0-20181109104322-1c2cbc0872f0/go.mod h1:YHyhIQUmuXyKtoVfDUMk/DyU93Taamlu6nPZkij/JtA= github.com/matrix-org/gomatrixserverlib v0.0.0-20181109104322-1c2cbc0872f0/go.mod h1:YHyhIQUmuXyKtoVfDUMk/DyU93Taamlu6nPZkij/JtA=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2 h1:pYajAEdi3sowj4iSunqctchhcMNW3rDjeeH0T4uDkMY= github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2 h1:pYajAEdi3sowj4iSunqctchhcMNW3rDjeeH0T4uDkMY=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E= github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190724145009-a6df10ef35d6 h1:B8n1H5Wb1B5jwLzTylBpY0kJCMRqrofT7PmOw4aJFJA=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190724145009-a6df10ef35d6/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190805173246-3a2199d5ecd6 h1:xr69Hk6QM3RIN6JSvx3RpDowBGpHpDDqhqXCeySwYow=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190805173246-3a2199d5ecd6/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190814163046-d6285a18401f h1:20CZL7ApB7xgR7sZF9yD/qpsP51Sfx0TTgUJ3vKgnZQ=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190814163046-d6285a18401f/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E=
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 h1:p7WTwG+aXM86+yVrYAiCMW3ZHSmotVvuRbjtt3jC+4A= github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 h1:p7WTwG+aXM86+yVrYAiCMW3ZHSmotVvuRbjtt3jC+4A=
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A= github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
github.com/matrix-org/util v0.0.0-20171013132526-8b1c8ab81986 h1:TiWl4hLvezAhRPM8tPcPDFTysZ7k4T/1J4GPp/iqlZo= github.com/matrix-org/util v0.0.0-20171013132526-8b1c8ab81986 h1:TiWl4hLvezAhRPM8tPcPDFTysZ7k4T/1J4GPp/iqlZo=
@ -88,9 +97,14 @@ github.com/sirupsen/logrus v0.0.0-20170822132746-89742aefa4b2 h1:+8J/sCAVv2Y9Ct1
github.com/sirupsen/logrus v0.0.0-20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v0.0.0-20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc=
github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME= github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME=
github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v0.0.0-20170809224252-890a5c3458b4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v0.0.0-20170809224252-890a5c3458b4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/tidwall/gjson v1.0.2 h1:5BsM7kyEAHAUGEGDkEKO9Mdyiuw6QQ6TSDdarP0Nnmk= github.com/tidwall/gjson v1.0.2 h1:5BsM7kyEAHAUGEGDkEKO9Mdyiuw6QQ6TSDdarP0Nnmk=
github.com/tidwall/gjson v1.0.2/go.mod h1:c/nTNbUr0E0OrXEhq1pwa8iEgc2DOt4ZZqAt1HtCkPA= github.com/tidwall/gjson v1.0.2/go.mod h1:c/nTNbUr0E0OrXEhq1pwa8iEgc2DOt4ZZqAt1HtCkPA=
github.com/tidwall/gjson v1.1.5 h1:QysILxBeUEY3GTLA0fQVgkQG1zme8NxGvhh2SSqWNwI= github.com/tidwall/gjson v1.1.5 h1:QysILxBeUEY3GTLA0fQVgkQG1zme8NxGvhh2SSqWNwI=
@ -126,6 +140,9 @@ golang.org/x/sys v0.0.0-20171012164349-43eea11bc926 h1:PY6OU86NqbyZiOzaPnDw6oOjA
golang.org/x/sys v0.0.0-20171012164349-43eea11bc926/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20171012164349-43eea11bc926/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7 h1:LepdCS8Gf/MVejFIt8lsiexZATdoGVyp5bcyS+rYoUI=
golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/Shopify/sarama.v1 v1.11.0 h1:/3kaCyeYaPbr59IBjeqhIcUOB1vXlIVqXAYa5g5C5F0= gopkg.in/Shopify/sarama.v1 v1.11.0 h1:/3kaCyeYaPbr59IBjeqhIcUOB1vXlIVqXAYa5g5C5F0=
gopkg.in/Shopify/sarama.v1 v1.11.0/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc= gopkg.in/Shopify/sarama.v1 v1.11.0/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc=
gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U=
@ -140,4 +157,3 @@ gopkg.in/yaml.v2 v2.0.0-20171116090243-287cf08546ab h1:yZ6iByf7GKeJ3gsd1Dr/xaj1D
gopkg.in/yaml.v2 v2.0.0-20171116090243-287cf08546ab/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.0.0-20171116090243-287cf08546ab/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View file

@ -55,7 +55,7 @@ type downloadRequest struct {
Logger *log.Entry Logger *log.Entry
} }
// Download implements /download amd /thumbnail // Download implements GET /download and GET /thumbnail
// Files from this server (i.e. origin == cfg.ServerName) are served directly // Files from this server (i.e. origin == cfg.ServerName) are served directly
// Files from remote servers (i.e. origin != cfg.ServerName) are cached locally. // Files from remote servers (i.e. origin != cfg.ServerName) are cached locally.
// If they are present in the cache, they are served directly. // If they are present in the cache, they are served directly.
@ -107,14 +107,6 @@ func Download(
} }
// request validation // request validation
if req.Method != http.MethodGet {
dReq.jsonErrorResponse(w, util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.Unknown("request method must be GET"),
})
return
}
if resErr := dReq.Validate(); resErr != nil { if resErr := dReq.Validate(); resErr != nil {
dReq.jsonErrorResponse(w, *resErr) dReq.jsonErrorResponse(w, *resErr)
return return

View file

@ -48,7 +48,7 @@ type uploadResponse struct {
ContentURI string `json:"content_uri"` ContentURI string `json:"content_uri"`
} }
// Upload implements /upload // Upload implements POST /upload
// This endpoint involves uploading potentially significant amounts of data to the homeserver. // This endpoint involves uploading potentially significant amounts of data to the homeserver.
// This implementation supports a configurable maximum file size limit in bytes. If a user tries to upload more than this, they will receive an error that their upload is too large. // This implementation supports a configurable maximum file size limit in bytes. If a user tries to upload more than this, they will receive an error that their upload is too large.
// Uploaded files are processed piece-wise to avoid DoS attacks which would starve the server of memory. // Uploaded files are processed piece-wise to avoid DoS attacks which would starve the server of memory.
@ -75,13 +75,6 @@ func Upload(req *http.Request, cfg *config.Dendrite, db *storage.Database, activ
// all the metadata about the media being uploaded. // all the metadata about the media being uploaded.
// Returns either an uploadRequest or an error formatted as a util.JSONResponse // Returns either an uploadRequest or an error formatted as a util.JSONResponse
func parseAndValidateRequest(req *http.Request, cfg *config.Dendrite) (*uploadRequest, *util.JSONResponse) { func parseAndValidateRequest(req *http.Request, cfg *config.Dendrite) (*uploadRequest, *util.JSONResponse) {
if req.Method != http.MethodPost {
return nil, &util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.Unknown("HTTP request method must be POST."),
}
}
r := &uploadRequest{ r := &uploadRequest{
MediaMetadata: &types.MediaMetadata{ MediaMetadata: &types.MediaMetadata{
Origin: cfg.Matrix.ServerName, Origin: cfg.Matrix.ServerName,

View file

@ -19,6 +19,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/publicroomsapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -39,7 +40,7 @@ func GetVisibility(
var v roomVisibility var v roomVisibility
if isPublic { if isPublic {
v.Visibility = "public" v.Visibility = gomatrixserverlib.Public
} else { } else {
v.Visibility = "private" v.Visibility = "private"
} }
@ -61,7 +62,7 @@ func SetVisibility(
return *reqErr return *reqErr
} }
isPublic := v.Visibility == "public" isPublic := v.Visibility == gomatrixserverlib.Public
if err := publicRoomsDatabase.SetRoomVisibility(req.Context(), isPublic, roomID); err != nil { if err := publicRoomsDatabase.SetRoomVisibility(req.Context(), isPublic, roomID); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -42,8 +42,8 @@ type publicRoomRes struct {
Estimate int64 `json:"total_room_count_estimate,omitempty"` Estimate int64 `json:"total_room_count_estimate,omitempty"`
} }
// GetPublicRooms implements GET /publicRooms // GetPostPublicRooms implements GET and POST /publicRooms
func GetPublicRooms( func GetPostPublicRooms(
req *http.Request, publicRoomDatabase *storage.PublicRoomsServerDatabase, req *http.Request, publicRoomDatabase *storage.PublicRoomsServerDatabase,
) util.JSONResponse { ) util.JSONResponse {
var limit int16 var limit int16
@ -89,6 +89,7 @@ func GetPublicRooms(
// fillPublicRoomsReq fills the Limit, Since and Filter attributes of a GET or POST request // fillPublicRoomsReq fills the Limit, Since and Filter attributes of a GET or POST request
// on /publicRooms by parsing the incoming HTTP request // on /publicRooms by parsing the incoming HTTP request
// Filter is only filled for POST requests
func fillPublicRoomsReq(httpReq *http.Request, request *publicRoomReq) *util.JSONResponse { func fillPublicRoomsReq(httpReq *http.Request, request *publicRoomReq) *util.JSONResponse {
if httpReq.Method == http.MethodGet { if httpReq.Method == http.MethodGet {
limit, err := strconv.Atoi(httpReq.FormValue("limit")) limit, err := strconv.Atoi(httpReq.FormValue("limit"))

View file

@ -64,7 +64,7 @@ func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB *storag
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/publicRooms", r0mux.Handle("/publicRooms",
common.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { common.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse {
return directory.GetPublicRooms(req, publicRoomsDB) return directory.GetPostPublicRooms(req, publicRoomsDB)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
} }

View file

@ -185,7 +185,7 @@ func (d *PublicRoomsServerDatabase) updateNumJoinedUsers(
return err return err
} }
if membership != "join" { if membership != gomatrixserverlib.Join {
return nil return nil
} }

View file

@ -33,13 +33,16 @@ import (
type RoomserverAliasAPIDatabase interface { type RoomserverAliasAPIDatabase interface {
// Save a given room alias with the room ID it refers to. // Save a given room alias with the room ID it refers to.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
SetRoomAlias(ctx context.Context, alias string, roomID string) error SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error
// Look up the room ID a given alias refers to. // Look up the room ID a given alias refers to.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetRoomIDForAlias(ctx context.Context, alias string) (string, error) GetRoomIDForAlias(ctx context.Context, alias string) (string, error)
// Look up all aliases referring to a given room ID. // Look up all aliases referring to a given room ID.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
// Get the user ID of the creator of an alias.
// Returns an error if there was a problem talking to the database.
GetCreatorIDForAlias(ctx context.Context, alias string) (string, error)
// Remove a given room alias. // Remove a given room alias.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
RemoveRoomAlias(ctx context.Context, alias string) error RemoveRoomAlias(ctx context.Context, alias string) error
@ -73,7 +76,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias(
response.AliasExists = false response.AliasExists = false
// Save the new alias // Save the new alias
if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID); err != nil { if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID, request.UserID); err != nil {
return err return err
} }
@ -133,6 +136,22 @@ func (r *RoomserverAliasAPI) GetAliasesForRoomID(
return nil return nil
} }
// GetCreatorIDForAlias implements alias.RoomserverAliasAPI
func (r *RoomserverAliasAPI) GetCreatorIDForAlias(
ctx context.Context,
request *roomserverAPI.GetCreatorIDForAliasRequest,
response *roomserverAPI.GetCreatorIDForAliasResponse,
) error {
// Look up the aliases in the database for the given RoomID
creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias)
if err != nil {
return err
}
response.UserID = creatorID
return nil
}
// RemoveRoomAlias implements alias.RoomserverAliasAPI // RemoveRoomAlias implements alias.RoomserverAliasAPI
func (r *RoomserverAliasAPI) RemoveRoomAlias( func (r *RoomserverAliasAPI) RemoveRoomAlias(
ctx context.Context, ctx context.Context,
@ -277,6 +296,20 @@ func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
servMux.Handle(
roomserverAPI.RoomserverGetCreatorIDForAliasPath,
common.MakeInternalAPI("GetCreatorIDForAlias", func(req *http.Request) util.JSONResponse {
var request roomserverAPI.GetCreatorIDForAliasRequest
var response roomserverAPI.GetCreatorIDForAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetCreatorIDForAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle( servMux.Handle(
roomserverAPI.RoomserverGetAliasesForRoomIDPath, roomserverAPI.RoomserverGetAliasesForRoomIDPath,
common.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { common.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse {

View file

@ -30,7 +30,7 @@ type MockRoomserverAliasAPIDatabase struct {
} }
// These methods can be essentially noop // These methods can be essentially noop
func (db MockRoomserverAliasAPIDatabase) SetRoomAlias(ctx context.Context, alias string, roomID string) error { func (db MockRoomserverAliasAPIDatabase) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
return nil return nil
} }
@ -43,6 +43,12 @@ func (db MockRoomserverAliasAPIDatabase) RemoveRoomAlias(ctx context.Context, al
return nil return nil
} }
func (db *MockRoomserverAliasAPIDatabase) GetCreatorIDForAlias(
ctx context.Context, alias string,
) (string, error) {
return "", nil
}
// This method needs to change depending on test case // This method needs to change depending on test case
func (db *MockRoomserverAliasAPIDatabase) GetRoomIDForAlias( func (db *MockRoomserverAliasAPIDatabase) GetRoomIDForAlias(
ctx context.Context, ctx context.Context,

View file

@ -62,6 +62,18 @@ type GetAliasesForRoomIDResponse struct {
Aliases []string `json:"aliases"` Aliases []string `json:"aliases"`
} }
// GetCreatorIDForAliasRequest is a request to GetCreatorIDForAlias
type GetCreatorIDForAliasRequest struct {
// The alias we want to find the creator of
Alias string `json:"alias"`
}
// GetCreatorIDForAliasResponse is a response to GetCreatorIDForAlias
type GetCreatorIDForAliasResponse struct {
// The user ID of the alias creator
UserID string `json:"user_id"`
}
// RemoveRoomAliasRequest is a request to RemoveRoomAlias // RemoveRoomAliasRequest is a request to RemoveRoomAlias
type RemoveRoomAliasRequest struct { type RemoveRoomAliasRequest struct {
// ID of the user removing the alias // ID of the user removing the alias
@ -96,6 +108,13 @@ type RoomserverAliasAPI interface {
response *GetAliasesForRoomIDResponse, response *GetAliasesForRoomIDResponse,
) error ) error
// Get the user ID of the creator of an alias
GetCreatorIDForAlias(
ctx context.Context,
req *GetCreatorIDForAliasRequest,
response *GetCreatorIDForAliasResponse,
) error
// Remove a room alias // Remove a room alias
RemoveRoomAlias( RemoveRoomAlias(
ctx context.Context, ctx context.Context,
@ -113,6 +132,9 @@ const RoomserverGetRoomIDForAliasPath = "/api/roomserver/GetRoomIDForAlias"
// RoomserverGetAliasesForRoomIDPath is the HTTP path for the GetAliasesForRoomID API. // RoomserverGetAliasesForRoomIDPath is the HTTP path for the GetAliasesForRoomID API.
const RoomserverGetAliasesForRoomIDPath = "/api/roomserver/GetAliasesForRoomID" const RoomserverGetAliasesForRoomIDPath = "/api/roomserver/GetAliasesForRoomID"
// RoomserverGetCreatorIDForAliasPath is the HTTP path for the GetCreatorIDForAlias API.
const RoomserverGetCreatorIDForAliasPath = "/api/roomserver/GetCreatorIDForAlias"
// RoomserverRemoveRoomAliasPath is the HTTP path for the RemoveRoomAlias API. // RoomserverRemoveRoomAliasPath is the HTTP path for the RemoveRoomAlias API.
const RoomserverRemoveRoomAliasPath = "/api/roomserver/removeRoomAlias" const RoomserverRemoveRoomAliasPath = "/api/roomserver/removeRoomAlias"
@ -169,6 +191,19 @@ func (h *httpRoomserverAliasAPI) GetAliasesForRoomID(
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// GetCreatorIDForAlias implements RoomserverAliasAPI
func (h *httpRoomserverAliasAPI) GetCreatorIDForAlias(
ctx context.Context,
request *GetCreatorIDForAliasRequest,
response *GetCreatorIDForAliasResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetCreatorIDForAlias")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverGetCreatorIDForAliasPath
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// RemoveRoomAlias implements RoomserverAliasAPI // RemoveRoomAlias implements RoomserverAliasAPI
func (h *httpRoomserverAliasAPI) RemoveRoomAlias( func (h *httpRoomserverAliasAPI) RemoveRoomAlias(
ctx context.Context, ctx context.Context,

View file

@ -75,9 +75,9 @@ type InputRoomEvent struct {
} }
// TransactionID contains the transaction ID sent by a client when sending an // TransactionID contains the transaction ID sent by a client when sending an
// event, along with the ID of that device. // event, along with the ID of the client session.
type TransactionID struct { type TransactionID struct {
DeviceID string `json:"device_id"` SessionID int64 `json:"session_id"`
TransactionID string `json:"id"` TransactionID string `json:"id"`
} }

View file

@ -23,7 +23,7 @@ func IsServerAllowed(
) bool { ) bool {
for _, ev := range authEvents { for _, ev := range authEvents {
membership, err := ev.Membership() membership, err := ev.Membership()
if err != nil || membership != "join" { if err != nil || membership != gomatrixserverlib.Join {
continue continue
} }

View file

@ -32,7 +32,7 @@ type RoomEventDatabase interface {
StoreEvent( StoreEvent(
ctx context.Context, ctx context.Context,
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
txnAndDeviceID *api.TransactionID, txnAndSessionID *api.TransactionID,
authEventNIDs []types.EventNID, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) ) (types.RoomNID, types.StateAtEvent, error)
// Look up the state entries for a list of string event IDs // Look up the state entries for a list of string event IDs
@ -67,7 +67,7 @@ type RoomEventDatabase interface {
// Returns an empty string if no such event exists. // Returns an empty string if no such event exists.
GetTransactionEventID( GetTransactionEventID(
ctx context.Context, transactionID string, ctx context.Context, transactionID string,
deviceID string, userID string, sessionID int64, userID string,
) (string, error) ) (string, error)
} }
@ -100,7 +100,7 @@ func processRoomEvent(
if input.TransactionID != nil { if input.TransactionID != nil {
tdID := input.TransactionID tdID := input.TransactionID
eventID, err = db.GetTransactionEventID( eventID, err = db.GetTransactionEventID(
ctx, tdID.TransactionID, tdID.DeviceID, input.Event.Sender(), ctx, tdID.TransactionID, tdID.SessionID, input.Event.Sender(),
) )
// On error OR event with the transaction already processed/processesing // On error OR event with the transaction already processed/processesing
if err != nil || eventID != "" { if err != nil || eventID != "" {

View file

@ -23,13 +23,6 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// Membership values
// TODO: Factor these out somewhere sensible?
const join = "join"
const leave = "leave"
const invite = "invite"
const ban = "ban"
// updateMembership updates the current membership and the invites for each // updateMembership updates the current membership and the invites for each
// user affected by a change in the current state of the room. // user affected by a change in the current state of the room.
// Returns a list of output events to write to the kafka log to inform the // Returns a list of output events to write to the kafka log to inform the
@ -91,8 +84,8 @@ func updateMembership(
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
var err error var err error
// Default the membership to Leave if no event was added or removed. // Default the membership to Leave if no event was added or removed.
oldMembership := leave oldMembership := gomatrixserverlib.Leave
newMembership := leave newMembership := gomatrixserverlib.Leave
if remove != nil { if remove != nil {
oldMembership, err = remove.Membership() oldMembership, err = remove.Membership()
@ -106,7 +99,7 @@ func updateMembership(
return nil, err return nil, err
} }
} }
if oldMembership == newMembership && newMembership != join { if oldMembership == newMembership && newMembership != gomatrixserverlib.Join {
// If the membership is the same then nothing changed and we can return // If the membership is the same then nothing changed and we can return
// immediately, unless it's a Join update (e.g. profile update). // immediately, unless it's a Join update (e.g. profile update).
return updates, nil return updates, nil
@ -118,11 +111,11 @@ func updateMembership(
} }
switch newMembership { switch newMembership {
case invite: case gomatrixserverlib.Invite:
return updateToInviteMembership(mu, add, updates) return updateToInviteMembership(mu, add, updates)
case join: case gomatrixserverlib.Join:
return updateToJoinMembership(mu, add, updates) return updateToJoinMembership(mu, add, updates)
case leave, ban: case gomatrixserverlib.Leave, gomatrixserverlib.Ban:
return updateToLeaveMembership(mu, add, newMembership, updates) return updateToLeaveMembership(mu, add, newMembership, updates)
default: default:
panic(fmt.Errorf( panic(fmt.Errorf(
@ -183,7 +176,7 @@ func updateToJoinMembership(
for _, eventID := range retired { for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{ orie := api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
Membership: join, Membership: gomatrixserverlib.Join,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(), TargetUserID: *add.StateKey(),
} }

View file

@ -359,7 +359,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(
return nil, err return nil, err
} }
if membership == "join" { if membership == gomatrixserverlib.Join {
events = append(events, event) events = append(events, event)
} }
} }

View file

@ -25,14 +25,16 @@ CREATE TABLE IF NOT EXISTS roomserver_room_aliases (
-- Alias of the room -- Alias of the room
alias TEXT NOT NULL PRIMARY KEY, alias TEXT NOT NULL PRIMARY KEY,
-- Room ID the alias refers to -- Room ID the alias refers to
room_id TEXT NOT NULL room_id TEXT NOT NULL,
-- User ID of the creator of this alias
creator_id TEXT NOT NULL
); );
CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id); CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
` `
const insertRoomAliasSQL = "" + const insertRoomAliasSQL = "" +
"INSERT INTO roomserver_room_aliases (alias, room_id) VALUES ($1, $2)" "INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)"
const selectRoomIDFromAliasSQL = "" + const selectRoomIDFromAliasSQL = "" +
"SELECT room_id FROM roomserver_room_aliases WHERE alias = $1" "SELECT room_id FROM roomserver_room_aliases WHERE alias = $1"
@ -40,6 +42,9 @@ const selectRoomIDFromAliasSQL = "" +
const selectAliasesFromRoomIDSQL = "" + const selectAliasesFromRoomIDSQL = "" +
"SELECT alias FROM roomserver_room_aliases WHERE room_id = $1" "SELECT alias FROM roomserver_room_aliases WHERE room_id = $1"
const selectCreatorIDFromAliasSQL = "" +
"SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1"
const deleteRoomAliasSQL = "" + const deleteRoomAliasSQL = "" +
"DELETE FROM roomserver_room_aliases WHERE alias = $1" "DELETE FROM roomserver_room_aliases WHERE alias = $1"
@ -47,6 +52,7 @@ type roomAliasesStatements struct {
insertRoomAliasStmt *sql.Stmt insertRoomAliasStmt *sql.Stmt
selectRoomIDFromAliasStmt *sql.Stmt selectRoomIDFromAliasStmt *sql.Stmt
selectAliasesFromRoomIDStmt *sql.Stmt selectAliasesFromRoomIDStmt *sql.Stmt
selectCreatorIDFromAliasStmt *sql.Stmt
deleteRoomAliasStmt *sql.Stmt deleteRoomAliasStmt *sql.Stmt
} }
@ -59,14 +65,15 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
{&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.insertRoomAliasStmt, insertRoomAliasSQL},
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
{&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL},
{&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL},
{&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, {&s.deleteRoomAliasStmt, deleteRoomAliasSQL},
}.prepare(db) }.prepare(db)
} }
func (s *roomAliasesStatements) insertRoomAlias( func (s *roomAliasesStatements) insertRoomAlias(
ctx context.Context, alias string, roomID string, ctx context.Context, alias string, roomID string, creatorUserID string,
) (err error) { ) (err error) {
_, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID) _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
return return
} }
@ -101,6 +108,16 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID(
return return
} }
func (s *roomAliasesStatements) selectCreatorIDFromAlias(
ctx context.Context, alias string,
) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *roomAliasesStatements) deleteRoomAlias( func (s *roomAliasesStatements) deleteRoomAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (err error) { ) (err error) {

View file

@ -47,7 +47,7 @@ func Open(dataSourceName string) (*Database, error) {
// StoreEvent implements input.EventDatabase // StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent( func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event, ctx context.Context, event gomatrixserverlib.Event,
txnAndDeviceID *api.TransactionID, authEventNIDs []types.EventNID, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) { ) (types.RoomNID, types.StateAtEvent, error) {
var ( var (
roomNID types.RoomNID roomNID types.RoomNID
@ -58,10 +58,10 @@ func (d *Database) StoreEvent(
err error err error
) )
if txnAndDeviceID != nil { if txnAndSessionID != nil {
if err = d.statements.insertTransaction( if err = d.statements.insertTransaction(
ctx, txnAndDeviceID.TransactionID, ctx, txnAndSessionID.TransactionID,
txnAndDeviceID.DeviceID, event.Sender(), event.EventID(), txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil { ); err != nil {
return 0, types.StateAtEvent{}, err return 0, types.StateAtEvent{}, err
} }
@ -322,9 +322,9 @@ func (d *Database) GetLatestEventsForUpdate(
// GetTransactionEventID implements input.EventDatabase // GetTransactionEventID implements input.EventDatabase
func (d *Database) GetTransactionEventID( func (d *Database) GetTransactionEventID(
ctx context.Context, transactionID string, ctx context.Context, transactionID string,
deviceID string, userID string, sessionID int64, userID string,
) (string, error) { ) (string, error) {
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, deviceID, userID) eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
@ -441,8 +441,8 @@ func (d *Database) GetInvitesForUser(
} }
// SetRoomAlias implements alias.RoomserverAliasAPIDB // SetRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string) error { func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
return d.statements.insertRoomAlias(ctx, alias, roomID) return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID)
} }
// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB // GetRoomIDForAlias implements alias.RoomserverAliasAPIDB
@ -455,6 +455,13 @@ func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]st
return d.statements.selectAliasesFromRoomID(ctx, roomID) return d.statements.selectAliasesFromRoomID(ctx, roomID)
} }
// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string,
) (string, error) {
return d.statements.selectCreatorIDFromAlias(ctx, alias)
}
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB // RemoveRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
return d.statements.deleteRoomAlias(ctx, alias) return d.statements.deleteRoomAlias(ctx, alias)

View file

@ -23,8 +23,8 @@ const transactionsSchema = `
CREATE TABLE IF NOT EXISTS roomserver_transactions ( CREATE TABLE IF NOT EXISTS roomserver_transactions (
-- The transaction ID of the event. -- The transaction ID of the event.
transaction_id TEXT NOT NULL, transaction_id TEXT NOT NULL,
-- The device ID of the originating transaction. -- The session ID of the originating transaction.
device_id TEXT NOT NULL, session_id BIGINT NOT NULL,
-- User ID of the sender who authored the event -- User ID of the sender who authored the event
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
-- Event ID corresponding to the transaction -- Event ID corresponding to the transaction
@ -32,16 +32,16 @@ CREATE TABLE IF NOT EXISTS roomserver_transactions (
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
-- A transaction ID is unique for a user and device -- A transaction ID is unique for a user and device
-- This automatically creates an index. -- This automatically creates an index.
PRIMARY KEY (transaction_id, device_id, user_id) PRIMARY KEY (transaction_id, session_id, user_id)
); );
` `
const insertTransactionSQL = "" + const insertTransactionSQL = "" +
"INSERT INTO roomserver_transactions (transaction_id, device_id, user_id, event_id)" + "INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)" +
" VALUES ($1, $2, $3, $4)" " VALUES ($1, $2, $3, $4)"
const selectTransactionEventIDSQL = "" + const selectTransactionEventIDSQL = "" +
"SELECT event_id FROM roomserver_transactions" + "SELECT event_id FROM roomserver_transactions" +
" WHERE transaction_id = $1 AND device_id = $2 AND user_id = $3" " WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3"
type transactionStatements struct { type transactionStatements struct {
insertTransactionStmt *sql.Stmt insertTransactionStmt *sql.Stmt
@ -63,12 +63,12 @@ func (s *transactionStatements) prepare(db *sql.DB) (err error) {
func (s *transactionStatements) insertTransaction( func (s *transactionStatements) insertTransaction(
ctx context.Context, ctx context.Context,
transactionID string, transactionID string,
deviceID string, sessionID int64,
userID string, userID string,
eventID string, eventID string,
) (err error) { ) (err error) {
_, err = s.insertTransactionStmt.ExecContext( _, err = s.insertTransactionStmt.ExecContext(
ctx, transactionID, deviceID, userID, eventID, ctx, transactionID, sessionID, userID, eventID,
) )
return return
} }
@ -76,11 +76,11 @@ func (s *transactionStatements) insertTransaction(
func (s *transactionStatements) selectTransactionEventID( func (s *transactionStatements) selectTransactionEventID(
ctx context.Context, ctx context.Context,
transactionID string, transactionID string,
deviceID string, sessionID int64,
userID string, userID string,
) (eventID string, err error) { ) (eventID string, err error) {
err = s.selectTransactionEventIDStmt.QueryRowContext( err = s.selectTransactionEventIDStmt.QueryRowContext(
ctx, transactionID, deviceID, userID, ctx, transactionID, sessionID, userID,
).Scan(&eventID) ).Scan(&eventID)
return return
} }

View file

@ -22,7 +22,15 @@ then args="--fast"
fi fi
echo "Installing golangci-lint..." echo "Installing golangci-lint..."
# Make a backup of go.{mod,sum} first
# TODO: Once go 1.13 is out, use go get's -mod=readonly option
# https://github.com/golang/go/issues/30667
cp go.mod go.mod.bak && cp go.sum go.sum.bak
go get github.com/golangci/golangci-lint/cmd/golangci-lint go get github.com/golangci/golangci-lint/cmd/golangci-lint
echo "Looking for lint..." echo "Looking for lint..."
golangci-lint run $args golangci-lint run $args
# Restore go.{mod,sum}
mv go.mod.bak go.mod && mv go.sum.bak go.sum

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1" sarama "gopkg.in/Shopify/sarama.v1"
) )
@ -29,7 +30,7 @@ import (
// OutputClientDataConsumer consumes events that originated in the client API server. // OutputClientDataConsumer consumes events that originated in the client API server.
type OutputClientDataConsumer struct { type OutputClientDataConsumer struct {
clientAPIConsumer *common.ContinualConsumer clientAPIConsumer *common.ContinualConsumer
db *storage.SyncServerDatabase db *storage.SyncServerDatasource
notifier *sync.Notifier notifier *sync.Notifier
} }
@ -38,7 +39,7 @@ func NewOutputClientDataConsumer(
cfg *config.Dendrite, cfg *config.Dendrite,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *sync.Notifier, n *sync.Notifier,
store *storage.SyncServerDatabase, store *storage.SyncServerDatasource,
) *OutputClientDataConsumer { ) *OutputClientDataConsumer {
consumer := common.ContinualConsumer{ consumer := common.ContinualConsumer{
@ -78,7 +79,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
"room_id": output.RoomID, "room_id": output.RoomID,
}).Info("received data from client API server") }).Info("received data from client API server")
syncStreamPos, err := s.db.UpsertAccountData( pduPos, err := s.db.UpsertAccountData(
context.TODO(), string(msg.Key), output.RoomID, output.Type, context.TODO(), string(msg.Key), output.RoomID, output.Type,
) )
if err != nil { if err != nil {
@ -89,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data") }).Panicf("could not save account data")
} }
s.notifier.OnNewEvent(nil, string(msg.Key), syncStreamPos) s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.SyncPosition{PDUPosition: pduPos})
return nil return nil
} }

View file

@ -33,7 +33,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer roomServerConsumer *common.ContinualConsumer
db *storage.SyncServerDatabase db *storage.SyncServerDatasource
notifier *sync.Notifier notifier *sync.Notifier
query api.RoomserverQueryAPI query api.RoomserverQueryAPI
} }
@ -43,7 +43,7 @@ func NewOutputRoomEventConsumer(
cfg *config.Dendrite, cfg *config.Dendrite,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *sync.Notifier, n *sync.Notifier,
store *storage.SyncServerDatabase, store *storage.SyncServerDatasource,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
@ -126,7 +126,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
} }
} }
syncStreamPos, err := s.db.WriteEvent( pduPos, err := s.db.WriteEvent(
ctx, ctx,
&ev, &ev,
addsStateEvents, addsStateEvents,
@ -144,7 +144,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write event failure") }).Panicf("roomserver output log: write event failure")
return nil return nil
} }
s.notifier.OnNewEvent(&ev, "", types.StreamPosition(syncStreamPos)) s.notifier.OnNewEvent(&ev, "", nil, types.SyncPosition{PDUPosition: pduPos})
return nil return nil
} }
@ -152,7 +152,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
func (s *OutputRoomEventConsumer) onNewInviteEvent( func (s *OutputRoomEventConsumer) onNewInviteEvent(
ctx context.Context, msg api.OutputNewInviteEvent, ctx context.Context, msg api.OutputNewInviteEvent,
) error { ) error {
syncStreamPos, err := s.db.AddInviteEvent(ctx, msg.Event) pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -161,7 +161,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure") }).Panicf("roomserver output log: write invite failure")
return nil return nil
} }
s.notifier.OnNewEvent(&msg.Event, "", syncStreamPos) s.notifier.OnNewEvent(&msg.Event, "", nil, types.SyncPosition{PDUPosition: pduPos})
return nil return nil
} }

View file

@ -0,0 +1,96 @@
// Copyright 2019 Alex Chen
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"encoding/json"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/typingserver/api"
log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1"
)
// OutputTypingEventConsumer consumes events that originated in the typing server.
type OutputTypingEventConsumer struct {
typingConsumer *common.ContinualConsumer
db *storage.SyncServerDatasource
notifier *sync.Notifier
}
// NewOutputTypingEventConsumer creates a new OutputTypingEventConsumer.
// Call Start() to begin consuming from the typing server.
func NewOutputTypingEventConsumer(
cfg *config.Dendrite,
kafkaConsumer sarama.Consumer,
n *sync.Notifier,
store *storage.SyncServerDatasource,
) *OutputTypingEventConsumer {
consumer := common.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputTypingEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
}
s := &OutputTypingEventConsumer{
typingConsumer: &consumer,
db: store,
notifier: n,
}
consumer.ProcessMessage = s.onMessage
return s
}
// Start consuming from typing api
func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent(nil, roomID, nil, types.SyncPosition{TypingPosition: latestSyncPosition})
})
return s.typingConsumer.Start()
}
func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
var output api.OutputTypingEvent
if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("typing server output log: message parse failure")
return nil
}
log.WithFields(log.Fields{
"room_id": output.Event.RoomID,
"user_id": output.Event.UserID,
"typing": output.Event.Typing,
}).Debug("received data from typing server")
var typingPos int64
typingEvent := output.Event
if typingEvent.Typing {
typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime)
} else {
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
}
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.SyncPosition{TypingPosition: typingPos})
return nil
}

View file

@ -34,7 +34,7 @@ const pathPrefixR0 = "/_matrix/client/r0"
// Due to Setup being used to call many other functions, a gocyclo nolint is // Due to Setup being used to call many other functions, a gocyclo nolint is
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database) { func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatasource, deviceDB *devices.Database) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
authData := auth.Data{ authData := auth.Data{

View file

@ -40,11 +40,14 @@ type stateEventInStateResp struct {
// TODO: Check if the user is in the room. If not, check if the room's history // TODO: Check if the user is in the room. If not, check if the room's history
// is publicly visible. Current behaviour is returning an empty array if the // is publicly visible. Current behaviour is returning an empty array if the
// user cannot see the room's history. // user cannot see the room's history.
func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatabase, roomID string) util.JSONResponse { func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatasource, roomID string) util.JSONResponse {
// TODO(#287): Auth request and handle the case where the user has left (where // TODO(#287): Auth request and handle the case where the user has left (where
// we should return the state at the poin they left) // we should return the state at the poin they left)
stateEvents, err := db.GetStateEventsForRoom(req.Context(), roomID) stateFilterPart := gomatrixserverlib.DefaultFilterPart()
// TODO: stateFilterPart should not limit the number of state events (or only limits abusive number of events)
stateEvents, err := db.GetStateEventsForRoom(req.Context(), roomID, &stateFilterPart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -84,7 +87,7 @@ func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatabase, r
// /rooms/{roomID}/state/{type}/{statekey} request. It will look in current // /rooms/{roomID}/state/{type}/{statekey} request. It will look in current
// state to see if there is an event with that type and state key, if there // state to see if there is an event with that type and state key, if there
// is then (by default) we return the content, otherwise a 404. // is then (by default) we return the content, otherwise a 404.
func OnIncomingStateTypeRequest(req *http.Request, db *storage.SyncServerDatabase, roomID string, evType, stateKey string) util.JSONResponse { func OnIncomingStateTypeRequest(req *http.Request, db *storage.SyncServerDatasource, roomID string, evType, stateKey string) util.JSONResponse {
// TODO(#287): Auth request and handle the case where the user has left (where // TODO(#287): Auth request and handle the case where the user has left (where
// we should return the state at the poin they left) // we should return the state at the poin they left)

View file

@ -18,9 +18,9 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/syncapi/types"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -43,7 +43,7 @@ CREATE TABLE IF NOT EXISTS syncapi_account_data_type (
CONSTRAINT syncapi_account_data_unique UNIQUE (user_id, room_id, type) CONSTRAINT syncapi_account_data_unique UNIQUE (user_id, room_id, type)
); );
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_account_data_id_idx ON syncapi_account_data_type(id); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_account_data_id_idx ON syncapi_account_data_type(id, type);
` `
const insertAccountDataSQL = "" + const insertAccountDataSQL = "" +
@ -55,7 +55,9 @@ const insertAccountDataSQL = "" +
const selectAccountDataInRangeSQL = "" + const selectAccountDataInRangeSQL = "" +
"SELECT room_id, type FROM syncapi_account_data_type" + "SELECT room_id, type FROM syncapi_account_data_type" +
" WHERE user_id = $1 AND id > $2 AND id <= $3" + " WHERE user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC" " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
" ORDER BY id ASC LIMIT $6"
const selectMaxAccountDataIDSQL = "" + const selectMaxAccountDataIDSQL = "" +
"SELECT MAX(id) FROM syncapi_account_data_type" "SELECT MAX(id) FROM syncapi_account_data_type"
@ -94,7 +96,8 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountDataInRange( func (s *accountDataStatements) selectAccountDataInRange(
ctx context.Context, ctx context.Context,
userID string, userID string,
oldPos, newPos types.StreamPosition, oldPos, newPos int64,
accountDataFilterPart *gomatrixserverlib.FilterPart,
) (data map[string][]string, err error) { ) (data map[string][]string, err error) {
data = make(map[string][]string) data = make(map[string][]string)
@ -105,7 +108,11 @@ func (s *accountDataStatements) selectAccountDataInRange(
oldPos-- oldPos--
} }
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos) rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos,
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.NotTypes)),
accountDataFilterPart.Limit,
)
if err != nil { if err != nil {
return return
} }

View file

@ -17,6 +17,7 @@ package storage
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -32,6 +33,10 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
-- The state event type e.g 'm.room.member' -- The state event type e.g 'm.room.member'
type TEXT NOT NULL, type TEXT NOT NULL,
-- The 'sender' property of the event.
sender TEXT NOT NULL,
-- true if the event content contains a url key
contains_url BOOL NOT NULL,
-- The state_key value for this state event e.g '' -- The state_key value for this state event e.g ''
state_key TEXT NOT NULL, state_key TEXT NOT NULL,
-- The JSON for the event. Stored as TEXT because this should be valid UTF-8. -- The JSON for the event. Stored as TEXT because this should be valid UTF-8.
@ -46,16 +51,16 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
CONSTRAINT syncapi_room_state_unique UNIQUE (room_id, type, state_key) CONSTRAINT syncapi_room_state_unique UNIQUE (room_id, type, state_key)
); );
-- for event deletion -- for event deletion
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
-- for querying membership states of users -- for querying membership states of users
CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +
"INSERT INTO syncapi_current_room_state (room_id, event_id, type, state_key, event_json, membership, added_at)" + "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, event_json, membership, added_at)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
" ON CONFLICT ON CONSTRAINT syncapi_room_state_unique" + " ON CONFLICT ON CONSTRAINT syncapi_room_state_unique" +
" DO UPDATE SET event_id = $2, event_json = $5, membership = $6, added_at = $7" " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, event_json = $7, membership = $8, added_at = $9"
const deleteRoomStateByEventIDSQL = "" + const deleteRoomStateByEventIDSQL = "" +
"DELETE FROM syncapi_current_room_state WHERE event_id = $1" "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
@ -64,7 +69,13 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT event_json FROM syncapi_current_room_state WHERE room_id = $1" "SELECT event_json FROM syncapi_current_room_state WHERE room_id = $1" +
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
" AND ( $6::bool IS NULL OR contains_url = $6 )" +
" LIMIT $7"
const selectJoinedUsersSQL = "" + const selectJoinedUsersSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
@ -166,9 +177,17 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership(
// CurrentState returns all the current state events for the given room. // CurrentState returns all the current state events for the given room.
func (s *currentRoomStateStatements) selectCurrentState( func (s *currentRoomStateStatements) selectCurrentState(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
stateFilterPart *gomatrixserverlib.FilterPart,
) ([]gomatrixserverlib.Event, error) { ) ([]gomatrixserverlib.Event, error) {
stmt := common.TxStmt(txn, s.selectCurrentStateStmt) stmt := common.TxStmt(txn, s.selectCurrentStateStmt)
rows, err := stmt.QueryContext(ctx, roomID) rows, err := stmt.QueryContext(ctx, roomID,
pq.StringArray(stateFilterPart.Senders),
pq.StringArray(stateFilterPart.NotSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
stateFilterPart.ContainsURL,
stateFilterPart.Limit,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -189,12 +208,23 @@ func (s *currentRoomStateStatements) upsertRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event gomatrixserverlib.Event, membership *string, addedAt int64, event gomatrixserverlib.Event, membership *string, addedAt int64,
) error { ) error {
// Parse content as JSON and search for an "url" key
containsURL := false
var content map[string]interface{}
if json.Unmarshal(event.Content(), &content) != nil {
// Set containsURL to true if url is present
_, containsURL = content["url"]
}
// upsert state event
stmt := common.TxStmt(txn, s.upsertRoomStateStmt) stmt := common.TxStmt(txn, s.upsertRoomStateStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, ctx,
event.RoomID(), event.RoomID(),
event.EventID(), event.EventID(),
event.Type(), event.Type(),
event.Sender(),
containsURL,
*event.StateKey(), *event.StateKey(),
event.JSON(), event.JSON(),
membership, membership,

View file

@ -0,0 +1,36 @@
// Copyright 2017 Thibaut CHARLES
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"strings"
)
// filterConvertWildcardToSQL converts wildcards as defined in
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
// to SQL wildcards that can be used with LIKE()
func filterConvertTypeWildcardToSQL(values []string) []string {
if values == nil {
// Return nil instead of []string{} so IS NULL can work correctly when
// the return value is passed into SQL queries
return nil
}
ret := make([]string, len(values))
for i := range values {
ret[i] = strings.Replace(values[i], "*", "%", -1)
}
return ret
}

View file

@ -23,7 +23,7 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx
-- For deleting old invites -- For deleting old invites
CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx
ON syncapi_invite_events(target_user_id, id); ON syncapi_invite_events (event_id);
` `
const insertInviteEventSQL = "" + const insertInviteEventSQL = "" +

View file

@ -17,13 +17,13 @@ package storage
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"sort" "sort"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -44,11 +44,17 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
-- The JSON for the event. Stored as TEXT because this should be valid UTF-8. -- The JSON for the event. Stored as TEXT because this should be valid UTF-8.
event_json TEXT NOT NULL, event_json TEXT NOT NULL,
-- The event type e.g 'm.room.member'.
type TEXT NOT NULL,
-- The 'sender' property of the event.
sender TEXT NOT NULL,
-- true if the event content contains a url key.
contains_url BOOL NOT NULL,
-- A list of event IDs which represent a delta of added/removed room state. This can be NULL -- A list of event IDs which represent a delta of added/removed room state. This can be NULL
-- if there is no delta. -- if there is no delta.
add_state_ids TEXT[], add_state_ids TEXT[],
remove_state_ids TEXT[], remove_state_ids TEXT[],
device_id TEXT, -- The local device that sent the event, if any session_id BIGINT, -- The client session that sent the event, if any
transaction_id TEXT -- The transaction id used to send the event, if any transaction_id TEXT -- The transaction id used to send the event, if any
); );
-- for event selection -- for event selection
@ -57,14 +63,14 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" + "INSERT INTO syncapi_output_room_events (" +
" room_id, event_id, event_json, add_state_ids, remove_state_ids, device_id, transaction_id" + "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" +
") VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id" ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT id, event_json FROM syncapi_output_room_events WHERE event_id = ANY($1)" "SELECT id, event_json FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT id, event_json, device_id, transaction_id FROM syncapi_output_room_events" + "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
@ -76,7 +82,13 @@ const selectStateInRangeSQL = "" +
"SELECT id, event_json, add_state_ids, remove_state_ids" + "SELECT id, event_json, add_state_ids, remove_state_ids" +
" FROM syncapi_output_room_events" + " FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" ORDER BY id ASC" " AND ( $3::text[] IS NULL OR sender = ANY($3) )" +
" AND ( $4::text[] IS NULL OR NOT(sender = ANY($4)) )" +
" AND ( $5::text[] IS NULL OR type LIKE ANY($5) )" +
" AND ( $6::text[] IS NULL OR NOT(type LIKE ANY($6)) )" +
" AND ( $7::bool IS NULL OR contains_url = $7 )" +
" ORDER BY id ASC" +
" LIMIT $8"
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
@ -109,15 +121,24 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
return return
} }
// selectStateInRange returns the state events between the two given stream positions, exclusive of oldPos, inclusive of newPos. // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) selectStateInRange( func (s *outputRoomEventsStatements) selectStateInRange(
ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, ctx context.Context, txn *sql.Tx, oldPos, newPos int64,
stateFilterPart *gomatrixserverlib.FilterPart,
) (map[string]map[string]bool, map[string]streamEvent, error) { ) (map[string]map[string]bool, map[string]streamEvent, error) {
stmt := common.TxStmt(txn, s.selectStateInRangeStmt) stmt := common.TxStmt(txn, s.selectStateInRangeStmt)
rows, err := stmt.QueryContext(ctx, oldPos, newPos) rows, err := stmt.QueryContext(
ctx, oldPos, newPos,
pq.StringArray(stateFilterPart.Senders),
pq.StringArray(stateFilterPart.NotSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
stateFilterPart.ContainsURL,
stateFilterPart.Limit,
)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -171,7 +192,7 @@ func (s *outputRoomEventsStatements) selectStateInRange(
eventIDToEvent[ev.EventID()] = streamEvent{ eventIDToEvent[ev.EventID()] = streamEvent{
Event: ev, Event: ev,
streamPosition: types.StreamPosition(streamPos), streamPosition: streamPos,
} }
} }
@ -200,21 +221,33 @@ func (s *outputRoomEventsStatements) insertEvent(
event *gomatrixserverlib.Event, addState, removeState []string, event *gomatrixserverlib.Event, addState, removeState []string,
transactionID *api.TransactionID, transactionID *api.TransactionID,
) (streamPos int64, err error) { ) (streamPos int64, err error) {
var deviceID, txnID *string var txnID *string
var sessionID *int64
if transactionID != nil { if transactionID != nil {
deviceID = &transactionID.DeviceID sessionID = &transactionID.SessionID
txnID = &transactionID.TransactionID txnID = &transactionID.TransactionID
} }
// Parse content as JSON and search for an "url" key
containsURL := false
var content map[string]interface{}
if json.Unmarshal(event.Content(), &content) != nil {
// Set containsURL to true if url is present
_, containsURL = content["url"]
}
stmt := common.TxStmt(txn, s.insertEventStmt) stmt := common.TxStmt(txn, s.insertEventStmt)
err = stmt.QueryRowContext( err = stmt.QueryRowContext(
ctx, ctx,
event.RoomID(), event.RoomID(),
event.EventID(), event.EventID(),
event.JSON(), event.JSON(),
event.Type(),
event.Sender(),
containsURL,
pq.StringArray(addState), pq.StringArray(addState),
pq.StringArray(removeState), pq.StringArray(removeState),
deviceID, sessionID,
txnID, txnID,
).Scan(&streamPos) ).Scan(&streamPos)
return return
@ -223,7 +256,7 @@ func (s *outputRoomEventsStatements) insertEvent(
// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. // RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'.
func (s *outputRoomEventsStatements) selectRecentEvents( func (s *outputRoomEventsStatements) selectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, fromPos, toPos types.StreamPosition, limit int, roomID string, fromPos, toPos int64, limit int,
) ([]streamEvent, error) { ) ([]streamEvent, error) {
stmt := common.TxStmt(txn, s.selectRecentEventsStmt) stmt := common.TxStmt(txn, s.selectRecentEventsStmt)
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
@ -236,7 +269,7 @@ func (s *outputRoomEventsStatements) selectRecentEvents(
return nil, err return nil, err
} }
// The events need to be returned from oldest to latest, which isn't // The events need to be returned from oldest to latest, which isn't
// necessary the way the SQL query returns them, so a sort is necessary to // necessarily the way the SQL query returns them, so a sort is necessary to
// ensure the events are in the right order in the slice. // ensure the events are in the right order in the slice.
sort.SliceStable(events, func(i int, j int) bool { sort.SliceStable(events, func(i int, j int) bool {
return events[i].streamPosition < events[j].streamPosition return events[i].streamPosition < events[j].streamPosition
@ -264,11 +297,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
var ( var (
streamPos int64 streamPos int64
eventBytes []byte eventBytes []byte
deviceID *string sessionID *int64
txnID *string txnID *string
transactionID *api.TransactionID transactionID *api.TransactionID
) )
if err := rows.Scan(&streamPos, &eventBytes, &deviceID, &txnID); err != nil { if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
@ -277,16 +310,16 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
return nil, err return nil, err
} }
if deviceID != nil && txnID != nil { if sessionID != nil && txnID != nil {
transactionID = &api.TransactionID{ transactionID = &api.TransactionID{
DeviceID: *deviceID, SessionID: *sessionID,
TransactionID: *txnID, TransactionID: *txnID,
} }
} }
result = append(result, streamEvent{ result = append(result, streamEvent{
Event: ev, Event: ev,
streamPosition: types.StreamPosition(streamPos), streamPosition: streamPos,
transactionID: transactionID, transactionID: transactionID,
}) })
} }

View file

@ -17,7 +17,10 @@ package storage
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"strconv"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -28,6 +31,7 @@ import (
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/typingserver/cache"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -35,33 +39,35 @@ type stateDelta struct {
roomID string roomID string
stateEvents []gomatrixserverlib.Event stateEvents []gomatrixserverlib.Event
membership string membership string
// The stream position of the latest membership event for this user, if applicable. // The PDU stream position of the latest membership event for this user, if applicable.
// Can be 0 if there is no membership event in this delta. // Can be 0 if there is no membership event in this delta.
membershipPos types.StreamPosition membershipPos int64
} }
// Same as gomatrixserverlib.Event but also has the stream position for this event. // Same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type streamEvent struct { type streamEvent struct {
gomatrixserverlib.Event gomatrixserverlib.Event
streamPosition types.StreamPosition streamPosition int64
transactionID *api.TransactionID transactionID *api.TransactionID
} }
// SyncServerDatabase represents a sync server database // SyncServerDatabase represents a sync server datasource which manages
type SyncServerDatabase struct { // both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
db *sql.DB db *sql.DB
common.PartitionOffsetStatements common.PartitionOffsetStatements
accountData accountDataStatements accountData accountDataStatements
events outputRoomEventsStatements events outputRoomEventsStatements
roomstate currentRoomStateStatements roomstate currentRoomStateStatements
invites inviteEventsStatements invites inviteEventsStatements
typingCache *cache.TypingCache
} }
// NewSyncServerDatabase creates a new sync server database // NewSyncServerDatabase creates a new sync server database
func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) {
var d SyncServerDatabase var d SyncServerDatasource
var err error var err error
if d.db, err = sql.Open("postgres", dataSourceName); err != nil { if d.db, err = sql.Open("postgres", dbDataSourceName); err != nil {
return nil, err return nil, err
} }
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
@ -79,11 +85,12 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) {
if err := d.invites.prepare(d.db); err != nil { if err := d.invites.prepare(d.db); err != nil {
return nil, err return nil, err
} }
d.typingCache = cache.NewTypingCache()
return &d, nil return &d, nil
} }
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.roomstate.selectJoinedUsers(ctx) return d.roomstate.selectJoinedUsers(ctx)
} }
@ -92,7 +99,7 @@ func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[str
// If an event is not found in the database then it will be omitted from the list. // If an event is not found in the database then it will be omitted from the list.
// Returns an error if there was a problem talking with the database. // Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events. // Does not include any transaction IDs in the returned events.
func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -104,38 +111,38 @@ func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]g
} }
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
// when generating the stream position for this event. Returns the sync stream position for the inserted event. // when generating the sync stream position for this event. Returns the sync stream position for the inserted event.
// Returns an error if there was a problem inserting this event. // Returns an error if there was a problem inserting this event.
func (d *SyncServerDatabase) WriteEvent( func (d *SyncServerDatasource) WriteEvent(
ctx context.Context, ctx context.Context,
ev *gomatrixserverlib.Event, ev *gomatrixserverlib.Event,
addStateEvents []gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event,
addStateEventIDs, removeStateEventIDs []string, addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, transactionID *api.TransactionID,
) (streamPos types.StreamPosition, returnErr error) { ) (pduPosition int64, returnErr error) {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID) pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID)
if err != nil { if err != nil {
return err return err
} }
streamPos = types.StreamPosition(pos) pduPosition = pos
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
// Nothing to do, the event may have just been a message event. // Nothing to do, the event may have just been a message event.
return nil return nil
} }
return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, streamPos) return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition)
}) })
return return
} }
func (d *SyncServerDatabase) updateRoomState( func (d *SyncServerDatasource) updateRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
removedEventIDs []string, removedEventIDs []string,
addedEvents []gomatrixserverlib.Event, addedEvents []gomatrixserverlib.Event,
streamPos types.StreamPosition, pduPosition int64,
) error { ) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs { for _, eventID := range removedEventIDs {
@ -157,7 +164,7 @@ func (d *SyncServerDatabase) updateRoomState(
} }
membership = &value membership = &value
} }
if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, int64(streamPos)); err != nil { if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
return err return err
} }
} }
@ -168,7 +175,7 @@ func (d *SyncServerDatabase) updateRoomState(
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error
func (d *SyncServerDatabase) GetStateEvent( func (d *SyncServerDatasource) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey)
@ -177,56 +184,60 @@ func (d *SyncServerDatabase) GetStateEvent(
// GetStateEventsForRoom fetches the state events for a given room. // GetStateEventsForRoom fetches the state events for a given room.
// Returns an empty slice if no state events could be found for this room. // Returns an empty slice if no state events could be found for this room.
// Returns an error if there was an issue with the retrieval. // Returns an error if there was an issue with the retrieval.
func (d *SyncServerDatabase) GetStateEventsForRoom( func (d *SyncServerDatasource) GetStateEventsForRoom(
ctx context.Context, roomID string, ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.FilterPart,
) (stateEvents []gomatrixserverlib.Event, err error) { ) (stateEvents []gomatrixserverlib.Event, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error { err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID) stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart)
return err return err
}) })
return return
} }
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. // SyncPosition returns the latest positions for syncing.
func (d *SyncServerDatabase) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) {
return d.syncStreamPositionTx(ctx, nil) return d.syncPositionTx(ctx, nil)
} }
func (d *SyncServerDatabase) syncStreamPositionTx( func (d *SyncServerDatasource) syncPositionTx(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (types.StreamPosition, error) { ) (sp types.SyncPosition, err error) {
maxID, err := d.events.selectMaxEventID(ctx, txn)
maxEventID, err := d.events.selectMaxEventID(ctx, txn)
if err != nil { if err != nil {
return 0, err return sp, err
} }
maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn)
if err != nil { if err != nil {
return 0, err return sp, err
} }
if maxAccountDataID > maxID { if maxAccountDataID > maxEventID {
maxID = maxAccountDataID maxEventID = maxAccountDataID
} }
maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
if err != nil { if err != nil {
return 0, err return sp, err
} }
if maxInviteID > maxID { if maxInviteID > maxEventID {
maxID = maxInviteID maxEventID = maxInviteID
} }
return types.StreamPosition(maxID), nil sp.PDUPosition = maxEventID
sp.TypingPosition = d.typingCache.GetLatestSyncPosition()
return
} }
// IncrementalSync returns all the data needed in order to create an incremental // addPDUDeltaToResponse adds all PDU deltas to a sync response.
// sync response for the given user. Events returned will include any client // IDs of all rooms the user joined are returned so EDU deltas can be added for them.
// transaction IDs associated with the given device. These transaction IDs come func (d *SyncServerDatasource) addPDUDeltaToResponse(
// from when the device sent the event via an API that included a transaction
// ID.
func (d *SyncServerDatabase) IncrementalSync(
ctx context.Context, ctx context.Context,
device authtypes.Device, device authtypes.Device,
fromPos, toPos types.StreamPosition, fromPos, toPos int64,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) (*types.Response, error) { wantFullState bool,
res *types.Response,
) ([]string, error) {
txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
if err != nil { if err != nil {
return nil, err return nil, err
@ -234,16 +245,27 @@ func (d *SyncServerDatabase) IncrementalSync(
var succeeded bool var succeeded bool
defer common.EndTransaction(txn, &succeeded) defer common.EndTransaction(txn, &succeeded)
stateFilterPart := gomatrixserverlib.DefaultFilterPart() // TODO: use filter provided in request
// Work out which rooms to return in the response. This is done by getting not only the currently // Work out which rooms to return in the response. This is done by getting not only the currently
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions.
// This works out what the 'state' key should be for each room as well as which membership block // This works out what the 'state' key should be for each room as well as which membership block
// to put the room into. // to put the room into.
deltas, err := d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID) var deltas []stateDelta
var joinedRoomIDs []string
if !wantFullState {
deltas, joinedRoomIDs, err = d.getStateDeltas(
ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart,
)
} else {
deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync(
ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart,
)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
res := types.NewResponse(toPos)
for _, delta := range deltas { for _, delta := range deltas {
err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res)
if err != nil { if err != nil {
@ -257,52 +279,154 @@ func (d *SyncServerDatabase) IncrementalSync(
} }
succeeded = true succeeded = true
return joinedRoomIDs, nil
}
// addTypingDeltaToResponse adds all typing notifications to a sync response
// since the specified position.
func (d *SyncServerDatasource) addTypingDeltaToResponse(
since int64,
joinedRoomIDs []string,
res *types.Response,
) error {
var jr types.JoinResponse
var ok bool
var err error
for _, roomID := range joinedRoomIDs {
if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter(
roomID, since,
); updated {
ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MTyping,
}
ev.Content, err = json.Marshal(map[string]interface{}{
"user_ids": typingUsers,
})
if err != nil {
return err
}
if jr, ok = res.Rooms.Join[roomID]; !ok {
jr = *types.NewJoinResponse()
}
jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev)
res.Rooms.Join[roomID] = jr
}
}
return nil
}
// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if
// the positions of that type are not equal in fromPos and toPos.
func (d *SyncServerDatasource) addEDUDeltaToResponse(
fromPos, toPos types.SyncPosition,
joinedRoomIDs []string,
res *types.Response,
) (err error) {
if fromPos.TypingPosition != toPos.TypingPosition {
err = d.addTypingDeltaToResponse(
fromPos.TypingPosition, joinedRoomIDs, res,
)
}
return
}
// IncrementalSync returns all the data needed in order to create an incremental
// sync response for the given user. Events returned will include any client
// transaction IDs associated with the given device. These transaction IDs come
// from when the device sent the event via an API that included a transaction
// ID.
func (d *SyncServerDatasource) IncrementalSync(
ctx context.Context,
device authtypes.Device,
fromPos, toPos types.SyncPosition,
numRecentEventsPerRoom int,
wantFullState bool,
) (*types.Response, error) {
nextBatchPos := fromPos.WithUpdates(toPos)
res := types.NewResponse(nextBatchPos)
var joinedRoomIDs []string
var err error
if fromPos.PDUPosition != toPos.PDUPosition || wantFullState {
joinedRoomIDs, err = d.addPDUDeltaToResponse(
ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res,
)
} else {
joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(
ctx, nil, device.UserID, gomatrixserverlib.Join,
)
}
if err != nil {
return nil, err
}
err = d.addEDUDeltaToResponse(
fromPos, toPos, joinedRoomIDs, res,
)
if err != nil {
return nil, err
}
return res, nil return res, nil
} }
// CompleteSync a complete /sync API response for the given user. // getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed
func (d *SyncServerDatabase) CompleteSync( // to it. It returns toPos and joinedRoomIDs for use of adding EDUs.
ctx context.Context, userID string, numRecentEventsPerRoom int, func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
) (*types.Response, error) { ctx context.Context,
userID string,
numRecentEventsPerRoom int,
) (
res *types.Response,
toPos types.SyncPosition,
joinedRoomIDs []string,
err error,
) {
// This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have
// a consistent view of the database throughout. This includes extracting the sync stream position. // a consistent view of the database throughout. This includes extracting the sync position.
// This does have the unfortunate side-effect that all the matrixy logic resides in this function, // This does have the unfortunate side-effect that all the matrixy logic resides in this function,
// but it's better to not hide the fact that this is being done in a transaction. // but it's better to not hide the fact that this is being done in a transaction.
txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
if err != nil { if err != nil {
return nil, err return
} }
var succeeded bool var succeeded bool
defer common.EndTransaction(txn, &succeeded) defer common.EndTransaction(txn, &succeeded)
// Get the current stream position which we will base the sync response on. // Get the current sync position which we will base the sync response on.
pos, err := d.syncStreamPositionTx(ctx, txn) toPos, err = d.syncPositionTx(ctx, txn)
if err != nil { if err != nil {
return nil, err return
} }
res = types.NewResponse(toPos)
// Extract room state and recent events for all rooms the user is joined to. // Extract room state and recent events for all rooms the user is joined to.
roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil { if err != nil {
return nil, err return
} }
stateFilterPart := gomatrixserverlib.DefaultFilterPart() // TODO: use filter provided in request
// Build up a /sync response. Add joined rooms. // Build up a /sync response. Add joined rooms.
res := types.NewResponse(pos) for _, roomID := range joinedRoomIDs {
for _, roomID := range roomIDs {
var stateEvents []gomatrixserverlib.Event var stateEvents []gomatrixserverlib.Event
stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID) stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart)
if err != nil { if err != nil {
return nil, err return
} }
// TODO: When filters are added, we may need to call this multiple times to get enough events. // TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
var recentStreamEvents []streamEvent var recentStreamEvents []streamEvent
recentStreamEvents, err = d.events.selectRecentEvents( recentStreamEvents, err = d.events.selectRecentEvents(
ctx, txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom, ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom,
) )
if err != nil { if err != nil {
return nil, err return
} }
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
@ -311,10 +435,12 @@ func (d *SyncServerDatabase) CompleteSync(
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 {
jr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() // Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
} else { } else {
jr.Timeline.PrevBatch = types.StreamPosition(1).String() // Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = "1"
} }
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = true jr.Timeline.Limited = true
@ -322,12 +448,34 @@ func (d *SyncServerDatabase) CompleteSync(
res.Rooms.Join[roomID] = *jr res.Rooms.Join[roomID] = *jr
} }
if err = d.addInvitesToResponse(ctx, txn, userID, 0, pos, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil {
return nil, err return
} }
succeeded = true succeeded = true
return res, err return res, toPos, joinedRoomIDs, err
}
// CompleteSync returns a complete /sync API response for the given user.
func (d *SyncServerDatasource) CompleteSync(
ctx context.Context, userID string, numRecentEventsPerRoom int,
) (*types.Response, error) {
res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
ctx, userID, numRecentEventsPerRoom,
)
if err != nil {
return nil, err
}
// Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse(
types.SyncPosition{}, toPos, joinedRoomIDs, res,
)
if err != nil {
return nil, err
}
return res, nil
} }
var txReadOnlySnapshot = sql.TxOptions{ var txReadOnlySnapshot = sql.TxOptions{
@ -345,10 +493,11 @@ var txReadOnlySnapshot = sql.TxOptions{
// Returns a map following the format data[roomID] = []dataTypes // Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map // If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error // If there was an issue with the retrieval, returns an error
func (d *SyncServerDatabase) GetAccountDataInRange( func (d *SyncServerDatasource) GetAccountDataInRange(
ctx context.Context, userID string, oldPos, newPos types.StreamPosition, ctx context.Context, userID string, oldPos, newPos int64,
accountDataFilterPart *gomatrixserverlib.FilterPart,
) (map[string][]string, error) { ) (map[string][]string, error) {
return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos) return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
} }
// UpsertAccountData keeps track of new or updated account data, by saving the type // UpsertAccountData keeps track of new or updated account data, by saving the type
@ -357,26 +506,24 @@ func (d *SyncServerDatabase) GetAccountDataInRange(
// If no data with the given type, user ID and room ID exists in the database, // If no data with the given type, user ID and room ID exists in the database,
// creates a new row, else update the existing one // creates a new row, else update the existing one
// Returns an error if there was an issue with the upsert // Returns an error if there was an issue with the upsert
func (d *SyncServerDatabase) UpsertAccountData( func (d *SyncServerDatasource) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string, ctx context.Context, userID, roomID, dataType string,
) (types.StreamPosition, error) { ) (int64, error) {
pos, err := d.accountData.insertAccountData(ctx, userID, roomID, dataType) return d.accountData.insertAccountData(ctx, userID, roomID, dataType)
return types.StreamPosition(pos), err
} }
// AddInviteEvent stores a new invite event for a user. // AddInviteEvent stores a new invite event for a user.
// If the invite was successfully stored this returns the stream ID it was stored at. // If the invite was successfully stored this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatabase) AddInviteEvent( func (d *SyncServerDatasource) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event, ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (types.StreamPosition, error) { ) (int64, error) {
pos, err := d.invites.insertInviteEvent(ctx, inviteEvent) return d.invites.insertInviteEvent(ctx, inviteEvent)
return types.StreamPosition(pos), err
} }
// RetireInviteEvent removes an old invite event from the database. // RetireInviteEvent removes an old invite event from the database.
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatabase) RetireInviteEvent( func (d *SyncServerDatasource) RetireInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) error { ) error {
// TODO: Record that invite has been retired in a stream so that we can // TODO: Record that invite has been retired in a stream so that we can
@ -385,10 +532,30 @@ func (d *SyncServerDatabase) RetireInviteEvent(
return err return err
} }
func (d *SyncServerDatabase) addInvitesToResponse( func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
d.typingCache.SetTimeoutCallback(fn)
}
// AddTypingUser adds a typing user to the typing cache.
// Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) AddTypingUser(
userID, roomID string, expireTime *time.Time,
) int64 {
return d.typingCache.AddTypingUser(userID, roomID, expireTime)
}
// RemoveTypingUser removes a typing user from the typing cache.
// Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) RemoveTypingUser(
userID, roomID string,
) int64 {
return d.typingCache.RemoveUser(userID, roomID)
}
func (d *SyncServerDatasource) addInvitesToResponse(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID string, userID string,
fromPos, toPos types.StreamPosition, fromPos, toPos int64,
res *types.Response, res *types.Response,
) error { ) error {
invites, err := d.invites.selectInviteEventsInRange( invites, err := d.invites.selectInviteEventsInRange(
@ -409,17 +576,17 @@ func (d *SyncServerDatabase) addInvitesToResponse(
} }
// addRoomDeltaToResponse adds a room state delta to a sync response // addRoomDeltaToResponse adds a room state delta to a sync response
func (d *SyncServerDatabase) addRoomDeltaToResponse( func (d *SyncServerDatasource) addRoomDeltaToResponse(
ctx context.Context, ctx context.Context,
device *authtypes.Device, device *authtypes.Device,
txn *sql.Tx, txn *sql.Tx,
fromPos, toPos types.StreamPosition, fromPos, toPos int64,
delta stateDelta, delta stateDelta,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
res *types.Response, res *types.Response,
) error { ) error {
endPos := toPos endPos := toPos
if delta.membershipPos > 0 && delta.membership == "leave" { if delta.membershipPos > 0 && delta.membership == gomatrixserverlib.Leave {
// make sure we don't leak recent events after the leave event. // make sure we don't leak recent events after the leave event.
// TODO: History visibility makes this somewhat complex to handle correctly. For example: // TODO: History visibility makes this somewhat complex to handle correctly. For example:
// TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join). // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join).
@ -437,34 +604,42 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse(
recentEvents := streamEventsToEvents(device, recentStreamEvents) recentEvents := streamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
var prevPDUPos int64
if len(recentEvents) == 0 {
if len(delta.stateEvents) == 0 {
// Don't bother appending empty room entries // Don't bother appending empty room entries
if len(recentEvents) == 0 && len(delta.stateEvents) == 0 {
return nil return nil
} }
switch delta.membership { // If full_state=true and since is already up to date, then we'll have
case "join": // state events but no recent events.
jr := types.NewJoinResponse() prevPDUPos = toPos - 1
if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 {
jr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String()
} else { } else {
jr.Timeline.PrevBatch = types.StreamPosition(1).String() prevPDUPos = recentStreamEvents[0].streamPosition - 1
} }
if prevPDUPos <= 0 {
prevPDUPos = 1
}
switch delta.membership {
case gomatrixserverlib.Join:
jr := types.NewJoinResponse()
// Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.roomID] = *jr res.Rooms.Join[delta.roomID] = *jr
case "leave": case gomatrixserverlib.Leave:
fallthrough // transitions to leave are the same as ban fallthrough // transitions to leave are the same as ban
case "ban": case gomatrixserverlib.Ban:
// TODO: recentEvents may contain events that this user is not allowed to see because they are // TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { // Use the short form of batch token for prev_batch
lr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
} else {
lr.Timeline.PrevBatch = types.StreamPosition(1).String()
}
lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -476,7 +651,7 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse(
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. // fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
// Returns a map of room ID to list of events. // Returns a map of room ID to list of events.
func (d *SyncServerDatabase) fetchStateEvents( func (d *SyncServerDatasource) fetchStateEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomIDToEventIDSet map[string]map[string]bool, roomIDToEventIDSet map[string]map[string]bool,
eventIDToEvent map[string]streamEvent, eventIDToEvent map[string]streamEvent,
@ -521,7 +696,7 @@ func (d *SyncServerDatabase) fetchStateEvents(
return stateBetween, nil return stateBetween, nil
} }
func (d *SyncServerDatabase) fetchMissingStateEvents( func (d *SyncServerDatasource) fetchMissingStateEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]streamEvent, error) { ) ([]streamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the // Fetch from the events table first so we pick up the stream ID for the
@ -560,10 +735,15 @@ func (d *SyncServerDatabase) fetchMissingStateEvents(
return events, nil return events, nil
} }
func (d *SyncServerDatabase) getStateDeltas( // getStateDeltas returns the state deltas between fromPos and toPos,
// exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it.
func (d *SyncServerDatasource) getStateDeltas(
ctx context.Context, device *authtypes.Device, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos types.StreamPosition, userID string, fromPos, toPos int64, userID string,
) ([]stateDelta, error) { stateFilterPart *gomatrixserverlib.FilterPart,
) ([]stateDelta, []string, error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
// - Get membership list changes for this user in this sync response // - Get membership list changes for this user in this sync response
// - For each room which has membership list changes: // - For each room which has membership list changes:
@ -575,13 +755,13 @@ func (d *SyncServerDatabase) getStateDeltas(
var deltas []stateDelta var deltas []stateDelta
// get all the state events ever between these two positions // get all the state events ever between these two positions
stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos) stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
for roomID, stateStreamEvents := range state { for roomID, stateStreamEvents := range state {
@ -592,16 +772,12 @@ func (d *SyncServerDatabase) getStateDeltas(
// the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to
// the timeline. // the timeline.
if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
if membership == "join" { if membership == gomatrixserverlib.Join {
// send full room state down instead of a delta // send full room state down instead of a delta
var allState []gomatrixserverlib.Event var s []streamEvent
allState, err = d.roomstate.selectCurrentState(ctx, txn, roomID) s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart)
if err != nil { if err != nil {
return nil, err return nil, nil, err
}
s := make([]streamEvent, len(allState))
for i := 0; i < len(s); i++ {
s[i] = streamEvent{Event: allState[i], streamPosition: types.StreamPosition(0)}
} }
state[roomID] = s state[roomID] = s
continue // we'll add this room in when we do joined rooms continue // we'll add this room in when we do joined rooms
@ -619,19 +795,94 @@ func (d *SyncServerDatabase) getStateDeltas(
} }
// Add in currently joined rooms // Add in currently joined rooms
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: "join", membership: gomatrixserverlib.Join,
stateEvents: streamEventsToEvents(device, state[joinedRoomID]), stateEvents: streamEventsToEvents(device, state[joinedRoomID]),
roomID: joinedRoomID, roomID: joinedRoomID,
}) })
} }
return deltas, nil return deltas, joinedRoomIDs, nil
}
// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
// requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms.
func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos int64, userID string,
stateFilterPart *gomatrixserverlib.FilterPart,
) ([]stateDelta, []string, error) {
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil {
return nil, nil, err
}
// Use a reasonable initial capacity
deltas := make([]stateDelta, 0, len(joinedRoomIDs))
// Add full states for all joined rooms
for _, joinedRoomID := range joinedRoomIDs {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilterPart)
if stateErr != nil {
return nil, nil, stateErr
}
deltas = append(deltas, stateDelta{
membership: gomatrixserverlib.Join,
stateEvents: streamEventsToEvents(device, s),
roomID: joinedRoomID,
})
}
// Get all the state events ever between these two positions
stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart)
if err != nil {
return nil, nil, err
}
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
if err != nil {
return nil, nil, err
}
for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents {
if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
deltas = append(deltas, stateDelta{
membership: membership,
membershipPos: ev.streamPosition,
stateEvents: streamEventsToEvents(device, stateStreamEvents),
roomID: roomID,
})
}
break
}
}
}
return deltas, joinedRoomIDs, nil
}
func (d *SyncServerDatasource) currentStateStreamEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
stateFilterPart *gomatrixserverlib.FilterPart,
) ([]streamEvent, error) {
allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart)
if err != nil {
return nil, err
}
s := make([]streamEvent, len(allState))
for i := 0; i < len(s); i++ {
s[i] = streamEvent{Event: allState[i], streamPosition: 0}
}
return s, nil
} }
// streamEventsToEvents converts streamEvent to Event. If device is non-nil and // streamEventsToEvents converts streamEvent to Event. If device is non-nil and
@ -642,7 +893,7 @@ func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrix
for i := 0; i < len(in); i++ { for i := 0; i < len(in); i++ {
out[i] = in[i].Event out[i] = in[i].Event
if device != nil && in[i].transactionID != nil { if device != nil && in[i].transactionID != nil {
if device.UserID == in[i].Sender() && device.ID == in[i].transactionID.DeviceID { if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID {
err := out[i].SetUnsignedField( err := out[i].SetUnsignedField(
"transaction_id", in[i].transactionID.TransactionID, "transaction_id", in[i].transactionID.TransactionID,
) )

View file

@ -26,7 +26,7 @@ import (
) )
// Notifier will wake up sleeping requests when there is some new data. // Notifier will wake up sleeping requests when there is some new data.
// It does not tell requests what that data is, only the stream position which // It does not tell requests what that data is, only the sync position which
// they can use to get at it. This is done to prevent races whereby we tell the caller // they can use to get at it. This is done to prevent races whereby we tell the caller
// the event, but the token has already advanced by the time they fetch it, resulting // the event, but the token has already advanced by the time they fetch it, resulting
// in missed events. // in missed events.
@ -35,18 +35,18 @@ type Notifier struct {
roomIDToJoinedUsers map[string]userIDSet roomIDToJoinedUsers map[string]userIDSet
// Protects currPos and userStreams. // Protects currPos and userStreams.
streamLock *sync.Mutex streamLock *sync.Mutex
// The latest sync stream position // The latest sync position
currPos types.StreamPosition currPos types.SyncPosition
// A map of user_id => UserStream which can be used to wake a given user's /sync request. // A map of user_id => UserStream which can be used to wake a given user's /sync request.
userStreams map[string]*UserStream userStreams map[string]*UserStream
// The last time we cleaned out stale entries from the userStreams map // The last time we cleaned out stale entries from the userStreams map
lastCleanUpTime time.Time lastCleanUpTime time.Time
} }
// NewNotifier creates a new notifier set to the given stream position. // NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and // In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier(pos types.StreamPosition) *Notifier { func NewNotifier(pos types.SyncPosition) *Notifier {
return &Notifier{ return &Notifier{
currPos: pos, currPos: pos,
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]userIDSet),
@ -58,20 +58,30 @@ func NewNotifier(pos types.StreamPosition) *Notifier {
// OnNewEvent is called when a new event is received from the room server. Must only be // OnNewEvent is called when a new event is received from the room server. Must only be
// called from a single goroutine, to avoid races between updates which could set the // called from a single goroutine, to avoid races between updates which could set the
// current position in the stream incorrectly. // current sync position incorrectly.
// Can be called either with a *gomatrixserverlib.Event, or with an user ID // Chooses which user sync streams to update by a provided *gomatrixserverlib.Event
func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, userID string, pos types.StreamPosition) { // (based on the users in the event's room),
// a roomID directly, or a list of user IDs, prioritised by parameter ordering.
// posUpdate contains the latest position(s) for one or more types of events.
// If a position in posUpdate is 0, it means no updates are available of that type.
// Typically a consumer supplies a posUpdate with the latest sync position for the
// event type it handles, leaving other fields as 0.
func (n *Notifier) OnNewEvent(
ev *gomatrixserverlib.Event, roomID string, userIDs []string,
posUpdate types.SyncPosition,
) {
// update the current position then notify relevant /sync streams. // update the current position then notify relevant /sync streams.
// This needs to be done PRIOR to waking up users as they will read this value. // This needs to be done PRIOR to waking up users as they will read this value.
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
n.currPos = pos latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.removeEmptyUserStreams() n.removeEmptyUserStreams()
if ev != nil { if ev != nil {
// Map this event's room_id to a list of joined users, and wake them up. // Map this event's room_id to a list of joined users, and wake them up.
userIDs := n.joinedUsers(ev.RoomID()) usersToNotify := n.joinedUsers(ev.RoomID())
// If this is an invite, also add in the invitee to this list. // If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil { if ev.Type() == "m.room.member" && ev.StateKey() != nil {
targetUserID := *ev.StateKey() targetUserID := *ev.StateKey()
@ -83,26 +93,30 @@ func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, userID string, pos ty
} else { } else {
// Keep the joined user map up-to-date // Keep the joined user map up-to-date
switch membership { switch membership {
case "invite": case gomatrixserverlib.Invite:
userIDs = append(userIDs, targetUserID) usersToNotify = append(usersToNotify, targetUserID)
case "join": case gomatrixserverlib.Join:
// Manually append the new user's ID so they get notified // Manually append the new user's ID so they get notified
// along all members in the room // along all members in the room
userIDs = append(userIDs, targetUserID) usersToNotify = append(usersToNotify, targetUserID)
n.addJoinedUser(ev.RoomID(), targetUserID) n.addJoinedUser(ev.RoomID(), targetUserID)
case "leave": case gomatrixserverlib.Leave:
fallthrough fallthrough
case "ban": case gomatrixserverlib.Ban:
n.removeJoinedUser(ev.RoomID(), targetUserID) n.removeJoinedUser(ev.RoomID(), targetUserID)
} }
} }
} }
for _, toNotifyUserID := range userIDs { n.wakeupUsers(usersToNotify, latestPos)
n.wakeupUser(toNotifyUserID, pos) } else if roomID != "" {
} n.wakeupUsers(n.joinedUsers(roomID), latestPos)
} else if len(userID) > 0 { } else if len(userIDs) > 0 {
n.wakeupUser(userID, pos) n.wakeupUsers(userIDs, latestPos)
} else {
log.WithFields(log.Fields{
"posUpdate": posUpdate.String,
}).Warn("Notifier.OnNewEvent called but caller supplied no user to wake up")
} }
} }
@ -127,7 +141,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener {
} }
// Load the membership states required to notify users correctly. // Load the membership states required to notify users correctly.
func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatabase) error { func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatasource) error {
roomToUsers, err := db.AllJoinedUsersInRooms(ctx) roomToUsers, err := db.AllJoinedUsersInRooms(ctx)
if err != nil { if err != nil {
return err return err
@ -136,8 +150,11 @@ func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatabase) err
return nil return nil
} }
// CurrentPosition returns the current stream position // CurrentPosition returns the current sync position
func (n *Notifier) CurrentPosition() types.StreamPosition { func (n *Notifier) CurrentPosition() types.SyncPosition {
n.streamLock.Lock()
defer n.streamLock.Unlock()
return n.currPos return n.currPos
} }
@ -156,17 +173,19 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
} }
} }
func (n *Notifier) wakeupUser(userID string, newPos types.StreamPosition) { func (n *Notifier) wakeupUsers(userIDs []string, newPos types.SyncPosition) {
for _, userID := range userIDs {
stream := n.fetchUserStream(userID, false) stream := n.fetchUserStream(userID, false)
if stream == nil { if stream != nil {
return stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
}
} }
stream.Broadcast(newPos) // wakeup all goroutines Wait()ing on this stream
} }
// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, // fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true,
// a stream will be made for this user if one doesn't exist and it will be returned. This // a stream will be made for this user if one doesn't exist and it will be returned. This
// function does not wait for data to be available on the stream. // function does not wait for data to be available on the stream.
// NB: Callers should have locked the mutex before calling this function.
func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream { func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream {
stream, ok := n.userStreams[userID] stream, ok := n.userStreams[userID]
if !ok && makeIfNotExists { if !ok && makeIfNotExists {

View file

@ -32,19 +32,40 @@ var (
randomMessageEvent gomatrixserverlib.Event randomMessageEvent gomatrixserverlib.Event
aliceInviteBobEvent gomatrixserverlib.Event aliceInviteBobEvent gomatrixserverlib.Event
bobLeaveEvent gomatrixserverlib.Event bobLeaveEvent gomatrixserverlib.Event
syncPositionVeryOld types.SyncPosition
syncPositionBefore types.SyncPosition
syncPositionAfter types.SyncPosition
syncPositionNewEDU types.SyncPosition
syncPositionAfter2 types.SyncPosition
) )
var ( var (
streamPositionVeryOld = types.StreamPosition(5)
streamPositionBefore = types.StreamPosition(11)
streamPositionAfter = types.StreamPosition(12)
streamPositionAfter2 = types.StreamPosition(13)
roomID = "!test:localhost" roomID = "!test:localhost"
alice = "@alice:localhost" alice = "@alice:localhost"
bob = "@bob:localhost" bob = "@bob:localhost"
) )
func init() { func init() {
baseSyncPos := types.SyncPosition{
PDUPosition: 0,
TypingPosition: 0,
}
syncPositionVeryOld = baseSyncPos
syncPositionVeryOld.PDUPosition = 5
syncPositionBefore = baseSyncPos
syncPositionBefore.PDUPosition = 11
syncPositionAfter = baseSyncPos
syncPositionAfter.PDUPosition = 12
syncPositionNewEDU = syncPositionAfter
syncPositionNewEDU.TypingPosition = 1
syncPositionAfter2 = baseSyncPos
syncPositionAfter2.PDUPosition = 13
var err error var err error
randomMessageEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ randomMessageEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{
"type": "m.room.message", "type": "m.room.message",
@ -92,19 +113,19 @@ func init() {
// Test that the current position is returned if a request is already behind. // Test that the current position is returned if a request is already behind.
func TestImmediateNotification(t *testing.T) { func TestImmediateNotification(t *testing.T) {
n := NewNotifier(streamPositionBefore) n := NewNotifier(syncPositionBefore)
pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionVeryOld)) pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionVeryOld))
if err != nil { if err != nil {
t.Fatalf("TestImmediateNotification error: %s", err) t.Fatalf("TestImmediateNotification error: %s", err)
} }
if pos != streamPositionBefore { if pos != syncPositionBefore {
t.Fatalf("TestImmediateNotification want %d, got %d", streamPositionBefore, pos) t.Fatalf("TestImmediateNotification want %d, got %d", syncPositionBefore, pos)
} }
} }
// Test that new events to a joined room unblocks the request. // Test that new events to a joined room unblocks the request.
func TestNewEventAndJoinedToRoom(t *testing.T) { func TestNewEventAndJoinedToRoom(t *testing.T) {
n := NewNotifier(streamPositionBefore) n := NewNotifier(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -112,27 +133,27 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestNewEventAndJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndJoinedToRoom error: %s", err)
} }
if pos != streamPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", streamPositionAfter, pos) t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
}() }()
stream := n.fetchUserStream(bob, true) stream := lockedFetchUserStream(n, bob)
waitForBlocking(stream, 1) waitForBlocking(stream, 1)
n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
wg.Wait() wg.Wait()
} }
// Test that an invite unblocks the request // Test that an invite unblocks the request
func TestNewInviteEventForUser(t *testing.T) { func TestNewInviteEventForUser(t *testing.T) {
n := NewNotifier(streamPositionBefore) n := NewNotifier(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -140,27 +161,55 @@ func TestNewInviteEventForUser(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestNewInviteEventForUser error: %s", err) t.Errorf("TestNewInviteEventForUser error: %s", err)
} }
if pos != streamPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewInviteEventForUser want %d, got %d", streamPositionAfter, pos) t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
}() }()
stream := n.fetchUserStream(bob, true) stream := lockedFetchUserStream(n, bob)
waitForBlocking(stream, 1) waitForBlocking(stream, 1)
n.OnNewEvent(&aliceInviteBobEvent, "", streamPositionAfter) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter)
wg.Wait()
}
// Test an EDU-only update wakes up the request.
func TestEDUWakeup(t *testing.T) {
n := NewNotifier(syncPositionAfter)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
})
var wg sync.WaitGroup
wg.Add(1)
go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter))
if err != nil {
t.Errorf("TestNewInviteEventForUser error: %s", err)
}
if pos != syncPositionNewEDU {
t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionNewEDU, pos)
}
wg.Done()
}()
stream := lockedFetchUserStream(n, bob)
waitForBlocking(stream, 1)
n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU)
wg.Wait() wg.Wait()
} }
// Test that all blocked requests get woken up on a new event. // Test that all blocked requests get woken up on a new event.
func TestMultipleRequestWakeup(t *testing.T) { func TestMultipleRequestWakeup(t *testing.T) {
n := NewNotifier(streamPositionBefore) n := NewNotifier(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -168,12 +217,12 @@ func TestMultipleRequestWakeup(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(3) wg.Add(3)
poll := func() { poll := func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestMultipleRequestWakeup error: %s", err) t.Errorf("TestMultipleRequestWakeup error: %s", err)
} }
if pos != streamPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestMultipleRequestWakeup want %d, got %d", streamPositionAfter, pos) t.Errorf("TestMultipleRequestWakeup want %d, got %d", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
} }
@ -181,10 +230,10 @@ func TestMultipleRequestWakeup(t *testing.T) {
go poll() go poll()
go poll() go poll()
stream := n.fetchUserStream(bob, true) stream := lockedFetchUserStream(n, bob)
waitForBlocking(stream, 3) waitForBlocking(stream, 3)
n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
wg.Wait() wg.Wait()
@ -198,7 +247,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
// listen as bob. Make bob leave room. Make alice send event to room. // listen as bob. Make bob leave room. Make alice send event to room.
// Make sure alice gets woken up only and not bob as well. // Make sure alice gets woken up only and not bob as well.
n := NewNotifier(streamPositionBefore) n := NewNotifier(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -208,38 +257,38 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
// Make bob leave the room // Make bob leave the room
leaveWG.Add(1) leaveWG.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
} }
if pos != streamPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter, pos) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter, pos)
} }
leaveWG.Done() leaveWG.Done()
}() }()
bobStream := n.fetchUserStream(bob, true) bobStream := lockedFetchUserStream(n, bob)
waitForBlocking(bobStream, 1) waitForBlocking(bobStream, 1)
n.OnNewEvent(&bobLeaveEvent, "", streamPositionAfter) n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter)
leaveWG.Wait() leaveWG.Wait()
// send an event into the room. Make sure alice gets it. Bob should not. // send an event into the room. Make sure alice gets it. Bob should not.
var aliceWG sync.WaitGroup var aliceWG sync.WaitGroup
aliceStream := n.fetchUserStream(alice, true) aliceStream := lockedFetchUserStream(n, alice)
aliceWG.Add(1) aliceWG.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionAfter)) pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter))
if err != nil { if err != nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
} }
if pos != streamPositionAfter2 { if pos != syncPositionAfter2 {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter2, pos) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter2, pos)
} }
aliceWG.Done() aliceWG.Done()
}() }()
go func() { go func() {
// this should timeout with an error (but the main goroutine won't wait for the timeout explicitly) // this should timeout with an error (but the main goroutine won't wait for the timeout explicitly)
_, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionAfter)) _, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter))
if err == nil { if err == nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil") t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil")
} }
@ -248,7 +297,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
waitForBlocking(aliceStream, 1) waitForBlocking(aliceStream, 1)
waitForBlocking(bobStream, 1) waitForBlocking(bobStream, 1)
n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter2) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter2)
aliceWG.Wait() aliceWG.Wait()
// it's possible that at this point alice has been informed and bob is about to be informed, so wait // it's possible that at this point alice has been informed and bob is about to be informed, so wait
@ -256,18 +305,17 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
// same as Notifier.WaitForEvents but with a timeout. func waitForEvents(n *Notifier, req syncRequest) (types.SyncPosition, error) {
func waitForEvents(n *Notifier, req syncRequest) (types.StreamPosition, error) {
listener := n.GetListener(req) listener := n.GetListener(req)
defer listener.Close() defer listener.Close()
select { select {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
return types.StreamPosition(0), fmt.Errorf( return types.SyncPosition{}, fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since, "waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since,
) )
case <-listener.GetNotifyChannel(*req.since): case <-listener.GetNotifyChannel(*req.since):
p := listener.GetStreamPosition() p := listener.GetSyncPosition()
return p, nil return p, nil
} }
} }
@ -280,7 +328,16 @@ func waitForBlocking(s *UserStream, numBlocking uint) {
} }
} }
func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest { // lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock.
// A new stream is made if it doesn't exist already.
func lockedFetchUserStream(n *Notifier, userID string) *UserStream {
n.streamLock.Lock()
defer n.streamLock.Unlock()
return n.fetchUserStream(userID, true)
}
func newTestSyncRequest(userID string, since types.SyncPosition) syncRequest {
return syncRequest{ return syncRequest{
device: authtypes.Device{UserID: userID}, device: authtypes.Device{UserID: userID},
timeout: 1 * time.Minute, timeout: 1 * time.Minute,

View file

@ -16,8 +16,10 @@ package sync
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -36,7 +38,7 @@ type syncRequest struct {
device authtypes.Device device authtypes.Device
limit int limit int
timeout time.Duration timeout time.Duration
since *types.StreamPosition // nil means that no since token was supplied since *types.SyncPosition // nil means that no since token was supplied
wantFullState bool wantFullState bool
log *log.Entry log *log.Entry
} }
@ -73,15 +75,41 @@ func getTimeout(timeoutMS string) time.Duration {
} }
// getSyncStreamPosition tries to parse a 'since' token taken from the API to a // getSyncStreamPosition tries to parse a 'since' token taken from the API to a
// stream position. If the string is empty then (nil, nil) is returned. // types.SyncPosition. If the string is empty then (nil, nil) is returned.
func getSyncStreamPosition(since string) (*types.StreamPosition, error) { // There are two forms of tokens: The full length form containing all PDU and EDU
// positions separated by "_", and the short form containing only the PDU
// position. Short form can be used for, e.g., `prev_batch` tokens.
func getSyncStreamPosition(since string) (*types.SyncPosition, error) {
if since == "" { if since == "" {
return nil, nil return nil, nil
} }
i, err := strconv.Atoi(since)
posStrings := strings.Split(since, "_")
if len(posStrings) != 2 && len(posStrings) != 1 {
// A token can either be full length or short (PDU-only).
return nil, errors.New("malformed batch token")
}
positions := make([]int64, len(posStrings))
for i, posString := range posStrings {
pos, err := strconv.ParseInt(posString, 10, 64)
if err != nil { if err != nil {
return nil, err return nil, err
} }
token := types.StreamPosition(i) positions[i] = pos
return &token, nil }
if len(positions) == 2 {
// Full length token; construct SyncPosition with every entry in
// `positions`. These entries must have the same order with the fields
// in struct SyncPosition, so we disable the govet check below.
return &types.SyncPosition{ //nolint:govet
positions[0], positions[1],
}, nil
} else {
// Token with PDU position only
return &types.SyncPosition{
PDUPosition: positions[0],
}, nil
}
} }

View file

@ -31,13 +31,13 @@ import (
// RequestPool manages HTTP long-poll connections for /sync // RequestPool manages HTTP long-poll connections for /sync
type RequestPool struct { type RequestPool struct {
db *storage.SyncServerDatabase db *storage.SyncServerDatasource
accountDB *accounts.Database accountDB *accounts.Database
notifier *Notifier notifier *Notifier
} }
// NewRequestPool makes a new RequestPool // NewRequestPool makes a new RequestPool
func NewRequestPool(db *storage.SyncServerDatabase, n *Notifier, adb *accounts.Database) *RequestPool { func NewRequestPool(db *storage.SyncServerDatasource, n *Notifier, adb *accounts.Database) *RequestPool {
return &RequestPool{db, adb, n} return &RequestPool{db, adb, n}
} }
@ -65,8 +65,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
currPos := rp.notifier.CurrentPosition() currPos := rp.notifier.CurrentPosition()
// If this is an initial sync or timeout=0 we return immediately if shouldReturnImmediately(syncReq) {
if syncReq.since == nil || syncReq.timeout == 0 {
syncData, err = rp.currentSyncForUser(*syncReq, currPos) syncData, err = rp.currentSyncForUser(*syncReq, currPos)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
@ -92,11 +91,13 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
// respond with, so we skip the return an go back to waiting for content to // respond with, so we skip the return an go back to waiting for content to
// be sent down or the request timing out. // be sent down or the request timing out.
var hasTimedOut bool var hasTimedOut bool
sincePos := *syncReq.since
for { for {
select { select {
// Wait for notifier to wake us up // Wait for notifier to wake us up
case <-userStreamListener.GetNotifyChannel(currPos): case <-userStreamListener.GetNotifyChannel(sincePos):
currPos = userStreamListener.GetStreamPosition() currPos = userStreamListener.GetSyncPosition()
sincePos = currPos
// Or for timeout to expire // Or for timeout to expire
case <-timer.C: case <-timer.C:
// We just need to ensure we get out of the select after reaching the // We just need to ensure we get out of the select after reaching the
@ -128,24 +129,26 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
} }
} }
func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) { func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncPosition) (res *types.Response, err error) {
// TODO: handle ignored users // TODO: handle ignored users
if req.since == nil { if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
} else { } else {
res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, currentPos, req.limit) res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState)
} }
if err != nil { if err != nil {
return return
} }
res, err = rp.appendAccountData(res, req.device.UserID, req, currentPos) accountDataFilter := gomatrixserverlib.DefaultFilterPart() // TODO: use filter provided in req instead
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter)
return return
} }
func (rp *RequestPool) appendAccountData( func (rp *RequestPool) appendAccountData(
data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition, data *types.Response, userID string, req syncRequest, currentPos int64,
accountDataFilter *gomatrixserverlib.FilterPart,
) (*types.Response, error) { ) (*types.Response, error) {
// TODO: Account data doesn't have a sync position of its own, meaning that // TODO: Account data doesn't have a sync position of its own, meaning that
// account data might be sent multiple time to the client if multiple account // account data might be sent multiple time to the client if multiple account
@ -179,7 +182,7 @@ func (rp *RequestPool) appendAccountData(
} }
// Sync is not initial, get all account data since the latest sync // Sync is not initial, get all account data since the latest sync
dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, *req.since, currentPos) dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, req.since.PDUPosition, currentPos, accountDataFilter)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -214,3 +217,10 @@ func (rp *RequestPool) appendAccountData(
return data, nil return data, nil
} }
// shouldReturnImmediately returns whether the /sync request is an initial sync,
// or timeout=0, or full_state=true, in any of the cases the request should
// return immediately.
func shouldReturnImmediately(syncReq *syncRequest) bool {
return syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState
}

View file

@ -34,8 +34,8 @@ type UserStream struct {
lock sync.Mutex lock sync.Mutex
// Closed when there is an update. // Closed when there is an update.
signalChannel chan struct{} signalChannel chan struct{}
// The last stream position that there may have been an update for the suser // The last sync position that there may have been an update for the user
pos types.StreamPosition pos types.SyncPosition
// The last time when we had some listeners waiting // The last time when we had some listeners waiting
timeOfLastChannel time.Time timeOfLastChannel time.Time
// The number of listeners waiting // The number of listeners waiting
@ -51,7 +51,7 @@ type UserStreamListener struct {
} }
// NewUserStream creates a new user stream // NewUserStream creates a new user stream
func NewUserStream(userID string, currPos types.StreamPosition) *UserStream { func NewUserStream(userID string, currPos types.SyncPosition) *UserStream {
return &UserStream{ return &UserStream{
UserID: userID, UserID: userID,
timeOfLastChannel: time.Now(), timeOfLastChannel: time.Now(),
@ -84,8 +84,8 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener {
return listener return listener
} }
// Broadcast a new stream position for this user. // Broadcast a new sync position for this user.
func (s *UserStream) Broadcast(pos types.StreamPosition) { func (s *UserStream) Broadcast(pos types.SyncPosition) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -118,9 +118,9 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time {
return s.timeOfLastChannel return s.timeOfLastChannel
} }
// GetStreamPosition returns last stream position which the UserStream was // GetStreamPosition returns last sync position which the UserStream was
// notified about // notified about
func (s *UserStreamListener) GetStreamPosition() types.StreamPosition { func (s *UserStreamListener) GetSyncPosition() types.SyncPosition {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()
@ -132,11 +132,11 @@ func (s *UserStreamListener) GetStreamPosition() types.StreamPosition {
// sincePos specifies from which point we want to be notified about. If there // sincePos specifies from which point we want to be notified about. If there
// has already been an update after sincePos we'll return a closed channel // has already been an update after sincePos we'll return a closed channel
// immediately. // immediately.
func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamPosition) <-chan struct{} { func (s *UserStreamListener) GetNotifyChannel(sincePos types.SyncPosition) <-chan struct{} {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()
if sincePos < s.userStream.pos { if s.userStream.pos.IsAfter(sincePos) {
// If the listener is behind, i.e. missed a potential update, then we // If the listener is behind, i.e. missed a potential update, then we
// want them to wake up immediately. We do this by returning a new // want them to wake up immediately. We do this by returning a new
// closed stream, which returns immediately when selected. // closed stream, which returns immediately when selected.

View file

@ -28,7 +28,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/routing"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
) )
// SetupSyncAPIComponent sets up and registers HTTP handlers for the SyncAPI // SetupSyncAPIComponent sets up and registers HTTP handlers for the SyncAPI
@ -39,17 +38,17 @@ func SetupSyncAPIComponent(
accountsDB *accounts.Database, accountsDB *accounts.Database,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) { ) {
syncDB, err := storage.NewSyncServerDatabase(string(base.Cfg.Database.SyncAPI)) syncDB, err := storage.NewSyncServerDatasource(string(base.Cfg.Database.SyncAPI))
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to sync db") logrus.WithError(err).Panicf("failed to connect to sync db")
} }
pos, err := syncDB.SyncStreamPosition(context.Background()) pos, err := syncDB.SyncPosition(context.Background())
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to get stream position") logrus.WithError(err).Panicf("failed to get sync position")
} }
notifier := sync.NewNotifier(types.StreamPosition(pos)) notifier := sync.NewNotifier(pos)
err = notifier.Load(context.Background(), syncDB) err = notifier.Load(context.Background(), syncDB)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to start notifier") logrus.WithError(err).Panicf("failed to start notifier")
@ -71,5 +70,12 @@ func SetupSyncAPIComponent(
logrus.WithError(err).Panicf("failed to start client data consumer") logrus.WithError(err).Panicf("failed to start client data consumer")
} }
typingConsumer := consumers.NewOutputTypingEventConsumer(
base.Cfg, base.KafkaConsumer, notifier, syncDB,
)
if err = typingConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start typing server consumer")
}
routing.Setup(base.APIMux, requestPool, syncDB, deviceDB) routing.Setup(base.APIMux, requestPool, syncDB, deviceDB)
} }

View file

@ -21,12 +21,38 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// StreamPosition represents the offset in the sync stream a client is at. // SyncPosition contains the PDU and EDU stream sync positions for a client.
type StreamPosition int64 type SyncPosition struct {
// PDUPosition is the stream position for PDUs the client is at.
PDUPosition int64
// TypingPosition is the client's position for typing notifications.
TypingPosition int64
}
// String implements the Stringer interface. // String implements the Stringer interface.
func (sp StreamPosition) String() string { func (sp SyncPosition) String() string {
return strconv.FormatInt(int64(sp), 10) return strconv.FormatInt(sp.PDUPosition, 10) + "_" +
strconv.FormatInt(sp.TypingPosition, 10)
}
// IsAfter returns whether one SyncPosition refers to states newer than another SyncPosition.
func (sp SyncPosition) IsAfter(other SyncPosition) bool {
return sp.PDUPosition > other.PDUPosition ||
sp.TypingPosition > other.TypingPosition
}
// WithUpdates returns a copy of the SyncPosition with updates applied from another SyncPosition.
// If the latter SyncPosition contains a field that is not 0, it is considered an update,
// and its value will replace the corresponding value in the SyncPosition on which WithUpdates is called.
func (sp SyncPosition) WithUpdates(other SyncPosition) SyncPosition {
ret := sp
if other.PDUPosition != 0 {
ret.PDUPosition = other.PDUPosition
}
if other.TypingPosition != 0 {
ret.TypingPosition = other.TypingPosition
}
return ret
} }
// PrevEventRef represents a reference to a previous event in a state event upgrade // PrevEventRef represents a reference to a previous event in a state event upgrade
@ -53,11 +79,10 @@ type Response struct {
} }
// NewResponse creates an empty response with initialised maps. // NewResponse creates an empty response with initialised maps.
func NewResponse(pos StreamPosition) *Response { func NewResponse(pos SyncPosition) *Response {
res := Response{} res := Response{
// Make sure we send the next_batch as a string. We don't want to confuse clients by sending this NextBatch: pos.String(),
// as an integer even though (at the moment) it is. }
res.NextBatch = pos.String()
// Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section,
// so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors.
res.Rooms.Join = make(map[string]JoinResponse) res.Rooms.Join = make(map[string]JoinResponse)

View file

@ -42,6 +42,7 @@ POST /join/:room_alias can join a room
POST /join/:room_id can join a room POST /join/:room_id can join a room
POST /join/:room_id can join a room with custom content POST /join/:room_id can join a room with custom content
POST /join/:room_alias can join a room with custom content POST /join/:room_alias can join a room with custom content
POST /rooms/:room_id/join can join a room
POST /rooms/:room_id/leave can leave a room POST /rooms/:room_id/leave can leave a room
POST /rooms/:room_id/invite can send an invite POST /rooms/:room_id/invite can send an invite
POST /rooms/:room_id/ban can ban a user POST /rooms/:room_id/ban can ban a user
@ -142,4 +143,32 @@ Trying to get push rules with unknown rule_id fails with 404
Events come down the correct room Events come down the correct room
local user can join room with version 5 local user can join room with version 5
User can invite local user to room with version 5 User can invite local user to room with version 5
Inbound federation can receive room-join requests Inbound federation can receive v1 room-join requests
Typing events appear in initial sync
Typing events appear in incremental sync
Typing events appear in gapped sync
Inbound federation of state requires event_id as a mandatory paramater
Inbound federation of state_ids requires event_id as a mandatory paramater
POST /register returns the same device_id as that in the request
POST /login returns the same device_id as that in the request
POST /createRoom with creation content
User can create and send/receive messages in a room with version 1
POST /createRoom ignores attempts to set the room version via creation_content
Inbound federation rejects remote attempts to join local users to rooms
Inbound federation rejects remote attempts to kick local users to rooms
An event which redacts itself should be ignored
A pair of events which redact each other should be ignored
Full state sync includes joined rooms
A message sent after an initial sync appears in the timeline of an incremental sync.
Can add tag
Can remove tag
Can list tags for a room
Tags appear in an initial v2 /sync
Newly updated tags appear in an incremental v2 /sync
Deleted tags appear in an incremental v2 /sync
/event/ on non world readable room does not work
Outbound federation can query profile data
/event/ on joined room works
/event/ does not allow access to events before the user joined
Federation key API allows unsigned requests for keys
Can paginate public room list

View file

@ -12,14 +12,17 @@
package api package api
import "time"
// OutputTypingEvent is an entry in typing server output kafka log. // OutputTypingEvent is an entry in typing server output kafka log.
// This contains the event with extra fields used to create 'm.typing' event // This contains the event with extra fields used to create 'm.typing' event
// in clientapi & federation. // in clientapi & federation.
type OutputTypingEvent struct { type OutputTypingEvent struct {
// The Event for the typing edu event. // The Event for the typing edu event.
Event TypingEvent `json:"event"` Event TypingEvent `json:"event"`
// Users typing in the room when the event was generated. // ExpireTime is the interval after which the user should no longer be
TypingUsers []string `json:"typing_users"` // considered typing. Only available if Event.Typing is true.
ExpireTime *time.Time
} }
// TypingEvent represents a matrix edu event of type 'm.typing'. // TypingEvent represents a matrix edu event of type 'm.typing'.

View file

@ -22,25 +22,66 @@ const defaultTypingTimeout = 10 * time.Second
// userSet is a map of user IDs to a timer, timer fires at expiry. // userSet is a map of user IDs to a timer, timer fires at expiry.
type userSet map[string]*time.Timer type userSet map[string]*time.Timer
// TimeoutCallbackFn is a function called right after the removal of a user
// from the typing user list due to timeout.
// latestSyncPosition is the typing sync position after the removal.
type TimeoutCallbackFn func(userID, roomID string, latestSyncPosition int64)
type roomData struct {
syncPosition int64
userSet userSet
}
// TypingCache maintains a list of users typing in each room. // TypingCache maintains a list of users typing in each room.
type TypingCache struct { type TypingCache struct {
sync.RWMutex sync.RWMutex
data map[string]userSet latestSyncPosition int64
data map[string]*roomData
timeoutCallback TimeoutCallbackFn
}
// Create a roomData with its sync position set to the latest sync position.
// Must only be called after locking the cache.
func (t *TypingCache) newRoomData() *roomData {
return &roomData{
syncPosition: t.latestSyncPosition,
userSet: make(userSet),
}
} }
// NewTypingCache returns a new TypingCache initialised for use. // NewTypingCache returns a new TypingCache initialised for use.
func NewTypingCache() *TypingCache { func NewTypingCache() *TypingCache {
return &TypingCache{data: make(map[string]userSet)} return &TypingCache{data: make(map[string]*roomData)}
}
// SetTimeoutCallback sets a callback function that is called right after
// a user is removed from the typing user list due to timeout.
func (t *TypingCache) SetTimeoutCallback(fn TimeoutCallbackFn) {
t.timeoutCallback = fn
} }
// GetTypingUsers returns the list of users typing in a room. // GetTypingUsers returns the list of users typing in a room.
func (t *TypingCache) GetTypingUsers(roomID string) (users []string) { func (t *TypingCache) GetTypingUsers(roomID string) []string {
users, _ := t.GetTypingUsersIfUpdatedAfter(roomID, 0)
// 0 should work above because the first position used will be 1.
return users
}
// GetTypingUsersIfUpdatedAfter returns all users typing in this room with
// updated == true if the typing sync position of the room is after the given
// position. Otherwise, returns an empty slice with updated == false.
func (t *TypingCache) GetTypingUsersIfUpdatedAfter(
roomID string, position int64,
) (users []string, updated bool) {
t.RLock() t.RLock()
usersMap, ok := t.data[roomID] defer t.RUnlock()
t.RUnlock()
if ok { roomData, ok := t.data[roomID]
users = make([]string, 0, len(usersMap)) if ok && roomData.syncPosition > position {
for userID := range usersMap { updated = true
userSet := roomData.userSet
users = make([]string, 0, len(userSet))
for userID := range userSet {
users = append(users, userID) users = append(users, userID)
} }
} }
@ -51,53 +92,84 @@ func (t *TypingCache) GetTypingUsers(roomID string) (users []string) {
// AddTypingUser sets an user as typing in a room. // AddTypingUser sets an user as typing in a room.
// expire is the time when the user typing should time out. // expire is the time when the user typing should time out.
// if expire is nil, defaultTypingTimeout is assumed. // if expire is nil, defaultTypingTimeout is assumed.
func (t *TypingCache) AddTypingUser(userID, roomID string, expire *time.Time) { // Returns the latest sync position for typing after update.
func (t *TypingCache) AddTypingUser(
userID, roomID string, expire *time.Time,
) int64 {
expireTime := getExpireTime(expire) expireTime := getExpireTime(expire)
if until := time.Until(expireTime); until > 0 { if until := time.Until(expireTime); until > 0 {
timer := time.AfterFunc(until, t.timeoutCallback(userID, roomID)) timer := time.AfterFunc(until, func() {
t.addUser(userID, roomID, timer) latestSyncPosition := t.RemoveUser(userID, roomID)
if t.timeoutCallback != nil {
t.timeoutCallback(userID, roomID, latestSyncPosition)
} }
})
return t.addUser(userID, roomID, timer)
}
return t.GetLatestSyncPosition()
} }
// addUser with mutex lock & replace the previous timer. // addUser with mutex lock & replace the previous timer.
func (t *TypingCache) addUser(userID, roomID string, expiryTimer *time.Timer) { // Returns the latest typing sync position after update.
func (t *TypingCache) addUser(
userID, roomID string, expiryTimer *time.Timer,
) int64 {
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
t.latestSyncPosition++
if t.data[roomID] == nil { if t.data[roomID] == nil {
t.data[roomID] = make(userSet) t.data[roomID] = t.newRoomData()
} else {
t.data[roomID].syncPosition = t.latestSyncPosition
} }
// Stop the timer to cancel the call to timeoutCallback // Stop the timer to cancel the call to timeoutCallback
if timer, ok := t.data[roomID][userID]; ok { if timer, ok := t.data[roomID].userSet[userID]; ok {
// It may happen that at this stage timer fires but now we have a lock on t. // It may happen that at this stage the timer fires, but we now have a lock on
// Hence the execution of timeoutCallback will happen after we unlock. // it. Hence the execution of timeoutCallback will happen after we unlock. So
// So we may lose a typing state, though this event is highly unlikely. // we may lose a typing state, though this is highly unlikely. This can be
// This can be mitigated by keeping another time.Time in the map and check against it // mitigated by keeping another time.Time in the map and checking against it
// before removing. This however is not required in most practical scenario. // before removing, but its occurrence is so infrequent it does not seem
// worthwhile.
timer.Stop() timer.Stop()
} }
t.data[roomID][userID] = expiryTimer t.data[roomID].userSet[userID] = expiryTimer
}
// Returns a function which is called after timeout happens. return t.latestSyncPosition
// This removes the user.
func (t *TypingCache) timeoutCallback(userID, roomID string) func() {
return func() {
t.RemoveUser(userID, roomID)
}
} }
// RemoveUser with mutex lock & stop the timer. // RemoveUser with mutex lock & stop the timer.
func (t *TypingCache) RemoveUser(userID, roomID string) { // Returns the latest sync position for typing after update.
func (t *TypingCache) RemoveUser(userID, roomID string) int64 {
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
if timer, ok := t.data[roomID][userID]; ok { roomData, ok := t.data[roomID]
timer.Stop() if !ok {
delete(t.data[roomID], userID) return t.latestSyncPosition
} }
timer, ok := roomData.userSet[userID]
if !ok {
return t.latestSyncPosition
}
timer.Stop()
delete(roomData.userSet, userID)
t.latestSyncPosition++
t.data[roomID].syncPosition = t.latestSyncPosition
return t.latestSyncPosition
}
func (t *TypingCache) GetLatestSyncPosition() int64 {
t.Lock()
defer t.Unlock()
return t.latestSyncPosition
} }
func getExpireTime(expire *time.Time) time.Time { func getExpireTime(expire *time.Time) time.Time {

View file

@ -57,15 +57,21 @@ func (t *TypingServerInputAPI) InputTypingEvent(
} }
func (t *TypingServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { func (t *TypingServerInputAPI) sendEvent(ite *api.InputTypingEvent) error {
userIDs := t.Cache.GetTypingUsers(ite.RoomID)
ev := &api.TypingEvent{ ev := &api.TypingEvent{
Type: gomatrixserverlib.MTyping, Type: gomatrixserverlib.MTyping,
RoomID: ite.RoomID, RoomID: ite.RoomID,
UserID: ite.UserID, UserID: ite.UserID,
Typing: ite.Typing,
} }
ote := &api.OutputTypingEvent{ ote := &api.OutputTypingEvent{
Event: *ev, Event: *ev,
TypingUsers: userIDs, }
if ev.Typing {
expireTime := ite.OriginServerTS.Time().Add(
time.Duration(ite.Timeout) * time.Millisecond,
)
ote.ExpireTime = &expireTime
} }
eventJSON, err := json.Marshal(ote) eventJSON, err := json.Marshal(ote)