Merge branch 'master' into neilalexander/sarama-import

This commit is contained in:
Neil Alexander 2020-04-22 14:19:27 +01:00
commit 29ef6e0679
282 changed files with 19177 additions and 2381 deletions

6
.gitignore vendored
View file

@ -43,3 +43,9 @@ _testmain.go
# Default configuration file # Default configuration file
dendrite.yaml dendrite.yaml
# Database files
*.db
# Log files
*.log*

View file

@ -102,7 +102,7 @@ linters-settings:
#local-prefixes: github.com/org/project #local-prefixes: github.com/org/project
gocyclo: gocyclo:
# minimal code complexity to report, 30 by default (but we recommend 10-20) # minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 12 min-complexity: 13
maligned: maligned:
# print struct with more effective memory layout or not, false by default # print struct with more effective memory layout or not, false by default
suggest-new: true suggest-new: true

View file

@ -20,34 +20,40 @@ should pick up any unit test and run it). There are also [scripts](scripts) for
[linting](scripts/find-lint.sh) and doing a [build/test/lint [linting](scripts/find-lint.sh) and doing a [build/test/lint
run](scripts/build-test-lint.sh). run](scripts/build-test-lint.sh).
As of February 2020, we are deprecating support for Go 1.11 and Go 1.12 and are
now targeting Go 1.13 or later. Please ensure that you are using at least Go
1.13 when developing for Dendrite - our CI will lint and run tests against this
version.
## Continuous Integration ## Continuous Integration
When a Pull Request is submitted, continuous integration jobs are run When a Pull Request is submitted, continuous integration jobs are run
automatically to ensure the code builds and is relatively well-written. The automatically to ensure the code builds and is relatively well-written. The jobs
jobs are run on [Buildkite](https://buildkite.com/matrix-dot-org/dendrite/), are run on [Buildkite](https://buildkite.com/matrix-dot-org/dendrite/), and the
and the Buildkite pipeline configuration can be found in Matrix.org's Buildkite pipeline configuration can be found in Matrix.org's [pipelines
[pipelines repository](https://github.com/matrix-org/pipelines). repository](https://github.com/matrix-org/pipelines).
If a job fails, click the "details" button and you should be taken to the job's If a job fails, click the "details" button and you should be taken to the job's
logs. logs.
![Click the details button on the failing build step](docs/images/details-button-location.jpg) ![Click the details button on the failing build
step](docs/images/details-button-location.jpg)
Scroll down to the failing step and you should see some log output. Scan Scroll down to the failing step and you should see some log output. Scan the
the logs until you find what it's complaining about, fix it, submit a new logs until you find what it's complaining about, fix it, submit a new commit,
commit, then rinse and repeat until CI passes. then rinse and repeat until CI passes.
### Running CI Tests Locally ### Running CI Tests Locally
To save waiting for CI to finish after every commit, it is ideal to run the To save waiting for CI to finish after every commit, it is ideal to run the
checks locally before pushing, fixing errors first. This also saves other checks locally before pushing, fixing errors first. This also saves other people
people time as only so many PRs can be tested at a given time. time as only so many PRs can be tested at a given time.
To execute what Buildkite tests, first run `./scripts/build-test-lint.sh`; To execute what Buildkite tests, first run `./scripts/build-test-lint.sh`; this
this script will build the code, lint it, and run `go test ./...` with race script will build the code, lint it, and run `go test ./...` with race condition
condition checking enabled. If something needs to be changed, fix it and then checking enabled. If something needs to be changed, fix it and then run the
run the script again until it no longer complains. Be warned that the linting script again until it no longer complains. Be warned that the linting can take a
can take a significant amount of CPU and RAM. significant amount of CPU and RAM.
Once the code builds, run [Sytest](https://github.com/matrix-org/sytest) Once the code builds, run [Sytest](https://github.com/matrix-org/sytest)
according to the guide in according to the guide in
@ -61,16 +67,18 @@ tests.
## Picking Things To Do ## Picking Things To Do
If you're new then feel free to pick up an issue labelled [good first issue](https://github.com/matrix-org/dendrite/labels/good%20first%20issue). If you're new then feel free to pick up an issue labelled [good first
issue](https://github.com/matrix-org/dendrite/labels/good%20first%20issue).
These should be well-contained, small pieces of work that can be picked up to These should be well-contained, small pieces of work that can be picked up to
help you get familiar with the code base. help you get familiar with the code base.
Once you're comfortable with hacking on Dendrite there are issues lablled as Once you're comfortable with hacking on Dendrite there are issues lablled as
[help wanted](https://github.com/matrix-org/dendrite/labels/help%20wanted), these [help wanted](https://github.com/matrix-org/dendrite/labels/help%20wanted),
are often slightly larger or more complicated pieces of work but are hopefully these are often slightly larger or more complicated pieces of work but are
nonetheless fairly well-contained. hopefully nonetheless fairly well-contained.
We ask people who are familiar with Dendrite to leave the [good first issue](https://github.com/matrix-org/dendrite/labels/good%20first%20issue) We ask people who are familiar with Dendrite to leave the [good first
issue](https://github.com/matrix-org/dendrite/labels/good%20first%20issue)
issues so that there is always a way for new people to come and get involved. issues so that there is always a way for new people to come and get involved.
## Getting Help ## Getting Help
@ -79,9 +87,11 @@ For questions related to developing on Dendrite we have a dedicated room on
Matrix [#dendrite-dev:matrix.org](https://matrix.to/#/#dendrite-dev:matrix.org) Matrix [#dendrite-dev:matrix.org](https://matrix.to/#/#dendrite-dev:matrix.org)
where we're happy to help. where we're happy to help.
For more general questions please use [#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org). For more general questions please use
[#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org).
## Sign off ## Sign off
We ask that everyone who contributes to the project signs off their We ask that everyone who contributes to the project signs off their
contributions, in accordance with the [DCO](https://github.com/matrix-org/matrix-doc/blob/master/CONTRIBUTING.rst#sign-off). contributions, in accordance with the
[DCO](https://github.com/matrix-org/matrix-doc/blob/master/CONTRIBUTING.rst#sign-off).

View file

@ -12,7 +12,7 @@ Dendrite can be run in one of two configurations:
## Requirements ## Requirements
- Go 1.11+ - Go 1.13+
- Postgres 9.5+ - Postgres 9.5+
- For Kafka (optional if using the monolith server): - For Kafka (optional if using the monolith server):
- Unix-based system (https://kafka.apache.org/documentation/#os) - Unix-based system (https://kafka.apache.org/documentation/#os)
@ -22,7 +22,7 @@ Dendrite can be run in one of two configurations:
## Setting up a development environment ## Setting up a development environment
Assumes Go 1.10+ and JDK 1.8+ are already installed and are on PATH. Assumes Go 1.13+ and JDK 1.8+ are already installed and are on PATH.
```bash ```bash
# Get the code # Get the code
@ -101,7 +101,7 @@ Create config file, based on `dendrite-config.yaml`. Call it `dendrite.yaml`. Th
It is possible to use 'naffka' as an in-process replacement to Kafka when using It is possible to use 'naffka' as an in-process replacement to Kafka when using
the monolith server. To do this, set `use_naffka: true` in `dendrite.yaml` and uncomment the monolith server. To do this, set `use_naffka: true` in `dendrite.yaml` and uncomment
the necessary line related to naffka in the `database` section. Be sure to update the the necessary line related to naffka in the `database` section. Be sure to update the
database username and password if needed. database username and password if needed.
The monolith server can be started as shown below. By default it listens for The monolith server can be started as shown below. By default it listens for
@ -255,7 +255,7 @@ you want to support federation.
./bin/dendrite-federation-sender-server --config dendrite.yaml ./bin/dendrite-federation-sender-server --config dendrite.yaml
``` ```
### Run an appservice server ### Run an appservice server
This sends events from the network to [application This sends events from the network to [application
services](https://matrix.org/docs/spec/application_service/unstable.html) services](https://matrix.org/docs/spec/application_service/unstable.html)

View file

@ -1,27 +1,31 @@
# Dendrite [![Build Status](https://badge.buildkite.com/4be40938ab19f2bbc4a6c6724517353ee3ec1422e279faf374.svg?branch=master)](https://buildkite.com/matrix-dot-org/dendrite) [![Dendrite Dev on Matrix](https://img.shields.io/matrix/dendrite-dev:matrix.org.svg?label=%23dendrite-dev%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite-dev:matrix.org) [![Dendrite on Matrix](https://img.shields.io/matrix/dendrite:matrix.org.svg?label=%23dendrite%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite:matrix.org) # Dendrite [![Build Status](https://badge.buildkite.com/4be40938ab19f2bbc4a6c6724517353ee3ec1422e279faf374.svg?branch=master)](https://buildkite.com/matrix-dot-org/dendrite) [![Dendrite Dev on Matrix](https://img.shields.io/matrix/dendrite-dev:matrix.org.svg?label=%23dendrite-dev%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite-dev:matrix.org) [![Dendrite on Matrix](https://img.shields.io/matrix/dendrite:matrix.org.svg?label=%23dendrite%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite:matrix.org)
Dendrite will be a matrix homeserver written in go. Dendrite will be a second-generation Matrix homeserver written in Go.
It's still very much a work in progress, but installation instructions can It's still very much a work in progress, but installation instructions can be
be found in [INSTALL.md](INSTALL.md) found in [INSTALL.md](INSTALL.md). It is not recommended to use Dendrite as a
production homeserver at this time.
An overview of the design can be found in [DESIGN.md](DESIGN.md) An overview of the design can be found in [DESIGN.md](DESIGN.md).
# Contributing # Contributing
Everyone is welcome to help out and contribute! See [CONTRIBUTING.md](CONTRIBUTING.md) Everyone is welcome to help out and contribute! See
to get started! [CONTRIBUTING.md](CONTRIBUTING.md) to get started!
We aim to try and make it as easy as possible to jump in. Please note that, as of February 2020, Dendrite now only targets Go 1.13 or
later. Please ensure that you are using at least Go 1.13 when developing for
Dendrite.
# Discussion # Discussion
For questions about Dendrite we have a dedicated room on Matrix For questions about Dendrite we have a dedicated room on Matrix
[#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org). [#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org). Development
Development discussion should happen in discussion should happen in
[#dendrite-dev:matrix.org](https://matrix.to/#/#dendrite-dev:matrix.org). [#dendrite-dev:matrix.org](https://matrix.to/#/#dendrite-dev:matrix.org).
# Progress # Progress
There's plenty still to do to make Dendrite usable! We're tracking progress in There's plenty still to do to make Dendrite usable! We're tracking progress in a
a [project board](https://github.com/matrix-org/dendrite/projects/2). [project board](https://github.com/matrix-org/dendrite/projects/2).

View file

@ -72,7 +72,7 @@ Diagram:
| | | | | | | |
| | +---+ | | | | +---+ | |
| | +----------| S | | | | | +----------| S | | |
| | | Typing +---+ | | | | | EDU +---+ | |
| |>=========================================>| Server |>=====================>| | | |>=========================================>| Server |>=====================>| |
+------------+ | | +----------+ +------------+ | | +----------+
+---+ | | +---+ | |
@ -156,7 +156,7 @@ choke-point to implement ratelimiting and backoff correctly.
* It may be impossible to implement without folding it into the Room Server * It may be impossible to implement without folding it into the Room Server
forever coupling the components together. forever coupling the components together.
## Typing Server ## EDU Server
* Reads new updates to typing from the logs written by the FS and CTS. * Reads new updates to typing from the logs written by the FS and CTS.
* Updates the current list of people typing in a room. * Updates the current list of people typing in a room.
@ -179,7 +179,7 @@ choke-point to implement ratelimiting and backoff correctly.
* Reads new events and the current state of the rooms from logs written by the Room Server. * Reads new events and the current state of the rooms from logs written by the Room Server.
* Reads new receipts positions from the logs written by the Receipts Server. * Reads new receipts positions from the logs written by the Receipts Server.
* Reads changes to presence from the logs written by the Presence Server. * Reads changes to presence from the logs written by the Presence Server.
* Reads changes to typing from the logs written by the Typing Server. * Reads changes to typing from the logs written by the EDU Server.
* Writes when a client starts and stops syncing to the logs. * Writes when a client starts and stops syncing to the logs.
## Client Search ## Client Search

View file

@ -20,6 +20,7 @@ package api
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -97,15 +98,15 @@ type httpAppServiceQueryAPI struct {
// NewAppServiceQueryAPIHTTP creates a AppServiceQueryAPI implemented by talking // NewAppServiceQueryAPIHTTP creates a AppServiceQueryAPI implemented by talking
// to a HTTP POST API. // to a HTTP POST API.
// If httpClient is nil then it uses http.DefaultClient // If httpClient is nil an error is returned
func NewAppServiceQueryAPIHTTP( func NewAppServiceQueryAPIHTTP(
appserviceURL string, appserviceURL string,
httpClient *http.Client, httpClient *http.Client,
) AppServiceQueryAPI { ) (AppServiceQueryAPI, error) {
if httpClient == nil { if httpClient == nil {
httpClient = http.DefaultClient return nil, errors.New("NewRoomserverAliasAPIHTTP: httpClient is <nil>")
} }
return &httpAppServiceQueryAPI{appserviceURL, httpClient} return &httpAppServiceQueryAPI{appserviceURL, httpClient}, nil
} }
// RoomAliasExists implements AppServiceQueryAPI // RoomAliasExists implements AppServiceQueryAPI
@ -140,7 +141,7 @@ func RetrieveUserProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
asAPI AppServiceQueryAPI, asAPI AppServiceQueryAPI,
accountDB *accounts.Database, accountDB accounts.Database,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {

View file

@ -41,8 +41,8 @@ import (
// component. // component.
func SetupAppServiceAPIComponent( func SetupAppServiceAPIComponent(
base *basecomponent.BaseDendrite, base *basecomponent.BaseDendrite,
accountsDB *accounts.Database, accountsDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
roomserverAliasAPI roomserverAPI.RoomserverAliasAPI, roomserverAliasAPI roomserverAPI.RoomserverAliasAPI,
roomserverQueryAPI roomserverAPI.RoomserverQueryAPI, roomserverQueryAPI roomserverAPI.RoomserverQueryAPI,
@ -100,7 +100,7 @@ func SetupAppServiceAPIComponent(
// Set up HTTP Endpoints // Set up HTTP Endpoints
routing.Setup( routing.Setup(
base.APIMux, *base.Cfg, roomserverQueryAPI, roomserverAliasAPI, base.APIMux, base.Cfg, roomserverQueryAPI, roomserverAliasAPI,
accountsDB, federation, transactionsCache, accountsDB, federation, transactionsCache,
) )
@ -111,8 +111,8 @@ func SetupAppServiceAPIComponent(
// `sender_localpart` field of each application service if it doesn't // `sender_localpart` field of each application service if it doesn't
// exist already // exist already
func generateAppServiceAccount( func generateAppServiceAccount(
accountsDB *accounts.Database, accountsDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
as config.ApplicationService, as config.ApplicationService,
) error { ) error {
ctx := context.Background() ctx := context.Background()

View file

@ -33,8 +33,8 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer roomServerConsumer *common.ContinualConsumer
db *accounts.Database db accounts.Database
asDB *storage.Database asDB storage.Database
query api.RoomserverQueryAPI query api.RoomserverQueryAPI
alias api.RoomserverAliasAPI alias api.RoomserverAliasAPI
serverName string serverName string
@ -46,8 +46,8 @@ type OutputRoomEventConsumer struct {
func NewOutputRoomEventConsumer( func NewOutputRoomEventConsumer(
cfg *config.Dendrite, cfg *config.Dendrite,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
store *accounts.Database, store accounts.Database,
appserviceDB *storage.Database, appserviceDB storage.Database,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
aliasAPI api.RoomserverAliasAPI, aliasAPI api.RoomserverAliasAPI,
workerStates []types.ApplicationServiceWorkerState, workerStates []types.ApplicationServiceWorkerState,
@ -114,19 +114,19 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
// lookupMissingStateEvents looks up the state events that are added by a new event, // lookupMissingStateEvents looks up the state events that are added by a new event,
// and returns any not already present. // and returns any not already present.
func (s *OutputRoomEventConsumer) lookupMissingStateEvents( func (s *OutputRoomEventConsumer) lookupMissingStateEvents(
addsStateEventIDs []string, event gomatrixserverlib.Event, addsStateEventIDs []string, event gomatrixserverlib.HeaderedEvent,
) ([]gomatrixserverlib.Event, error) { ) ([]gomatrixserverlib.HeaderedEvent, error) {
// Fast path if there aren't any new state events. // Fast path if there aren't any new state events.
if len(addsStateEventIDs) == 0 { if len(addsStateEventIDs) == 0 {
return []gomatrixserverlib.Event{}, nil return []gomatrixserverlib.HeaderedEvent{}, nil
} }
// Fast path if the only state event added is the event itself. // Fast path if the only state event added is the event itself.
if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() {
return []gomatrixserverlib.Event{}, nil return []gomatrixserverlib.HeaderedEvent{}, nil
} }
result := []gomatrixserverlib.Event{} result := []gomatrixserverlib.HeaderedEvent{}
missing := []string{} missing := []string{}
for _, id := range addsStateEventIDs { for _, id := range addsStateEventIDs {
if id != event.EventID() { if id != event.EventID() {
@ -155,7 +155,7 @@ func (s *OutputRoomEventConsumer) lookupMissingStateEvents(
// application service. // application service.
func (s *OutputRoomEventConsumer) filterRoomserverEvents( func (s *OutputRoomEventConsumer) filterRoomserverEvents(
ctx context.Context, ctx context.Context,
events []gomatrixserverlib.Event, events []gomatrixserverlib.HeaderedEvent,
) error { ) error {
for _, ws := range s.workerStates { for _, ws := range s.workerStates {
for _, event := range events { for _, event := range events {
@ -178,7 +178,7 @@ func (s *OutputRoomEventConsumer) filterRoomserverEvents(
// appserviceIsInterestedInEvent returns a boolean depending on whether a given // appserviceIsInterestedInEvent returns a boolean depending on whether a given
// event falls within one of a given application service's namespaces. // event falls within one of a given application service's namespaces.
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event gomatrixserverlib.Event, appservice config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool {
// No reason to queue events if they'll never be sent to the application // No reason to queue events if they'll never be sent to the application
// service // service
if appservice.URL == "" { if appservice.URL == "" {
@ -191,6 +191,12 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont
return true return true
} }
if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil {
if appservice.IsInterestedInUserID(*event.StateKey()) {
return true
}
}
// Check all known room aliases of the room the event came from // Check all known room aliases of the room the event came from
queryReq := api.GetAliasesForRoomIDRequest{RoomID: event.RoomID()} queryReq := api.GetAliasesForRoomIDRequest{RoomID: event.RoomID()}
var queryRes api.GetAliasesForRoomIDResponse var queryRes api.GetAliasesForRoomIDResponse

View file

@ -36,9 +36,9 @@ const pathPrefixApp = "/_matrix/app/v1"
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup( func Setup(
apiMux *mux.Router, cfg config.Dendrite, // nolint: unparam apiMux *mux.Router, cfg *config.Dendrite, // nolint: unparam
queryAPI api.RoomserverQueryAPI, aliasAPI api.RoomserverAliasAPI, // nolint: unparam queryAPI api.RoomserverQueryAPI, aliasAPI api.RoomserverAliasAPI, // nolint: unparam
accountDB *accounts.Database, // nolint: unparam accountDB accounts.Database, // nolint: unparam
federation *gomatrixserverlib.FederationClient, // nolint: unparam federation *gomatrixserverlib.FederationClient, // nolint: unparam
transactionsCache *transactions.Cache, // nolint: unparam transactionsCache *transactions.Cache, // nolint: unparam
) { ) {

View file

@ -0,0 +1,30 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error
GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error)
CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error)
UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error
RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error
GetLatestTxnID(ctx context.Context) (int, error)
}

View file

@ -1,4 +1,5 @@
// Copyright 2018 New Vector Ltd // Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package storage package postgres
import ( import (
"context" "context"
@ -32,7 +33,7 @@ CREATE TABLE IF NOT EXISTS appservice_events (
-- The ID of the application service the event will be sent to -- The ID of the application service the event will be sent to
as_id TEXT NOT NULL, as_id TEXT NOT NULL,
-- JSON representation of the event -- JSON representation of the event
event_json TEXT NOT NULL, headered_event_json TEXT NOT NULL,
-- The ID of the transaction that this event is a part of -- The ID of the transaction that this event is a part of
txn_id BIGINT NOT NULL txn_id BIGINT NOT NULL
); );
@ -41,14 +42,14 @@ CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
` `
const selectEventsByApplicationServiceIDSQL = "" + const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, event_json, txn_id " + "SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" + const countEventsByApplicationServiceIDSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" "SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, event_json, txn_id) " + "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)" "VALUES ($1, $2, $3)"
const updateTxnIDForEventsSQL = "" + const updateTxnIDForEventsSQL = "" +
@ -106,7 +107,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
limit int, limit int,
) ( ) (
txnID, maxID int, txnID, maxID int,
events []gomatrixserverlib.Event, events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool, eventsRemaining bool,
err error, err error,
) { ) {
@ -131,7 +132,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
return return
} }
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.Event, maxID, txnID int, eventsRemaining bool, err error) { func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age // Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
@ -140,7 +141,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
// new ones. Send back those events first. // new ones. Send back those events first.
lastTxnID := invalidTxnID lastTxnID := invalidTxnID
for eventsProcessed := 0; eventRows.Next(); { for eventsProcessed := 0; eventRows.Next(); {
var event gomatrixserverlib.Event var event gomatrixserverlib.HeaderedEvent
var eventJSON []byte var eventJSON []byte
var id int var id int
err = eventRows.Scan( err = eventRows.Scan(
@ -208,7 +209,7 @@ func (s *eventsStatements) countEventsByApplicationServiceID(
func (s *eventsStatements) insertEvent( func (s *eventsStatements) insertEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.Event, event *gomatrixserverlib.HeaderedEvent,
) (err error) { ) (err error) {
// Convert event to JSON before inserting // Convert event to JSON before inserting
eventJSON, err := json.Marshal(event) eventJSON, err := json.Marshal(event)

View file

@ -0,0 +1,112 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
// Import postgres database driver
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores events intended to be later sent to application services
type Database struct {
events eventsStatements
txnID txnStatements
db *sql.DB
}
// NewDatabase opens a new database
func NewDatabase(dataSourceName string) (*Database, error) {
var result Database
var err error
if result.db, err = sqlutil.Open("postgres", dataSourceName); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
return nil, err
}
return &result, nil
}
func (d *Database) prepare() error {
if err := d.events.prepare(d.db); err != nil {
return err
}
return d.txnID.prepare(d.db)
}
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
// for a transaction worker to pull and later send to an application service.
func (d *Database) StoreEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) error {
return d.events.insertEvent(ctx, appServiceID, event)
}
// GetEventsWithAppServiceID returns a slice of events and their IDs intended to
// be sent to an application service given its ID.
func (d *Database) GetEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
limit int,
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
}
// CountEventsWithAppServiceID returns the number of events destined for an
// application service given its ID.
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
}
// UpdateTxnIDForEvents takes in an application service ID and a
// and stores them in the DB, unless the pair already exists, in
// which case it updates them.
func (d *Database) UpdateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) error {
return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID)
}
// RemoveEventsBeforeAndIncludingID removes all events from the database that
// are less than or equal to a given maximum ID. IDs here are implemented as a
// serial, thus this should always delete events in chronological order.
func (d *Database) RemoveEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) error {
return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID)
}
// GetLatestTxnID returns the latest available transaction id
func (d *Database) GetLatestTxnID(
ctx context.Context,
) (int, error) {
return d.txnID.selectTxnID(ctx)
}

View file

@ -1,4 +1,5 @@
// Copyright 2018 New Vector Ltd // Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package storage package postgres
import ( import (
"context" "context"

View file

@ -0,0 +1,249 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const appserviceEventsSchema = `
-- Stores events to be sent to application services
CREATE TABLE IF NOT EXISTS appservice_events (
-- An auto-incrementing id unique to each event in the table
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The ID of the application service the event will be sent to
as_id TEXT NOT NULL,
-- JSON representation of the event
headered_event_json TEXT NOT NULL,
-- The ID of the transaction that this event is a part of
txn_id INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
`
const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)"
const updateTxnIDForEventsSQL = "" +
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
const deleteEventsBeforeAndIncludingIDSQL = "" +
"DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2"
const (
// A transaction ID number that no transaction should ever have. Used for
// checking again the default value.
invalidTxnID = -2
)
type eventsStatements struct {
selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt
insertEventStmt *sql.Stmt
updateTxnIDForEventsStmt *sql.Stmt
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
}
func (s *eventsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(appserviceEventsSchema)
if err != nil {
return
}
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
return
}
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil {
return
}
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
return
}
if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil {
return
}
if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil {
return
}
return
}
// selectEventsByApplicationServiceID takes in an application service ID and
// returns a slice of events that need to be sent to that application service,
// as well as an int later used to remove these same events from the database
// once successfully sent to an application service.
func (s *eventsStatements) selectEventsByApplicationServiceID(
ctx context.Context,
applicationServiceID string,
limit int,
) (
txnID, maxID int,
events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool,
err error,
) {
// Retrieve events from the database. Unsuccessfully sent events first
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
if err != nil {
return
}
defer func() {
err = eventRows.Close()
if err != nil {
log.WithFields(log.Fields{
"appservice": applicationServiceID,
}).WithError(err).Fatalf("appservice unable to select new events to send")
}
}()
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
if err != nil {
return
}
return
}
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
// Iterate through each row and store event contents
// If txn_id changes dramatically, we've switched from collecting old events to
// new ones. Send back those events first.
lastTxnID := invalidTxnID
for eventsProcessed := 0; eventRows.Next(); {
var event gomatrixserverlib.HeaderedEvent
var eventJSON []byte
var id int
err = eventRows.Scan(
&id,
&eventJSON,
&txnID,
)
if err != nil {
return nil, 0, 0, false, err
}
// Unmarshal eventJSON
if err = json.Unmarshal(eventJSON, &event); err != nil {
return nil, 0, 0, false, err
}
// If txnID has changed on this event from the previous event, then we've
// reached the end of a transaction's events. Return only those events.
if lastTxnID > invalidTxnID && lastTxnID != txnID {
return events, maxID, lastTxnID, true, nil
}
lastTxnID = txnID
// Limit events that aren't part of an old transaction
if txnID == -1 {
// Return if we've hit the limit
if eventsProcessed++; eventsProcessed > limit {
return events, maxID, lastTxnID, true, nil
}
}
if id > maxID {
maxID = id
}
// Portion of the event that is unsigned due to rapid change
// TODO: Consider removing age as not many app services use it
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
return nil, 0, 0, false, err
}
events = append(events, event)
}
return
}
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service
// IDs into the db.
func (s *eventsStatements) countEventsByApplicationServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
var count int
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return count, nil
}
// insertEvent inserts an event mapped to its corresponding application service
// IDs into the db.
func (s *eventsStatements) insertEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) (err error) {
// Convert event to JSON before inserting
eventJSON, err := json.Marshal(event)
if err != nil {
return err
}
_, err = s.insertEventStmt.ExecContext(
ctx,
appServiceID,
eventJSON,
-1, // No transaction ID yet
)
return
}
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
// before sending them to an AppService. Referenced before sending to make sure
// we aren't constructing multiple transactions with the same events.
func (s *eventsStatements) updateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) (err error) {
_, err = s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
return
}
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) (err error) {
_, err = s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
return
}

View file

@ -0,0 +1,113 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
// Import SQLite database driver
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
)
// Database stores events intended to be later sent to application services
type Database struct {
events eventsStatements
txnID txnStatements
db *sql.DB
}
// NewDatabase opens a new database
func NewDatabase(dataSourceName string) (*Database, error) {
var result Database
var err error
if result.db, err = sqlutil.Open(common.SQLiteDriverName(), dataSourceName); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
return nil, err
}
return &result, nil
}
func (d *Database) prepare() error {
if err := d.events.prepare(d.db); err != nil {
return err
}
return d.txnID.prepare(d.db)
}
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
// for a transaction worker to pull and later send to an application service.
func (d *Database) StoreEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) error {
return d.events.insertEvent(ctx, appServiceID, event)
}
// GetEventsWithAppServiceID returns a slice of events and their IDs intended to
// be sent to an application service given its ID.
func (d *Database) GetEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
limit int,
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
}
// CountEventsWithAppServiceID returns the number of events destined for an
// application service given its ID.
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
}
// UpdateTxnIDForEvents takes in an application service ID and a
// and stores them in the DB, unless the pair already exists, in
// which case it updates them.
func (d *Database) UpdateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) error {
return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID)
}
// RemoveEventsBeforeAndIncludingID removes all events from the database that
// are less than or equal to a given maximum ID. IDs here are implemented as a
// serial, thus this should always delete events in chronological order.
func (d *Database) RemoveEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) error {
return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID)
}
// GetLatestTxnID returns the latest available transaction id
func (d *Database) GetLatestTxnID(
ctx context.Context,
) (int, error) {
return d.txnID.selectTxnID(ctx)
}

View file

@ -0,0 +1,60 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
)
const txnIDSchema = `
-- Keeps a count of the current transaction ID
CREATE TABLE IF NOT EXISTS appservice_counters (
name TEXT PRIMARY KEY NOT NULL,
last_id INTEGER DEFAULT 1
);
INSERT OR IGNORE INTO appservice_counters (name, last_id) VALUES('txn_id', 1);
`
const selectTxnIDSQL = `
SELECT last_id FROM appservice_counters WHERE name='txn_id';
UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id';
`
type txnStatements struct {
selectTxnIDStmt *sql.Stmt
}
func (s *txnStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(txnIDSchema)
if err != nil {
return
}
if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil {
return
}
return
}
// selectTxnID selects the latest ascending transaction ID
func (s *txnStatements) selectTxnID(
ctx context.Context,
) (txnID int, err error) {
err = s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
return
}

View file

@ -1,4 +1,4 @@
// Copyright 2018 New Vector Ltd // Copyright 2020 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,99 +12,28 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// +build !wasm
package storage package storage
import ( import (
"context" "net/url"
"database/sql"
// Import postgres database driver "github.com/matrix-org/dendrite/appservice/storage/postgres"
_ "github.com/lib/pq" "github.com/matrix-org/dendrite/appservice/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
) )
// Database stores events intended to be later sent to application services func NewDatabase(dataSourceName string) (Database, error) {
type Database struct { uri, err := url.Parse(dataSourceName)
events eventsStatements if err != nil {
txnID txnStatements return postgres.NewDatabase(dataSourceName)
db *sql.DB
}
// NewDatabase opens a new database
func NewDatabase(dataSourceName string) (*Database, error) {
var result Database
var err error
if result.db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
} }
if err = result.prepare(); err != nil { switch uri.Scheme {
return nil, err case "postgres":
return postgres.NewDatabase(dataSourceName)
case "file":
return sqlite3.NewDatabase(dataSourceName)
default:
return postgres.NewDatabase(dataSourceName)
} }
return &result, nil
}
func (d *Database) prepare() error {
if err := d.events.prepare(d.db); err != nil {
return err
}
return d.txnID.prepare(d.db)
}
// StoreEvent takes in a gomatrixserverlib.Event and stores it in the database
// for a transaction worker to pull and later send to an application service.
func (d *Database) StoreEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.Event,
) error {
return d.events.insertEvent(ctx, appServiceID, event)
}
// GetEventsWithAppServiceID returns a slice of events and their IDs intended to
// be sent to an application service given its ID.
func (d *Database) GetEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
limit int,
) (int, int, []gomatrixserverlib.Event, bool, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
}
// CountEventsWithAppServiceID returns the number of events destined for an
// application service given its ID.
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
}
// UpdateTxnIDForEvents takes in an application service ID and a
// and stores them in the DB, unless the pair already exists, in
// which case it updates them.
func (d *Database) UpdateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) error {
return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID)
}
// RemoveEventsBeforeAndIncludingID removes all events from the database that
// are less than or equal to a given maximum ID. IDs here are implemented as a
// serial, thus this should always delete events in chronological order.
func (d *Database) RemoveEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) error {
return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID)
}
// GetLatestTxnID returns the latest available transaction id
func (d *Database) GetLatestTxnID(
ctx context.Context,
) (int, error) {
return d.txnID.selectTxnID(ctx)
} }

View file

@ -0,0 +1,37 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"fmt"
"net/url"
"github.com/matrix-org/dendrite/appservice/storage/sqlite3"
)
func NewDatabase(dataSourceName string) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return nil, fmt.Errorf("Cannot use postgres implementation")
}
switch uri.Scheme {
case "postgres":
return nil, fmt.Errorf("Cannot use postgres implementation")
case "file":
return sqlite3.NewDatabase(dataSourceName)
default:
return nil, fmt.Errorf("Cannot use postgres implementation")
}
}

View file

@ -43,7 +43,7 @@ var (
// size), then send that off to the AS's /transactions/{txnID} endpoint. It also // size), then send that off to the AS's /transactions/{txnID} endpoint. It also
// handles exponentially backing off in case the AS isn't currently available. // handles exponentially backing off in case the AS isn't currently available.
func SetupTransactionWorkers( func SetupTransactionWorkers(
appserviceDB *storage.Database, appserviceDB storage.Database,
workerStates []types.ApplicationServiceWorkerState, workerStates []types.ApplicationServiceWorkerState,
) error { ) error {
// Create a worker that handles transmitting events to a single homeserver // Create a worker that handles transmitting events to a single homeserver
@ -58,7 +58,7 @@ func SetupTransactionWorkers(
// worker is a goroutine that sends any queued events to the application service // worker is a goroutine that sends any queued events to the application service
// it is given. // it is given.
func worker(db *storage.Database, ws types.ApplicationServiceWorkerState) { func worker(db storage.Database, ws types.ApplicationServiceWorkerState) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": ws.AppService.ID, "appservice": ws.AppService.ID,
}).Info("starting application service") }).Info("starting application service")
@ -149,7 +149,7 @@ func backoff(ws *types.ApplicationServiceWorkerState, err error) {
// transaction, and JSON-encodes the results. // transaction, and JSON-encodes the results.
func createTransaction( func createTransaction(
ctx context.Context, ctx context.Context,
db *storage.Database, db storage.Database,
appserviceID string, appserviceID string,
) ( ) (
transactionJSON []byte, transactionJSON []byte,
@ -181,9 +181,14 @@ func createTransaction(
} }
} }
var ev []gomatrixserverlib.Event
for _, e := range events {
ev = append(ev, e.Event)
}
// Create a transaction and store the events inside // Create a transaction and store the events inside
transaction := gomatrixserverlib.ApplicationServiceTransaction{ transaction := gomatrixserverlib.ApplicationServiceTransaction{
Events: events, Events: ev,
} }
transactionJSON, err = json.Marshal(transaction) transactionJSON, err = json.Marshal(transaction)

830
are-we-synapse-yet.list Normal file
View file

@ -0,0 +1,830 @@
reg GET /register yields a set of flows
reg POST /register can create a user
reg POST /register downcases capitals in usernames
reg POST /register returns the same device_id as that in the request
reg POST /register rejects registration of usernames with '!'
reg POST /register rejects registration of usernames with '"'
reg POST /register rejects registration of usernames with ':'
reg POST /register rejects registration of usernames with '?'
reg POST /register rejects registration of usernames with '\'
reg POST /register rejects registration of usernames with '@'
reg POST /register rejects registration of usernames with '['
reg POST /register rejects registration of usernames with ']'
reg POST /register rejects registration of usernames with '{'
reg POST /register rejects registration of usernames with '|'
reg POST /register rejects registration of usernames with '}'
reg POST /register rejects registration of usernames with '£'
reg POST /register rejects registration of usernames with 'é'
reg POST /register rejects registration of usernames with '\n'
reg POST /register rejects registration of usernames with '''
reg POST /r0/admin/register with shared secret
reg POST /r0/admin/register admin with shared secret
reg POST /r0/admin/register with shared secret downcases capitals
reg POST /r0/admin/register with shared secret disallows symbols
reg POST rejects invalid utf-8 in JSON
log GET /login yields a set of flows
log POST /login can log in as a user
log POST /login returns the same device_id as that in the request
log POST /login can log in as a user with just the local part of the id
log POST /login as non-existing user is rejected
log POST /login wrong password is rejected
v1s GET /events initially
v1s GET /initialSync initially
csa Version responds 200 OK with valid structure
pro PUT /profile/:user_id/displayname sets my name
pro GET /profile/:user_id/displayname publicly accessible
pro PUT /profile/:user_id/avatar_url sets my avatar
pro GET /profile/:user_id/avatar_url publicly accessible
dev GET /device/{deviceId}
dev GET /device/{deviceId} gives a 404 for unknown devices
dev GET /devices
dev PUT /device/{deviceId} updates device fields
dev PUT /device/{deviceId} gives a 404 for unknown devices
dev DELETE /device/{deviceId}
dev DELETE /device/{deviceId} requires UI auth user to match device owner
dev DELETE /device/{deviceId} with no body gives a 401
dev The deleted device must be consistent through an interactive auth session
pre GET /presence/:user_id/status fetches initial status
pre PUT /presence/:user_id/status updates my presence
crm POST /createRoom makes a public room
crm POST /createRoom makes a private room
crm POST /createRoom makes a private room with invites
crm POST /createRoom makes a room with a name
crm POST /createRoom makes a room with a topic
syn Can /sync newly created room
crm POST /createRoom creates a room with the given version
crm POST /createRoom rejects attempts to create rooms with numeric versions
crm POST /createRoom rejects attempts to create rooms with unknown versions
crm POST /createRoom ignores attempts to set the room version via creation_content
mem GET /rooms/:room_id/state/m.room.member/:user_id fetches my membership
mem GET /rooms/:room_id/state/m.room.member/:user_id?format=event fetches my membership event
rst GET /rooms/:room_id/state/m.room.power_levels fetches powerlevels
mem GET /rooms/:room_id/joined_members fetches my membership
v1s GET /rooms/:room_id/initialSync fetches initial sync state
pub GET /publicRooms lists newly-created room
ali GET /directory/room/:room_alias yields room ID
mem GET /joined_rooms lists newly-created room
rst POST /rooms/:room_id/state/m.room.name sets name
rst GET /rooms/:room_id/state/m.room.name gets name
rst POST /rooms/:room_id/state/m.room.topic sets topic
rst GET /rooms/:room_id/state/m.room.topic gets topic
rst GET /rooms/:room_id/state fetches entire room state
crm POST /createRoom with creation content
ali PUT /directory/room/:room_alias creates alias
nsp GET /rooms/:room_id/aliases lists aliases
jon POST /rooms/:room_id/join can join a room
jon POST /join/:room_alias can join a room
jon POST /join/:room_id can join a room
jon POST /join/:room_id can join a room with custom content
jon POST /join/:room_alias can join a room with custom content
lev POST /rooms/:room_id/leave can leave a room
inv POST /rooms/:room_id/invite can send an invite
ban POST /rooms/:room_id/ban can ban a user
snd POST /rooms/:room_id/send/:event_type sends a message
snd PUT /rooms/:room_id/send/:event_type/:txn_id sends a message
snd PUT /rooms/:room_id/send/:event_type/:txn_id deduplicates the same txn id
get GET /rooms/:room_id/messages returns a message
get GET /rooms/:room_id/messages lazy loads members correctly
typ PUT /rooms/:room_id/typing/:user_id sets typing notification
rst GET /rooms/:room_id/state/m.room.power_levels can fetch levels
rst PUT /rooms/:room_id/state/m.room.power_levels can set levels
rst PUT power_levels should not explode if the old power levels were empty
rst Both GET and PUT work
rct POST /rooms/:room_id/receipt can create receipts
red POST /rooms/:room_id/read_markers can create read marker
med POST /media/v1/upload can create an upload
med GET /media/v1/download can fetch the value again
cap GET /capabilities is present and well formed for registered user
cap GET /r0/capabilities is not public
reg Register with a recaptcha
reg registration is idempotent, without username specified
reg registration is idempotent, with username specified
reg registration remembers parameters
reg registration accepts non-ascii passwords
reg registration with inhibit_login inhibits login
reg User signups are forbidden from starting with '_'
reg Can register using an email address
log Can login with 3pid and password using m.login.password
log login types include SSO
log /login/cas/redirect redirects if the old m.login.cas login type is listed
log Can login with new user via CAS
lox Can logout current device
lox Can logout all devices
lox Request to logout with invalid an access token is rejected
lox Request to logout without an access token is rejected
log After changing password, can't log in with old password
log After changing password, can log in with new password
log After changing password, existing session still works
log After changing password, a different session no longer works by default
log After changing password, different sessions can optionally be kept
psh Pushers created with a different access token are deleted on password change
psh Pushers created with a the same access token are not deleted on password change
acc Can deactivate account
acc Can't deactivate account with wrong password
acc After deactivating account, can't log in with password
acc After deactivating account, can't log in with an email
v1s initialSync sees my presence status
pre Presence change reports an event to myself
pre Friends presence changes reports events
crm Room creation reports m.room.create to myself
crm Room creation reports m.room.member to myself
rst Setting room topic reports m.room.topic to myself
v1s Global initialSync
v1s Global initialSync with limit=0 gives no messages
v1s Room initialSync
v1s Room initialSync with limit=0 gives no messages
rst Setting state twice is idempotent
jon Joining room twice is idempotent
syn New room members see their own join event
v1s New room members see existing users' presence in room initialSync
syn Existing members see new members' join events
syn Existing members see new members' presence
v1s All room members see all room members' presence in global initialSync
f,jon Remote users can join room by alias
syn New room members see their own join event
v1s New room members see existing members' presence in room initialSync
syn Existing members see new members' join events
syn Existing members see new member's presence
v1s New room members see first user's profile information in global initialSync
v1s New room members see first user's profile information in per-room initialSync
f,jon Remote users may not join unfederated rooms
syn Local room members see posted message events
v1s Fetching eventstream a second time doesn't yield the message again
syn Local non-members don't see posted message events
get Local room members can get room messages
f,syn Remote room members also see posted message events
f,get Remote room members can get room messages
get Message history can be paginated
f,get Message history can be paginated over federation
eph Ephemeral messages received from clients are correctly expired
ali Room aliases can contain Unicode
f,ali Remote room alias queries can handle Unicode
ali Canonical alias can be set
ali Canonical alias can include alt_aliases
ali Regular users can add and delete aliases in the default room configuration
ali Regular users can add and delete aliases when m.room.aliases is restricted
ali Deleting a non-existent alias should return a 404
ali Users can't delete other's aliases
ali Users with sufficient power-level can delete other's aliases
ali Can delete canonical alias
ali Alias creators can delete alias with no ops
ali Alias creators can delete canonical alias with no ops
ali Only room members can list aliases of a room
inv Can invite users to invite-only rooms
inv Uninvited users cannot join the room
inv Invited user can reject invite
f,inv Invited user can reject invite over federation
f,inv Invited user can reject invite over federation several times
inv Invited user can reject invite for empty room
f,inv Invited user can reject invite over federation for empty room
inv Invited user can reject local invite after originator leaves
inv Invited user can see room metadata
f,inv Remote invited user can see room metadata
inv Users cannot invite themselves to a room
inv Users cannot invite a user that is already in the room
ban Banned user is kicked and may not rejoin until unbanned
f,ban Remote banned user is kicked and may not rejoin until unbanned
ban 'ban' event respects room powerlevel
plv setting 'm.room.name' respects room powerlevel
plv setting 'm.room.power_levels' respects room powerlevel (2 subtests)
plv Unprivileged users can set m.room.topic if it only needs level 0
plv Users cannot set ban powerlevel higher than their own (2 subtests)
plv Users cannot set kick powerlevel higher than their own (2 subtests)
plv Users cannot set redact powerlevel higher than their own (2 subtests)
v1s Check that event streams started after a client joined a room work (SYT-1)
v1s Event stream catches up fully after many messages
xxx POST /rooms/:room_id/redact/:event_id as power user redacts message
xxx POST /rooms/:room_id/redact/:event_id as original message sender redacts message
xxx POST /rooms/:room_id/redact/:event_id as random user does not redact message
xxx POST /redact disallows redaction of event in different room
xxx Redaction of a redaction redacts the redaction reason
v1s A departed room is still included in /initialSync (SPEC-216)
v1s Can get rooms/{roomId}/initialSync for a departed room (SPEC-216)
rst Can get rooms/{roomId}/state for a departed room (SPEC-216)
mem Can get rooms/{roomId}/members for a departed room (SPEC-216)
get Can get rooms/{roomId}/messages for a departed room (SPEC-216)
rst Can get 'm.room.name' state for a departed room (SPEC-216)
syn Getting messages going forward is limited for a departed room (SPEC-216)
3pd Can invite existing 3pid
3pd Can invite existing 3pid with no ops into a private room
3pd Can invite existing 3pid in createRoom
3pd Can invite unbound 3pid
f,3pd Can invite unbound 3pid over federation
3pd Can invite unbound 3pid with no ops into a private room
f,3pd Can invite unbound 3pid over federation with no ops into a private room
f,3pd Can invite unbound 3pid over federation with users from both servers
3pd Can accept unbound 3pid invite after inviter leaves
3pd Can accept third party invite with /join
3pd 3pid invite join with wrong but valid signature are rejected
3pd 3pid invite join valid signature but revoked keys are rejected
3pd 3pid invite join valid signature but unreachable ID server are rejected
gst Guest user cannot call /events globally
gst Guest users can join guest_access rooms
gst Guest users can send messages to guest_access rooms if joined
gst Guest user calling /events doesn't tightloop
gst Guest users are kicked from guest_access rooms on revocation of guest_access
gst Guest user can set display names
gst Guest users are kicked from guest_access rooms on revocation of guest_access over federation
gst Guest user can upgrade to fully featured user
gst Guest user cannot upgrade other users
pub GET /publicRooms lists rooms
pub GET /publicRooms includes avatar URLs
gst Guest users can accept invites to private rooms over federation
gst Guest users denied access over federation if guest access prohibited
mem Room members can override their displayname on a room-specific basis
mem Room members can join a room with an overridden displayname
mem Users cannot kick users from a room they are not in
mem Users cannot kick users who have already left a room
typ Typing notification sent to local room members
f,typ Typing notifications also sent to remote room members
typ Typing can be explicitly stopped
rct Read receipts are visible to /initialSync
rct Read receipts are sent as events
rct Receipts must be m.read
pro displayname updates affect room member events
pro avatar_url updates affect room member events
gst m.room.history_visibility == "world_readable" allows/forbids appropriately for Guest users
gst m.room.history_visibility == "shared" allows/forbids appropriately for Guest users
gst m.room.history_visibility == "invited" allows/forbids appropriately for Guest users
gst m.room.history_visibility == "joined" allows/forbids appropriately for Guest users
gst m.room.history_visibility == "default" allows/forbids appropriately for Guest users
gst Guest non-joined user cannot call /events on shared room
gst Guest non-joined user cannot call /events on invited room
gst Guest non-joined user cannot call /events on joined room
gst Guest non-joined user cannot call /events on default room
gst Guest non-joined user can call /events on world_readable room
gst Guest non-joined users can get state for world_readable rooms
gst Guest non-joined users can get individual state for world_readable rooms
gst Guest non-joined users cannot room initalSync for non-world_readable rooms
gst Guest non-joined users can room initialSync for world_readable rooms
gst Guest non-joined users can get individual state for world_readable rooms after leaving
gst Guest non-joined users cannot send messages to guest_access rooms if not joined
gst Guest users can sync from world_readable guest_access rooms if joined
gst Guest users can sync from shared guest_access rooms if joined
gst Guest users can sync from invited guest_access rooms if joined
gst Guest users can sync from joined guest_access rooms if joined
gst Guest users can sync from default guest_access rooms if joined
ath m.room.history_visibility == "world_readable" allows/forbids appropriately for Real users
ath m.room.history_visibility == "shared" allows/forbids appropriately for Real users
ath m.room.history_visibility == "invited" allows/forbids appropriately for Real users
ath m.room.history_visibility == "joined" allows/forbids appropriately for Real users
ath m.room.history_visibility == "default" allows/forbids appropriately for Real users
ath Real non-joined user cannot call /events on shared room
ath Real non-joined user cannot call /events on invited room
ath Real non-joined user cannot call /events on joined room
ath Real non-joined user cannot call /events on default room
ath Real non-joined user can call /events on world_readable room
ath Real non-joined users can get state for world_readable rooms
ath Real non-joined users can get individual state for world_readable rooms
ath Real non-joined users cannot room initalSync for non-world_readable rooms
ath Real non-joined users can room initialSync for world_readable rooms
ath Real non-joined users can get individual state for world_readable rooms after leaving
ath Real non-joined users cannot send messages to guest_access rooms if not joined
ath Real users can sync from world_readable guest_access rooms if joined
ath Real users can sync from shared guest_access rooms if joined
ath Real users can sync from invited guest_access rooms if joined
ath Real users can sync from joined guest_access rooms if joined
ath Real users can sync from default guest_access rooms if joined
ath Only see history_visibility changes on boundaries
f,ath Backfill works correctly with history visibility set to joined
fgt Forgotten room messages cannot be paginated
fgt Forgetting room does not show up in v2 /sync
fgt Can forget room you've been kicked from
fgt Can't forget room you're still in
mem Can re-join room if re-invited
ath Only original members of the room can see messages from erased users
mem /joined_rooms returns only joined rooms
mem /joined_members return joined members
ctx /context/ on joined room works
ctx /context/ on non world readable room does not work
ctx /context/ returns correct number of events
ctx /context/ with lazy_load_members filter works
get /event/ on joined room works
get /event/ on non world readable room does not work
get /event/ does not allow access to events before the user joined
mem Can get rooms/{roomId}/members
mem Can get rooms/{roomId}/members at a given point
mem Can filter rooms/{roomId}/members
upg /upgrade creates a new room
upg /upgrade should preserve room visibility for public rooms
upg /upgrade should preserve room visibility for private rooms
upg /upgrade copies >100 power levels to the new room
upg /upgrade copies the power levels to the new room
upg /upgrade preserves the power level of the upgrading user in old and new rooms
upg /upgrade copies important state to the new room
upg /upgrade copies ban events to the new room
upg local user has push rules copied to upgraded room
f,upg remote user has push rules copied to upgraded room
upg /upgrade moves aliases to the new room
upg /upgrade moves remote aliases to the new room
upg /upgrade preserves direct room state
upg /upgrade preserves room federation ability
upg /upgrade restricts power levels in the old room
upg /upgrade restricts power levels in the old room when the old PLs are unusual
upg /upgrade to an unknown version is rejected
upg /upgrade is rejected if the user can't send state events
upg /upgrade of a bogus room fails gracefully
upg Cannot send tombstone event that points to the same room
f,upg Local and remote users' homeservers remove a room from their public directory on upgrade
rst Name/topic keys are correct
f,pub Can get remote public room list
pub Can paginate public room list
pub Can search public room list
syn Can create filter
syn Can download filter
syn Can sync
syn Can sync a joined room
syn Full state sync includes joined rooms
syn Newly joined room is included in an incremental sync
syn Newly joined room has correct timeline in incremental sync
syn Newly joined room includes presence in incremental sync
syn Get presence for newly joined members in incremental sync
syn Can sync a room with a single message
syn Can sync a room with a message with a transaction id
syn A message sent after an initial sync appears in the timeline of an incremental sync.
syn A filtered timeline reaches its limit
syn Syncing a new room with a large timeline limit isn't limited
syn A full_state incremental update returns only recent timeline
syn A prev_batch token can be used in the v1 messages API
syn A next_batch token can be used in the v1 messages API
syn User sees their own presence in a sync
syn User is offline if they set_presence=offline in their sync
syn User sees updates to presence from other users in the incremental sync.
syn State is included in the timeline in the initial sync
f,syn State from remote users is included in the state in the initial sync
syn Changes to state are included in an incremental sync
syn Changes to state are included in an gapped incremental sync
f,syn State from remote users is included in the timeline in an incremental sync
syn A full_state incremental update returns all state
syn When user joins a room the state is included in the next sync
syn A change to displayname should not result in a full state sync
syn A change to displayname should appear in incremental /sync
syn When user joins a room the state is included in a gapped sync
syn When user joins and leaves a room in the same batch, the full state is still included in the next sync
syn Current state appears in timeline in private history
syn Current state appears in timeline in private history with many messages before
syn Current state appears in timeline in private history with many messages after
syn Rooms a user is invited to appear in an initial sync
syn Rooms a user is invited to appear in an incremental sync
syn Newly joined room is included in an incremental sync after invite
syn Sync can be polled for updates
syn Sync is woken up for leaves
syn Left rooms appear in the leave section of sync
syn Newly left rooms appear in the leave section of incremental sync
syn We should see our own leave event, even if history_visibility is restricted (SYN-662)
syn We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462)
syn Newly left rooms appear in the leave section of gapped sync
syn Previously left rooms don't appear in the leave section of sync
syn Left rooms appear in the leave section of full state sync
syn Archived rooms only contain history from before the user left
syn Banned rooms appear in the leave section of sync
syn Newly banned rooms appear in the leave section of incremental sync
syn Newly banned rooms appear in the leave section of incremental sync
syn Typing events appear in initial sync
syn Typing events appear in incremental sync
syn Typing events appear in gapped sync
syn Read receipts appear in initial v2 /sync
syn New read receipts appear in incremental v2 /sync
syn Can pass a JSON filter as a query parameter
syn Can request federation format via the filter
syn Read markers appear in incremental v2 /sync
syn Read markers appear in initial v2 /sync
syn Read markers can be updated
syn Lazy loading parameters in the filter are strictly boolean
syn The only membership state included in an initial sync is for all the senders in the timeline
syn The only membership state included in an incremental sync is for senders in the timeline
syn The only membership state included in a gapped incremental sync is for senders in the timeline
syn Gapped incremental syncs include all state changes
syn Old leaves are present in gapped incremental syncs
syn Leaves are present in non-gapped incremental syncs
syn Old members are included in gappy incr LL sync if they start speaking
syn Members from the gap are included in gappy incr LL sync
syn We don't send redundant membership state across incremental syncs by default
syn We do send redundant membership state across incremental syncs if asked
syn Unnamed room comes with a name summary
syn Named room comes with just joined member count summary
syn Room summary only has 5 heroes
syn Room summary counts change when membership changes
rmv User can create and send/receive messages in a room with version 1
rmv User can create and send/receive messages in a room with version 1 (2 subtests)
rmv local user can join room with version 1
rmv User can invite local user to room with version 1
rmv remote user can join room with version 1
rmv User can invite remote user to room with version 1
rmv Remote user can backfill in a room with version 1
rmv Can reject invites over federation for rooms with version 1
rmv Can receive redactions from regular users over federation in room version 1
rmv User can create and send/receive messages in a room with version 2
rmv User can create and send/receive messages in a room with version 2 (2 subtests)
rmv local user can join room with version 2
rmv User can invite local user to room with version 2
rmv remote user can join room with version 2
rmv User can invite remote user to room with version 2
rmv Remote user can backfill in a room with version 2
rmv Can reject invites over federation for rooms with version 2
rmv Can receive redactions from regular users over federation in room version 2
rmv User can create and send/receive messages in a room with version 3
rmv User can create and send/receive messages in a room with version 3 (2 subtests)
rmv local user can join room with version 3
rmv User can invite local user to room with version 3
rmv remote user can join room with version 3
rmv User can invite remote user to room with version 3
rmv Remote user can backfill in a room with version 3
rmv Can reject invites over federation for rooms with version 3
rmv Can receive redactions from regular users over federation in room version 3
rmv User can create and send/receive messages in a room with version 4
rmv User can create and send/receive messages in a room with version 4 (2 subtests)
rmv local user can join room with version 4
rmv User can invite local user to room with version 4
rmv remote user can join room with version 4
rmv User can invite remote user to room with version 4
rmv Remote user can backfill in a room with version 4
rmv Can reject invites over federation for rooms with version 4
rmv Can receive redactions from regular users over federation in room version 4
rmv User can create and send/receive messages in a room with version 5
rmv User can create and send/receive messages in a room with version 5 (2 subtests)
rmv local user can join room with version 5
rmv User can invite local user to room with version 5
rmv remote user can join room with version 5
rmv User can invite remote user to room with version 5
rmv Remote user can backfill in a room with version 5
rmv Can reject invites over federation for rooms with version 5
rmv Can receive redactions from regular users over federation in room version 5
pre Presence changes are reported to local room members
f,pre Presence changes are also reported to remote room members
pre Presence changes to UNAVAILABLE are reported to local room members
f,pre Presence changes to UNAVAILABLE are reported to remote room members
v1s Newly created users see their own presence in /initialSync (SYT-34)
dvk Can upload device keys
dvk Should reject keys claiming to belong to a different user
dvk Can query device keys using POST
dvk Can query specific device keys using POST
dvk query for user with no keys returns empty key dict
dvk Can claim one time key using POST
f,dvk Can query remote device keys using POST
f,dvk Can claim remote one time key using POST
dvk Local device key changes appear in v2 /sync
dvk Local new device changes appear in v2 /sync
dvk Local delete device changes appear in v2 /sync
dvk Local update device changes appear in v2 /sync
dvk Can query remote device keys using POST after notification
f,dev Device deletion propagates over federation
f,dev If remote user leaves room, changes device and rejoins we see update in sync
f,dev If remote user leaves room we no longer receive device updates
dvk Local device key changes appear in /keys/changes
dvk New users appear in /keys/changes
f,dvk If remote user leaves room, changes device and rejoins we see update in /keys/changes
dvk Get left notifs in sync and /keys/changes when other user leaves
dvk Get left notifs for other users in sync and /keys/changes when user leaves
f,dvk If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes
dvk Can create backup version
dvk Can update backup version
dvk Responds correctly when backup is empty
dvk Can backup keys
dvk Can update keys with better versions
dvk Will not update keys with worse versions
dvk Will not back up to an old backup version
dvk Can delete backup
dvk Deleted & recreated backups are empty
dvk Can create more than 10 backup versions
dvk Can upload self-signing keys
dvk Fails to upload self-signing keys with no auth
dvk Fails to upload self-signing key without master key
dvk Changing master key notifies local users
dvk Changing user-signing key notifies local users
f,dvk can fetch self-signing keys over federation
f,dvk uploading self-signing key notifies over federation
f,dvk uploading signed devices gets propagated over federation
tag Can add tag
tag Can remove tag
tag Can list tags for a room
v1s Tags appear in the v1 /events stream
v1s Tags appear in the v1 /initalSync
v1s Tags appear in the v1 room initial sync
tag Tags appear in an initial v2 /sync
tag Newly updated tags appear in an incremental v2 /sync
tag Deleted tags appear in an incremental v2 /sync
tag local user has tags copied to the new room
f,tag remote user has tags copied to the new room
sch Can search for an event by body
sch Can get context around search results
sch Can back-paginate search results
sch Search works across an upgraded room and its predecessor
sch Search results with rank ordering do not include redacted events
sch Search results with recent ordering do not include redacted events
acc Can add account data
acc Can add account data to room
acc Can get account data without syncing
acc Can get room account data without syncing
v1s Latest account data comes down in /initialSync
v1s Latest account data comes down in room initialSync
v1s Account data appears in v1 /events stream
v1s Room account data appears in v1 /events stream
acc Latest account data appears in v2 /sync
acc New account data appears in incremental v2 /sync
oid Can generate a openid access_token that can be exchanged for information about a user
oid Invalid openid access tokens are rejected
oid Requests to userinfo without access tokens are rejected
std Can send a message directly to a device using PUT /sendToDevice
std Can recv a device message using /sync
std Can recv device messages until they are acknowledged
std Device messages with the same txn_id are deduplicated
std Device messages wake up /sync
std Can recv device messages over federation
std Device messages over federation wake up /sync
std Can send messages with a wildcard device id
std Can send messages with a wildcard device id to two devices
std Wildcard device messages wake up /sync
std Wildcard device messages over federation wake up /sync
adm /whois
nsp /purge_history
nsp /purge_history by ts
nsp Can backfill purged history
nsp Shutdown room
ign Ignore user in existing room
ign Ignore invite in full sync
ign Ignore invite in incremental sync
fky Checking local federation server
fky Federation key API allows unsigned requests for keys
fky Federation key API can act as a notary server via a GET request
fky Federation key API can act as a notary server via a POST request
fky Key notary server should return an expired key if it can't find any others
fky Key notary server must not overwrite a valid key with a spurious result from the origin server
fqu Non-numeric ports in server names are rejected
fqu Outbound federation can query profile data
fqu Inbound federation can query profile data
fqu Outbound federation can query room alias directory
fqu Inbound federation can query room alias directory
fsj Outbound federation can query v1 /send_join
fsj Outbound federation can query v2 /send_join
fmj Outbound federation passes make_join failures through to the client
fsj Inbound federation can receive v1 /send_join
fsj Inbound federation can receive v2 /send_join
fmj Inbound /v1/make_join rejects remote attempts to join local users to rooms
fsj Inbound /v1/send_join rejects incorrectly-signed joins
fsj Inbound /v1/send_join rejects joins from other servers
fau Inbound federation rejects remote attempts to kick local users to rooms
frv Inbound federation rejects attempts to join v1 rooms from servers without v1 support
frv Inbound federation rejects attempts to join v2 rooms from servers lacking version support
frv Inbound federation rejects attempts to join v2 rooms from servers only supporting v1
frv Inbound federation accepts attempts to join v2 rooms from servers with support
frv Outbound federation correctly handles unsupported room versions
frv A pair of servers can establish a join in a v2 room
fsj Outbound federation rejects send_join responses with no m.room.create event
frv Outbound federation rejects m.room.create events with an unknown room version
fsj Event with an invalid signature in the send_join response should not cause room join to fail
fed Outbound federation can send events
fed Inbound federation can receive events
fed Inbound federation can receive redacted events
fed Ephemeral messages received from servers are correctly expired
fed Events whose auth_events are in the wrong room do not mess up the room state
fed Inbound federation can return events
fed Inbound federation redacts events from erased users
fme Outbound federation can request missing events
fme Inbound federation can return missing events for world_readable visibility
fme Inbound federation can return missing events for shared visibility
fme Inbound federation can return missing events for invite visibility
fme Inbound federation can return missing events for joined visibility
fme outliers whose auth_events are in a different room are correctly rejected
fbk Outbound federation can backfill events
fbk Inbound federation can backfill events
fbk Backfill checks the events requested belong to the room
fbk Backfilled events whose prev_events are in a different room do not allow cross-room back-pagination
fiv Outbound federation can send invites via v1 API
fiv Outbound federation can send invites via v2 API
fiv Inbound federation can receive invites via v1 API
fiv Inbound federation can receive invites via v2 API
fiv Inbound federation can receive invite and reject when remote replies with a 403
fiv Inbound federation can receive invite and reject when remote replies with a 500
fiv Inbound federation can receive invite and reject when remote is unreachable
fiv Inbound federation rejects invites which are not signed by the sender
fiv Inbound federation can receive invite rejections
fiv Inbound federation rejects incorrectly-signed invite rejections
fsl Inbound /v1/send_leave rejects leaves from other servers
fst Inbound federation can get state for a room
fst Inbound federation of state requires event_id as a mandatory paramater
fst Inbound federation can get state_ids for a room
fst Inbound federation of state_ids requires event_id as a mandatory paramater
fst Federation rejects inbound events where the prev_events cannot be found
fst Room state at a rejected message event is the same as its predecessor
fst Room state at a rejected state event is the same as its predecessor
fst Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
fst Federation handles empty auth_events in state_ids sanely
fst Getting state checks the events requested belong to the room
fst Getting state IDs checks the events requested belong to the room
fst Should not be able to take over the room by pretending there is no PL event
fpb Inbound federation can get public room list
fed Outbound federation sends receipts
fed Inbound federation rejects receipts from wrong remote
fed Inbound federation ignores redactions from invalid servers room > v3
fed An event which redacts an event in a different room should be ignored
fed An event which redacts itself should be ignored
fed A pair of events which redact each other should be ignored
fdk Local device key changes get to remote servers
fdk Server correctly handles incoming m.device_list_update
fdk Server correctly resyncs when client query keys and there is no remote cache
fdk Server correctly resyncs when server leaves and rejoins a room
fdk Local device key changes get to remote servers with correct prev_id
fdk Device list doesn't change if remote server is down
fdk If a device list update goes missing, the server resyncs on the next one
fst Name/topic keys are correct
fau Remote servers cannot set power levels in rooms without existing powerlevels
fau Remote servers should reject attempts by non-creators to set the power levels
fau Inbound federation rejects typing notifications from wrong remote
fed Forward extremities remain so even after the next events are populated as outliers
fau Banned servers cannot send events
fau Banned servers cannot /make_join
fau Banned servers cannot /send_join
fau Banned servers cannot /make_leave
fau Banned servers cannot /send_leave
fau Banned servers cannot /invite
fau Banned servers cannot get room state
fau Banned servers cannot get room state ids
fau Banned servers cannot backfill
fau Banned servers cannot /event_auth
fau Banned servers cannot get missing events
fau Server correctly handles transactions that break edu limits
fau Inbound federation correctly soft fails events
fau Inbound federation accepts a second soft-failed event
fau Inbound federation correctly handles soft failed events as extremities
med Can upload with Unicode file name
med Can download with Unicode file name locally
f,med Can download with Unicode file name over federation
med Alternative server names do not cause a routing loop
med Can download specifying a different Unicode file name
med Can upload without a file name
med Can download without a file name locally
f,med Can download without a file name over federation
med Can upload with ASCII file name
med Can download file 'ascii'
med Can download file 'name with spaces'
med Can download file 'name;with;semicolons'
med Can download specifying a different ASCII file name
med Can send image in room message
med Can fetch images in room
med POSTed media can be thumbnailed
f,med Remote media can be thumbnailed
med Test URL preview
med Can read configuration endpoint
nsp Can quarantine media in rooms
udr User appears in user directory
udr User in private room doesn't appear in user directory
udr User joining then leaving public room appears and dissappears from directory
udr Users appear/disappear from directory when join_rules are changed
udr Users appear/disappear from directory when history_visibility are changed
udr Users stay in directory when join_rules are changed but history_visibility is world_readable
f,udr User in remote room doesn't appear in user directory after server left room
udr User directory correctly update on display name change
udr User in shared private room does appear in user directory
udr User in shared private room does appear in user directory until leave
udr User in dir while user still shares private rooms
nsp Create group
nsp Add group rooms
nsp Remove group rooms
nsp Get local group profile
nsp Get local group users
nsp Add/remove local group rooms
nsp Get local group summary
nsp Get remote group profile
nsp Get remote group users
nsp Add/remove remote group rooms
nsp Get remote group summary
nsp Add local group users
nsp Remove self from local group
nsp Remove other from local group
nsp Add remote group users
nsp Remove self from remote group
nsp Listing invited users of a remote group when not a member returns a 403
nsp Add group category
nsp Remove group category
nsp Get group categories
nsp Add group role
nsp Remove group role
nsp Get group roles
nsp Add room to group summary
nsp Adding room to group summary keeps room_id when fetching rooms in group
nsp Adding multiple rooms to group summary have correct order
nsp Remove room from group summary
nsp Add room to group summary with category
nsp Remove room from group summary with category
nsp Add user to group summary
nsp Adding multiple users to group summary have correct order
nsp Remove user from group summary
nsp Add user to group summary with role
nsp Remove user from group summary with role
nsp Local group invites come down sync
nsp Group creator sees group in sync
nsp Group creator sees group in initial sync
nsp Get/set local group publicity
nsp Bulk get group publicity
nsp Joinability comes down summary
nsp Set group joinable and join it
nsp Group is not joinable by default
nsp Group is joinable over federation
nsp Room is transitioned on local and remote groups upon room upgrade
3pd Can bind 3PID via home server
3pd Can bind and unbind 3PID via homeserver
3pd Can unbind 3PID via homeserver when bound out of band
3pd 3PIDs are unbound after account deactivation
3pd Can bind and unbind 3PID via /unbind by specifying the identity server
3pd Can bind and unbind 3PID via /unbind without specifying the identity server
app AS can create a user
app AS can create a user with an underscore
app AS can create a user with inhibit_login
app AS cannot create users outside its own namespace
app Regular users cannot register within the AS namespace
app AS can make room aliases
app Regular users cannot create room aliases within the AS namespace
app AS-ghosted users can use rooms via AS
app AS-ghosted users can use rooms themselves
app Ghost user must register before joining room
app AS can set avatar for ghosted users
app AS can set displayname for ghosted users
app AS can't set displayname for random users
app Inviting an AS-hosted user asks the AS server
app Accesing an AS-hosted room alias asks the AS server
app Events in rooms with AS-hosted room aliases are sent to AS server
app AS user (not ghost) can join room without registering
app AS user (not ghost) can join room without registering, with user_id query param
app HS provides query metadata
app HS can provide query metadata on a single protocol
app HS will proxy request for 3PU mapping
app HS will proxy request for 3PL mapping
app AS can publish rooms in their own list
app AS and main public room lists are separate
app AS can deactivate a user
psh Test that a message is pushed
psh Invites are pushed
psh Rooms with names are correctly named in pushed
psh Rooms with canonical alias are correctly named in pushed
psh Rooms with many users are correctly pushed
psh Don't get pushed for rooms you've muted
psh Rejected events are not pushed
psh Can add global push rule for room
psh Can add global push rule for sender
psh Can add global push rule for content
psh Can add global push rule for override
psh Can add global push rule for underride
psh Can add global push rule for content
psh New rules appear before old rules by default
psh Can add global push rule before an existing rule
psh Can add global push rule after an existing rule
psh Can delete a push rule
psh Can disable a push rule
psh Adding the same push rule twice is idempotent
psh Messages that notify from another user increment unread notification count
psh Messages that highlight from another user increment unread highlight count
psh Can change the actions of default rules
psh Changing the actions of an unknown default rule fails with 404
psh Can change the actions of a user specified rule
psh Changing the actions of an unknown rule fails with 404
psh Can fetch a user's pushers
psh Push rules come down in an initial /sync
psh Adding a push rule wakes up an incremental /sync
psh Disabling a push rule wakes up an incremental /sync
psh Enabling a push rule wakes up an incremental /sync
psh Setting actions for a push rule wakes up an incremental /sync
psh Can enable/disable default rules
psh Enabling an unknown default rule fails with 404
psh Test that rejected pushers are removed.
psh Notifications can be viewed with GET /notifications
psh Trying to add push rule with no scope fails with 400
psh Trying to add push rule with invalid scope fails with 400
psh Trying to add push rule with missing template fails with 400
psh Trying to add push rule with missing rule_id fails with 400
psh Trying to add push rule with empty rule_id fails with 400
psh Trying to add push rule with invalid template fails with 400
psh Trying to add push rule with rule_id with slashes fails with 400
psh Trying to add push rule with override rule without conditions fails with 400
psh Trying to add push rule with underride rule without conditions fails with 400
psh Trying to add push rule with condition without kind fails with 400
psh Trying to add push rule with content rule without pattern fails with 400
psh Trying to add push rule with no actions fails with 400
psh Trying to add push rule with invalid action fails with 400
psh Trying to add push rule with invalid attr fails with 400
psh Trying to add push rule with invalid value for enabled fails with 400
psh Trying to get push rules with no trailing slash fails with 400
psh Trying to get push rules with scope without trailing slash fails with 400
psh Trying to get push rules with template without tailing slash fails with 400
psh Trying to get push rules with unknown scope fails with 400
psh Trying to get push rules with unknown template fails with 400
psh Trying to get push rules with unknown attribute fails with 400
psh Trying to get push rules with unknown rule_id fails with 404
v1s GET /initialSync with non-numeric 'limit'
v1s GET /events with non-numeric 'limit'
v1s GET /events with negative 'limit'
v1s GET /events with non-numeric 'timeout'
ath Event size limits
syn Check creating invalid filters returns 4xx
f,pre New federated private chats get full presence information (SYN-115)
pre Left room members do not cause problems for presence
crm Rooms can be created with an initial invite list (SYN-205)
typ Typing notifications don't leak
ban Non-present room members cannot ban others
psh Getting push rules doesn't corrupt the cache SYN-390
inv Test that we can be reinvited to a room we created
syn Multiple calls to /sync should not cause 500 errors
gst Guest user can call /events on another world_readable room (SYN-606)
gst Real user can call /events on another world_readable room (SYN-606)
gst Events come down the correct room
pub Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list

252
are-we-synapse-yet.py Executable file
View file

@ -0,0 +1,252 @@
#!/usr/bin/env python3
from __future__ import division
import argparse
import re
import sys
# Usage: $ ./are-we-synapse-yet.py [-v] results.tap
# This script scans a results.tap file from Dendrite's CI process and spits out
# a rating of how close we are to Synapse parity, based purely on SyTests.
# The main complexity is grouping tests sensibly into features like 'Registration'
# and 'Federation'. Then it just checks the ones which are passing and calculates
# percentages for each group. Produces results like:
#
# Client-Server APIs: 29% (196/666 tests)
# -------------------
# Registration : 62% (20/32 tests)
# Login : 7% (1/15 tests)
# V1 CS APIs : 10% (3/30 tests)
# ...
#
# or in verbose mode:
#
# Client-Server APIs: 29% (196/666 tests)
# -------------------
# Registration : 62% (20/32 tests)
# ✓ GET /register yields a set of flows
# ✓ POST /register can create a user
# ✓ POST /register downcases capitals in usernames
# ...
#
# You can also tack `-v` on to see exactly which tests each category falls under.
test_mappings = {
"nsp": "Non-Spec API",
"f": "Federation", # flag to mark test involves federation
"federation_apis": {
"fky": "Key API",
"fsj": "send_join API",
"fmj": "make_join API",
"fsl": "send_leave API",
"fiv": "Invite API",
"fqu": "Query API",
"frv": "room versions",
"fau": "Auth",
"fbk": "Backfill API",
"fme": "get_missing_events API",
"fst": "State APIs",
"fpb": "Public Room API",
"fdk": "Device Key APIs",
"fed": "Federation API",
},
"client_apis": {
"reg": "Registration",
"log": "Login",
"lox": "Logout",
"v1s": "V1 CS APIs",
"csa": "Misc CS APIs",
"pro": "Profile",
"dev": "Devices",
"dvk": "Device Keys",
"pre": "Presence",
"crm": "Create Room",
"syn": "Sync API",
"rmv": "Room Versions",
"rst": "Room State APIs",
"pub": "Public Room APIs",
"mem": "Room Membership",
"ali": "Room Aliases",
"jon": "Joining Rooms",
"lev": "Leaving Rooms",
"inv": "Inviting users to Rooms",
"ban": "Banning users",
"snd": "Sending events",
"get": "Getting events for Rooms",
"rct": "Receipts",
"red": "Read markers",
"med": "Media APIs",
"cap": "Capabilities API",
"typ": "Typing API",
"psh": "Push APIs",
"acc": "Account APIs",
"eph": "Ephemeral Events",
"plv": "Power Levels",
"xxx": "Redaction",
"3pd": "Third-Party ID APIs",
"gst": "Guest APIs",
"ath": "Room Auth",
"fgt": "Forget APIs",
"ctx": "Context APIs",
"upg": "Room Upgrade APIs",
"tag": "Tagging APIs",
"sch": "Search APIs",
"oid": "OpenID API",
"std": "Send-to-Device APIs",
"adm": "Server Admin API",
"ign": "Ignore Users",
"udr": "User Directory APIs",
"app": "Application Services API",
},
}
# optional 'not ' with test number then anything but '#'
re_testname = re.compile(r"^(not )?ok [0-9]+ ([^#]+)")
# Parses lines like the following:
#
# SUCCESS: ok 3 POST /register downcases capitals in usernames
# FAIL: not ok 54 (expected fail) POST /createRoom creates a room with the given version
# SKIP: ok 821 Multiple calls to /sync should not cause 500 errors # skip lack of can_post_room_receipts
# EXPECT FAIL: not ok 822 (expected fail) Guest user can call /events on another world_readable room (SYN-606) # TODO expected fail
#
# Only SUCCESS lines are treated as success, the rest are not implemented.
#
# Returns a dict like:
# { name: "...", ok: True }
def parse_test_line(line):
if not line.startswith("ok ") and not line.startswith("not ok "):
return
re_match = re_testname.match(line)
test_name = re_match.groups()[1].replace("(expected fail) ", "").strip()
test_pass = False
if line.startswith("ok ") and not "# skip " in line:
test_pass = True
return {
"name": test_name,
"ok": test_pass,
}
# Prints the stats for a complete section.
# header_name => "Client-Server APIs"
# gid_to_tests => { gid: { <name>: True|False }}
# gid_to_name => { gid: "Group Name" }
# verbose => True|False
# Produces:
# Client-Server APIs: 29% (196/666 tests)
# -------------------
# Registration : 62% (20/32 tests)
# Login : 7% (1/15 tests)
# V1 CS APIs : 10% (3/30 tests)
# ...
# or in verbose mode:
# Client-Server APIs: 29% (196/666 tests)
# -------------------
# Registration : 62% (20/32 tests)
# ✓ GET /register yields a set of flows
# ✓ POST /register can create a user
# ✓ POST /register downcases capitals in usernames
# ...
def print_stats(header_name, gid_to_tests, gid_to_name, verbose):
subsections = [] # Registration: 100% (13/13 tests)
subsection_test_names = {} # 'subsection name': ["✓ Test 1", "✓ Test 2", "× Test 3"]
total_passing = 0
total_tests = 0
for gid, tests in gid_to_tests.items():
group_total = len(tests)
group_passing = 0
test_names_and_marks = []
for name, passing in tests.items():
if passing:
group_passing += 1
test_names_and_marks.append(f"{'' if passing else '×'} {name}")
total_tests += group_total
total_passing += group_passing
pct = "{0:.0f}%".format(group_passing/group_total * 100)
line = "%s: %s (%d/%d tests)" % (gid_to_name[gid].ljust(25, ' '), pct.rjust(4, ' '), group_passing, group_total)
subsections.append(line)
subsection_test_names[line] = test_names_and_marks
pct = "{0:.0f}%".format(total_passing/total_tests * 100)
print("%s: %s (%d/%d tests)" % (header_name, pct, total_passing, total_tests))
print("-" * (len(header_name)+1))
for line in subsections:
print(" %s" % (line,))
if verbose:
for test_name_and_pass_mark in subsection_test_names[line]:
print(" %s" % (test_name_and_pass_mark,))
print("")
print("")
def main(results_tap_path, verbose):
# Load up test mappings
test_name_to_group_id = {}
fed_tests = set()
client_tests = set()
with open("./are-we-synapse-yet.list", "r") as f:
for line in f.readlines():
test_name = " ".join(line.split(" ")[1:]).strip()
groups = line.split(" ")[0].split(",")
for gid in groups:
if gid == "f" or gid in test_mappings["federation_apis"]:
fed_tests.add(test_name)
else:
client_tests.add(test_name)
if gid == "f":
continue # we expect another group ID
test_name_to_group_id[test_name] = gid
# parse results.tap
summary = {
"client": {
# gid: {
# test_name: OK
# }
},
"federation": {
# gid: {
# test_name: OK
# }
},
"nonspec": {
"nsp": {}
},
}
with open(results_tap_path, "r") as f:
for line in f.readlines():
test_result = parse_test_line(line)
if not test_result:
continue
name = test_result["name"]
group_id = test_name_to_group_id.get(name)
if not group_id:
raise Exception("The test '%s' doesn't have a group" % (name,))
if group_id == "nsp":
summary["nonspec"]["nsp"][name] = test_result["ok"]
elif group_id in test_mappings["federation_apis"]:
group = summary["federation"].get(group_id, {})
group[name] = test_result["ok"]
summary["federation"][group_id] = group
elif group_id in test_mappings["client_apis"]:
group = summary["client"].get(group_id, {})
group[name] = test_result["ok"]
summary["client"][group_id] = group
print("Are We Synapse Yet?")
print("===================")
print("")
print_stats("Non-Spec APIs", summary["nonspec"], test_mappings, verbose)
print_stats("Client-Server APIs", summary["client"], test_mappings["client_apis"], verbose)
print_stats("Federation APIs", summary["federation"], test_mappings["federation_apis"], verbose)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("tap_file", help="path to results.tap")
parser.add_argument("-v", action="store_true", help="show individual test names in output")
args = parser.parse_args()
main(args.tap_file, args.v)

View file

@ -1,3 +1,6 @@
#!/bin/sh #!/bin/bash -eu
GOBIN=$PWD/`dirname $0`/bin go install -v $PWD/`dirname $0`/cmd/... # Put installed packages into ./bin
export GOBIN=$PWD/`dirname $0`/bin
go install -v $PWD/`dirname $0`/cmd/...

View file

@ -26,7 +26,6 @@ import (
"github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
@ -166,7 +165,8 @@ func verifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *auth
JSON: jsonerror.UnknownToken("Unknown token"), JSON: jsonerror.UnknownToken("Unknown token"),
} }
} else { } else {
jsonErr := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByAccessToken failed")
jsonErr := jsonerror.InternalServerError()
resErr = &jsonErr resErr = &jsonErr
} }
} }

View file

@ -26,4 +26,5 @@ type Device struct {
// associated with access tokens. // associated with access tokens.
SessionID int64 SessionID int64
// TODO: display name, last used timestamp, keys, etc // TODO: display name, last used timestamp, keys, etc
DisplayName string
} }

View file

@ -0,0 +1,54 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
import (
"context"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
common.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error)
CreateGuestAccount(ctx context.Context) (*authtypes.Account, error)
UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error)
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")

View file

@ -0,0 +1,143 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
-- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
room_id TEXT,
-- The account data type
type TEXT NOT NULL,
-- The account data content
content TEXT NOT NULL,
PRIMARY KEY(localpart, room_id, type)
);
`
const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
`
const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct {
insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(accountDataSchema)
if err != nil {
return
}
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
return
}
if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
return
}
if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil {
return
}
return
}
func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
) (err error) {
stmt := txn.Stmt(s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return
}
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
return
}
defer common.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
global = []gomatrixserverlib.ClientEvent{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
for rows.Next() {
var roomID string
var dataType string
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return
}
ac := gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
} else {
global = append(global, ac)
}
}
return global, rooms, rows.Err()
}
func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
stmt := s.selectAccountDataByTypeStmt
var content []byte
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return
}
data = &gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
return
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package postgres
import ( import (
"context" "context"
@ -91,10 +91,10 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := txn.Stmt(s.insertAccountStmt)
var err error var err error
if appserviceID == "" { if appserviceID == "" {
@ -146,8 +146,12 @@ func (s *accountsStatements) selectAccountByLocalpart(
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
return return
} }

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package postgres
import ( import (
"context" "context"

View file

@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -51,6 +53,9 @@ const selectMembershipsByLocalpartSQL = "" +
const selectMembershipInRoomByLocalpartSQL = "" + const selectMembershipInRoomByLocalpartSQL = "" +
"SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"
const selectRoomIDsByLocalPartSQL = "" +
"SELECT room_id FROM account_memberships WHERE localpart = $1"
const deleteMembershipsByEventIDsSQL = "" + const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM account_memberships WHERE event_id = ANY($1)" "DELETE FROM account_memberships WHERE event_id = ANY($1)"
@ -59,6 +64,7 @@ type membershipStatements struct {
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipInRoomByLocalpartStmt *sql.Stmt selectMembershipInRoomByLocalpartStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt
selectRoomIDsByLocalPartStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -78,6 +84,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return return
} }
if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil {
return
}
return return
} }
@ -118,15 +127,34 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
memberships = []authtypes.Membership{} memberships = []authtypes.Membership{}
defer rows.Close() // nolint: errcheck defer common.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed")
for rows.Next() { for rows.Next() {
var m authtypes.Membership var m authtypes.Membership
m.Localpart = localpart m.Localpart = localpart
if err := rows.Scan(&m.RoomID, &m.EventID); err != nil { if err = rows.Scan(&m.RoomID, &m.EventID); err != nil {
return nil, err return
} }
memberships = append(memberships, m) memberships = append(memberships, m)
} }
return memberships, rows.Err()
return }
func (s *membershipStatements) selectRoomIDsByLocalPart(
ctx context.Context, localPart string,
) ([]string, error) {
stmt := s.selectRoomIDsByLocalPartStmt
rows, err := stmt.QueryContext(ctx, localPart)
if err != nil {
return nil, err
}
roomIDs := []string{}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
}
return roomIDs, rows.Err()
} }

View file

@ -0,0 +1,107 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account
display_name TEXT,
-- The URL of the avatar for this account
avatar_url TEXT
);
`
const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
type profilesStatements struct {
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
setDisplayNameStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(profilesSchema)
if err != nil {
return
}
if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil {
return
}
if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil {
return
}
if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil {
return
}
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
return
}
return
}
func (s *profilesStatements) insertProfile(
ctx context.Context, txn *sql.Tx, localpart string,
) (err error) {
_, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return
}
func (s *profilesStatements) selectProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
}
return &profile, nil
}
func (s *profilesStatements) setAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return
}
func (s *profilesStatements) setDisplayName(
ctx context.Context, localpart string, displayName string,
) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
}

View file

@ -0,0 +1,433 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"errors"
"strconv"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// Database represents an account database
type Database struct {
db *sql.DB
common.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
memberships membershipStatements
accountDatas accountDataStatements
threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sqlutil.Open("postgres", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
return nil, err
}
m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
return nil, err
}
f := filterStatements{}
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the membership event.
// If a membership already exists between the user and the room, or if the
// insert fails, returns the SQL error
func (d *Database) saveMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
}
// removeMembershipsByEventIDs removes the memberships corresponding to the
// `join` membership events IDs in the eventIDs slice.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
}
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(ctx, txn, event); err != nil {
return err
}
}
return nil
})
}
// GetMembershipInRoomByLocalpart returns the membership for an user
// matching the given localpart if he is a member of the room matching roomID,
// if not sql.ErrNoRows is returned.
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
}
// GetRoomIDsByLocalPart returns an array containing the room ids of all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetRoomIDsByLocalPart(
ctx context.Context, localpart string,
) ([]string, error) {
return d.memberships.selectRoomIDsByLocalPart(ctx, localpart)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
}
// newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error
func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err
}
}
}
return nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*authtypes.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package postgres
import ( import (
"context" "context"

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package sqlite3
import ( import (
"context" "context"
@ -39,7 +39,7 @@ CREATE TABLE IF NOT EXISTS account_data (
const insertAccountDataSQL = ` const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
` `
const selectAccountDataSQL = "" + const selectAccountDataSQL = "" +
@ -72,10 +72,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
) (err error) { ) (err error) {
stmt := s.insertAccountDataStmt _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return return
} }

View file

@ -0,0 +1,155 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account.
password_hash TEXT,
-- Identifies which application service this account belongs to, if any.
appservice_id TEXT
-- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
);
`
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1"
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
// TODO: Update password
type accountsStatements struct {
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(accountsSchema)
if err != nil {
return
}
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
return
}
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
return
}
s.serverName = server
return
}
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
var err error
if appserviceID == "" {
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else {
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
}
if err != nil {
return nil, err
}
return &authtypes.Account{
Localpart: localpart,
UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName,
AppServiceID: appserviceID,
}, nil
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return
}
func (s *accountsStatements) selectAccountByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Account, error) {
var appserviceIDPtr sql.NullString
var acc authtypes.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
}
return nil, err
}
if appserviceIDPtr.Valid {
acc.AppServiceID = appserviceIDPtr.String
}
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName
return &acc, nil
}
func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
return
}

View file

@ -0,0 +1,135 @@
// Copyright 2017 Jan Christian Grünhage
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
const filterSchema = `
-- Stores data about filters
CREATE TABLE IF NOT EXISTS account_filter (
-- The filter
filter TEXT NOT NULL,
-- The ID
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The localpart of the Matrix user ID associated to this filter
localpart TEXT NOT NULL,
UNIQUE (id, localpart)
);
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart);
`
const selectFilterSQL = "" +
"SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2"
const selectFilterIDByContentSQL = "" +
"SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2"
const insertFilterSQL = "" +
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
type filterStatements struct {
selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
}
func (s *filterStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(filterSchema)
if err != nil {
return
}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return
}
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return
}
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
return
}
return
}
func (s *filterStatements) selectFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
// Retrieve filter from database (stored as canonical JSON)
var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil {
return nil, err
}
// Unmarshal JSON into Filter struct
var filter gomatrixserverlib.Filter
if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err
}
return &filter, nil
}
func (s *filterStatements) insertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) {
var existingFilterID string
// Serialise json
filterJSON, err := json.Marshal(filter)
if err != nil {
return "", err
}
// Remove whitespaces and sort JSON data
// needed to prevent from inserting the same filter multiple times
filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
if err != nil {
return "", err
}
// Check if filter already exists in the database using its localpart and content
//
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows {
return "", err
}
// If it does, return the existing ID
if existingFilterID != "" {
return existingFilterID, err
}
// Otherwise insert the filter and return the new ID
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
if err != nil {
return "", err
}
rowid, err := res.LastInsertId()
if err != nil {
return "", err
}
filterID = fmt.Sprintf("%d", rowid)
return
}

View file

@ -0,0 +1,158 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"strings"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
)
const membershipSchema = `
-- Stores data about users memberships to rooms.
CREATE TABLE IF NOT EXISTS account_memberships (
-- The Matrix user ID localpart for the member
localpart TEXT NOT NULL,
-- The room this user is a member of
room_id TEXT NOT NULL,
-- The ID of the join membership event
event_id TEXT NOT NULL,
-- A user can only be member of a room once
PRIMARY KEY (localpart, room_id),
UNIQUE (event_id)
);
`
const insertMembershipSQL = `
INSERT INTO account_memberships(localpart, room_id, event_id) VALUES ($1, $2, $3)
ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id
`
const selectMembershipsByLocalpartSQL = "" +
"SELECT room_id, event_id FROM account_memberships WHERE localpart = $1"
const selectMembershipInRoomByLocalpartSQL = "" +
"SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"
const selectRoomIDsByLocalPartSQL = "" +
"SELECT room_id FROM account_memberships WHERE localpart = $1"
const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM account_memberships WHERE event_id IN ($1)"
type membershipStatements struct {
insertMembershipStmt *sql.Stmt
selectMembershipInRoomByLocalpartStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt
selectRoomIDsByLocalPartStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(membershipSchema)
if err != nil {
return
}
if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
return
}
if s.selectMembershipInRoomByLocalpartStmt, err = db.Prepare(selectMembershipInRoomByLocalpartSQL); err != nil {
return
}
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return
}
if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil {
return
}
return
}
func (s *membershipStatements) insertMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) (err error) {
stmt := txn.Stmt(s.insertMembershipStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, eventID)
return
}
func (s *membershipStatements) deleteMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) (err error) {
sqlStr := strings.Replace(deleteMembershipsByEventIDsSQL, "($1)", common.QueryVariadic(len(eventIDs)), 1)
iEventIDs := make([]interface{}, len(eventIDs))
for i, e := range eventIDs {
iEventIDs[i] = e
}
_, err = txn.ExecContext(ctx, sqlStr, iEventIDs...)
return
}
func (s *membershipStatements) selectMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
membership := authtypes.Membership{Localpart: localpart, RoomID: roomID}
stmt := s.selectMembershipInRoomByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart, roomID).Scan(&membership.EventID)
return membership, err
}
func (s *membershipStatements) selectMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
stmt := s.selectMembershipsByLocalpartStmt
rows, err := stmt.QueryContext(ctx, localpart)
if err != nil {
return
}
memberships = []authtypes.Membership{}
defer common.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed")
for rows.Next() {
var m authtypes.Membership
m.Localpart = localpart
if err := rows.Scan(&m.RoomID, &m.EventID); err != nil {
return nil, err
}
memberships = append(memberships, m)
}
return
}
func (s *membershipStatements) selectRoomIDsByLocalPart(
ctx context.Context, localPart string,
) ([]string, error) {
stmt := s.selectRoomIDsByLocalPartStmt
rows, err := stmt.QueryContext(ctx, localPart)
if err != nil {
return nil, err
}
roomIDs := []string{}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
}
return roomIDs, rows.Err()
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package sqlite3
import ( import (
"context" "context"
@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }

View file

@ -0,0 +1,442 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"errors"
"strconv"
"sync"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/mattn/go-sqlite3"
)
// Database represents an account database
type Database struct {
db *sql.DB
common.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
memberships membershipStatements
accountDatas accountDataStatements
threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName
createGuestAccountMu sync.Mutex
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sqlutil.Open(common.SQLiteDriverName(), dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
return nil, err
}
m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
return nil, err
}
f := filterStatements{}
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName, sync.Mutex{}}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
// We need to lock so we sequentially create numeric localparts. If we don't, two calls to
// this function will cause the same number to be selected and one will fail with 'database is locked'
// when the first txn upgrades to a write txn.
// We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
d.createGuestAccountMu.Lock()
defer d.createGuestAccountMu.Unlock()
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the membership event.
// If a membership already exists between the user and the room, or if the
// insert fails, returns the SQL error
func (d *Database) saveMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
}
// removeMembershipsByEventIDs removes the memberships corresponding to the
// `join` membership events IDs in the eventIDs slice.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
}
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(ctx, txn, event); err != nil {
return err
}
}
return nil
})
}
// GetMembershipInRoomByLocalpart returns the membership for an user
// matching the given localpart if he is a member of the room matching roomID,
// if not sql.ErrNoRows is returned.
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
}
// GetRoomIDsByLocalPart returns an array containing the room ids of all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetRoomIDsByLocalPart(
ctx context.Context, localpart string,
) ([]string, error) {
return d.memberships.selectRoomIDsByLocalPart(ctx, localpart)
}
// newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error
func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err
}
}
}
return nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*authtypes.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}

View file

@ -0,0 +1,129 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
const threepidSchema = `
-- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid (
-- The third party identifier
threepid TEXT NOT NULL,
-- The 3PID medium
medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
PRIMARY KEY(threepid, medium)
);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
const insertThreePIDSQL = "" +
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
type threepidStatements struct {
selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
deleteThreePIDStmt *sql.Stmt
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(threepidSchema)
if err != nil {
return
}
if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil {
return
}
if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil {
return
}
if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil {
return
}
if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil {
return
}
return
}
func (s *threepidStatements) selectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *threepidStatements) selectThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
if err != nil {
return
}
defer common.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed")
threepids = []authtypes.ThreePID{}
for rows.Next() {
var threepid string
var medium string
if err = rows.Scan(&threepid, &medium); err != nil {
return
}
threepids = append(threepids, authtypes.ThreePID{
Address: threepid,
Medium: medium,
})
}
return threepids, rows.Err()
}
func (s *threepidStatements) insertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) {
stmt := common.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
return
}
func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
return
}

View file

@ -1,4 +1,4 @@
// Copyright 2017 Vector Creations Ltd // Copyright 2020 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,370 +12,29 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// +build !wasm
package accounts package accounts
import ( import (
"context" "net/url"
"database/sql"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/postgres"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/lib/pq"
) )
// Database represents an account database func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
type Database struct { uri, err := url.Parse(dataSourceName)
db *sql.DB
common.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
memberships membershipStatements
accountDatas accountDataStatements
threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
return nil, err
}
m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
return nil, err
}
f := filterStatements{}
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil { if err != nil {
return nil, err return postgres.NewDatabase(dataSourceName, serverName)
} }
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { switch uri.Scheme {
return nil, err case "postgres":
return postgres.NewDatabase(dataSourceName, serverName)
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return postgres.NewDatabase(dataSourceName, serverName)
} }
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the membership event.
// If a membership already exists between the user and the room, or if the
// insert fails, returns the SQL error
func (d *Database) saveMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
}
// removeMembershipsByEventIDs removes the memberships corresponding to the
// `join` membership events IDs in the eventIDs slice.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
}
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(ctx, txn, event); err != nil {
return err
}
}
return nil
})
}
// GetMembershipInRoomByLocalpart returns the membership for an user
// matching the given localpart if he is a member of the room matching roomID,
// if not sql.ErrNoRows is returned.
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
}
// newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error
func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err
}
}
}
return nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*authtypes.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
} }

View file

@ -0,0 +1,38 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
import (
"fmt"
"net/url"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return nil, fmt.Errorf("Cannot use postgres implementation")
}
switch uri.Scheme {
case "postgres":
return nil, fmt.Errorf("Cannot use postgres implementation")
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return nil, fmt.Errorf("Cannot use postgres implementation")
}
}

View file

@ -0,0 +1,32 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"context"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error)
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
RemoveAllDevices(ctx context.Context, localpart string) error
}

View file

@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package devices package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"time" "time"
"github.com/matrix-org/dendrite/common" "github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -80,6 +80,9 @@ const deleteDeviceSQL = "" +
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1" "DELETE FROM device_devices WHERE localpart = $1"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
type devicesStatements struct { type devicesStatements struct {
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
@ -88,6 +91,7 @@ type devicesStatements struct {
updateDeviceNameStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt
deleteDevicesStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -117,6 +121,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return return
} }
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -142,6 +149,7 @@ func (s *devicesStatements) insertDevice(
}, nil }, nil
} }
// deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id, localpart string,
) error { ) error {
@ -150,6 +158,18 @@ func (s *devicesStatements) deleteDevice(
return err return err
} }
// deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed.
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
stmt := common.TxStmt(txn, s.deleteDevicesStmt)
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
return err
}
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) error { ) error {
@ -206,10 +226,11 @@ func (s *devicesStatements) selectDevicesByLocalpart(
if err != nil { if err != nil {
return devices, err return devices, err
} }
defer common.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
for rows.Next() { for rows.Next() {
var dev authtypes.Device var dev authtypes.Device
err = rows.Scan(&dev.ID) err = rows.Scan(&dev.ID, &dev.DisplayName)
if err != nil { if err != nil {
return devices, err return devices, err
} }
@ -217,5 +238,5 @@ func (s *devicesStatements) selectDevicesByLocalpart(
devices = append(devices, dev) devices = append(devices, dev)
} }
return devices, nil return devices, rows.Err()
} }

View file

@ -0,0 +1,183 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
// The length of generated device IDs
var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sqlutil.Open("postgres", dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}

View file

@ -0,0 +1,243 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"strings"
"time"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
)
const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
UNIQUE (localpart, device_id)
);
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" +
" VALUES ($1, $2, $3, $4, $5, $6)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
type devicesStatements struct {
db *sql.DB
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(devicesSchema)
if err != nil {
return
}
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
s.serverName = server
return
}
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string,
) (*authtypes.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
countStmt := common.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := common.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
return nil, err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return nil, err
}
return &authtypes.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
SessionID: sessionID,
}, nil
}
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := common.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
return err
}
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1)
prep, err := s.db.Prepare(orig)
if err != nil {
return err
}
stmt := common.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
}
params = append(params, params...)
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
return err
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
}
func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*authtypes.Device, error) {
var dev authtypes.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.AccessToken = accessToken
}
return &dev, err
}
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
var dev authtypes.Device
var created sql.NullInt64
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
}
return &dev, err
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
devices := []authtypes.Device{}
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
if err != nil {
return devices, err
}
for rows.Next() {
var dev authtypes.Device
err = rows.Scan(&dev.ID, &dev.DisplayName)
if err != nil {
return devices, err
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, nil
}

View file

@ -0,0 +1,185 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
)
// The length of generated device IDs
var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sqlutil.Open(common.SQLiteDriverName(), dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}

View file

@ -1,4 +1,4 @@
// Copyright 2017 Vector Creations Ltd // Copyright 2020 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,156 +12,29 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// +build !wasm
package devices package devices
import ( import (
"context" "net/url"
"crypto/rand"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/postgres"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// The length of generated device IDs func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
var deviceIDByteLength = 6 uri, err := url.Parse(dataSourceName)
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil { if err != nil {
return "", err return postgres.NewDatabase(dataSourceName, serverName)
}
switch uri.Scheme {
case "postgres":
return postgres.NewDatabase(dataSourceName, serverName)
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return postgres.NewDatabase(dataSourceName, serverName)
} }
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
} }

View file

@ -0,0 +1,38 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"fmt"
"net/url"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return nil, fmt.Errorf("Cannot use postgres implementation")
}
switch uri.Scheme {
case "postgres":
return nil, fmt.Errorf("Cannot use postgres implementation")
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return nil, fmt.Errorf("Cannot use postgres implementation")
}
}

View file

@ -23,9 +23,9 @@ import (
"github.com/matrix-org/dendrite/clientapi/routing" "github.com/matrix-org/dendrite/clientapi/routing"
"github.com/matrix-org/dendrite/common/basecomponent" "github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/transactions" "github.com/matrix-org/dendrite/common/transactions"
eduServerAPI "github.com/matrix-org/dendrite/eduserver/api"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
typingServerAPI "github.com/matrix-org/dendrite/typingserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -34,20 +34,20 @@ import (
// component. // component.
func SetupClientAPIComponent( func SetupClientAPIComponent(
base *basecomponent.BaseDendrite, base *basecomponent.BaseDendrite,
deviceDB *devices.Database, deviceDB devices.Database,
accountsDB *accounts.Database, accountsDB accounts.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
keyRing *gomatrixserverlib.KeyRing, keyRing *gomatrixserverlib.KeyRing,
aliasAPI roomserverAPI.RoomserverAliasAPI, aliasAPI roomserverAPI.RoomserverAliasAPI,
inputAPI roomserverAPI.RoomserverInputAPI, inputAPI roomserverAPI.RoomserverInputAPI,
queryAPI roomserverAPI.RoomserverQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI,
typingInputAPI typingServerAPI.TypingServerInputAPI, eduInputAPI eduServerAPI.EDUServerInputAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
transactionsCache *transactions.Cache, transactionsCache *transactions.Cache,
fedSenderAPI federationSenderAPI.FederationSenderQueryAPI, fedSenderAPI federationSenderAPI.FederationSenderQueryAPI,
) { ) {
roomserverProducer := producers.NewRoomserverProducer(inputAPI) roomserverProducer := producers.NewRoomserverProducer(inputAPI, queryAPI)
typingProducer := producers.NewTypingServerProducer(typingInputAPI) eduProducer := producers.NewEDUServerProducer(eduInputAPI)
userUpdateProducer := &producers.UserUpdateProducer{ userUpdateProducer := &producers.UserUpdateProducer{
Producer: base.KafkaProducer, Producer: base.KafkaProducer,
@ -67,8 +67,8 @@ func SetupClientAPIComponent(
} }
routing.Setup( routing.Setup(
base.APIMux, *base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI, base.APIMux, base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI,
accountsDB, deviceDB, federation, *keyRing, userUpdateProducer, accountsDB, deviceDB, federation, *keyRing, userUpdateProducer,
syncProducer, typingProducer, transactionsCache, fedSenderAPI, syncProducer, eduProducer, transactionsCache, fedSenderAPI,
) )
} }

View file

@ -31,7 +31,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer roomServerConsumer *common.ContinualConsumer
db *accounts.Database db accounts.Database
query api.RoomserverQueryAPI query api.RoomserverQueryAPI
serverName string serverName string
} }
@ -40,7 +40,7 @@ type OutputRoomEventConsumer struct {
func NewOutputRoomEventConsumer( func NewOutputRoomEventConsumer(
cfg *config.Dendrite, cfg *config.Dendrite,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
store *accounts.Database, store accounts.Database,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
@ -91,7 +91,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
"type": ev.Type(), "type": ev.Type(),
}).Info("received event from roomserver") }).Info("received event from roomserver")
events, err := s.lookupStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev) events, err := s.lookupStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev.Event)
if err != nil { if err != nil {
return err return err
} }
@ -138,7 +138,9 @@ func (s *OutputRoomEventConsumer) lookupStateEvents(
return nil, err return nil, err
} }
result = append(result, eventResp.Events...) for _, headeredEvent := range eventResp.Events {
result = append(result, headeredEvent.Event)
}
return result, nil return result, nil
} }

View file

@ -36,11 +36,3 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon
} }
return nil return nil
} }
// LogThenError logs the given error then returns a matrix-compliant 500 internal server error response.
// This should be used to log fatal errors which require investigation. It should not be used
// to log client validation errors, etc.
func LogThenError(req *http.Request, err error) util.JSONResponse {
util.GetLogger(req.Context()).WithError(err).Error("request failed")
return jsonerror.InternalServerError()
}

View file

@ -124,6 +124,12 @@ func GuestAccessForbidden(msg string) *MatrixError {
return &MatrixError{"M_GUEST_ACCESS_FORBIDDEN", msg} return &MatrixError{"M_GUEST_ACCESS_FORBIDDEN", msg}
} }
// UnsupportedRoomVersion is an error which is returned when the client
// requests a room with a version that is unsupported.
func UnsupportedRoomVersion(msg string) *MatrixError {
return &MatrixError{"M_UNSUPPORTED_ROOM_VERSION", msg}
}
// LimitExceededError is a rate-limiting error. // LimitExceededError is a rate-limiting error.
type LimitExceededError struct { type LimitExceededError struct {
MatrixError MatrixError

View file

@ -16,32 +16,32 @@ import (
"context" "context"
"time" "time"
"github.com/matrix-org/dendrite/typingserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// TypingServerProducer produces events for the typing server to consume // EDUServerProducer produces events for the EDU server to consume
type TypingServerProducer struct { type EDUServerProducer struct {
InputAPI api.TypingServerInputAPI InputAPI api.EDUServerInputAPI
} }
// NewTypingServerProducer creates a new TypingServerProducer // NewEDUServerProducer creates a new EDUServerProducer
func NewTypingServerProducer(inputAPI api.TypingServerInputAPI) *TypingServerProducer { func NewEDUServerProducer(inputAPI api.EDUServerInputAPI) *EDUServerProducer {
return &TypingServerProducer{ return &EDUServerProducer{
InputAPI: inputAPI, InputAPI: inputAPI,
} }
} }
// Send typing event to typing server // SendTyping sends a typing event to EDU server
func (p *TypingServerProducer) Send( func (p *EDUServerProducer) SendTyping(
ctx context.Context, userID, roomID string, ctx context.Context, userID, roomID string,
typing bool, timeout int64, typing bool, timeoutMS int64,
) error { ) error {
requestData := api.InputTypingEvent{ requestData := api.InputTypingEvent{
UserID: userID, UserID: userID,
RoomID: roomID, RoomID: roomID,
Typing: typing, Typing: typing,
Timeout: timeout, TimeoutMS: timeoutMS,
OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()),
} }

View file

@ -24,18 +24,20 @@ import (
// RoomserverProducer produces events for the roomserver to consume. // RoomserverProducer produces events for the roomserver to consume.
type RoomserverProducer struct { type RoomserverProducer struct {
InputAPI api.RoomserverInputAPI InputAPI api.RoomserverInputAPI
QueryAPI api.RoomserverQueryAPI
} }
// NewRoomserverProducer creates a new RoomserverProducer // NewRoomserverProducer creates a new RoomserverProducer
func NewRoomserverProducer(inputAPI api.RoomserverInputAPI) *RoomserverProducer { func NewRoomserverProducer(inputAPI api.RoomserverInputAPI, queryAPI api.RoomserverQueryAPI) *RoomserverProducer {
return &RoomserverProducer{ return &RoomserverProducer{
InputAPI: inputAPI, InputAPI: inputAPI,
QueryAPI: queryAPI,
} }
} }
// SendEvents writes the given events to the roomserver input log. The events are written with KindNew. // SendEvents writes the given events to the roomserver input log. The events are written with KindNew.
func (c *RoomserverProducer) SendEvents( func (c *RoomserverProducer) SendEvents(
ctx context.Context, events []gomatrixserverlib.Event, sendAsServer gomatrixserverlib.ServerName, ctx context.Context, events []gomatrixserverlib.HeaderedEvent, sendAsServer gomatrixserverlib.ServerName,
txnID *api.TransactionID, txnID *api.TransactionID,
) (string, error) { ) (string, error) {
ires := make([]api.InputRoomEvent, len(events)) ires := make([]api.InputRoomEvent, len(events))
@ -54,20 +56,20 @@ func (c *RoomserverProducer) SendEvents(
// SendEventWithState writes an event with KindNew to the roomserver input log // SendEventWithState writes an event with KindNew to the roomserver input log
// with the state at the event as KindOutlier before it. // with the state at the event as KindOutlier before it.
func (c *RoomserverProducer) SendEventWithState( func (c *RoomserverProducer) SendEventWithState(
ctx context.Context, state gomatrixserverlib.RespState, event gomatrixserverlib.Event, ctx context.Context, state gomatrixserverlib.RespState, event gomatrixserverlib.HeaderedEvent,
) error { ) error {
outliers, err := state.Events() outliers, err := state.Events()
if err != nil { if err != nil {
return err return err
} }
ires := make([]api.InputRoomEvent, len(outliers)+1) var ires []api.InputRoomEvent
for i, outlier := range outliers { for _, outlier := range outliers {
ires[i] = api.InputRoomEvent{ ires = append(ires, api.InputRoomEvent{
Kind: api.KindOutlier, Kind: api.KindOutlier,
Event: outlier, Event: outlier.Headered(event.RoomVersion),
AuthEventIDs: outlier.AuthEventIDs(), AuthEventIDs: outlier.AuthEventIDs(),
} })
} }
stateEventIDs := make([]string, len(state.StateEvents)) stateEventIDs := make([]string, len(state.StateEvents))
@ -75,13 +77,13 @@ func (c *RoomserverProducer) SendEventWithState(
stateEventIDs[i] = state.StateEvents[i].EventID() stateEventIDs[i] = state.StateEvents[i].EventID()
} }
ires[len(outliers)] = api.InputRoomEvent{ ires = append(ires, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
AuthEventIDs: event.AuthEventIDs(), AuthEventIDs: event.AuthEventIDs(),
HasState: true, HasState: true,
StateEventIDs: stateEventIDs, StateEventIDs: stateEventIDs,
} })
_, err = c.SendInputRoomEvents(ctx, ires) _, err = c.SendInputRoomEvents(ctx, ires)
return err return err
@ -102,10 +104,15 @@ func (c *RoomserverProducer) SendInputRoomEvents(
// This should only be needed for invite events that occur outside of a known room. // This should only be needed for invite events that occur outside of a known room.
// If we are in the room then the event should be sent using the SendEvents method. // If we are in the room then the event should be sent using the SendEvents method.
func (c *RoomserverProducer) SendInvite( func (c *RoomserverProducer) SendInvite(
ctx context.Context, inviteEvent gomatrixserverlib.Event, ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
inviteRoomState []gomatrixserverlib.InviteV2StrippedState,
) error { ) error {
request := api.InputRoomEventsRequest{ request := api.InputRoomEventsRequest{
InputInviteEvents: []api.InputInviteEvent{{Event: inviteEvent}}, InputInviteEvents: []api.InputInviteEvent{{
Event: inviteEvent,
InviteRoomState: inviteRoomState,
RoomVersion: inviteEvent.RoomVersion,
}},
} }
var response api.InputRoomEventsResponse var response api.InputRoomEventsResponse
return c.InputAPI.InputRoomEvents(ctx, &request, &response) return c.InputAPI.InputRoomEvents(ctx, &request, &response)

View file

@ -15,12 +15,12 @@
package routing package routing
import ( import (
"encoding/json"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -28,9 +28,42 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
func GetAccountData(
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, roomID string, dataType string,
) util.JSONResponse {
if userID != device.UserID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("userID does not match the current user"),
}
}
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
if data, err := accountDB.GetAccountDataByType(
req.Context(), localpart, roomID, dataType,
); err == nil {
return util.JSONResponse{
Code: http.StatusOK,
JSON: data,
}
}
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.Forbidden("data not found"),
}
}
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
func SaveAccountData( func SaveAccountData(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -42,24 +75,42 @@ func SaveAccountData(
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
defer req.Body.Close() // nolint: errcheck defer req.Body.Close() // nolint: errcheck
if req.Body == http.NoBody {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.NotJSON("Content not JSON"),
}
}
body, err := ioutil.ReadAll(req.Body) body, err := ioutil.ReadAll(req.Body)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("ioutil.ReadAll failed")
return jsonerror.InternalServerError()
}
if !json.Valid(body) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Bad JSON content"),
}
} }
if err := accountDB.SaveAccountData( if err := accountDB.SaveAccountData(
req.Context(), localpart, roomID, dataType, string(body), req.Context(), localpart, roomID, dataType, string(body),
); err != nil { ); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed")
return jsonerror.InternalServerError()
} }
if err := syncProducer.SendData(userID, roomID, dataType); err != nil { if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -19,7 +19,6 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -102,7 +101,7 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s
// AuthFallback implements GET and POST /auth/{authType}/fallback/web?session={sessionID} // AuthFallback implements GET and POST /auth/{authType}/fallback/web?session={sessionID}
func AuthFallback( func AuthFallback(
w http.ResponseWriter, req *http.Request, authType string, w http.ResponseWriter, req *http.Request, authType string,
cfg config.Dendrite, cfg *config.Dendrite,
) *util.JSONResponse { ) *util.JSONResponse {
sessionID := req.URL.Query().Get("session") sessionID := req.URL.Query().Get("session")
@ -130,7 +129,7 @@ func AuthFallback(
if req.Method == http.MethodGet { if req.Method == http.MethodGet {
// Handle Recaptcha // Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha { if authType == authtypes.LoginTypeRecaptcha {
if err := checkRecaptchaEnabled(&cfg, w, req); err != nil { if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
return err return err
} }
@ -144,19 +143,20 @@ func AuthFallback(
} else if req.Method == http.MethodPost { } else if req.Method == http.MethodPost {
// Handle Recaptcha // Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha { if authType == authtypes.LoginTypeRecaptcha {
if err := checkRecaptchaEnabled(&cfg, w, req); err != nil { if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
return err return err
} }
clientIP := req.RemoteAddr clientIP := req.RemoteAddr
err := req.ParseForm() err := req.ParseForm()
if err != nil { if err != nil {
res := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
res := jsonerror.InternalServerError()
return &res return &res
} }
response := req.Form.Get("g-recaptcha-response") response := req.Form.Get("g-recaptcha-response")
if err := validateRecaptcha(&cfg, response, clientIP); err != nil { if err := validateRecaptcha(cfg, response, clientIP); err != nil {
util.GetLogger(req.Context()).Error(err) util.GetLogger(req.Context()).Error(err)
return err return err
} }
@ -203,7 +203,8 @@ func writeHTTPMessage(
w.WriteHeader(header) w.WriteHeader(header)
_, err := w.Write([]byte(message)) _, err := w.Write([]byte(message))
if err != nil { if err != nil {
res := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
res := jsonerror.InternalServerError()
return &res return &res
} }
return nil return nil

View file

@ -0,0 +1,52 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/util"
)
// SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite)
// by building a m.room.member event then sending it to the room server
func GetCapabilities(
req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI,
) util.JSONResponse {
roomVersionsQueryReq := roomserverAPI.QueryRoomVersionCapabilitiesRequest{}
roomVersionsQueryRes := roomserverAPI.QueryRoomVersionCapabilitiesResponse{}
if err := queryAPI.QueryRoomVersionCapabilities(
req.Context(),
&roomVersionsQueryReq,
&roomVersionsQueryRes,
); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryRoomVersionCapabilities failed")
return jsonerror.InternalServerError()
}
response := map[string]interface{}{
"capabilities": map[string]interface{}{
"m.room_versions": roomVersionsQueryRes,
},
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: response,
}
}

View file

@ -23,6 +23,7 @@ import (
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
roomserverVersion "github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
@ -38,15 +39,16 @@ import (
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-createroom // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-createroom
type createRoomRequest struct { type createRoomRequest struct {
Invite []string `json:"invite"` Invite []string `json:"invite"`
Name string `json:"name"` Name string `json:"name"`
Visibility string `json:"visibility"` Visibility string `json:"visibility"`
Topic string `json:"topic"` Topic string `json:"topic"`
Preset string `json:"preset"` Preset string `json:"preset"`
CreationContent map[string]interface{} `json:"creation_content"` CreationContent map[string]interface{} `json:"creation_content"`
InitialState []fledglingEvent `json:"initial_state"` InitialState []fledglingEvent `json:"initial_state"`
RoomAliasName string `json:"room_alias_name"` RoomAliasName string `json:"room_alias_name"`
GuestCanJoin bool `json:"guest_can_join"` GuestCanJoin bool `json:"guest_can_join"`
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
} }
const ( const (
@ -134,8 +136,8 @@ type fledglingEvent struct {
// CreateRoom implements /createRoom // CreateRoom implements /createRoom
func CreateRoom( func CreateRoom(
req *http.Request, device *authtypes.Device, req *http.Request, device *authtypes.Device,
cfg config.Dendrite, producer *producers.RoomserverProducer, cfg *config.Dendrite, producer *producers.RoomserverProducer,
accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
@ -148,8 +150,8 @@ func CreateRoom(
// nolint: gocyclo // nolint: gocyclo
func createRoom( func createRoom(
req *http.Request, device *authtypes.Device, req *http.Request, device *authtypes.Device,
cfg config.Dendrite, roomID string, producer *producers.RoomserverProducer, cfg *config.Dendrite, roomID string, producer *producers.RoomserverProducer,
accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
@ -180,20 +182,34 @@ func createRoom(
} }
r.CreationContent["creator"] = userID r.CreationContent["creator"] = userID
r.CreationContent["room_version"] = "1" // TODO: We set this to 1 before we support Room versioning roomVersion := roomserverVersion.DefaultRoomVersion()
if r.RoomVersion != "" {
candidateVersion := gomatrixserverlib.RoomVersion(r.RoomVersion)
_, roomVersionError := roomserverVersion.SupportedRoomVersion(candidateVersion)
if roomVersionError != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(roomVersionError.Error()),
}
}
roomVersion = candidateVersion
}
r.CreationContent["room_version"] = roomVersion
// TODO: visibility/presets/raw initial state // TODO: visibility/presets/raw initial state
// TODO: Create room alias association // TODO: Create room alias association
// Make sure this doesn't fall into an application service's namespace though! // Make sure this doesn't fall into an application service's namespace though!
logger.WithFields(log.Fields{ logger.WithFields(log.Fields{
"userID": userID, "userID": userID,
"roomID": roomID, "roomID": roomID,
"roomVersion": r.CreationContent["room_version"],
}).Info("Creating new room") }).Info("Creating new room")
profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB) profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed")
return jsonerror.InternalServerError()
} }
membershipContent := gomatrixserverlib.MemberContent{ membershipContent := gomatrixserverlib.MemberContent{
@ -221,7 +237,7 @@ func createRoom(
historyVisibility = historyVisibilityShared historyVisibility = historyVisibilityShared
} }
var builtEvents []gomatrixserverlib.Event var builtEvents []gomatrixserverlib.HeaderedEvent
// send events into the room in order of: // send events into the room in order of:
// 1- m.room.create // 1- m.room.create
@ -276,33 +292,38 @@ func createRoom(
} }
err = builder.SetContent(e.Content) err = builder.SetContent(e.Content)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed")
return jsonerror.InternalServerError()
} }
if i > 0 { if i > 0 {
builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()}
} }
var ev *gomatrixserverlib.Event var ev *gomatrixserverlib.Event
ev, err = buildEvent(&builder, &authEvents, cfg, evTime) ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("buildEvent failed")
return jsonerror.InternalServerError()
} }
if err = gomatrixserverlib.Allowed(*ev, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(*ev, &authEvents); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.Allowed failed")
return jsonerror.InternalServerError()
} }
// Add the event to the list of auth events // Add the event to the list of auth events
builtEvents = append(builtEvents, *ev) builtEvents = append(builtEvents, (*ev).Headered(roomVersion))
err = authEvents.AddEvent(ev) err = authEvents.AddEvent(ev)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed")
return jsonerror.InternalServerError()
} }
} }
// send events to the room server // send events to the room server
_, err = producer.SendEvents(req.Context(), builtEvents, cfg.Matrix.ServerName, nil) _, err = producer.SendEvents(req.Context(), builtEvents, cfg.Matrix.ServerName, nil)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed")
return jsonerror.InternalServerError()
} }
// TODO(#269): Reserve room alias while we create the room. This stops us // TODO(#269): Reserve room alias while we create the room. This stops us
@ -321,7 +342,8 @@ func createRoom(
var aliasResp roomserverAPI.SetRoomAliasResponse var aliasResp roomserverAPI.SetRoomAliasResponse
err = aliasAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp) err = aliasAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed")
return jsonerror.InternalServerError()
} }
if aliasResp.AliasExists { if aliasResp.AliasExists {
@ -344,8 +366,9 @@ func createRoom(
func buildEvent( func buildEvent(
builder *gomatrixserverlib.EventBuilder, builder *gomatrixserverlib.EventBuilder,
provider gomatrixserverlib.AuthEventProvider, provider gomatrixserverlib.AuthEventProvider,
cfg config.Dendrite, cfg *config.Dendrite,
evTime time.Time, evTime time.Time,
roomVersion gomatrixserverlib.RoomVersion,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil { if err != nil {
@ -356,10 +379,12 @@ func buildEvent(
return nil, err return nil, err
} }
builder.AuthEvents = refs builder.AuthEvents = refs
eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), cfg.Matrix.ServerName) event, err := builder.Build(
event, err := builder.Build(eventID, evTime, cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey) evTime, cfg.Matrix.ServerName, cfg.Matrix.KeyID,
cfg.Matrix.PrivateKey, roomVersion,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot build event %s : Builder failed to build. %s", builder.Type, err) return nil, fmt.Errorf("cannot build event %s : Builder failed to build. %w", builder.Type, err)
} }
return &event, nil return &event, nil
} }

View file

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -40,14 +39,19 @@ type deviceUpdateJSON struct {
DisplayName *string `json:"display_name"` DisplayName *string `json:"display_name"`
} }
type devicesDeleteJSON struct {
Devices []string `json:"devices"`
}
// GetDeviceByID handles /devices/{deviceID} // GetDeviceByID handles /devices/{deviceID}
func GetDeviceByID( func GetDeviceByID(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string, deviceID string,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
ctx := req.Context() ctx := req.Context()
@ -58,7 +62,8 @@ func GetDeviceByID(
JSON: jsonerror.NotFound("Unknown device"), JSON: jsonerror.NotFound("Unknown device"),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -72,18 +77,20 @@ func GetDeviceByID(
// GetDevicesByLocalpart handles /devices // GetDevicesByLocalpart handles /devices
func GetDevicesByLocalpart( func GetDevicesByLocalpart(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
ctx := req.Context() ctx := req.Context()
deviceList, err := deviceDB.GetDevicesByLocalpart(ctx, localpart) deviceList, err := deviceDB.GetDevicesByLocalpart(ctx, localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDevicesByLocalpart failed")
return jsonerror.InternalServerError()
} }
res := devicesJSON{} res := devicesJSON{}
@ -103,12 +110,13 @@ func GetDevicesByLocalpart(
// UpdateDeviceByID handles PUT on /devices/{deviceID} // UpdateDeviceByID handles PUT on /devices/{deviceID}
func UpdateDeviceByID( func UpdateDeviceByID(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string, deviceID string,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
ctx := req.Context() ctx := req.Context()
@ -119,7 +127,8 @@ func UpdateDeviceByID(
JSON: jsonerror.NotFound("Unknown device"), JSON: jsonerror.NotFound("Unknown device"),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed")
return jsonerror.InternalServerError()
} }
if dev.UserID != device.UserID { if dev.UserID != device.UserID {
@ -134,11 +143,69 @@ func UpdateDeviceByID(
payload := deviceUpdateJSON{} payload := deviceUpdateJSON{}
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("json.NewDecoder.Decode failed")
return jsonerror.InternalServerError()
} }
if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil { if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.UpdateDevice failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
// DeleteDeviceById handles DELETE requests to /devices/{deviceId}
func DeleteDeviceById(
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
ctx := req.Context()
defer req.Body.Close() // nolint: errcheck
if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevice failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
// DeleteDevices handles POST requests to /delete_devices
func DeleteDevices(
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
ctx := req.Context()
payload := devicesDeleteJSON{}
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.NewDecoder.Decode failed")
return jsonerror.InternalServerError()
}
defer req.Body.Close() // nolint: errcheck
if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevices failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -63,7 +63,8 @@ func DirectoryRoom(
queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias} queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias}
var queryRes roomserverAPI.GetRoomIDForAliasResponse var queryRes roomserverAPI.GetRoomIDForAliasResponse
if err = rsAPI.GetRoomIDForAlias(req.Context(), &queryReq, &queryRes); err != nil { if err = rsAPI.GetRoomIDForAlias(req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("rsAPI.GetRoomIDForAlias failed")
return jsonerror.InternalServerError()
} }
res.RoomID = queryRes.RoomID res.RoomID = queryRes.RoomID
@ -76,7 +77,8 @@ func DirectoryRoom(
if fedErr != nil { if fedErr != nil {
// TODO: Return 502 if the remote server errored. // TODO: Return 502 if the remote server errored.
// TODO: Return 504 if the remote server timed out. // TODO: Return 504 if the remote server timed out.
return httputil.LogThenError(req, fedErr) util.GetLogger(req.Context()).WithError(err).Error("federation.LookupRoomAlias failed")
return jsonerror.InternalServerError()
} }
res.RoomID = fedRes.RoomID res.RoomID = fedRes.RoomID
res.fillServers(fedRes.Servers) res.fillServers(fedRes.Servers)
@ -94,7 +96,8 @@ func DirectoryRoom(
joinedHostsReq := federationSenderAPI.QueryJoinedHostServerNamesInRoomRequest{RoomID: res.RoomID} joinedHostsReq := federationSenderAPI.QueryJoinedHostServerNamesInRoomRequest{RoomID: res.RoomID}
var joinedHostsRes federationSenderAPI.QueryJoinedHostServerNamesInRoomResponse var joinedHostsRes federationSenderAPI.QueryJoinedHostServerNamesInRoomResponse
if err = fedSenderAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &joinedHostsReq, &joinedHostsRes); err != nil { if err = fedSenderAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &joinedHostsReq, &joinedHostsRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("fedSenderAPI.QueryJoinedHostServerNamesInRoom failed")
return jsonerror.InternalServerError()
} }
res.fillServers(joinedHostsRes.ServerNames) res.fillServers(joinedHostsRes.ServerNames)
} }
@ -165,7 +168,8 @@ func SetLocalAlias(
} }
var queryRes roomserverAPI.SetRoomAliasResponse var queryRes roomserverAPI.SetRoomAliasResponse
if err := aliasAPI.SetRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { if err := aliasAPI.SetRoomAlias(req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed")
return jsonerror.InternalServerError()
} }
if queryRes.AliasExists { if queryRes.AliasExists {
@ -194,7 +198,8 @@ func RemoveLocalAlias(
} }
var creatorQueryRes roomserverAPI.GetCreatorIDForAliasResponse var creatorQueryRes roomserverAPI.GetCreatorIDForAliasResponse
if err := aliasAPI.GetCreatorIDForAlias(req.Context(), &creatorQueryReq, &creatorQueryRes); err != nil { if err := aliasAPI.GetCreatorIDForAlias(req.Context(), &creatorQueryReq, &creatorQueryRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.GetCreatorIDForAlias failed")
return jsonerror.InternalServerError()
} }
if creatorQueryRes.UserID == "" { if creatorQueryRes.UserID == "" {
@ -218,7 +223,8 @@ func RemoveLocalAlias(
} }
var queryRes roomserverAPI.RemoveRoomAliasResponse var queryRes roomserverAPI.RemoveRoomAliasResponse
if err := aliasAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { if err := aliasAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.RemoveRoomAlias failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -27,7 +27,7 @@ import (
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
func GetFilter( func GetFilter(
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string, req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, filterID string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{
@ -37,7 +37,8 @@ func GetFilter(
} }
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
filter, err := accountDB.GetFilter(req.Context(), localpart, filterID) filter, err := accountDB.GetFilter(req.Context(), localpart, filterID)
@ -63,7 +64,7 @@ type filterResponse struct {
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter //PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
func PutFilter( func PutFilter(
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{
@ -74,7 +75,8 @@ func PutFilter(
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
var filter gomatrixserverlib.Filter var filter gomatrixserverlib.Filter
@ -93,7 +95,8 @@ func PutFilter(
filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter) filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.PutFilter failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -18,7 +18,6 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -31,7 +30,7 @@ type getEventRequest struct {
device *authtypes.Device device *authtypes.Device
roomID string roomID string
eventID string eventID string
cfg config.Dendrite cfg *config.Dendrite
federation *gomatrixserverlib.FederationClient federation *gomatrixserverlib.FederationClient
keyRing gomatrixserverlib.KeyRing keyRing gomatrixserverlib.KeyRing
requestedEvent gomatrixserverlib.Event requestedEvent gomatrixserverlib.Event
@ -44,7 +43,7 @@ func GetEvent(
device *authtypes.Device, device *authtypes.Device,
roomID string, roomID string,
eventID string, eventID string,
cfg config.Dendrite, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.KeyRing, keyRing gomatrixserverlib.KeyRing,
@ -55,7 +54,8 @@ func GetEvent(
var eventsResp api.QueryEventsByIDResponse var eventsResp api.QueryEventsByIDResponse
err := queryAPI.QueryEventsByID(req.Context(), &eventsReq, &eventsResp) err := queryAPI.QueryEventsByID(req.Context(), &eventsReq, &eventsResp)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryEventsByID failed")
return jsonerror.InternalServerError()
} }
if len(eventsResp.Events) == 0 { if len(eventsResp.Events) == 0 {
@ -66,7 +66,7 @@ func GetEvent(
} }
} }
requestedEvent := eventsResp.Events[0] requestedEvent := eventsResp.Events[0].Event
r := getEventRequest{ r := getEventRequest{
req: req, req: req,
@ -89,7 +89,8 @@ func GetEvent(
} }
var stateResp api.QueryStateAfterEventsResponse var stateResp api.QueryStateAfterEventsResponse
if err := queryAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil { if err := queryAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryStateAfterEvents failed")
return jsonerror.InternalServerError()
} }
if !stateResp.RoomExists { if !stateResp.RoomExists {
@ -109,7 +110,8 @@ func GetEvent(
if stateEvent.StateKeyEquals(r.device.UserID) { if stateEvent.StateKeyEquals(r.device.UserID) {
membership, err := stateEvent.Membership() membership, err := stateEvent.Membership()
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("stateEvent.Membership failed")
return jsonerror.InternalServerError()
} }
if membership == gomatrixserverlib.Join { if membership == gomatrixserverlib.Join {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -27,10 +27,12 @@ import (
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// JoinRoomByIDOrAlias implements the "/join/{roomIDOrAlias}" API. // JoinRoomByIDOrAlias implements the "/join/{roomIDOrAlias}" API.
@ -39,13 +41,13 @@ func JoinRoomByIDOrAlias(
req *http.Request, req *http.Request,
device *authtypes.Device, device *authtypes.Device,
roomIDOrAlias string, roomIDOrAlias string,
cfg config.Dendrite, cfg *config.Dendrite,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
queryAPI roomserverAPI.RoomserverQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI,
aliasAPI roomserverAPI.RoomserverAliasAPI, aliasAPI roomserverAPI.RoomserverAliasAPI,
keyRing gomatrixserverlib.KeyRing, keyRing gomatrixserverlib.KeyRing,
accountDB *accounts.Database, accountDB accounts.Database,
) util.JSONResponse { ) util.JSONResponse {
var content map[string]interface{} // must be a JSON object var content map[string]interface{} // must be a JSON object
if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil { if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil {
@ -62,12 +64,14 @@ func JoinRoomByIDOrAlias(
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed")
return jsonerror.InternalServerError()
} }
content["membership"] = gomatrixserverlib.Join content["membership"] = gomatrixserverlib.Join
@ -98,7 +102,7 @@ type joinRoomReq struct {
evTime time.Time evTime time.Time
content map[string]interface{} content map[string]interface{}
userID string userID string
cfg config.Dendrite cfg *config.Dendrite
federation *gomatrixserverlib.FederationClient federation *gomatrixserverlib.FederationClient
producer *producers.RoomserverProducer producer *producers.RoomserverProducer
queryAPI roomserverAPI.RoomserverQueryAPI queryAPI roomserverAPI.RoomserverQueryAPI
@ -119,7 +123,8 @@ func (r joinRoomReq) joinRoomByID(roomID string) util.JSONResponse {
} }
var queryRes roomserverAPI.QueryInvitesForUserResponse var queryRes roomserverAPI.QueryInvitesForUserResponse
if err := r.queryAPI.QueryInvitesForUser(r.req.Context(), &queryReq, &queryRes); err != nil { if err := r.queryAPI.QueryInvitesForUser(r.req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("r.queryAPI.QueryInvitesForUser failed")
return jsonerror.InternalServerError()
} }
servers := []gomatrixserverlib.ServerName{} servers := []gomatrixserverlib.ServerName{}
@ -127,7 +132,8 @@ func (r joinRoomReq) joinRoomByID(roomID string) util.JSONResponse {
for _, userID := range queryRes.InviteSenderUserIDs { for _, userID := range queryRes.InviteSenderUserIDs {
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if !seenInInviterIDs[domain] { if !seenInInviterIDs[domain] {
servers = append(servers, domain) servers = append(servers, domain)
@ -141,7 +147,8 @@ func (r joinRoomReq) joinRoomByID(roomID string) util.JSONResponse {
// Note: It's no guarantee we'll succeed because a room isn't bound to the domain in its ID // Note: It's no guarantee we'll succeed because a room isn't bound to the domain in its ID
_, domain, err := gomatrixserverlib.SplitID('!', roomID) _, domain, err := gomatrixserverlib.SplitID('!', roomID)
if err != nil { if err != nil {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if domain != r.cfg.Matrix.ServerName && !seenInInviterIDs[domain] { if domain != r.cfg.Matrix.ServerName && !seenInInviterIDs[domain] {
servers = append(servers, domain) servers = append(servers, domain)
@ -164,7 +171,8 @@ func (r joinRoomReq) joinRoomByAlias(roomAlias string) util.JSONResponse {
queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias} queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias}
var queryRes roomserverAPI.GetRoomIDForAliasResponse var queryRes roomserverAPI.GetRoomIDForAliasResponse
if err = r.aliasAPI.GetRoomIDForAlias(r.req.Context(), &queryReq, &queryRes); err != nil { if err = r.aliasAPI.GetRoomIDForAlias(r.req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("r.aliasAPI.GetRoomIDForAlias failed")
return jsonerror.InternalServerError()
} }
if len(queryRes.RoomID) > 0 { if len(queryRes.RoomID) > 0 {
@ -194,7 +202,8 @@ func (r joinRoomReq) joinRoomByRemoteAlias(
} }
} }
} }
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("r.federation.LookupRoomAlias failed")
return jsonerror.InternalServerError()
} }
return r.joinRoomUsingServers(resp.RoomID, resp.Servers) return r.joinRoomUsingServers(resp.RoomID, resp.Servers)
@ -227,14 +236,26 @@ func (r joinRoomReq) joinRoomUsingServers(
var eb gomatrixserverlib.EventBuilder var eb gomatrixserverlib.EventBuilder
err := r.writeToBuilder(&eb, roomID) err := r.writeToBuilder(&eb, roomID)
if err != nil { if err != nil {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("r.writeToBuilder failed")
return jsonerror.InternalServerError()
} }
var queryRes roomserverAPI.QueryLatestEventsAndStateResponse queryRes := roomserverAPI.QueryLatestEventsAndStateResponse{}
event, err := common.BuildEvent(r.req.Context(), &eb, r.cfg, r.evTime, r.queryAPI, &queryRes) event, err := common.BuildEvent(r.req.Context(), &eb, r.cfg, r.evTime, r.queryAPI, &queryRes)
if err == nil { if err == nil {
if _, err = r.producer.SendEvents(r.req.Context(), []gomatrixserverlib.Event{*event}, r.cfg.Matrix.ServerName, nil); err != nil { // If we have successfully built an event at this point then we can
return httputil.LogThenError(r.req, err) // assert that the room is a local room, as BuildEvent was able to
// add prev_events etc successfully.
if _, err = r.producer.SendEvents(
r.req.Context(),
[]gomatrixserverlib.HeaderedEvent{
(*event).Headered(queryRes.RoomVersion),
},
r.cfg.Matrix.ServerName,
nil,
); err != nil {
util.GetLogger(r.req.Context()).WithError(err).Error("r.producer.SendEvents failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -243,8 +264,16 @@ func (r joinRoomReq) joinRoomUsingServers(
}{roomID}, }{roomID},
} }
} }
// Otherwise, if we've reached here, then we haven't been able to populate
// prev_events etc for the room, therefore the room is probably federated.
// TODO: This needs to be re-thought, as in the case of an invite, the room
// will exist in the database in roomserver_rooms but won't have any state
// events, therefore this below check fails.
if err != common.ErrRoomNoExists { if err != common.ErrRoomNoExists {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("common.BuildEvent failed")
return jsonerror.InternalServerError()
} }
if len(servers) == 0 { if len(servers) == 0 {
@ -262,7 +291,15 @@ func (r joinRoomReq) joinRoomUsingServers(
// There was a problem talking to one of the servers. // There was a problem talking to one of the servers.
util.GetLogger(r.req.Context()).WithError(lastErr).WithField("server", server).Warn("Failed to join room using server") util.GetLogger(r.req.Context()).WithError(lastErr).WithField("server", server).Warn("Failed to join room using server")
// Try the next server. // Try the next server.
continue if r.req.Context().Err() != nil {
// The request context has expired so don't bother trying any
// more servers - they will immediately fail due to the expired
// context.
break
} else {
// The request context hasn't expired yet so try the next server.
continue
}
} }
return *response return *response
} }
@ -280,7 +317,8 @@ func (r joinRoomReq) joinRoomUsingServers(
// 4) We couldn't fetch the public keys needed to verify the // 4) We couldn't fetch the public keys needed to verify the
// signatures on the state events. // signatures on the state events.
// 5) ... // 5) ...
return httputil.LogThenError(r.req, lastErr) util.GetLogger(r.req.Context()).WithError(lastErr).Error("failed to join through any server")
return jsonerror.InternalServerError()
} }
// joinRoomUsingServer tries to join a remote room using a given matrix server. // joinRoomUsingServer tries to join a remote room using a given matrix server.
@ -288,42 +326,70 @@ func (r joinRoomReq) joinRoomUsingServers(
// server was invalid this returns an error. // server was invalid this returns an error.
// Otherwise this returns a JSONResponse. // Otherwise this returns a JSONResponse.
func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) { func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) {
respMakeJoin, err := r.federation.MakeJoin(r.req.Context(), server, roomID, r.userID) // Ask the room server for information about room versions.
var request api.QueryRoomVersionCapabilitiesRequest
var response api.QueryRoomVersionCapabilitiesResponse
if err := r.queryAPI.QueryRoomVersionCapabilities(r.req.Context(), &request, &response); err != nil {
return nil, err
}
var supportedVersions []gomatrixserverlib.RoomVersion
for version := range response.AvailableRoomVersions {
supportedVersions = append(supportedVersions, version)
}
respMakeJoin, err := r.federation.MakeJoin(r.req.Context(), server, roomID, r.userID, supportedVersions)
if err != nil { if err != nil {
// TODO: Check if the user was not allowed to join the room. // TODO: Check if the user was not allowed to join the room.
return nil, err return nil, fmt.Errorf("r.federation.MakeJoin: %w", err)
} }
// Set all the fields to be what they should be, this should be a no-op // Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd" // but it's possible that the remote server returned us something "odd"
err = r.writeToBuilder(&respMakeJoin.JoinEvent, roomID) err = r.writeToBuilder(&respMakeJoin.JoinEvent, roomID)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("r.writeToBuilder: %w", err)
}
if respMakeJoin.RoomVersion == "" {
respMakeJoin.RoomVersion = gomatrixserverlib.RoomVersionV1
}
if _, err = respMakeJoin.RoomVersion.EventFormat(); err != nil {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(
fmt.Sprintf("Room version '%s' is not supported", respMakeJoin.RoomVersion),
),
}, nil
} }
eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), r.cfg.Matrix.ServerName)
event, err := respMakeJoin.JoinEvent.Build( event, err := respMakeJoin.JoinEvent.Build(
eventID, r.evTime, r.cfg.Matrix.ServerName, r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, r.evTime, r.cfg.Matrix.ServerName, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, respMakeJoin.RoomVersion,
) )
if err != nil { if err != nil {
res := httputil.LogThenError(r.req, err) return nil, fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err)
return &res, nil
} }
respSendJoin, err := r.federation.SendJoin(r.req.Context(), server, event) respSendJoin, err := r.federation.SendJoin(r.req.Context(), server, event, respMakeJoin.RoomVersion)
if err != nil { if err != nil {
return nil, fmt.Errorf("r.federation.SendJoin: %w", err)
}
if err = r.checkSendJoinResponse(event, server, respMakeJoin, respSendJoin); err != nil {
return nil, err return nil, err
} }
if err = respSendJoin.Check(r.req.Context(), r.keyRing, event); err != nil { util.GetLogger(r.req.Context()).WithFields(logrus.Fields{
return nil, err "room_id": roomID,
} "num_auth_events": len(respSendJoin.AuthEvents),
"num_state_events": len(respSendJoin.StateEvents),
}).Info("Room join signature and auth verification passed")
if err = r.producer.SendEventWithState( if err = r.producer.SendEventWithState(
r.req.Context(), gomatrixserverlib.RespState(respSendJoin.RespState), event, r.req.Context(),
gomatrixserverlib.RespState(respSendJoin.RespState),
event.Headered(respMakeJoin.RoomVersion),
); err != nil { ); err != nil {
res := httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("r.producer.SendEventWithState")
return &res, nil
} }
return &util.JSONResponse{ return &util.JSONResponse{
@ -334,3 +400,49 @@ func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib
}{roomID}, }{roomID},
}, nil }, nil
} }
// checkSendJoinResponse checks that all of the signatures are correct
// and that the join is allowed by the supplied state.
func (r joinRoomReq) checkSendJoinResponse(
event gomatrixserverlib.Event,
server gomatrixserverlib.ServerName,
respMakeJoin gomatrixserverlib.RespMakeJoin,
respSendJoin gomatrixserverlib.RespSendJoin,
) error {
// A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join.
retries := map[string]bool{}
retryCheck:
// TODO: Can we expand Check here to return a list of missing auth
// events rather than failing one at a time?
if err := respSendJoin.Check(r.req.Context(), r.keyRing, event); err != nil {
switch e := err.(type) {
case gomatrixserverlib.MissingAuthEventError:
// Check that we haven't already retried for this event, prevents
// us from ending up in endless loops
if !retries[e.AuthEventID] {
// Ask the server that we're talking to right now for the event
tx, txerr := r.federation.GetEvent(r.req.Context(), server, e.AuthEventID)
if txerr != nil {
return fmt.Errorf("r.federation.GetEvent: %w", txerr)
}
// For each event returned, add it to the auth events.
for _, pdu := range tx.PDUs {
ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, respMakeJoin.RoomVersion)
if everr != nil {
return fmt.Errorf("gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr)
}
respSendJoin.AuthEvents = append(respSendJoin.AuthEvents, ev)
}
// Mark the event as retried and then give the check another go.
retries[e.AuthEventID] = true
goto retryCheck
}
return fmt.Errorf("respSendJoin (after retries): %w", e)
default:
return fmt.Errorf("respSendJoin: %w", err)
}
}
return nil
}

View file

@ -70,8 +70,8 @@ func passwordLogin() loginFlows {
// Login implements GET and POST /login // Login implements GET and POST /login
func Login( func Login(
req *http.Request, accountDB *accounts.Database, deviceDB *devices.Database, req *http.Request, accountDB accounts.Database, deviceDB devices.Database,
cfg config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options
return util.JSONResponse{ return util.JSONResponse{
@ -122,7 +122,8 @@ func Login(
token, err := auth.GenerateAccessToken() token, err := auth.GenerateAccessToken()
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("auth.GenerateAccessToken failed")
return jsonerror.InternalServerError()
} }
dev, err := getDevice(req.Context(), r, deviceDB, acc, token) dev, err := getDevice(req.Context(), r, deviceDB, acc, token)
@ -153,7 +154,7 @@ func Login(
func getDevice( func getDevice(
ctx context.Context, ctx context.Context,
r passwordRequest, r passwordRequest,
deviceDB *devices.Database, deviceDB devices.Database,
acc *authtypes.Account, acc *authtypes.Account,
token string, token string,
) (dev *authtypes.Device, err error) { ) (dev *authtypes.Device, err error) {

View file

@ -19,22 +19,24 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// Logout handles POST /logout // Logout handles POST /logout
func Logout( func Logout(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil { if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevice failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -45,15 +47,17 @@ func Logout(
// LogoutAll handles POST /logout/all // LogoutAll handles POST /logout/all
func LogoutAll( func LogoutAll(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if err := deviceDB.RemoveAllDevices(req.Context(), localpart); err != nil { if err := deviceDB.RemoveAllDevices(req.Context(), localpart); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveAllDevices failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -40,11 +41,20 @@ var errMissingUserID = errors.New("'user_id' must be supplied")
// SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite) // SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite)
// by building a m.room.member event then sending it to the room server // by building a m.room.member event then sending it to the room server
func SendMembership( func SendMembership(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
roomID string, membership string, cfg config.Dendrite, roomID string, membership string, cfg *config.Dendrite,
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
) util.JSONResponse { ) util.JSONResponse {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := queryAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(err.Error()),
}
}
var body threepid.MembershipRequest var body threepid.MembershipRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
@ -90,13 +100,18 @@ func SendMembership(
JSON: jsonerror.NotFound(err.Error()), JSON: jsonerror.NotFound(err.Error()),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed")
return jsonerror.InternalServerError()
} }
if _, err := producer.SendEvents( if _, err := producer.SendEvents(
req.Context(), []gomatrixserverlib.Event{*event}, cfg.Matrix.ServerName, nil, req.Context(),
[]gomatrixserverlib.HeaderedEvent{(*event).Headered(verRes.RoomVersion)},
cfg.Matrix.ServerName,
nil,
); err != nil { ); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed")
return jsonerror.InternalServerError()
} }
var returnData interface{} = struct{}{} var returnData interface{} = struct{}{}
@ -116,10 +131,10 @@ func SendMembership(
func buildMembershipEvent( func buildMembershipEvent(
ctx context.Context, ctx context.Context,
body threepid.MembershipRequest, accountDB *accounts.Database, body threepid.MembershipRequest, accountDB accounts.Database,
device *authtypes.Device, device *authtypes.Device,
membership, roomID string, membership, roomID string,
cfg config.Dendrite, evTime time.Time, cfg *config.Dendrite, evTime time.Time,
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
stateKey, reason, err := getMembershipStateKey(body, device, membership) stateKey, reason, err := getMembershipStateKey(body, device, membership)
@ -165,8 +180,8 @@ func buildMembershipEvent(
func loadProfile( func loadProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
cfg config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID) _, serverName, err := gomatrixserverlib.SplitID('@', userID)
@ -214,9 +229,9 @@ func checkAndProcessThreepid(
req *http.Request, req *http.Request,
device *authtypes.Device, device *authtypes.Device,
body *threepid.MembershipRequest, body *threepid.MembershipRequest,
cfg config.Dendrite, cfg *config.Dendrite,
queryAPI roomserverAPI.RoomserverQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI,
accountDB *accounts.Database, accountDB accounts.Database,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
membership, roomID string, membership, roomID string,
evTime time.Time, evTime time.Time,
@ -242,7 +257,8 @@ func checkAndProcessThreepid(
JSON: jsonerror.NotFound(err.Error()), JSON: jsonerror.NotFound(err.Error()),
} }
} else if err != nil { } else if err != nil {
er := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed")
er := jsonerror.InternalServerError()
return inviteStored, &er return inviteStored, &er
} }
return return

View file

@ -17,8 +17,9 @@ package routing
import ( import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -26,14 +27,18 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type response struct { type getMembershipResponse struct {
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
} }
type getJoinedRoomsResponse struct {
JoinedRooms []string `json:"joined_rooms"`
}
// GetMemberships implements GET /rooms/{roomId}/members // GetMemberships implements GET /rooms/{roomId}/members
func GetMemberships( func GetMemberships(
req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool, req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool,
_ config.Dendrite, _ *config.Dendrite,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
queryReq := api.QueryMembershipsForRoomRequest{ queryReq := api.QueryMembershipsForRoomRequest{
@ -43,7 +48,8 @@ func GetMemberships(
} }
var queryRes api.QueryMembershipsForRoomResponse var queryRes api.QueryMembershipsForRoomResponse
if err := queryAPI.QueryMembershipsForRoom(req.Context(), &queryReq, &queryRes); err != nil { if err := queryAPI.QueryMembershipsForRoom(req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryMembershipsForRoom failed")
return jsonerror.InternalServerError()
} }
if !queryRes.HasBeenInRoom { if !queryRes.HasBeenInRoom {
@ -55,6 +61,27 @@ func GetMemberships(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: response{queryRes.JoinEvents}, JSON: getMembershipResponse{queryRes.JoinEvents},
}
}
func GetJoinedRooms(
req *http.Request,
device *authtypes.Device,
accountsDB accounts.Database,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
joinedRooms, err := accountsDB.GetRoomIDsByLocalPart(req.Context(), localpart)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountsDB.GetRoomIDsByLocalPart failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: getJoinedRoomsResponse{joinedRooms},
} }
} }

View file

@ -36,7 +36,7 @@ import (
// GetProfile implements GET /profile/{userID} // GetProfile implements GET /profile/{userID}
func GetProfile( func GetProfile(
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -50,7 +50,8 @@ func GetProfile(
} }
} }
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("getProfile failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -64,7 +65,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url // GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL( func GetAvatarURL(
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -77,7 +78,8 @@ func GetAvatarURL(
} }
} }
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("getProfile failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -90,7 +92,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url // SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL( func SetAvatarURL(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite, userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -116,7 +118,8 @@ func SetAvatarURL(
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
evTime, err := httputil.ParseTSParam(req) evTime, err := httputil.ParseTSParam(req)
@ -129,16 +132,19 @@ func SetAvatarURL(
oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed")
return jsonerror.InternalServerError()
} }
if err = accountDB.SetAvatarURL(req.Context(), localpart, r.AvatarURL); err != nil { if err = accountDB.SetAvatarURL(req.Context(), localpart, r.AvatarURL); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.SetAvatarURL failed")
return jsonerror.InternalServerError()
} }
memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart) memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetMembershipsByLocalpart failed")
return jsonerror.InternalServerError()
} }
newProfile := authtypes.Profile{ newProfile := authtypes.Profile{
@ -151,15 +157,18 @@ func SetAvatarURL(
req.Context(), memberships, newProfile, userID, cfg, evTime, queryAPI, req.Context(), memberships, newProfile, userID, cfg, evTime, queryAPI,
) )
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError()
} }
if _, err := rsProducer.SendEvents(req.Context(), events, cfg.Matrix.ServerName, nil); err != nil { if _, err := rsProducer.SendEvents(req.Context(), events, cfg.Matrix.ServerName, nil); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("rsProducer.SendEvents failed")
return jsonerror.InternalServerError()
} }
if err := producer.SendUpdate(userID, changedKey, oldProfile.AvatarURL, r.AvatarURL); err != nil { if err := producer.SendUpdate(userID, changedKey, oldProfile.AvatarURL, r.AvatarURL); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("producer.SendUpdate failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -170,7 +179,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname // GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName( func GetDisplayName(
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -183,7 +192,8 @@ func GetDisplayName(
} }
} }
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("getProfile failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -196,7 +206,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname // SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName( func SetDisplayName(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite, userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -222,7 +232,8 @@ func SetDisplayName(
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
evTime, err := httputil.ParseTSParam(req) evTime, err := httputil.ParseTSParam(req)
@ -235,16 +246,19 @@ func SetDisplayName(
oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed")
return jsonerror.InternalServerError()
} }
if err = accountDB.SetDisplayName(req.Context(), localpart, r.DisplayName); err != nil { if err = accountDB.SetDisplayName(req.Context(), localpart, r.DisplayName); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.SetDisplayName failed")
return jsonerror.InternalServerError()
} }
memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart) memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetMembershipsByLocalpart failed")
return jsonerror.InternalServerError()
} }
newProfile := authtypes.Profile{ newProfile := authtypes.Profile{
@ -257,15 +271,18 @@ func SetDisplayName(
req.Context(), memberships, newProfile, userID, cfg, evTime, queryAPI, req.Context(), memberships, newProfile, userID, cfg, evTime, queryAPI,
) )
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError()
} }
if _, err := rsProducer.SendEvents(req.Context(), events, cfg.Matrix.ServerName, nil); err != nil { if _, err := rsProducer.SendEvents(req.Context(), events, cfg.Matrix.ServerName, nil); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("rsProducer.SendEvents failed")
return jsonerror.InternalServerError()
} }
if err := producer.SendUpdate(userID, changedKey, oldProfile.DisplayName, r.DisplayName); err != nil { if err := producer.SendUpdate(userID, changedKey, oldProfile.DisplayName, r.DisplayName); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("producer.SendUpdate failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -279,7 +296,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically // Returns an error when something goes wrong or specifically
// common.ErrProfileNoExists when the profile doesn't exist. // common.ErrProfileNoExists when the profile doesn't exist.
func getProfile( func getProfile(
ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite, ctx context.Context, accountDB accounts.Database, cfg *config.Dendrite,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -321,10 +338,16 @@ func buildMembershipEvents(
memberships []authtypes.Membership, memberships []authtypes.Membership,
newProfile authtypes.Profile, userID string, cfg *config.Dendrite, newProfile authtypes.Profile, userID string, cfg *config.Dendrite,
evTime time.Time, queryAPI api.RoomserverQueryAPI, evTime time.Time, queryAPI api.RoomserverQueryAPI,
) ([]gomatrixserverlib.Event, error) { ) ([]gomatrixserverlib.HeaderedEvent, error) {
evs := []gomatrixserverlib.Event{} evs := []gomatrixserverlib.HeaderedEvent{}
for _, membership := range memberships { for _, membership := range memberships {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: membership.RoomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := queryAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
return []gomatrixserverlib.HeaderedEvent{}, err
}
builder := gomatrixserverlib.EventBuilder{ builder := gomatrixserverlib.EventBuilder{
Sender: userID, Sender: userID,
RoomID: membership.RoomID, RoomID: membership.RoomID,
@ -343,12 +366,12 @@ func buildMembershipEvents(
return nil, err return nil, err
} }
event, err := common.BuildEvent(ctx, &builder, *cfg, evTime, queryAPI, nil) event, err := common.BuildEvent(ctx, &builder, cfg, evTime, queryAPI, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
evs = append(evs, *event) evs = append(evs, (*event).Headered(verRes.RoomVersion))
} }
return evs, nil return evs, nil

View file

@ -43,6 +43,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -361,7 +362,7 @@ func UsernameMatchesMultipleExclusiveNamespaces(
// Check namespaces and see if more than one match // Check namespaces and see if more than one match
matchCount := 0 matchCount := 0
for _, appservice := range cfg.Derived.ApplicationServices { for _, appservice := range cfg.Derived.ApplicationServices {
if appservice.IsInterestedInUserID(userID) { if appservice.OwnsNamespaceCoveringUserId(userID) {
if matchCount++; matchCount > 1 { if matchCount++; matchCount > 1 {
return true return true
} }
@ -439,16 +440,18 @@ func validateApplicationService(
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
func Register( func Register(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
cfg *config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
var r registerRequest var r registerRequest
resErr := httputil.UnmarshalJSONRequest(req, &r) resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
if req.URL.Query().Get("kind") == "guest" {
return handleGuestRegistration(req, r, cfg, accountDB, deviceDB)
}
// Retrieve or generate the sessionID // Retrieve or generate the sessionID
sessionID := r.Auth.Session sessionID := r.Auth.Session
@ -468,7 +471,8 @@ func Register(
if r.Username == "" { if r.Username == "" {
id, err := accountDB.GetNewNumericLocalpart(req.Context()) id, err := accountDB.GetNewNumericLocalpart(req.Context())
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetNewNumericLocalpart failed")
return jsonerror.InternalServerError()
} }
r.Username = strconv.FormatInt(id, 10) r.Username = strconv.FormatInt(id, 10)
@ -505,6 +509,51 @@ func Register(
return handleRegistrationFlow(req, r, sessionID, cfg, accountDB, deviceDB) return handleRegistrationFlow(req, r, sessionID, cfg, accountDB, deviceDB)
} }
func handleGuestRegistration(
req *http.Request,
r registerRequest,
cfg *config.Dendrite,
accountDB accounts.Database,
deviceDB devices.Database,
) util.JSONResponse {
acc, err := accountDB.CreateGuestAccount(req.Context())
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("failed to create account: " + err.Error()),
}
}
token, err := tokens.GenerateLoginToken(tokens.TokenOptions{
ServerPrivateKey: cfg.Matrix.PrivateKey.Seed(),
ServerName: string(acc.ServerName),
UserID: acc.UserID,
})
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("Failed to generate access token"),
}
}
//we don't allow guests to specify their own device_id
dev, err := deviceDB.CreateDevice(req.Context(), acc.Localpart, nil, token, r.InitialDisplayName)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("failed to create device: " + err.Error()),
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: registerResponse{
UserID: dev.UserID,
AccessToken: dev.AccessToken,
HomeServer: acc.ServerName,
DeviceID: dev.ID,
},
}
}
// handleRegistrationFlow will direct and complete registration flow stages // handleRegistrationFlow will direct and complete registration flow stages
// that the client has requested. // that the client has requested.
// nolint: gocyclo // nolint: gocyclo
@ -513,8 +562,8 @@ func handleRegistrationFlow(
r registerRequest, r registerRequest,
sessionID string, sessionID string,
cfg *config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
) util.JSONResponse { ) util.JSONResponse {
// TODO: Shared secret registration (create new user scripts) // TODO: Shared secret registration (create new user scripts)
// TODO: Enable registration config flag // TODO: Enable registration config flag
@ -545,7 +594,8 @@ func handleRegistrationFlow(
valid, err := isValidMacLogin(cfg, r.Username, r.Password, r.Admin, r.Auth.Mac) valid, err := isValidMacLogin(cfg, r.Username, r.Password, r.Admin, r.Auth.Mac)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("isValidMacLogin failed")
return jsonerror.InternalServerError()
} else if !valid { } else if !valid {
return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") return util.MessageResponse(http.StatusForbidden, "HMAC incorrect")
} }
@ -611,8 +661,8 @@ func handleApplicationServiceRegistration(
req *http.Request, req *http.Request,
r registerRequest, r registerRequest,
cfg *config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
) util.JSONResponse { ) util.JSONResponse {
// Check if we previously had issues extracting the access token from the // Check if we previously had issues extracting the access token from the
// request. // request.
@ -650,8 +700,8 @@ func checkAndCompleteFlow(
r registerRequest, r registerRequest,
sessionID string, sessionID string,
cfg *config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
) util.JSONResponse { ) util.JSONResponse {
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
@ -673,8 +723,8 @@ func checkAndCompleteFlow(
// LegacyRegister process register requests from the legacy v1 API // LegacyRegister process register requests from the legacy v1 API
func LegacyRegister( func LegacyRegister(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
cfg *config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
var r legacyRegisterRequest var r legacyRegisterRequest
@ -701,7 +751,8 @@ func LegacyRegister(
valid, err := isValidMacLogin(cfg, r.Username, r.Password, r.Admin, r.Mac) valid, err := isValidMacLogin(cfg, r.Username, r.Password, r.Admin, r.Mac)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("isValidMacLogin failed")
return jsonerror.InternalServerError()
} }
if !valid { if !valid {
@ -757,8 +808,8 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
// not all // not all
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
username, password, appserviceID string, username, password, appserviceID string,
inhibitLogin common.WeakBoolean, inhibitLogin common.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
@ -934,8 +985,8 @@ type availableResponse struct {
// RegisterAvailable checks if the username is already taken or invalid. // RegisterAvailable checks if the username is already taken or invalid.
func RegisterAvailable( func RegisterAvailable(
req *http.Request, req *http.Request,
cfg config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
) util.JSONResponse { ) util.JSONResponse {
username := req.URL.Query().Get("username") username := req.URL.Query().Get("username")
@ -949,7 +1000,7 @@ func RegisterAvailable(
// Check if this username is reserved by an application service // Check if this username is reserved by an application service
userID := userutil.MakeUserID(username, cfg.Matrix.ServerName) userID := userutil.MakeUserID(username, cfg.Matrix.ServerName)
for _, appservice := range cfg.Derived.ApplicationServices { for _, appservice := range cfg.Derived.ApplicationServices {
if appservice.IsInterestedInUserID(userID) { if appservice.OwnsNamespaceCoveringUserId(userID) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.UserInUse("Desired user ID is reserved by an application service."), JSON: jsonerror.UserInUse("Desired user ID is reserved by an application service."),

View file

@ -40,7 +40,7 @@ func newTag() gomatrix.TagContent {
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
func GetTags( func GetTags(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
device *authtypes.Device, device *authtypes.Device,
userID string, userID string,
roomID string, roomID string,
@ -56,7 +56,8 @@ func GetTags(
_, data, err := obtainSavedTags(req, userID, roomID, accountDB) _, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError()
} }
if data == nil { if data == nil {
@ -77,7 +78,7 @@ func GetTags(
// the tag to the "map" and saving the new "map" to the DB // the tag to the "map" and saving the new "map" to the DB
func PutTag( func PutTag(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
device *authtypes.Device, device *authtypes.Device,
userID string, userID string,
roomID string, roomID string,
@ -99,20 +100,23 @@ func PutTag(
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError()
} }
var tagContent gomatrix.TagContent var tagContent gomatrix.TagContent
if data != nil { if data != nil {
if err = json.Unmarshal(data.Content, &tagContent); err != nil { if err = json.Unmarshal(data.Content, &tagContent); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
} }
} else { } else {
tagContent = newTag() tagContent = newTag()
} }
tagContent.Tags[tag] = properties tagContent.Tags[tag] = properties
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
} }
// Send data to syncProducer in order to inform clients of changes // Send data to syncProducer in order to inform clients of changes
@ -134,7 +138,7 @@ func PutTag(
// the "map" and then saving the new "map" in the DB // the "map" and then saving the new "map" in the DB
func DeleteTag( func DeleteTag(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
device *authtypes.Device, device *authtypes.Device,
userID string, userID string,
roomID string, roomID string,
@ -151,7 +155,8 @@ func DeleteTag(
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError()
} }
// If there are no tags in the database, exit // If there are no tags in the database, exit
@ -166,7 +171,8 @@ func DeleteTag(
var tagContent gomatrix.TagContent var tagContent gomatrix.TagContent
err = json.Unmarshal(data.Content, &tagContent) err = json.Unmarshal(data.Content, &tagContent)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
} }
// Check whether the tag to be deleted exists // Check whether the tag to be deleted exists
@ -180,7 +186,8 @@ func DeleteTag(
} }
} }
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
} }
// Send data to syncProducer in order to inform clients of changes // Send data to syncProducer in order to inform clients of changes
@ -203,7 +210,7 @@ func obtainSavedTags(
req *http.Request, req *http.Request,
userID string, userID string,
roomID string, roomID string,
accountDB *accounts.Database, accountDB accounts.Database,
) (string, *gomatrixserverlib.ClientEvent, error) { ) (string, *gomatrixserverlib.ClientEvent, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
@ -222,7 +229,7 @@ func saveTagData(
req *http.Request, req *http.Request,
localpart string, localpart string,
roomID string, roomID string,
accountDB *accounts.Database, accountDB accounts.Database,
Tag gomatrix.TagContent, Tag gomatrix.TagContent,
) error { ) error {
newTagData, err := json.Marshal(Tag) newTagData, err := json.Marshal(Tag)

View file

@ -47,18 +47,18 @@ const pathPrefixUnstable = "/_matrix/client/unstable"
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup( func Setup(
apiMux *mux.Router, cfg config.Dendrite, apiMux *mux.Router, cfg *config.Dendrite,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
queryAPI roomserverAPI.RoomserverQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI,
aliasAPI roomserverAPI.RoomserverAliasAPI, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.KeyRing, keyRing gomatrixserverlib.KeyRing,
userUpdateProducer *producers.UserUpdateProducer, userUpdateProducer *producers.UserUpdateProducer,
syncProducer *producers.SyncAPIProducer, syncProducer *producers.SyncAPIProducer,
typingProducer *producers.TypingServerProducer, eduProducer *producers.EDUServerProducer,
transactionsCache *transactions.Cache, transactionsCache *transactions.Cache,
federationSender federationSenderAPI.FederationSenderQueryAPI, federationSender federationSenderAPI.FederationSenderQueryAPI,
) { ) {
@ -105,6 +105,12 @@ func Setup(
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/joined_rooms",
common.MakeAuthAPI("joined_rooms", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return GetJoinedRooms(req, device, accountDB)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}", r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}",
common.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { common.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req)) vars, err := common.URLDecodeMapValues(mux.Vars(req))
@ -143,6 +149,31 @@ func Setup(
return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, queryAPI, federation, keyRing) return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, queryAPI, federation, keyRing)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state", common.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return OnIncomingStateRequest(req.Context(), queryAPI, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type}", common.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return OnIncomingStateTypeRequest(req.Context(), queryAPI, vars["roomID"], vars["type"], "")
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", common.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return OnIncomingStateTypeRequest(req.Context(), queryAPI, vars["roomID"], vars["type"], vars["stateKey"])
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req)) vars, err := common.URLDecodeMapValues(mux.Vars(req))
@ -158,6 +189,7 @@ func Setup(
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, queryAPI, producer, nil) return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, queryAPI, producer, nil)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req)) vars, err := common.URLDecodeMapValues(mux.Vars(req))
@ -170,11 +202,11 @@ func Setup(
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
return Register(req, accountDB, deviceDB, &cfg) return Register(req, accountDB, deviceDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)
v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
return LegacyRegister(req, accountDB, deviceDB, &cfg) return LegacyRegister(req, accountDB, deviceDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
@ -187,7 +219,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return DirectoryRoom(req, vars["roomAlias"], federation, &cfg, aliasAPI, federationSender) return DirectoryRoom(req, vars["roomAlias"], federation, cfg, aliasAPI, federationSender)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -197,7 +229,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetLocalAlias(req, device, vars["roomAlias"], &cfg, aliasAPI) return SetLocalAlias(req, device, vars["roomAlias"], cfg, aliasAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
@ -229,7 +261,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, typingProducer) return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduProducer)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
@ -301,7 +333,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetProfile(req, accountDB, &cfg, vars["userID"], asAPI, federation) return GetProfile(req, accountDB, cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -311,7 +343,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetAvatarURL(req, accountDB, &cfg, vars["userID"], asAPI, federation) return GetAvatarURL(req, accountDB, cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -321,7 +353,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI) return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
@ -333,7 +365,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetDisplayName(req, accountDB, &cfg, vars["userID"], asAPI, federation) return GetDisplayName(req, accountDB, cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -343,7 +375,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI) return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
@ -390,7 +422,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/thirdparty/protocols", r0mux.Handle("/thirdparty/protocols",
common.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { common.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse {
// TODO: Return the third party protcols // TODO: Return the third party protcols
return util.JSONResponse{ return util.JSONResponse{
@ -430,6 +462,26 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userID}/account_data/{type}",
common.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetAccountData(req, accountDB, device, vars["userID"], "", vars["type"])
}),
).Methods(http.MethodGet)
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
common.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"])
}),
).Methods(http.MethodGet)
r0mux.Handle("/rooms/{roomID}/members", r0mux.Handle("/rooms/{roomID}/members",
common.MakeAuthAPI("rooms_members", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { common.MakeAuthAPI("rooms_members", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req)) vars, err := common.URLDecodeMapValues(mux.Vars(req))
@ -483,6 +535,22 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
common.MakeAuthAPI("delete_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return DeleteDeviceById(req, deviceDB, device, vars["deviceID"])
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/delete_devices",
common.MakeAuthAPI("delete_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return DeleteDevices(req, deviceDB, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub implementations for sytest // Stub implementations for sytest
r0mux.Handle("/events", r0mux.Handle("/events",
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
@ -531,4 +599,10 @@ func Setup(
return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}), }),
).Methods(http.MethodDelete, http.MethodOptions) ).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/capabilities",
common.MakeAuthAPI("capabilities", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return GetCapabilities(req, queryAPI)
}),
).Methods(http.MethodGet)
} }

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid // http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid
@ -43,11 +44,20 @@ func SendEvent(
req *http.Request, req *http.Request,
device *authtypes.Device, device *authtypes.Device,
roomID, eventType string, txnID, stateKey *string, roomID, eventType string, txnID, stateKey *string,
cfg config.Dendrite, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
txnCache *transactions.Cache, txnCache *transactions.Cache,
) util.JSONResponse { ) util.JSONResponse {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := queryAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(err.Error()),
}
}
if txnID != nil { if txnID != nil {
// Try to fetch response from transactionsCache // Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
@ -71,11 +81,22 @@ func SendEvent(
// pass the new event to the roomserver and receive the correct event ID // pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded // event ID in case of duplicate transaction is discarded
eventID, err := producer.SendEvents( eventID, err := producer.SendEvents(
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndSessionID, req.Context(),
[]gomatrixserverlib.HeaderedEvent{
e.Headered(verRes.RoomVersion),
},
cfg.Matrix.ServerName,
txnAndSessionID,
) )
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed")
return jsonerror.InternalServerError()
} }
util.GetLogger(req.Context()).WithFields(logrus.Fields{
"event_id": eventID,
"room_id": roomID,
"room_version": verRes.RoomVersion,
}).Info("Sent event to roomserver")
res := util.JSONResponse{ res := util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -93,7 +114,7 @@ func generateSendEvent(
req *http.Request, req *http.Request,
device *authtypes.Device, device *authtypes.Device,
roomID, eventType string, stateKey *string, roomID, eventType string, stateKey *string,
cfg config.Dendrite, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) (*gomatrixserverlib.Event, *util.JSONResponse) { ) (*gomatrixserverlib.Event, *util.JSONResponse) {
// parse the incoming http request // parse the incoming http request
@ -121,7 +142,8 @@ func generateSendEvent(
} }
err = builder.SetContent(r) err = builder.SetContent(r)
if err != nil { if err != nil {
resErr := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed")
resErr := jsonerror.InternalServerError()
return nil, &resErr return nil, &resErr
} }
@ -133,14 +155,15 @@ func generateSendEvent(
JSON: jsonerror.NotFound("Room does not exist"), JSON: jsonerror.NotFound("Room does not exist"),
} }
} else if err != nil { } else if err != nil {
resErr := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("common.BuildEvent failed")
resErr := jsonerror.InternalServerError()
return nil, &resErr return nil, &resErr
} }
// check to see if this user can perform this operation // check to see if this user can perform this operation
stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents))
for i := range queryRes.StateEvents { for i := range queryRes.StateEvents {
stateEvents[i] = &queryRes.StateEvents[i] stateEvents[i] = &queryRes.StateEvents[i].Event
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(*e, &provider); err != nil { if err = gomatrixserverlib.Allowed(*e, &provider); err != nil {

View file

@ -34,8 +34,8 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer // sends the typing events to client API typingProducer
func SendTyping( func SendTyping(
req *http.Request, device *authtypes.Device, roomID string, req *http.Request, device *authtypes.Device, roomID string,
userID string, accountDB *accounts.Database, userID string, accountDB accounts.Database,
typingProducer *producers.TypingServerProducer, eduProducer *producers.EDUServerProducer,
) util.JSONResponse { ) util.JSONResponse {
if device.UserID != userID { if device.UserID != userID {
return util.JSONResponse{ return util.JSONResponse{
@ -46,7 +46,8 @@ func SendTyping(
localpart, err := userutil.ParseUsernameParam(userID, nil) localpart, err := userutil.ParseUsernameParam(userID, nil)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("userutil.ParseUsernameParam failed")
return jsonerror.InternalServerError()
} }
// Verify that the user is a member of this room // Verify that the user is a member of this room
@ -57,7 +58,8 @@ func SendTyping(
JSON: jsonerror.Forbidden("User not in this room"), JSON: jsonerror.Forbidden("User not in this room"),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetMembershipInRoomByLocalPart failed")
return jsonerror.InternalServerError()
} }
// parse the incoming http request // parse the incoming http request
@ -67,10 +69,11 @@ func SendTyping(
return *resErr return *resErr
} }
if err = typingProducer.Send( if err = eduProducer.SendTyping(
req.Context(), userID, roomID, r.Typing, r.Timeout, req.Context(), userID, roomID, r.Typing, r.Timeout,
); err != nil { ); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("eduProducer.Send failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -15,14 +15,13 @@
package routing package routing
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -41,28 +40,39 @@ type stateEventInStateResp struct {
// TODO: Check if the user is in the room. If not, check if the room's history // TODO: Check if the user is in the room. If not, check if the room's history
// is publicly visible. Current behaviour is returning an empty array if the // is publicly visible. Current behaviour is returning an empty array if the
// user cannot see the room's history. // user cannot see the room's history.
func OnIncomingStateRequest(req *http.Request, db storage.Database, roomID string) util.JSONResponse { func OnIncomingStateRequest(ctx context.Context, queryAPI api.RoomserverQueryAPI, roomID string) util.JSONResponse {
// TODO(#287): Auth request and handle the case where the user has left (where // TODO(#287): Auth request and handle the case where the user has left (where
// we should return the state at the poin they left) // we should return the state at the poin they left)
stateReq := api.QueryLatestEventsAndStateRequest{
RoomID: roomID,
}
stateRes := api.QueryLatestEventsAndStateResponse{}
stateFilterPart := gomatrix.DefaultFilterPart() if err := queryAPI.QueryLatestEventsAndState(ctx, &stateReq, &stateRes); err != nil {
// TODO: stateFilterPart should not limit the number of state events (or only limits abusive number of events) util.GetLogger(ctx).WithError(err).Error("queryAPI.QueryLatestEventsAndState failed")
return jsonerror.InternalServerError()
}
stateEvents, err := db.GetStateEventsForRoom(req.Context(), roomID, &stateFilterPart) if len(stateRes.StateEvents) == 0 {
if err != nil { return util.JSONResponse{
return httputil.LogThenError(req, err) Code: http.StatusNotFound,
JSON: jsonerror.NotFound("cannot find state"),
}
} }
resp := []stateEventInStateResp{} resp := []stateEventInStateResp{}
// Fill the prev_content and replaces_state keys if necessary // Fill the prev_content and replaces_state keys if necessary
for _, event := range stateEvents { for _, event := range stateRes.StateEvents {
stateEvent := stateEventInStateResp{ stateEvent := stateEventInStateResp{
ClientEvent: gomatrixserverlib.ToClientEvent(event, gomatrixserverlib.FormatAll), ClientEvent: gomatrixserverlib.HeaderedToClientEvents(
[]gomatrixserverlib.HeaderedEvent{event}, gomatrixserverlib.FormatAll,
)[0],
} }
var prevEventRef types.PrevEventRef var prevEventRef types.PrevEventRef
if len(event.Unsigned()) > 0 { if len(event.Unsigned()) > 0 {
if err := json.Unmarshal(event.Unsigned(), &prevEventRef); err != nil { if err := json.Unmarshal(event.Unsigned(), &prevEventRef); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(ctx).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
} }
// Fills the previous state event ID if the state event replaces another // Fills the previous state event ID if the state event replaces another
// state event // state event
@ -88,23 +98,32 @@ func OnIncomingStateRequest(req *http.Request, db storage.Database, roomID strin
// /rooms/{roomID}/state/{type}/{statekey} request. It will look in current // /rooms/{roomID}/state/{type}/{statekey} request. It will look in current
// state to see if there is an event with that type and state key, if there // state to see if there is an event with that type and state key, if there
// is then (by default) we return the content, otherwise a 404. // is then (by default) we return the content, otherwise a 404.
func OnIncomingStateTypeRequest(req *http.Request, db storage.Database, roomID string, evType, stateKey string) util.JSONResponse { func OnIncomingStateTypeRequest(ctx context.Context, queryAPI api.RoomserverQueryAPI, roomID string, evType, stateKey string) util.JSONResponse {
// TODO(#287): Auth request and handle the case where the user has left (where // TODO(#287): Auth request and handle the case where the user has left (where
// we should return the state at the poin they left) // we should return the state at the poin they left)
util.GetLogger(ctx).WithFields(log.Fields{
logger := util.GetLogger(req.Context())
logger.WithFields(log.Fields{
"roomID": roomID, "roomID": roomID,
"evType": evType, "evType": evType,
"stateKey": stateKey, "stateKey": stateKey,
}).Info("Fetching state") }).Info("Fetching state")
event, err := db.GetStateEvent(req.Context(), roomID, evType, stateKey) stateReq := api.QueryLatestEventsAndStateRequest{
if err != nil { RoomID: roomID,
return httputil.LogThenError(req, err) StateToFetch: []gomatrixserverlib.StateKeyTuple{
gomatrixserverlib.StateKeyTuple{
EventType: evType,
StateKey: stateKey,
},
},
}
stateRes := api.QueryLatestEventsAndStateResponse{}
if err := queryAPI.QueryLatestEventsAndState(ctx, &stateReq, &stateRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("queryAPI.QueryLatestEventsAndState failed")
return jsonerror.InternalServerError()
} }
if event == nil { if len(stateRes.StateEvents) == 0 {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
JSON: jsonerror.NotFound("cannot find state"), JSON: jsonerror.NotFound("cannot find state"),
@ -112,7 +131,7 @@ func OnIncomingStateTypeRequest(req *http.Request, db storage.Database, roomID s
} }
stateEvent := stateEventInStateResp{ stateEvent := stateEventInStateResp{
ClientEvent: gomatrixserverlib.ToClientEvent(*event, gomatrixserverlib.FormatAll), ClientEvent: gomatrixserverlib.HeaderedToClientEvent(stateRes.StateEvents[0], gomatrixserverlib.FormatAll),
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -39,7 +39,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements: // RequestEmailToken implements:
// POST /account/3pid/email/requestToken // POST /account/3pid/email/requestToken
// POST /register/email/requestToken // POST /register/email/requestToken
func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg config.Dendrite) util.JSONResponse { func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.Dendrite) util.JSONResponse {
var body threepid.EmailAssociationRequest var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
@ -51,7 +51,8 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
// Check if the 3PID is already in use locally // Check if the 3PID is already in use locally
localpart, err := accountDB.GetLocalpartForThreePID(req.Context(), body.Email, "email") localpart, err := accountDB.GetLocalpartForThreePID(req.Context(), body.Email, "email")
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetLocalpartForThreePID failed")
return jsonerror.InternalServerError()
} }
if len(localpart) > 0 { if len(localpart) > 0 {
@ -71,7 +72,8 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
JSON: jsonerror.NotTrusted(body.IDServer), JSON: jsonerror.NotTrusted(body.IDServer),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("threepid.CreateSession failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -82,8 +84,8 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
// CheckAndSave3PIDAssociation implements POST /account/3pid // CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation( func CheckAndSave3PIDAssociation(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
cfg config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest var body threepid.EmailAssociationCheckRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
@ -98,7 +100,8 @@ func CheckAndSave3PIDAssociation(
JSON: jsonerror.NotTrusted(body.Creds.IDServer), JSON: jsonerror.NotTrusted(body.Creds.IDServer),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed")
return jsonerror.InternalServerError()
} }
if !verified { if !verified {
@ -120,18 +123,21 @@ func CheckAndSave3PIDAssociation(
JSON: jsonerror.NotTrusted(body.Creds.IDServer), JSON: jsonerror.NotTrusted(body.Creds.IDServer),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("threepid.PublishAssociation failed")
return jsonerror.InternalServerError()
} }
} }
// Save the association in the database // Save the association in the database
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if err = accountDB.SaveThreePIDAssociation(req.Context(), address, localpart, medium); err != nil { if err = accountDB.SaveThreePIDAssociation(req.Context(), address, localpart, medium); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountsDB.SaveThreePIDAssociation failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -142,16 +148,18 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid // GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
threepids, err := accountDB.GetThreePIDsForLocalpart(req.Context(), localpart) threepids, err := accountDB.GetThreePIDsForLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetThreePIDsForLocalpart failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -161,14 +169,15 @@ func GetAssociated3PIDs(
} }
// Forget3PID implements POST /account/3pid/delete // Forget3PID implements POST /account/3pid/delete
func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONResponse { func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse {
var body authtypes.ThreePID var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
} }
if err := accountDB.RemoveThreePIDAssociation(req.Context(), body.Address, body.Medium); err != nil { if err := accountDB.RemoveThreePIDAssociation(req.Context(), body.Address, body.Medium); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.RemoveThreePIDAssociation failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -23,7 +23,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -31,7 +31,7 @@ import (
// RequestTurnServer implements: // RequestTurnServer implements:
// GET /voip/turnServer // GET /voip/turnServer
func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg config.Dendrite) util.JSONResponse { func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg *config.Dendrite) util.JSONResponse {
turnConfig := cfg.TURN turnConfig := cfg.TURN
// TODO Guest Support // TODO Guest Support
@ -56,7 +56,8 @@ func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg config.D
_, err := mac.Write([]byte(resp.Username)) _, err := mac.Write([]byte(resp.Username))
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("mac.Write failed")
return jsonerror.InternalServerError()
} }
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID) resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)

View file

@ -86,8 +86,8 @@ var (
// can be emitted. // can be emitted.
func CheckAndProcessInvite( func CheckAndProcessInvite(
ctx context.Context, ctx context.Context,
device *authtypes.Device, body *MembershipRequest, cfg config.Dendrite, device *authtypes.Device, body *MembershipRequest, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, db *accounts.Database, queryAPI api.RoomserverQueryAPI, db accounts.Database,
producer *producers.RoomserverProducer, membership string, roomID string, producer *producers.RoomserverProducer, membership string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStoredOnIDServer bool, err error) { ) (inviteStoredOnIDServer bool, err error) {
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed. // Returns an error if a check or a request failed.
func queryIDServer( func queryIDServer(
ctx context.Context, ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device, db accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil { if err = isTrusted(body.IDServer, cfg); err != nil {
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed. // Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite( func queryIDServerStoreInvite(
ctx context.Context, ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device, db accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) { ) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name // Retrieve the sender's profile to get their display name
@ -330,7 +330,7 @@ func checkIDServerSignatures(
func emit3PIDInviteEvent( func emit3PIDInviteEvent(
ctx context.Context, ctx context.Context,
body *MembershipRequest, res *idServerStoreInviteResponse, body *MembershipRequest, res *idServerStoreInviteResponse,
device *authtypes.Device, roomID string, cfg config.Dendrite, device *authtypes.Device, roomID string, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer,
evTime time.Time, evTime time.Time,
) error { ) error {
@ -353,12 +353,19 @@ func emit3PIDInviteEvent(
return err return err
} }
var queryRes *api.QueryLatestEventsAndStateResponse queryRes := api.QueryLatestEventsAndStateResponse{}
event, err := common.BuildEvent(ctx, builder, cfg, evTime, queryAPI, queryRes) event, err := common.BuildEvent(ctx, builder, cfg, evTime, queryAPI, &queryRes)
if err != nil { if err != nil {
return err return err
} }
_, err = producer.SendEvents(ctx, []gomatrixserverlib.Event{*event}, cfg.Matrix.ServerName, nil) _, err = producer.SendEvents(
ctx,
[]gomatrixserverlib.HeaderedEvent{
(*event).Headered(queryRes.RoomVersion),
},
cfg.Matrix.ServerName,
nil,
)
return err return err
} }

View file

@ -53,7 +53,7 @@ type Credentials struct {
// Returns an error if there was a problem sending the request or decoding the // Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status. // response, or if the identity server responded with a non-OK status.
func CreateSession( func CreateSession(
ctx context.Context, req EmailAssociationRequest, cfg config.Dendrite, ctx context.Context, req EmailAssociationRequest, cfg *config.Dendrite,
) (string, error) { ) (string, error) {
if err := isTrusted(req.IDServer, cfg); err != nil { if err := isTrusted(req.IDServer, cfg); err != nil {
return "", err return "", err
@ -101,7 +101,7 @@ func CreateSession(
// Returns an error if there was a problem sending the request or decoding the // Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status. // response, or if the identity server responded with a non-OK status.
func CheckAssociation( func CheckAssociation(
ctx context.Context, creds Credentials, cfg config.Dendrite, ctx context.Context, creds Credentials, cfg *config.Dendrite,
) (bool, string, string, error) { ) (bool, string, string, error) {
if err := isTrusted(creds.IDServer, cfg); err != nil { if err := isTrusted(creds.IDServer, cfg); err != nil {
return false, "", "", err return false, "", "", err
@ -142,7 +142,7 @@ func CheckAssociation(
// identifier and a Matrix ID. // identifier and a Matrix ID.
// Returns an error if there was a problem sending the request or decoding the // Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status. // response, or if the identity server responded with a non-OK status.
func PublishAssociation(creds Credentials, userID string, cfg config.Dendrite) error { func PublishAssociation(creds Credentials, userID string, cfg *config.Dendrite) error {
if err := isTrusted(creds.IDServer, cfg); err != nil { if err := isTrusted(creds.IDServer, cfg); err != nil {
return err return err
} }
@ -177,7 +177,7 @@ func PublishAssociation(creds Credentials, userID string, cfg config.Dendrite) e
// isTrusted checks if a given identity server is part of the list of trusted // isTrusted checks if a given identity server is part of the list of trusted
// identity servers in the configuration file. // identity servers in the configuration file.
// Returns an error if the server isn't trusted. // Returns an error if the server isn't trusted.
func isTrusted(idServer string, cfg config.Dendrite) error { func isTrusted(idServer string, cfg *config.Dendrite) error {
for _, server := range cfg.Matrix.TrustedIDServers { for _, server := range cfg.Matrix.TrustedIDServers {
if idServer == server { if idServer == server {
return nil return nil

View file

@ -103,12 +103,14 @@ func main() {
// Build an event and write the event to the output. // Build an event and write the event to the output.
func buildAndOutput() gomatrixserverlib.EventReference { func buildAndOutput() gomatrixserverlib.EventReference {
eventID++ eventID++
id := fmt.Sprintf("$%d:%s", eventID, *serverName)
now = time.Unix(0, 0) now = time.Unix(0, 0)
name := gomatrixserverlib.ServerName(*serverName) name := gomatrixserverlib.ServerName(*serverName)
key := gomatrixserverlib.KeyID(*keyID) key := gomatrixserverlib.KeyID(*keyID)
event, err := b.Build(id, now, name, key, privateKey) event, err := b.Build(
now, name, key, privateKey,
gomatrixserverlib.RoomVersionV1,
)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -125,9 +127,9 @@ func writeEvent(event gomatrixserverlib.Event) {
if *format == "InputRoomEvent" { if *format == "InputRoomEvent" {
var ire api.InputRoomEvent var ire api.InputRoomEvent
ire.Kind = api.KindNew ire.Kind = api.KindNew
ire.Event = event ire.Event = event.Headered(gomatrixserverlib.RoomVersionV1)
authEventIDs := []string{} authEventIDs := []string{}
for _, ref := range b.AuthEvents { for _, ref := range b.AuthEvents.([]gomatrixserverlib.EventReference) {
authEventIDs = append(authEventIDs, ref.EventID) authEventIDs = append(authEventIDs, ref.EventID)
} }
ire.AuthEventIDs = authEventIDs ire.AuthEventIDs = authEventIDs

View file

@ -19,8 +19,8 @@ import (
"github.com/matrix-org/dendrite/common/basecomponent" "github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/keydb" "github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/dendrite/common/transactions" "github.com/matrix-org/dendrite/common/transactions"
"github.com/matrix-org/dendrite/typingserver" "github.com/matrix-org/dendrite/eduserver"
"github.com/matrix-org/dendrite/typingserver/cache" "github.com/matrix-org/dendrite/eduserver/cache"
) )
func main() { func main() {
@ -33,16 +33,16 @@ func main() {
deviceDB := base.CreateDeviceDB() deviceDB := base.CreateDeviceDB()
keyDB := base.CreateKeyDB() keyDB := base.CreateKeyDB()
federation := base.CreateFederationClient() federation := base.CreateFederationClient()
keyRing := keydb.CreateKeyRing(federation.Client, keyDB) keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives)
asQuery := base.CreateHTTPAppServiceAPIs() asQuery := base.CreateHTTPAppServiceAPIs()
alias, input, query := base.CreateHTTPRoomserverAPIs() alias, input, query := base.CreateHTTPRoomserverAPIs()
fedSenderAPI := base.CreateHTTPFederationSenderAPIs() fedSenderAPI := base.CreateHTTPFederationSenderAPIs()
typingInputAPI := typingserver.SetupTypingServerComponent(base, cache.NewTypingCache()) eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New())
clientapi.SetupClientAPIComponent( clientapi.SetupClientAPIComponent(
base, deviceDB, accountDB, federation, &keyRing, base, deviceDB, accountDB, federation, &keyRing,
alias, input, query, typingInputAPI, asQuery, transactions.New(), fedSenderAPI, alias, input, query, eduInputAPI, asQuery, transactions.New(), fedSenderAPI,
) )
base.SetupAndServeHTTP(string(base.Cfg.Bind.ClientAPI), string(base.Cfg.Listen.ClientAPI)) base.SetupAndServeHTTP(string(base.Cfg.Bind.ClientAPI), string(base.Cfg.Listen.ClientAPI))

View file

@ -0,0 +1,203 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"crypto/ed25519"
"flag"
"fmt"
"io/ioutil"
"net/http"
"os"
"time"
gostream "github.com/libp2p/go-libp2p-gostream"
p2phttp "github.com/libp2p/go-libp2p-http"
p2pdisc "github.com/libp2p/go-libp2p/p2p/discovery"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/clientapi"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-libp2p/storage"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/dendrite/common/transactions"
"github.com/matrix-org/dendrite/eduserver"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationsender"
"github.com/matrix-org/dendrite/mediaapi"
"github.com/matrix-org/dendrite/publicroomsapi"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/syncapi"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus"
)
func createKeyDB(
base *P2PDendrite,
) keydb.Database {
db, err := keydb.NewDatabase(
string(base.Base.Cfg.Database.ServerKey),
base.Base.Cfg.Matrix.ServerName,
base.Base.Cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey),
base.Base.Cfg.Matrix.KeyID,
)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to keys db")
}
mdns := mDNSListener{
host: base.LibP2P,
keydb: db,
}
serv, err := p2pdisc.NewMdnsService(
base.LibP2PContext,
base.LibP2P,
time.Second*10,
"_matrix-dendrite-p2p._tcp",
)
if err != nil {
panic(err)
}
serv.RegisterNotifee(&mdns)
return db
}
func createFederationClient(
base *P2PDendrite,
) *gomatrixserverlib.FederationClient {
fmt.Println("Running in libp2p federation mode")
fmt.Println("Warning: Federation with non-libp2p homeservers will not work in this mode yet!")
tr := &http.Transport{}
tr.RegisterProtocol(
"matrix",
p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")),
)
return gomatrixserverlib.NewFederationClientWithTransport(
base.Base.Cfg.Matrix.ServerName, base.Base.Cfg.Matrix.KeyID, base.Base.Cfg.Matrix.PrivateKey, tr,
)
}
func main() {
instanceName := flag.String("name", "dendrite-p2p", "the name of this P2P demo instance")
instancePort := flag.Int("port", 8080, "the port that the client API will listen on")
flag.Parse()
filename := fmt.Sprintf("%s-private.key", *instanceName)
_, err := os.Stat(filename)
var privKey ed25519.PrivateKey
if os.IsNotExist(err) {
_, privKey, _ = ed25519.GenerateKey(nil)
if err = ioutil.WriteFile(filename, privKey, 0600); err != nil {
fmt.Printf("Couldn't write private key to file '%s': %s\n", filename, err)
}
} else {
privKey, err = ioutil.ReadFile(filename)
if err != nil {
fmt.Printf("Couldn't read private key from file '%s': %s\n", filename, err)
_, privKey, _ = ed25519.GenerateKey(nil)
}
}
cfg := config.Dendrite{}
cfg.Matrix.ServerName = "p2p"
cfg.Matrix.PrivateKey = privKey
cfg.Matrix.KeyID = gomatrixserverlib.KeyID(fmt.Sprintf("ed25519:%s", *instanceName))
cfg.Kafka.UseNaffka = true
cfg.Kafka.Topics.OutputRoomEvent = "roomserverOutput"
cfg.Kafka.Topics.OutputClientData = "clientapiOutput"
cfg.Kafka.Topics.OutputTypingEvent = "typingServerOutput"
cfg.Kafka.Topics.UserUpdates = "userUpdates"
cfg.Database.Account = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.Database.Device = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.Database.MediaAPI = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s-serverkey.db", *instanceName))
cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName))
cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
cfg.Database.PublicRoomsAPI = config.DataSource(fmt.Sprintf("file:%s-publicroomsa.db", *instanceName))
cfg.Database.Naffka = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName))
if err = cfg.Derive(); err != nil {
panic(err)
}
base := NewP2PDendrite(&cfg, "Monolith")
defer base.Base.Close() // nolint: errcheck
accountDB := base.Base.CreateAccountsDB()
deviceDB := base.Base.CreateDeviceDB()
keyDB := createKeyDB(base)
federation := createFederationClient(base)
keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives)
alias, input, query := roomserver.SetupRoomServerComponent(&base.Base)
eduInputAPI := eduserver.SetupEDUServerComponent(&base.Base, cache.New())
asQuery := appservice.SetupAppServiceAPIComponent(
&base.Base, accountDB, deviceDB, federation, alias, query, transactions.New(),
)
fedSenderAPI := federationsender.SetupFederationSenderComponent(&base.Base, federation, query)
clientapi.SetupClientAPIComponent(
&base.Base, deviceDB, accountDB,
federation, &keyRing, alias, input, query,
eduInputAPI, asQuery, transactions.New(), fedSenderAPI,
)
eduProducer := producers.NewEDUServerProducer(eduInputAPI)
federationapi.SetupFederationAPIComponent(&base.Base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI, eduProducer)
mediaapi.SetupMediaAPIComponent(&base.Base, deviceDB)
publicRoomsDB, err := storage.NewPublicRoomsServerDatabaseWithPubSub(string(base.Base.Cfg.Database.PublicRoomsAPI), base.LibP2PPubsub)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to public rooms db")
}
publicroomsapi.SetupPublicRoomsAPIComponent(&base.Base, deviceDB, publicRoomsDB, query, federation, nil) // Check this later
syncapi.SetupSyncAPIComponent(&base.Base, deviceDB, accountDB, query, federation, &cfg)
httpHandler := common.WrapHandlerInCORS(base.Base.APIMux)
// Set up the API endpoints we handle. /metrics is for prometheus, and is
// not wrapped by CORS, while everything else is
http.Handle("/metrics", promhttp.Handler())
http.Handle("/", httpHandler)
// Expose the matrix APIs directly rather than putting them under a /api path.
go func() {
httpBindAddr := fmt.Sprintf(":%d", *instancePort)
logrus.Info("Listening on ", httpBindAddr)
logrus.Fatal(http.ListenAndServe(httpBindAddr, nil))
}()
// Expose the matrix APIs also via libp2p
if base.LibP2P != nil {
go func() {
logrus.Info("Listening on libp2p host ID ", base.LibP2P.ID())
listener, err := gostream.Listen(base.LibP2P, "/matrix")
if err != nil {
panic(err)
}
defer func() {
logrus.Fatal(listener.Close())
}()
logrus.Fatal(http.Serve(listener, nil))
}()
}
// We want to block forever to let the HTTP and HTTPS handler serve the APIs
select {}
}

View file

@ -0,0 +1,63 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"context"
"fmt"
"math"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/gomatrixserverlib"
)
type mDNSListener struct {
keydb keydb.Database
host host.Host
}
func (n *mDNSListener) HandlePeerFound(p peer.AddrInfo) {
if err := n.host.Connect(context.Background(), p); err != nil {
fmt.Println("Error adding peer", p.ID.String(), "via mDNS:", err)
}
if pubkey, err := p.ID.ExtractPublicKey(); err == nil {
raw, _ := pubkey.Raw()
if err := n.keydb.StoreKeys(
context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{
{
ServerName: gomatrixserverlib.ServerName(p.ID.String()),
KeyID: "ed25519:p2pdemo",
}: {
VerifyKey: gomatrixserverlib.VerifyKey{
Key: gomatrixserverlib.Base64String(raw),
},
ValidUntilTS: math.MaxUint64 >> 1,
ExpiredTS: gomatrixserverlib.PublicKeyNotExpired,
},
},
); err != nil {
fmt.Println("Failed to store keys:", err)
}
}
fmt.Println("Discovered", len(n.host.Peerstore().Peers())-1, "other libp2p peer(s):")
for _, peer := range n.host.Peerstore().Peers() {
if peer != n.host.ID() {
fmt.Println("-", peer)
}
}
}

View file

@ -0,0 +1,126 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"context"
"fmt"
"errors"
pstore "github.com/libp2p/go-libp2p-core/peerstore"
record "github.com/libp2p/go-libp2p-record"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/libp2p/go-libp2p"
circuit "github.com/libp2p/go-libp2p-circuit"
crypto "github.com/libp2p/go-libp2p-core/crypto"
routing "github.com/libp2p/go-libp2p-core/routing"
host "github.com/libp2p/go-libp2p-core/host"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/common/config"
)
// P2PDendrite is a Peer-to-Peer variant of BaseDendrite.
type P2PDendrite struct {
Base basecomponent.BaseDendrite
// Store our libp2p object so that we can make outgoing connections from it
// later
LibP2P host.Host
LibP2PContext context.Context
LibP2PCancel context.CancelFunc
LibP2PDHT *dht.IpfsDHT
LibP2PPubsub *pubsub.PubSub
}
// NewP2PDendrite creates a new instance to be used by a component.
// The componentName is used for logging purposes, and should be a friendly name
// of the component running, e.g. SyncAPI.
func NewP2PDendrite(cfg *config.Dendrite, componentName string) *P2PDendrite {
baseDendrite := basecomponent.NewBaseDendrite(cfg, componentName)
ctx, cancel := context.WithCancel(context.Background())
privKey, err := crypto.UnmarshalEd25519PrivateKey(cfg.Matrix.PrivateKey[:])
if err != nil {
panic(err)
}
//defaultIP6ListenAddr, _ := multiaddr.NewMultiaddr("/ip6/::/tcp/0")
var libp2pdht *dht.IpfsDHT
libp2p, err := libp2p.New(ctx,
libp2p.Identity(privKey),
libp2p.DefaultListenAddrs,
//libp2p.ListenAddrs(defaultIP6ListenAddr),
libp2p.DefaultTransports,
libp2p.Routing(func(h host.Host) (r routing.PeerRouting, err error) {
libp2pdht, err = dht.New(ctx, h)
if err != nil {
return nil, err
}
libp2pdht.Validator = libP2PValidator{}
r = libp2pdht
return
}),
libp2p.EnableAutoRelay(),
libp2p.EnableRelay(circuit.OptHop),
)
if err != nil {
panic(err)
}
libp2ppubsub, err := pubsub.NewFloodSub(context.Background(), libp2p, []pubsub.Option{
pubsub.WithMessageSigning(true),
}...)
if err != nil {
panic(err)
}
fmt.Println("Our public key:", privKey.GetPublic())
fmt.Println("Our node ID:", libp2p.ID())
fmt.Println("Our addresses:", libp2p.Addrs())
cfg.Matrix.ServerName = gomatrixserverlib.ServerName(libp2p.ID().String())
return &P2PDendrite{
Base: *baseDendrite,
LibP2P: libp2p,
LibP2PContext: ctx,
LibP2PCancel: cancel,
LibP2PDHT: libp2pdht,
LibP2PPubsub: libp2ppubsub,
}
}
type libP2PValidator struct {
KeyBook pstore.KeyBook
}
func (v libP2PValidator) Validate(key string, value []byte) error {
ns, _, err := record.SplitKey(key)
if err != nil || ns != "matrix" {
return errors.New("not Matrix path")
}
return nil
}
func (v libP2PValidator) Select(k string, vals [][]byte) (int, error) {
return 0, nil
}

View file

@ -0,0 +1,164 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgreswithdht
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/matrix-org/dendrite/publicroomsapi/storage/postgres"
"github.com/matrix-org/gomatrixserverlib"
dht "github.com/libp2p/go-libp2p-kad-dht"
)
const DHTInterval = time.Second * 10
// PublicRoomsServerDatabase represents a public rooms server database.
type PublicRoomsServerDatabase struct {
dht *dht.IpfsDHT
postgres.PublicRoomsServerDatabase
ourRoomsContext context.Context // our current value in the DHT
ourRoomsCancel context.CancelFunc // cancel when we want to expire our value
foundRooms map[string]gomatrixserverlib.PublicRoom // additional rooms we have learned about from the DHT
foundRoomsMutex sync.RWMutex // protects foundRooms
maintenanceTimer *time.Timer //
roomsAdvertised atomic.Value // stores int
roomsDiscovered atomic.Value // stores int
}
// NewPublicRoomsServerDatabase creates a new public rooms server database.
func NewPublicRoomsServerDatabase(dataSourceName string, dht *dht.IpfsDHT) (*PublicRoomsServerDatabase, error) {
pg, err := postgres.NewPublicRoomsServerDatabase(dataSourceName)
if err != nil {
return nil, err
}
provider := PublicRoomsServerDatabase{
dht: dht,
PublicRoomsServerDatabase: *pg,
}
go provider.ResetDHTMaintenance()
provider.roomsAdvertised.Store(0)
provider.roomsDiscovered.Store(0)
return &provider, nil
}
func (d *PublicRoomsServerDatabase) GetRoomVisibility(ctx context.Context, roomID string) (bool, error) {
return d.PublicRoomsServerDatabase.GetRoomVisibility(ctx, roomID)
}
func (d *PublicRoomsServerDatabase) SetRoomVisibility(ctx context.Context, visible bool, roomID string) error {
d.ResetDHTMaintenance()
return d.PublicRoomsServerDatabase.SetRoomVisibility(ctx, visible, roomID)
}
func (d *PublicRoomsServerDatabase) CountPublicRooms(ctx context.Context) (int64, error) {
count, err := d.PublicRoomsServerDatabase.CountPublicRooms(ctx)
if err != nil {
return 0, err
}
d.foundRoomsMutex.RLock()
defer d.foundRoomsMutex.RUnlock()
return count + int64(len(d.foundRooms)), nil
}
func (d *PublicRoomsServerDatabase) GetPublicRooms(ctx context.Context, offset int64, limit int16, filter string) ([]gomatrixserverlib.PublicRoom, error) {
realfilter := filter
if realfilter == "__local__" {
realfilter = ""
}
rooms, err := d.PublicRoomsServerDatabase.GetPublicRooms(ctx, offset, limit, realfilter)
if err != nil {
return []gomatrixserverlib.PublicRoom{}, err
}
if filter != "__local__" {
d.foundRoomsMutex.RLock()
defer d.foundRoomsMutex.RUnlock()
for _, room := range d.foundRooms {
rooms = append(rooms, room)
}
}
return rooms, nil
}
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, eventsToRemove []gomatrixserverlib.Event) error {
return d.PublicRoomsServerDatabase.UpdateRoomFromEvents(ctx, eventsToAdd, eventsToRemove)
}
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent(ctx context.Context, event gomatrixserverlib.Event) error {
return d.PublicRoomsServerDatabase.UpdateRoomFromEvent(ctx, event)
}
func (d *PublicRoomsServerDatabase) ResetDHTMaintenance() {
if d.maintenanceTimer != nil && !d.maintenanceTimer.Stop() {
<-d.maintenanceTimer.C
}
d.Interval()
}
func (d *PublicRoomsServerDatabase) Interval() {
if err := d.AdvertiseRoomsIntoDHT(); err != nil {
// fmt.Println("Failed to advertise room in DHT:", err)
}
if err := d.FindRoomsInDHT(); err != nil {
// fmt.Println("Failed to find rooms in DHT:", err)
}
fmt.Println("Found", d.roomsDiscovered.Load(), "room(s), advertised", d.roomsAdvertised.Load(), "room(s)")
d.maintenanceTimer = time.AfterFunc(DHTInterval, d.Interval)
}
func (d *PublicRoomsServerDatabase) AdvertiseRoomsIntoDHT() error {
dbCtx, dbCancel := context.WithTimeout(context.Background(), 3*time.Second)
_ = dbCancel
ourRooms, err := d.GetPublicRooms(dbCtx, 0, 1024, "__local__")
if err != nil {
return err
}
if j, err := json.Marshal(ourRooms); err == nil {
d.roomsAdvertised.Store(len(ourRooms))
d.ourRoomsContext, d.ourRoomsCancel = context.WithCancel(context.Background())
if err := d.dht.PutValue(d.ourRoomsContext, "/matrix/publicRooms", j); err != nil {
return err
}
}
return nil
}
func (d *PublicRoomsServerDatabase) FindRoomsInDHT() error {
d.foundRoomsMutex.Lock()
searchCtx, searchCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer searchCancel()
defer d.foundRoomsMutex.Unlock()
results, err := d.dht.GetValues(searchCtx, "/matrix/publicRooms", 1024)
if err != nil {
return err
}
d.foundRooms = make(map[string]gomatrixserverlib.PublicRoom)
for _, result := range results {
var received []gomatrixserverlib.PublicRoom
if err := json.Unmarshal(result.Val, &received); err != nil {
return err
}
for _, room := range received {
d.foundRooms[room.RoomID] = room
}
}
d.roomsDiscovered.Store(len(d.foundRooms))
return nil
}

View file

@ -0,0 +1,179 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgreswithpubsub
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/matrix-org/dendrite/publicroomsapi/storage/postgres"
"github.com/matrix-org/gomatrixserverlib"
pubsub "github.com/libp2p/go-libp2p-pubsub"
)
const MaintenanceInterval = time.Second * 10
type discoveredRoom struct {
time time.Time
room gomatrixserverlib.PublicRoom
}
// PublicRoomsServerDatabase represents a public rooms server database.
type PublicRoomsServerDatabase struct {
postgres.PublicRoomsServerDatabase //
pubsub *pubsub.PubSub //
subscription *pubsub.Subscription //
foundRooms map[string]discoveredRoom // additional rooms we have learned about from the DHT
foundRoomsMutex sync.RWMutex // protects foundRooms
maintenanceTimer *time.Timer //
roomsAdvertised atomic.Value // stores int
}
// NewPublicRoomsServerDatabase creates a new public rooms server database.
func NewPublicRoomsServerDatabase(dataSourceName string, pubsub *pubsub.PubSub) (*PublicRoomsServerDatabase, error) {
pg, err := postgres.NewPublicRoomsServerDatabase(dataSourceName)
if err != nil {
return nil, err
}
provider := PublicRoomsServerDatabase{
pubsub: pubsub,
PublicRoomsServerDatabase: *pg,
foundRooms: make(map[string]discoveredRoom),
}
if topic, err := pubsub.Join("/matrix/publicRooms"); err != nil {
return nil, err
} else if sub, err := topic.Subscribe(); err == nil {
provider.subscription = sub
go provider.MaintenanceTimer()
go provider.FindRooms()
provider.roomsAdvertised.Store(0)
return &provider, nil
} else {
return nil, err
}
}
func (d *PublicRoomsServerDatabase) GetRoomVisibility(ctx context.Context, roomID string) (bool, error) {
return d.PublicRoomsServerDatabase.GetRoomVisibility(ctx, roomID)
}
func (d *PublicRoomsServerDatabase) SetRoomVisibility(ctx context.Context, visible bool, roomID string) error {
d.MaintenanceTimer()
return d.PublicRoomsServerDatabase.SetRoomVisibility(ctx, visible, roomID)
}
func (d *PublicRoomsServerDatabase) CountPublicRooms(ctx context.Context) (int64, error) {
d.foundRoomsMutex.RLock()
defer d.foundRoomsMutex.RUnlock()
return int64(len(d.foundRooms)), nil
}
func (d *PublicRoomsServerDatabase) GetPublicRooms(ctx context.Context, offset int64, limit int16, filter string) ([]gomatrixserverlib.PublicRoom, error) {
var rooms []gomatrixserverlib.PublicRoom
if filter == "__local__" {
if r, err := d.PublicRoomsServerDatabase.GetPublicRooms(ctx, offset, limit, ""); err == nil {
rooms = append(rooms, r...)
} else {
return []gomatrixserverlib.PublicRoom{}, err
}
} else {
d.foundRoomsMutex.RLock()
defer d.foundRoomsMutex.RUnlock()
for _, room := range d.foundRooms {
rooms = append(rooms, room.room)
}
}
return rooms, nil
}
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, eventsToRemove []gomatrixserverlib.Event) error {
return d.PublicRoomsServerDatabase.UpdateRoomFromEvents(ctx, eventsToAdd, eventsToRemove)
}
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent(ctx context.Context, event gomatrixserverlib.Event) error {
return d.PublicRoomsServerDatabase.UpdateRoomFromEvent(ctx, event)
}
func (d *PublicRoomsServerDatabase) MaintenanceTimer() {
if d.maintenanceTimer != nil && !d.maintenanceTimer.Stop() {
<-d.maintenanceTimer.C
}
d.Interval()
}
func (d *PublicRoomsServerDatabase) Interval() {
d.foundRoomsMutex.Lock()
for k, v := range d.foundRooms {
if time.Since(v.time) > time.Minute {
delete(d.foundRooms, k)
}
}
d.foundRoomsMutex.Unlock()
if err := d.AdvertiseRooms(); err != nil {
fmt.Println("Failed to advertise room in DHT:", err)
}
d.foundRoomsMutex.RLock()
defer d.foundRoomsMutex.RUnlock()
fmt.Println("Found", len(d.foundRooms), "room(s), advertised", d.roomsAdvertised.Load(), "room(s)")
d.maintenanceTimer = time.AfterFunc(MaintenanceInterval, d.Interval)
}
func (d *PublicRoomsServerDatabase) AdvertiseRooms() error {
dbCtx, dbCancel := context.WithTimeout(context.Background(), 3*time.Second)
_ = dbCancel
ourRooms, err := d.GetPublicRooms(dbCtx, 0, 1024, "__local__")
if err != nil {
return err
}
advertised := 0
for _, room := range ourRooms {
if j, err := json.Marshal(room); err == nil {
if topic, err := d.pubsub.Join("/matrix/publicRooms"); err != nil {
fmt.Println("Failed to subscribe to topic:", err)
} else if err := topic.Publish(context.TODO(), j); err != nil {
fmt.Println("Failed to publish public room:", err)
} else {
advertised++
}
}
}
d.roomsAdvertised.Store(advertised)
return nil
}
func (d *PublicRoomsServerDatabase) FindRooms() {
for {
msg, err := d.subscription.Next(context.Background())
if err != nil {
continue
}
received := discoveredRoom{
time: time.Now(),
}
if err := json.Unmarshal(msg.Data, &received.room); err != nil {
fmt.Println("Unmarshal error:", err)
continue
}
d.foundRoomsMutex.Lock()
d.foundRooms[received.room.RoomID] = received
d.foundRoomsMutex.Unlock()
}
}

View file

@ -0,0 +1,61 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"net/url"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-libp2p/storage/postgreswithdht"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub"
"github.com/matrix-org/dendrite/publicroomsapi/storage"
"github.com/matrix-org/dendrite/publicroomsapi/storage/sqlite3"
)
const schemePostgres = "postgres"
const schemeFile = "file"
// NewPublicRoomsServerDatabase opens a database connection.
func NewPublicRoomsServerDatabaseWithDHT(dataSourceName string, dht *dht.IpfsDHT) (storage.Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht)
}
switch uri.Scheme {
case schemePostgres:
return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht)
case schemeFile:
return sqlite3.NewPublicRoomsServerDatabase(dataSourceName)
default:
return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht)
}
}
// NewPublicRoomsServerDatabase opens a database connection.
func NewPublicRoomsServerDatabaseWithPubSub(dataSourceName string, pubsub *pubsub.PubSub) (storage.Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub)
}
switch uri.Scheme {
case schemePostgres:
return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub)
case schemeFile:
return sqlite3.NewPublicRoomsServerDatabase(dataSourceName)
default:
return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub)
}
}

View file

@ -16,22 +16,22 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"github.com/matrix-org/dendrite/common/basecomponent" "github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/typingserver" "github.com/matrix-org/dendrite/eduserver"
"github.com/matrix-org/dendrite/typingserver/cache" "github.com/matrix-org/dendrite/eduserver/cache"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func main() { func main() {
cfg := basecomponent.ParseFlags() cfg := basecomponent.ParseFlags()
base := basecomponent.NewBaseDendrite(cfg, "TypingServerAPI") base := basecomponent.NewBaseDendrite(cfg, "EDUServerAPI")
defer func() { defer func() {
if err := base.Close(); err != nil { if err := base.Close(); err != nil {
logrus.WithError(err).Warn("BaseDendrite close failed") logrus.WithError(err).Warn("BaseDendrite close failed")
} }
}() }()
typingserver.SetupTypingServerComponent(base, cache.NewTypingCache()) eduserver.SetupEDUServerComponent(base, cache.New())
base.SetupAndServeHTTP(string(base.Cfg.Bind.TypingServer), string(base.Cfg.Listen.TypingServer)) base.SetupAndServeHTTP(string(base.Cfg.Bind.EDUServer), string(base.Cfg.Listen.EDUServer))
} }

View file

@ -15,8 +15,11 @@
package main package main
import ( import (
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common/basecomponent" "github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/keydb" "github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/dendrite/eduserver"
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
) )
@ -30,14 +33,16 @@ func main() {
keyDB := base.CreateKeyDB() keyDB := base.CreateKeyDB()
federation := base.CreateFederationClient() federation := base.CreateFederationClient()
federationSender := base.CreateHTTPFederationSenderAPIs() federationSender := base.CreateHTTPFederationSenderAPIs()
keyRing := keydb.CreateKeyRing(federation.Client, keyDB) keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives)
alias, input, query := base.CreateHTTPRoomserverAPIs() alias, input, query := base.CreateHTTPRoomserverAPIs()
asQuery := base.CreateHTTPAppServiceAPIs() asQuery := base.CreateHTTPAppServiceAPIs()
eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New())
eduProducer := producers.NewEDUServerProducer(eduInputAPI)
federationapi.SetupFederationAPIComponent( federationapi.SetupFederationAPIComponent(
base, accountDB, deviceDB, federation, &keyRing, base, accountDB, deviceDB, federation, &keyRing,
alias, input, query, asQuery, federationSender, alias, input, query, asQuery, federationSender, eduProducer,
) )
base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI)) base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI))

View file

@ -20,20 +20,22 @@ import (
"github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/clientapi" "github.com/matrix-org/dendrite/clientapi"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent" "github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/keydb" "github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/dendrite/common/transactions" "github.com/matrix-org/dendrite/common/transactions"
"github.com/matrix-org/dendrite/eduserver"
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationsender" "github.com/matrix-org/dendrite/federationsender"
"github.com/matrix-org/dendrite/mediaapi" "github.com/matrix-org/dendrite/mediaapi"
"github.com/matrix-org/dendrite/publicroomsapi" "github.com/matrix-org/dendrite/publicroomsapi"
"github.com/matrix-org/dendrite/publicroomsapi/storage"
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/dendrite/syncapi"
"github.com/matrix-org/dendrite/typingserver"
"github.com/matrix-org/dendrite/typingserver/cache"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -53,10 +55,10 @@ func main() {
deviceDB := base.CreateDeviceDB() deviceDB := base.CreateDeviceDB()
keyDB := base.CreateKeyDB() keyDB := base.CreateKeyDB()
federation := base.CreateFederationClient() federation := base.CreateFederationClient()
keyRing := keydb.CreateKeyRing(federation.Client, keyDB) keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives)
alias, input, query := roomserver.SetupRoomServerComponent(base) alias, input, query := roomserver.SetupRoomServerComponent(base)
typingInputAPI := typingserver.SetupTypingServerComponent(base, cache.NewTypingCache()) eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New())
asQuery := appservice.SetupAppServiceAPIComponent( asQuery := appservice.SetupAppServiceAPIComponent(
base, accountDB, deviceDB, federation, alias, query, transactions.New(), base, accountDB, deviceDB, federation, alias, query, transactions.New(),
) )
@ -65,32 +67,49 @@ func main() {
clientapi.SetupClientAPIComponent( clientapi.SetupClientAPIComponent(
base, deviceDB, accountDB, base, deviceDB, accountDB,
federation, &keyRing, alias, input, query, federation, &keyRing, alias, input, query,
typingInputAPI, asQuery, transactions.New(), fedSenderAPI, eduInputAPI, asQuery, transactions.New(), fedSenderAPI,
) )
federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI) eduProducer := producers.NewEDUServerProducer(eduInputAPI)
federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI, eduProducer)
mediaapi.SetupMediaAPIComponent(base, deviceDB) mediaapi.SetupMediaAPIComponent(base, deviceDB)
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, query) publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI))
if err != nil {
logrus.WithError(err).Panicf("failed to connect to public rooms db")
}
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, publicRoomsDB, query, federation, nil)
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg) syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg)
httpHandler := common.WrapHandlerInCORS(base.APIMux) httpHandler := common.WrapHandlerInCORS(base.APIMux)
// Set up the API endpoints we handle. /metrics is for prometheus, and is // Set up the API endpoints we handle. /metrics is for prometheus, and is
// not wrapped by CORS, while everything else is // not wrapped by CORS, while everything else is
http.Handle("/metrics", promhttp.Handler()) if cfg.Metrics.Enabled {
http.Handle("/metrics", common.WrapHandlerInBasicAuth(promhttp.Handler(), cfg.Metrics.BasicAuth))
}
http.Handle("/", httpHandler) http.Handle("/", httpHandler)
// Expose the matrix APIs directly rather than putting them under a /api path. // Expose the matrix APIs directly rather than putting them under a /api path.
go func() { go func() {
logrus.Info("Listening on ", *httpBindAddr) serv := http.Server{
logrus.Fatal(http.ListenAndServe(*httpBindAddr, nil)) Addr: *httpBindAddr,
WriteTimeout: basecomponent.HTTPServerTimeout,
}
logrus.Info("Listening on ", serv.Addr)
logrus.Fatal(serv.ListenAndServe())
}() }()
// Handle HTTPS if certificate and key are provided // Handle HTTPS if certificate and key are provided
go func() { if *certFile != "" && *keyFile != "" {
if *certFile != "" && *keyFile != "" { go func() {
logrus.Info("Listening on ", *httpsBindAddr) serv := http.Server{
logrus.Fatal(http.ListenAndServeTLS(*httpsBindAddr, *certFile, *keyFile, nil)) Addr: *httpsBindAddr,
} WriteTimeout: basecomponent.HTTPServerTimeout,
}() }
logrus.Info("Listening on ", serv.Addr)
logrus.Fatal(serv.ListenAndServeTLS(*certFile, *keyFile))
}()
}
// We want to block forever to let the HTTP and HTTPS handler serve the APIs // We want to block forever to let the HTTP and HTTPS handler serve the APIs
select {} select {}

View file

@ -17,6 +17,8 @@ package main
import ( import (
"github.com/matrix-org/dendrite/common/basecomponent" "github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/publicroomsapi" "github.com/matrix-org/dendrite/publicroomsapi"
"github.com/matrix-org/dendrite/publicroomsapi/storage"
"github.com/sirupsen/logrus"
) )
func main() { func main() {
@ -27,8 +29,11 @@ func main() {
deviceDB := base.CreateDeviceDB() deviceDB := base.CreateDeviceDB()
_, _, query := base.CreateHTTPRoomserverAPIs() _, _, query := base.CreateHTTPRoomserverAPIs()
publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI))
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, query) if err != nil {
logrus.WithError(err).Panicf("failed to connect to public rooms db")
}
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, publicRoomsDB, query, nil, nil)
base.SetupAndServeHTTP(string(base.Cfg.Bind.PublicRoomsAPI), string(base.Cfg.Listen.PublicRoomsAPI)) base.SetupAndServeHTTP(string(base.Cfg.Bind.PublicRoomsAPI), string(base.Cfg.Listen.PublicRoomsAPI))

104
cmd/dendritejs/jsServer.go Normal file
View file

@ -0,0 +1,104 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build wasm
package main
import (
"bufio"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"syscall/js"
)
// JSServer exposes an HTTP-like server interface which allows JS to 'send' requests to it.
type JSServer struct {
// The router which will service requests
Mux *http.ServeMux
}
// OnRequestFromJS is the function that JS will invoke when there is a new request.
// The JS function signature is:
// function(reqString: string): Promise<{result: string, error: string}>
// Usage is like:
// const res = await global._go_js_server.fetch(reqString);
// if (res.error) {
// // handle error: this is a 'network' error, not a non-2xx error.
// }
// const rawHttpResponse = res.result;
func (h *JSServer) OnRequestFromJS(this js.Value, args []js.Value) interface{} {
// we HAVE to spawn a new goroutine and return immediately or else Go will deadlock
// if this request blocks at all e.g for /sync calls
httpStr := args[0].String()
promise := js.Global().Get("Promise").New(js.FuncOf(func(pthis js.Value, pargs []js.Value) interface{} {
// The initial callback code for new Promise() is also called on the critical path, which is why
// we need to put this in an immediately invoked goroutine.
go func() {
resolve := pargs[0]
fmt.Println("Received request:")
fmt.Printf("%s\n", httpStr)
resStr, err := h.handle(httpStr)
errStr := ""
if err != nil {
errStr = err.Error()
}
fmt.Println("Sending response:")
fmt.Printf("%s\n", resStr)
resolve.Invoke(map[string]interface{}{
"result": resStr,
"error": errStr,
})
}()
return nil
}))
return promise
}
// handle invokes the http.ServeMux for this request and returns the raw HTTP response.
func (h *JSServer) handle(httpStr string) (resStr string, err error) {
req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(httpStr)))
if err != nil {
return
}
w := httptest.NewRecorder()
h.Mux.ServeHTTP(w, req)
res := w.Result()
var resBuffer strings.Builder
err = res.Write(&resBuffer)
return resBuffer.String(), err
}
// ListenAndServe registers a variable in JS-land with the given namespace. This variable is
// a function which JS-land can call to 'send' HTTP requests. The function is attached to
// a global object called "_go_js_server". See OnRequestFromJS for more info.
func (h *JSServer) ListenAndServe(namespace string) {
globalName := "_go_js_server"
// register a hook in JS-land for it to invoke stuff
server := js.Global().Get(globalName)
if !server.Truthy() {
server = js.Global().Get("Object").New()
js.Global().Set(globalName, server)
}
server.Set(namespace, js.FuncOf(h.OnRequestFromJS))
fmt.Printf("Listening for requests from JS on function %s.%s\n", globalName, namespace)
// Block forever to mimic http.ListenAndServe
select {}
}

View file

@ -0,0 +1,84 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build wasm
package main
import (
"context"
"fmt"
"time"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
const libp2pMatrixKeyID = "ed25519:libp2p-dendrite"
type libp2pKeyFetcher struct {
}
// FetchKeys looks up a batch of public keys.
// Takes a map from (server name, key ID) pairs to timestamp.
// The timestamp is when the keys need to be vaild up to.
// Returns a map from (server name, key ID) pairs to server key objects for
// that server name containing that key ID
// The result may have fewer (server name, key ID) pairs than were in the request.
// The result may have more (server name, key ID) pairs than were in the request.
// Returns an error if there was a problem fetching the keys.
func (f *libp2pKeyFetcher) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
res := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult)
for req := range requests {
if req.KeyID != libp2pMatrixKeyID {
return nil, fmt.Errorf("FetchKeys: cannot fetch key with ID %s, should be %s", req.KeyID, libp2pMatrixKeyID)
}
// The server name is a libp2p peer ID
peerIDStr := string(req.ServerName)
peerID, err := peer.Decode(peerIDStr)
if err != nil {
return nil, fmt.Errorf("Failed to decode peer ID from server name '%s': %w", peerIDStr, err)
}
pubKey, err := peerID.ExtractPublicKey()
if err != nil {
return nil, fmt.Errorf("Failed to extract public key from peer ID: %w", err)
}
pubKeyBytes, err := pubKey.Raw()
if err != nil {
return nil, fmt.Errorf("Failed to extract raw bytes from public key: %w", err)
}
util.GetLogger(ctx).Info("libp2pKeyFetcher.FetchKeys: Using public key %v for server name %s", pubKeyBytes, req.ServerName)
b64Key := gomatrixserverlib.Base64String(pubKeyBytes)
res[req] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: gomatrixserverlib.VerifyKey{
Key: b64Key,
},
ExpiredTS: gomatrixserverlib.PublicKeyNotExpired,
ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(24 * time.Hour * 365)),
}
}
return res, nil
}
// FetcherName returns the name of this fetcher, which can then be used for
// logging errors etc.
func (f *libp2pKeyFetcher) FetcherName() string {
return "libp2pKeyFetcher"
}

174
cmd/dendritejs/main.go Normal file
View file

@ -0,0 +1,174 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build wasm
package main
import (
"crypto/ed25519"
"fmt"
"net/http"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/clientapi"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/common/transactions"
"github.com/matrix-org/dendrite/eduserver"
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationsender"
"github.com/matrix-org/dendrite/mediaapi"
"github.com/matrix-org/dendrite/publicroomsapi"
"github.com/matrix-org/dendrite/publicroomsapi/storage"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/syncapi"
"github.com/matrix-org/go-http-js-libp2p/go_http_js_libp2p"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
_ "github.com/matrix-org/go-sqlite3-js"
)
func init() {
fmt.Println("dendrite.js starting...")
}
func generateKey() ed25519.PrivateKey {
_, priv, err := ed25519.GenerateKey(nil)
if err != nil {
logrus.Fatalf("Failed to generate ed25519 key: %s", err)
}
return priv
}
func createFederationClient(cfg *config.Dendrite, node *go_http_js_libp2p.P2pLocalNode) *gomatrixserverlib.FederationClient {
fmt.Println("Running in js-libp2p federation mode")
fmt.Println("Warning: Federation with non-libp2p homeservers will not work in this mode yet!")
tr := go_http_js_libp2p.NewP2pTransport(node)
fed := gomatrixserverlib.NewFederationClient(
cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey,
)
fed.Client = *gomatrixserverlib.NewClientWithTransport(tr)
return fed
}
func createP2PNode(privKey ed25519.PrivateKey) (serverName string, node *go_http_js_libp2p.P2pLocalNode) {
hosted := "/dns4/rendezvous.matrix.org/tcp/8443/wss/p2p-websocket-star/"
node = go_http_js_libp2p.NewP2pLocalNode("org.matrix.p2p.experiment", privKey.Seed(), []string{hosted})
serverName = node.Id
fmt.Println("p2p assigned ServerName: ", serverName)
return
}
func main() {
cfg := &config.Dendrite{}
cfg.SetDefaults()
cfg.Kafka.UseNaffka = true
cfg.Database.Account = "file:dendritejs_account.db"
cfg.Database.AppService = "file:dendritejs_appservice.db"
cfg.Database.Device = "file:dendritejs_device.db"
cfg.Database.FederationSender = "file:dendritejs_fedsender.db"
cfg.Database.MediaAPI = "file:dendritejs_mediaapi.db"
cfg.Database.Naffka = "file:dendritejs_naffka.db"
cfg.Database.PublicRoomsAPI = "file:dendritejs_publicrooms.db"
cfg.Database.RoomServer = "file:dendritejs_roomserver.db"
cfg.Database.ServerKey = "file:dendritejs_serverkey.db"
cfg.Database.SyncAPI = "file:dendritejs_syncapi.db"
cfg.Kafka.Topics.UserUpdates = "user_updates"
cfg.Kafka.Topics.OutputTypingEvent = "output_typing_event"
cfg.Kafka.Topics.OutputClientData = "output_client_data"
cfg.Kafka.Topics.OutputRoomEvent = "output_room_event"
cfg.Matrix.TrustedIDServers = []string{
"matrix.org", "vector.im",
}
cfg.Matrix.KeyID = libp2pMatrixKeyID
cfg.Matrix.PrivateKey = generateKey()
serverName, node := createP2PNode(cfg.Matrix.PrivateKey)
cfg.Matrix.ServerName = gomatrixserverlib.ServerName(serverName)
if err := cfg.Derive(); err != nil {
logrus.Fatalf("Failed to derive values from config: %s", err)
}
base := basecomponent.NewBaseDendrite(cfg, "Monolith")
defer base.Close() // nolint: errcheck
accountDB := base.CreateAccountsDB()
deviceDB := base.CreateDeviceDB()
keyDB := base.CreateKeyDB()
federation := createFederationClient(cfg, node)
keyRing := gomatrixserverlib.KeyRing{
KeyFetchers: []gomatrixserverlib.KeyFetcher{
&libp2pKeyFetcher{},
},
KeyDatabase: keyDB,
}
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node)
alias, input, query := roomserver.SetupRoomServerComponent(base)
eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New())
asQuery := appservice.SetupAppServiceAPIComponent(
base, accountDB, deviceDB, federation, alias, query, transactions.New(),
)
fedSenderAPI := federationsender.SetupFederationSenderComponent(base, federation, query)
clientapi.SetupClientAPIComponent(
base, deviceDB, accountDB,
federation, &keyRing, alias, input, query,
eduInputAPI, asQuery, transactions.New(), fedSenderAPI,
)
eduProducer := producers.NewEDUServerProducer(eduInputAPI)
federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI, eduProducer)
mediaapi.SetupMediaAPIComponent(base, deviceDB)
publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI))
if err != nil {
logrus.WithError(err).Panicf("failed to connect to public rooms db")
}
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, publicRoomsDB, query, federation, p2pPublicRoomProvider)
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg)
httpHandler := common.WrapHandlerInCORS(base.APIMux)
http.Handle("/", httpHandler)
// Expose the matrix APIs via libp2p-js - for federation traffic
if node != nil {
go func() {
logrus.Info("Listening on libp2p-js host ID ", node.Id)
s := JSServer{
Mux: http.DefaultServeMux,
}
s.ListenAndServe("p2p")
}()
}
// Expose the matrix APIs via fetch - for local traffic
go func() {
logrus.Info("Listening for service-worker fetch traffic")
s := JSServer{
Mux: http.DefaultServeMux,
}
s.ListenAndServe("fetch")
}()
// We want to block forever to let the fetch and libp2p handler serve the APIs
select {}
}

View file

@ -0,0 +1,23 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !wasm
package main
import "fmt"
func main() {
fmt.Println("dendritejs: no-op when not compiling for WebAssembly")
}

View file

@ -0,0 +1,46 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build wasm
package main
import (
"github.com/matrix-org/go-http-js-libp2p/go_http_js_libp2p"
)
type libp2pPublicRoomsProvider struct {
node *go_http_js_libp2p.P2pLocalNode
providers []go_http_js_libp2p.PeerInfo
}
func NewLibP2PPublicRoomsProvider(node *go_http_js_libp2p.P2pLocalNode) *libp2pPublicRoomsProvider {
p := &libp2pPublicRoomsProvider{
node: node,
}
node.RegisterFoundProviders(p.foundProviders)
return p
}
func (p *libp2pPublicRoomsProvider) foundProviders(peerInfos []go_http_js_libp2p.PeerInfo) {
p.providers = peerInfos
}
func (p *libp2pPublicRoomsProvider) Homeservers() []string {
result := make([]string, len(p.providers))
for i := range p.providers {
result[i] = p.providers[i].Id
}
return result
}

View file

@ -21,7 +21,7 @@ import (
"os" "os"
"strings" "strings"
"github.com/Shopify/sarama" sarama "gopkg.in/Shopify/sarama.v1"
) )
const usage = `Usage: %s const usage = `Usage: %s

View file

@ -28,6 +28,7 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/common/caching"
"github.com/matrix-org/dendrite/common/test" "github.com/matrix-org/dendrite/common/test"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -44,6 +45,8 @@ var (
// This needs to be high enough to account for the time it takes to create // This needs to be high enough to account for the time it takes to create
// the postgres database tables which can take a while on travis. // the postgres database tables which can take a while on travis.
timeoutString = defaulting(os.Getenv("TIMEOUT"), "60s") timeoutString = defaulting(os.Getenv("TIMEOUT"), "60s")
// Timeout for http client
timeoutHTTPClient = defaulting(os.Getenv("TIMEOUT_HTTP"), "30s")
// The name of maintenance database to connect to in order to create the test database. // The name of maintenance database to connect to in order to create the test database.
postgresDatabase = defaulting(os.Getenv("POSTGRES_DATABASE"), "postgres") postgresDatabase = defaulting(os.Getenv("POSTGRES_DATABASE"), "postgres")
// The name of the test database to create. // The name of the test database to create.
@ -68,7 +71,10 @@ func defaulting(value, defaultValue string) string {
return value return value
} }
var timeout time.Duration var (
timeout time.Duration
timeoutHTTP time.Duration
)
func init() { func init() {
var err error var err error
@ -76,6 +82,10 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
timeoutHTTP, err = time.ParseDuration(timeoutHTTPClient)
if err != nil {
panic(err)
}
} }
func createDatabase(database string) error { func createDatabase(database string) error {
@ -199,7 +209,10 @@ func writeToRoomServer(input []string, roomserverURL string) error {
return err return err
} }
} }
x := api.NewRoomserverInputAPIHTTP(roomserverURL, nil) x, err := api.NewRoomserverInputAPIHTTP(roomserverURL, &http.Client{Timeout: timeoutHTTP})
if err != nil {
return err
}
return x.InputRoomEvents(context.Background(), &request, &response) return x.InputRoomEvents(context.Background(), &request, &response)
} }
@ -241,6 +254,11 @@ func testRoomserver(input []string, wantOutput []string, checkQueries func(api.R
panic(err) panic(err)
} }
cache, err := caching.NewImmutableInMemoryLRUCache()
if err != nil {
panic(err)
}
doInput := func() { doInput := func() {
fmt.Printf("Roomserver is ready to receive input, sending %d events\n", len(input)) fmt.Printf("Roomserver is ready to receive input, sending %d events\n", len(input))
if err = writeToRoomServer(input, cfg.RoomServerURL()); err != nil { if err = writeToRoomServer(input, cfg.RoomServerURL()); err != nil {
@ -258,7 +276,7 @@ func testRoomserver(input []string, wantOutput []string, checkQueries func(api.R
cmd.Args = []string{"dendrite-room-server", "--config", filepath.Join(dir, test.ConfigFile)} cmd.Args = []string{"dendrite-room-server", "--config", filepath.Join(dir, test.ConfigFile)}
gotOutput, err := runAndReadFromTopic(cmd, cfg.RoomServerURL()+"/metrics", doInput, outputTopic, len(wantOutput), func() { gotOutput, err := runAndReadFromTopic(cmd, cfg.RoomServerURL()+"/metrics", doInput, outputTopic, len(wantOutput), func() {
queryAPI := api.NewRoomserverQueryAPIHTTP("http://"+string(cfg.Listen.RoomServer), nil) queryAPI, _ := api.NewRoomserverQueryAPIHTTP("http://"+string(cfg.Listen.RoomServer), &http.Client{Timeout: timeoutHTTP}, cache)
checkQueries(queryAPI) checkQueries(queryAPI)
}) })
if err != nil { if err != nil {

View file

@ -104,7 +104,7 @@ func clientEventJSONForOutputRoomEvent(outputRoomEvent string) string {
panic("failed to unmarshal output room event: " + err.Error()) panic("failed to unmarshal output room event: " + err.Error())
} }
clientEvs := gomatrixserverlib.ToClientEvents([]gomatrixserverlib.Event{ clientEvs := gomatrixserverlib.ToClientEvents([]gomatrixserverlib.Event{
out.NewRoomEvent.Event, out.NewRoomEvent.Event.Event,
}, gomatrixserverlib.FormatSync) }, gomatrixserverlib.FormatSync)
b, err := json.Marshal(clientEvs[0]) b, err := json.Marshal(clientEvs[0])
if err != nil { if err != nil {

View file

@ -18,10 +18,14 @@ import (
"database/sql" "database/sql"
"io" "io"
"net/http" "net/http"
"net/url"
"time"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
"github.com/matrix-org/dendrite/common/caching"
"github.com/matrix-org/dendrite/common/keydb" "github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/naffka" "github.com/matrix-org/naffka"
@ -34,9 +38,9 @@ import (
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
eduServerAPI "github.com/matrix-org/dendrite/eduserver/api"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
typingServerAPI "github.com/matrix-org/dendrite/typingserver/api"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -50,12 +54,17 @@ type BaseDendrite struct {
tracerCloser io.Closer tracerCloser io.Closer
// APIMux should be used to register new public matrix api endpoints // APIMux should be used to register new public matrix api endpoints
APIMux *mux.Router APIMux *mux.Router
Cfg *config.Dendrite httpClient *http.Client
KafkaConsumer sarama.Consumer Cfg *config.Dendrite
KafkaProducer sarama.SyncProducer ImmutableCache caching.ImmutableCache
KafkaConsumer sarama.Consumer
KafkaProducer sarama.SyncProducer
} }
const HTTPServerTimeout = time.Minute * 5
const HTTPClientTimeout = time.Second * 30
// NewBaseDendrite creates a new instance to be used by a component. // NewBaseDendrite creates a new instance to be used by a component.
// The componentName is used for logging purposes, and should be a friendly name // The componentName is used for logging purposes, and should be a friendly name
// of the compontent running, e.g. "SyncAPI" // of the compontent running, e.g. "SyncAPI"
@ -68,15 +77,28 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite {
logrus.WithError(err).Panicf("failed to start opentracing") logrus.WithError(err).Panicf("failed to start opentracing")
} }
kafkaConsumer, kafkaProducer := setupKafka(cfg) var kafkaConsumer sarama.Consumer
var kafkaProducer sarama.SyncProducer
if cfg.Kafka.UseNaffka {
kafkaConsumer, kafkaProducer = setupNaffka(cfg)
} else {
kafkaConsumer, kafkaProducer = setupKafka(cfg)
}
cache, err := caching.NewImmutableInMemoryLRUCache()
if err != nil {
logrus.WithError(err).Warnf("Failed to create cache")
}
return &BaseDendrite{ return &BaseDendrite{
componentName: componentName, componentName: componentName,
tracerCloser: closer, tracerCloser: closer,
Cfg: cfg, Cfg: cfg,
APIMux: mux.NewRouter().UseEncodedPath(), ImmutableCache: cache,
KafkaConsumer: kafkaConsumer, APIMux: mux.NewRouter().UseEncodedPath(),
KafkaProducer: kafkaProducer, httpClient: &http.Client{Timeout: HTTPClientTimeout},
KafkaConsumer: kafkaConsumer,
KafkaProducer: kafkaProducer,
} }
} }
@ -88,7 +110,11 @@ func (b *BaseDendrite) Close() error {
// CreateHTTPAppServiceAPIs returns the QueryAPI for hitting the appservice // CreateHTTPAppServiceAPIs returns the QueryAPI for hitting the appservice
// component over HTTP. // component over HTTP.
func (b *BaseDendrite) CreateHTTPAppServiceAPIs() appserviceAPI.AppServiceQueryAPI { func (b *BaseDendrite) CreateHTTPAppServiceAPIs() appserviceAPI.AppServiceQueryAPI {
return appserviceAPI.NewAppServiceQueryAPIHTTP(b.Cfg.AppServiceURL(), nil) a, err := appserviceAPI.NewAppServiceQueryAPIHTTP(b.Cfg.AppServiceURL(), b.httpClient)
if err != nil {
logrus.WithError(err).Panic("CreateHTTPAppServiceAPIs failed")
}
return a
} }
// CreateHTTPRoomserverAPIs returns the AliasAPI, InputAPI and QueryAPI for hitting // CreateHTTPRoomserverAPIs returns the AliasAPI, InputAPI and QueryAPI for hitting
@ -98,27 +124,44 @@ func (b *BaseDendrite) CreateHTTPRoomserverAPIs() (
roomserverAPI.RoomserverInputAPI, roomserverAPI.RoomserverInputAPI,
roomserverAPI.RoomserverQueryAPI, roomserverAPI.RoomserverQueryAPI,
) { ) {
alias := roomserverAPI.NewRoomserverAliasAPIHTTP(b.Cfg.RoomServerURL(), nil) alias, err := roomserverAPI.NewRoomserverAliasAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient)
input := roomserverAPI.NewRoomserverInputAPIHTTP(b.Cfg.RoomServerURL(), nil) if err != nil {
query := roomserverAPI.NewRoomserverQueryAPIHTTP(b.Cfg.RoomServerURL(), nil) logrus.WithError(err).Panic("NewRoomserverAliasAPIHTTP failed")
}
input, err := roomserverAPI.NewRoomserverInputAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient)
if err != nil {
logrus.WithError(err).Panic("NewRoomserverInputAPIHTTP failed", b.httpClient)
}
query, err := roomserverAPI.NewRoomserverQueryAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient, b.ImmutableCache)
if err != nil {
logrus.WithError(err).Panic("NewRoomserverQueryAPIHTTP failed", b.httpClient)
}
return alias, input, query return alias, input, query
} }
// CreateHTTPTypingServerAPIs returns typingInputAPI for hitting the typing // CreateHTTPEDUServerAPIs returns eduInputAPI for hitting the EDU
// server over HTTP // server over HTTP
func (b *BaseDendrite) CreateHTTPTypingServerAPIs() typingServerAPI.TypingServerInputAPI { func (b *BaseDendrite) CreateHTTPEDUServerAPIs() eduServerAPI.EDUServerInputAPI {
return typingServerAPI.NewTypingServerInputAPIHTTP(b.Cfg.TypingServerURL(), nil) e, err := eduServerAPI.NewEDUServerInputAPIHTTP(b.Cfg.EDUServerURL(), b.httpClient)
if err != nil {
logrus.WithError(err).Panic("NewEDUServerInputAPIHTTP failed", b.httpClient)
}
return e
} }
// CreateHTTPFederationSenderAPIs returns FederationSenderQueryAPI for hitting // CreateHTTPFederationSenderAPIs returns FederationSenderQueryAPI for hitting
// the federation sender over HTTP // the federation sender over HTTP
func (b *BaseDendrite) CreateHTTPFederationSenderAPIs() federationSenderAPI.FederationSenderQueryAPI { func (b *BaseDendrite) CreateHTTPFederationSenderAPIs() federationSenderAPI.FederationSenderQueryAPI {
return federationSenderAPI.NewFederationSenderQueryAPIHTTP(b.Cfg.FederationSenderURL(), nil) f, err := federationSenderAPI.NewFederationSenderQueryAPIHTTP(b.Cfg.FederationSenderURL(), b.httpClient)
if err != nil {
logrus.WithError(err).Panic("NewFederationSenderQueryAPIHTTP failed", b.httpClient)
}
return f
} }
// CreateDeviceDB creates a new instance of the device database. Should only be // CreateDeviceDB creates a new instance of the device database. Should only be
// called once per component. // called once per component.
func (b *BaseDendrite) CreateDeviceDB() *devices.Database { func (b *BaseDendrite) CreateDeviceDB() devices.Database {
db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName) db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to devices db") logrus.WithError(err).Panicf("failed to connect to devices db")
@ -129,7 +172,7 @@ func (b *BaseDendrite) CreateDeviceDB() *devices.Database {
// CreateAccountsDB creates a new instance of the accounts database. Should only // CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component. // be called once per component.
func (b *BaseDendrite) CreateAccountsDB() *accounts.Database { func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName) db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db") logrus.WithError(err).Panicf("failed to connect to accounts db")
@ -174,40 +217,24 @@ func (b *BaseDendrite) SetupAndServeHTTP(bindaddr string, listenaddr string) {
addr = listenaddr addr = listenaddr
} }
common.SetupHTTPAPI(http.DefaultServeMux, common.WrapHandlerInCORS(b.APIMux)) serv := http.Server{
logrus.Infof("Starting %s server on %s", b.componentName, addr) Addr: addr,
WriteTimeout: HTTPServerTimeout,
}
err := http.ListenAndServe(addr, nil) common.SetupHTTPAPI(http.DefaultServeMux, common.WrapHandlerInCORS(b.APIMux), b.Cfg)
logrus.Infof("Starting %s server on %s", b.componentName, serv.Addr)
err := serv.ListenAndServe()
if err != nil { if err != nil {
logrus.WithError(err).Fatal("failed to serve http") logrus.WithError(err).Fatal("failed to serve http")
} }
logrus.Infof("Stopped %s server on %s", b.componentName, addr) logrus.Infof("Stopped %s server on %s", b.componentName, serv.Addr)
} }
// setupKafka creates kafka consumer/producer pair from the config. Checks if // setupKafka creates kafka consumer/producer pair from the config.
// should use naffka.
func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) { func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
if cfg.Kafka.UseNaffka {
db, err := sql.Open("postgres", string(cfg.Database.Naffka))
if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database")
}
naffkaDB, err := naffka.NewPostgresqlDatabase(db)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
naff, err := naffka.New(naffkaDB)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka")
}
return naff, naff
}
consumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil) consumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
if err != nil { if err != nil {
logrus.WithError(err).Panic("failed to start kafka consumer") logrus.WithError(err).Panic("failed to start kafka consumer")
@ -220,3 +247,44 @@ func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
return consumer, producer return consumer, producer
} }
// setupNaffka creates kafka consumer/producer pair from the config.
func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
var err error
var db *sql.DB
var naffkaDB *naffka.DatabaseImpl
uri, err := url.Parse(string(cfg.Database.Naffka))
if err != nil || uri.Scheme == "file" {
db, err = sqlutil.Open(common.SQLiteDriverName(), string(cfg.Database.Naffka))
if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database")
}
naffkaDB, err = naffka.NewSqliteDatabase(db)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
} else {
db, err = sqlutil.Open("postgres", string(cfg.Database.Naffka))
if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database")
}
naffkaDB, err = naffka.NewPostgresqlDatabase(db)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
}
if naffkaDB == nil {
panic("naffka connection string not understood")
}
naff, err := naffka.New(naffkaDB)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka")
}
return naff, naff
}

Some files were not shown because too many files have changed in this diff Show more