Merge branch 'master' of https://github.com/matrix-org/dendrite into prateek

This commit is contained in:
Prateek Sachan 2020-03-28 12:00:43 +05:30
commit 415c3bdd77
206 changed files with 5196 additions and 2428 deletions

6
.gitignore vendored
View file

@ -43,3 +43,9 @@ _testmain.go
# Default configuration file # Default configuration file
dendrite.yaml dendrite.yaml
# Database files
*.db
# Log files
*.log*

View file

@ -102,7 +102,7 @@ linters-settings:
#local-prefixes: github.com/org/project #local-prefixes: github.com/org/project
gocyclo: gocyclo:
# minimal code complexity to report, 30 by default (but we recommend 10-20) # minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 12 min-complexity: 13
maligned: maligned:
# print struct with more effective memory layout or not, false by default # print struct with more effective memory layout or not, false by default
suggest-new: true suggest-new: true

View file

@ -28,3 +28,4 @@ discussion should happen in
There's plenty still to do to make Dendrite usable! We're tracking progress in a There's plenty still to do to make Dendrite usable! We're tracking progress in a
[project board](https://github.com/matrix-org/dendrite/projects/2). [project board](https://github.com/matrix-org/dendrite/projects/2).

View file

@ -114,19 +114,19 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
// lookupMissingStateEvents looks up the state events that are added by a new event, // lookupMissingStateEvents looks up the state events that are added by a new event,
// and returns any not already present. // and returns any not already present.
func (s *OutputRoomEventConsumer) lookupMissingStateEvents( func (s *OutputRoomEventConsumer) lookupMissingStateEvents(
addsStateEventIDs []string, event gomatrixserverlib.Event, addsStateEventIDs []string, event gomatrixserverlib.HeaderedEvent,
) ([]gomatrixserverlib.Event, error) { ) ([]gomatrixserverlib.HeaderedEvent, error) {
// Fast path if there aren't any new state events. // Fast path if there aren't any new state events.
if len(addsStateEventIDs) == 0 { if len(addsStateEventIDs) == 0 {
return []gomatrixserverlib.Event{}, nil return []gomatrixserverlib.HeaderedEvent{}, nil
} }
// Fast path if the only state event added is the event itself. // Fast path if the only state event added is the event itself.
if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() {
return []gomatrixserverlib.Event{}, nil return []gomatrixserverlib.HeaderedEvent{}, nil
} }
result := []gomatrixserverlib.Event{} result := []gomatrixserverlib.HeaderedEvent{}
missing := []string{} missing := []string{}
for _, id := range addsStateEventIDs { for _, id := range addsStateEventIDs {
if id != event.EventID() { if id != event.EventID() {
@ -155,7 +155,7 @@ func (s *OutputRoomEventConsumer) lookupMissingStateEvents(
// application service. // application service.
func (s *OutputRoomEventConsumer) filterRoomserverEvents( func (s *OutputRoomEventConsumer) filterRoomserverEvents(
ctx context.Context, ctx context.Context,
events []gomatrixserverlib.Event, events []gomatrixserverlib.HeaderedEvent,
) error { ) error {
for _, ws := range s.workerStates { for _, ws := range s.workerStates {
for _, event := range events { for _, event := range events {
@ -178,7 +178,7 @@ func (s *OutputRoomEventConsumer) filterRoomserverEvents(
// appserviceIsInterestedInEvent returns a boolean depending on whether a given // appserviceIsInterestedInEvent returns a boolean depending on whether a given
// event falls within one of a given application service's namespaces. // event falls within one of a given application service's namespaces.
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event gomatrixserverlib.Event, appservice config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool {
// No reason to queue events if they'll never be sent to the application // No reason to queue events if they'll never be sent to the application
// service // service
if appservice.URL == "" { if appservice.URL == "" {
@ -191,6 +191,12 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont
return true return true
} }
if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil {
if appservice.IsInterestedInUserID(*event.StateKey()) {
return true
}
}
// Check all known room aliases of the room the event came from // Check all known room aliases of the room the event came from
queryReq := api.GetAliasesForRoomIDRequest{RoomID: event.RoomID()} queryReq := api.GetAliasesForRoomIDRequest{RoomID: event.RoomID()}
var queryRes api.GetAliasesForRoomIDResponse var queryRes api.GetAliasesForRoomIDResponse

View file

@ -0,0 +1,30 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error
GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error)
CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error)
UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error
RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error
GetLatestTxnID(ctx context.Context) (int, error)
}

View file

@ -33,7 +33,7 @@ CREATE TABLE IF NOT EXISTS appservice_events (
-- The ID of the application service the event will be sent to -- The ID of the application service the event will be sent to
as_id TEXT NOT NULL, as_id TEXT NOT NULL,
-- JSON representation of the event -- JSON representation of the event
event_json TEXT NOT NULL, headered_event_json TEXT NOT NULL,
-- The ID of the transaction that this event is a part of -- The ID of the transaction that this event is a part of
txn_id BIGINT NOT NULL txn_id BIGINT NOT NULL
); );
@ -42,14 +42,14 @@ CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
` `
const selectEventsByApplicationServiceIDSQL = "" + const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, event_json, txn_id " + "SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" + const countEventsByApplicationServiceIDSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" "SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, event_json, txn_id) " + "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)" "VALUES ($1, $2, $3)"
const updateTxnIDForEventsSQL = "" + const updateTxnIDForEventsSQL = "" +
@ -107,7 +107,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
limit int, limit int,
) ( ) (
txnID, maxID int, txnID, maxID int,
events []gomatrixserverlib.Event, events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool, eventsRemaining bool,
err error, err error,
) { ) {
@ -132,7 +132,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
return return
} }
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.Event, maxID, txnID int, eventsRemaining bool, err error) { func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age // Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
@ -141,7 +141,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
// new ones. Send back those events first. // new ones. Send back those events first.
lastTxnID := invalidTxnID lastTxnID := invalidTxnID
for eventsProcessed := 0; eventRows.Next(); { for eventsProcessed := 0; eventRows.Next(); {
var event gomatrixserverlib.Event var event gomatrixserverlib.HeaderedEvent
var eventJSON []byte var eventJSON []byte
var id int var id int
err = eventRows.Scan( err = eventRows.Scan(
@ -209,7 +209,7 @@ func (s *eventsStatements) countEventsByApplicationServiceID(
func (s *eventsStatements) insertEvent( func (s *eventsStatements) insertEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.Event, event *gomatrixserverlib.HeaderedEvent,
) (err error) { ) (err error) {
// Convert event to JSON before inserting // Convert event to JSON before inserting
eventJSON, err := json.Marshal(event) eventJSON, err := json.Marshal(event)

View file

@ -52,12 +52,12 @@ func (d *Database) prepare() error {
return d.txnID.prepare(d.db) return d.txnID.prepare(d.db)
} }
// StoreEvent takes in a gomatrixserverlib.Event and stores it in the database // StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
// for a transaction worker to pull and later send to an application service. // for a transaction worker to pull and later send to an application service.
func (d *Database) StoreEvent( func (d *Database) StoreEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.Event, event *gomatrixserverlib.HeaderedEvent,
) error { ) error {
return d.events.insertEvent(ctx, appServiceID, event) return d.events.insertEvent(ctx, appServiceID, event)
} }
@ -68,7 +68,7 @@ func (d *Database) GetEventsWithAppServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
limit int, limit int,
) (int, int, []gomatrixserverlib.Event, bool, error) { ) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
} }

View file

@ -33,7 +33,7 @@ CREATE TABLE IF NOT EXISTS appservice_events (
-- The ID of the application service the event will be sent to -- The ID of the application service the event will be sent to
as_id TEXT NOT NULL, as_id TEXT NOT NULL,
-- JSON representation of the event -- JSON representation of the event
event_json TEXT NOT NULL, headered_event_json TEXT NOT NULL,
-- The ID of the transaction that this event is a part of -- The ID of the transaction that this event is a part of
txn_id INTEGER NOT NULL txn_id INTEGER NOT NULL
); );
@ -42,14 +42,14 @@ CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
` `
const selectEventsByApplicationServiceIDSQL = "" + const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, event_json, txn_id " + "SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" + const countEventsByApplicationServiceIDSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" "SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, event_json, txn_id) " + "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)" "VALUES ($1, $2, $3)"
const updateTxnIDForEventsSQL = "" + const updateTxnIDForEventsSQL = "" +
@ -107,7 +107,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
limit int, limit int,
) ( ) (
txnID, maxID int, txnID, maxID int,
events []gomatrixserverlib.Event, events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool, eventsRemaining bool,
err error, err error,
) { ) {
@ -132,7 +132,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
return return
} }
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.Event, maxID, txnID int, eventsRemaining bool, err error) { func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age // Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
@ -141,7 +141,7 @@ func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.
// new ones. Send back those events first. // new ones. Send back those events first.
lastTxnID := invalidTxnID lastTxnID := invalidTxnID
for eventsProcessed := 0; eventRows.Next(); { for eventsProcessed := 0; eventRows.Next(); {
var event gomatrixserverlib.Event var event gomatrixserverlib.HeaderedEvent
var eventJSON []byte var eventJSON []byte
var id int var id int
err = eventRows.Scan( err = eventRows.Scan(
@ -209,7 +209,7 @@ func (s *eventsStatements) countEventsByApplicationServiceID(
func (s *eventsStatements) insertEvent( func (s *eventsStatements) insertEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.Event, event *gomatrixserverlib.HeaderedEvent,
) (err error) { ) (err error) {
// Convert event to JSON before inserting // Convert event to JSON before inserting
eventJSON, err := json.Marshal(event) eventJSON, err := json.Marshal(event)

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
// Import SQLite database driver // Import SQLite database driver
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -35,7 +36,7 @@ type Database struct {
func NewDatabase(dataSourceName string) (*Database, error) { func NewDatabase(dataSourceName string) (*Database, error) {
var result Database var result Database
var err error var err error
if result.db, err = sql.Open("sqlite3", dataSourceName); err != nil { if result.db, err = sql.Open(common.SQLiteDriverName(), dataSourceName); err != nil {
return nil, err return nil, err
} }
if err = result.prepare(); err != nil { if err = result.prepare(); err != nil {
@ -52,12 +53,12 @@ func (d *Database) prepare() error {
return d.txnID.prepare(d.db) return d.txnID.prepare(d.db)
} }
// StoreEvent takes in a gomatrixserverlib.Event and stores it in the database // StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
// for a transaction worker to pull and later send to an application service. // for a transaction worker to pull and later send to an application service.
func (d *Database) StoreEvent( func (d *Database) StoreEvent(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
event *gomatrixserverlib.Event, event *gomatrixserverlib.HeaderedEvent,
) error { ) error {
return d.events.insertEvent(ctx, appServiceID, event) return d.events.insertEvent(ctx, appServiceID, event)
} }
@ -68,7 +69,7 @@ func (d *Database) GetEventsWithAppServiceID(
ctx context.Context, ctx context.Context,
appServiceID string, appServiceID string,
limit int, limit int,
) (int, int, []gomatrixserverlib.Event, bool, error) { ) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
} }

View file

@ -12,26 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// +build !wasm
package storage package storage
import ( import (
"context"
"net/url" "net/url"
"github.com/matrix-org/dendrite/appservice/storage/postgres" "github.com/matrix-org/dendrite/appservice/storage/postgres"
"github.com/matrix-org/dendrite/appservice/storage/sqlite3" "github.com/matrix-org/dendrite/appservice/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
) )
type Database interface {
StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.Event) error
GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.Event, bool, error)
CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error)
UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error
RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error
GetLatestTxnID(ctx context.Context) (int, error)
}
func NewDatabase(dataSourceName string) (Database, error) { func NewDatabase(dataSourceName string) (Database, error) {
uri, err := url.Parse(dataSourceName) uri, err := url.Parse(dataSourceName)
if err != nil { if err != nil {

View file

@ -0,0 +1,37 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"fmt"
"net/url"
"github.com/matrix-org/dendrite/appservice/storage/sqlite3"
)
func NewDatabase(dataSourceName string) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return nil, fmt.Errorf("Cannot use postgres implementation")
}
switch uri.Scheme {
case "postgres":
return nil, fmt.Errorf("Cannot use postgres implementation")
case "file":
return sqlite3.NewDatabase(dataSourceName)
default:
return nil, fmt.Errorf("Cannot use postgres implementation")
}
}

View file

@ -181,9 +181,14 @@ func createTransaction(
} }
} }
var ev []gomatrixserverlib.Event
for _, e := range events {
ev = append(ev, e.Event)
}
// Create a transaction and store the events inside // Create a transaction and store the events inside
transaction := gomatrixserverlib.ApplicationServiceTransaction{ transaction := gomatrixserverlib.ApplicationServiceTransaction{
Events: events, Events: ev,
} }
transactionJSON, err = json.Marshal(transaction) transactionJSON, err = json.Marshal(transaction)

View file

@ -1,3 +1,6 @@
#!/bin/sh #!/bin/bash -eu
GOBIN=$PWD/`dirname $0`/bin go install -v $PWD/`dirname $0`/cmd/... # Put installed packages into ./bin
export GOBIN=$PWD/`dirname $0`/bin
go install -v $PWD/`dirname $0`/cmd/...

View file

@ -26,7 +26,6 @@ import (
"github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
@ -166,7 +165,8 @@ func verifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *auth
JSON: jsonerror.UnknownToken("Unknown token"), JSON: jsonerror.UnknownToken("Unknown token"),
} }
} else { } else {
jsonErr := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByAccessToken failed")
jsonErr := jsonerror.InternalServerError()
resErr = &jsonErr resErr = &jsonErr
} }
} }

View file

@ -26,4 +26,5 @@ type Device struct {
// associated with access tokens. // associated with access tokens.
SessionID int64 SessionID int64
// TODO: display name, last used timestamp, keys, etc // TODO: display name, last used timestamp, keys, etc
DisplayName string
} }

View file

@ -0,0 +1,54 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
import (
"context"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
common.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error)
CreateGuestAccount(ctx context.Context) (*authtypes.Account, error)
UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error)
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -72,9 +74,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
) (err error) { ) (err error) {
stmt := s.insertAccountDataStmt stmt := txn.Stmt(s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return return
} }
@ -90,7 +92,7 @@ func (s *accountDataStatements) selectAccountData(
if err != nil { if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck defer common.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
global = []gomatrixserverlib.ClientEvent{} global = []gomatrixserverlib.ClientEvent{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent) rooms = make(map[string][]gomatrixserverlib.ClientEvent)

View file

@ -91,10 +91,10 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := txn.Stmt(s.insertAccountStmt)
var err error var err error
if appserviceID == "" { if appserviceID == "" {
@ -146,8 +146,12 @@ func (s *accountsStatements) selectAccountByLocalpart(
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
return return
} }

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -51,6 +53,9 @@ const selectMembershipsByLocalpartSQL = "" +
const selectMembershipInRoomByLocalpartSQL = "" + const selectMembershipInRoomByLocalpartSQL = "" +
"SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"
const selectRoomIDsByLocalPartSQL = "" +
"SELECT room_id FROM account_memberships WHERE localpart = $1"
const deleteMembershipsByEventIDsSQL = "" + const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM account_memberships WHERE event_id = ANY($1)" "DELETE FROM account_memberships WHERE event_id = ANY($1)"
@ -59,6 +64,7 @@ type membershipStatements struct {
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipInRoomByLocalpartStmt *sql.Stmt selectMembershipInRoomByLocalpartStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt
selectRoomIDsByLocalPartStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -78,6 +84,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return return
} }
if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil {
return
}
return return
} }
@ -118,7 +127,7 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
memberships = []authtypes.Membership{} memberships = []authtypes.Membership{}
defer rows.Close() // nolint: errcheck defer common.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed")
for rows.Next() { for rows.Next() {
var m authtypes.Membership var m authtypes.Membership
m.Localpart = localpart m.Localpart = localpart
@ -129,3 +138,23 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
} }
return memberships, rows.Err() return memberships, rows.Err()
} }
func (s *membershipStatements) selectRoomIDsByLocalPart(
ctx context.Context, localPart string,
) ([]string, error) {
stmt := s.selectRoomIDsByLocalPartStmt
rows, err := stmt.QueryContext(ctx, localPart)
if err != nil {
return nil, err
}
roomIDs := []string{}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
}
return roomIDs, rows.Err()
}

View file

@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"strconv"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -118,11 +119,37 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName) return d.profiles.setDisplayName(ctx, localpart, displayName)
} }
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil. // account already exists, it will return nil, nil.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
var err error var err error
@ -134,13 +161,14 @@ func (d *Database) CreateAccount(
return nil, err return nil, err
} }
} }
if err := d.profiles.insertProfile(ctx, localpart); err != nil { if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) { if common.IsUniqueConstraintViolationErr(err) {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -151,7 +179,7 @@ func (d *Database) CreateAccount(
}`); err != nil { }`); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
} }
// SaveMembership saves the user matching a given localpart as a member of a given // SaveMembership saves the user matching a given localpart as a member of a given
@ -206,6 +234,16 @@ func (d *Database) GetMembershipInRoomByLocalpart(
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID) return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
} }
// GetRoomIDsByLocalPart returns an array containing the room ids of all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetRoomIDsByLocalPart(
ctx context.Context, localpart string,
) ([]string, error) {
return d.memberships.selectRoomIDsByLocalPart(ctx, localpart)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all // GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of // the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array // If no membership match the given localpart, returns an empty array
@ -258,7 +296,9 @@ func (d *Database) newMembership(
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, localpart, roomID, dataType, content string,
) error { ) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
} }
// GetAccountData returns account data related to a given localpart // GetAccountData returns account data related to a given localpart
@ -288,7 +328,7 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart( func (d *Database) GetNewNumericLocalpart(
ctx context.Context, ctx context.Context,
) (int64, error) { ) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx) return d.accounts.selectNewNumericLocalpart(ctx, nil)
} }
func hashPassword(plaintext string) (hash string, err error) { func hashPassword(plaintext string) (hash string, err error) {

View file

@ -72,10 +72,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
) (err error) { ) (err error) {
stmt := s.insertAccountDataStmt _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return return
} }

View file

@ -89,16 +89,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
var err error var err error
if appserviceID == "" { if appserviceID == "" {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else { } else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -144,8 +144,12 @@ func (s *accountsStatements) selectAccountByLocalpart(
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
return return
} }

View file

@ -17,9 +17,10 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
) )
const membershipSchema = ` const membershipSchema = `
@ -50,14 +51,17 @@ const selectMembershipsByLocalpartSQL = "" +
const selectMembershipInRoomByLocalpartSQL = "" + const selectMembershipInRoomByLocalpartSQL = "" +
"SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"
const selectRoomIDsByLocalPartSQL = "" +
"SELECT room_id FROM account_memberships WHERE localpart = $1"
const deleteMembershipsByEventIDsSQL = "" + const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM account_memberships WHERE event_id IN ($1)" "DELETE FROM account_memberships WHERE event_id IN ($1)"
type membershipStatements struct { type membershipStatements struct {
deleteMembershipsByEventIDsStmt *sql.Stmt
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipInRoomByLocalpartStmt *sql.Stmt selectMembershipInRoomByLocalpartStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt
selectRoomIDsByLocalPartStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -65,9 +69,6 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
if err != nil { if err != nil {
return return
} }
if s.deleteMembershipsByEventIDsStmt, err = db.Prepare(deleteMembershipsByEventIDsSQL); err != nil {
return
}
if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil { if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
return return
} }
@ -77,6 +78,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return return
} }
if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil {
return
}
return return
} }
@ -91,8 +95,12 @@ func (s *membershipStatements) insertMembership(
func (s *membershipStatements) deleteMembershipsByEventIDs( func (s *membershipStatements) deleteMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) (err error) { ) (err error) {
stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt) sqlStr := strings.Replace(deleteMembershipsByEventIDsSQL, "($1)", common.QueryVariadic(len(eventIDs)), 1)
_, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
for i, e := range eventIDs {
iEventIDs[i] = e
}
_, err = txn.ExecContext(ctx, sqlStr, iEventIDs...)
return return
} }
@ -117,7 +125,7 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
memberships = []authtypes.Membership{} memberships = []authtypes.Membership{}
defer rows.Close() // nolint: errcheck defer common.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed")
for rows.Next() { for rows.Next() {
var m authtypes.Membership var m authtypes.Membership
m.Localpart = localpart m.Localpart = localpart
@ -129,3 +137,22 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
return return
} }
func (s *membershipStatements) selectRoomIDsByLocalPart(
ctx context.Context, localPart string,
) ([]string, error) {
stmt := s.selectRoomIDsByLocalPartStmt
rows, err := stmt.QueryContext(ctx, localPart)
if err != nil {
return nil, err
}
roomIDs := []string{}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
}
return roomIDs, rows.Err()
}

View file

@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"strconv"
"sync"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -39,13 +41,15 @@ type Database struct {
threepids threepidStatements threepids threepidStatements
filter filterStatements filter filterStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
createGuestAccountMu sync.Mutex
} }
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB var db *sql.DB
var err error var err error
if db, err = sql.Open("sqlite3", dataSourceName); err != nil { if db, err = sql.Open(common.SQLiteDriverName(), dataSourceName); err != nil {
return nil, err return nil, err
} }
partitions := common.PartitionOffsetStatements{} partitions := common.PartitionOffsetStatements{}
@ -76,7 +80,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
if err = f.prepare(db); err != nil { if err = f.prepare(db); err != nil {
return nil, err return nil, err
} }
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil return &Database{db, partitions, a, p, m, ac, t, f, serverName, sync.Mutex{}}, nil
} }
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
@ -118,14 +122,46 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName) return d.profiles.setDisplayName(ctx, localpart, displayName)
} }
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
// We need to lock so we sequentially create numeric localparts. If we don't, two calls to
// this function will cause the same number to be selected and one will fail with 'database is locked'
// when the first txn upgrades to a write txn.
// We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
d.createGuestAccountMu.Lock()
defer d.createGuestAccountMu.Unlock()
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil. // account already exists, it will return nil, nil.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
var err error var err error
// Generate a password hash if this is not a password-less user // Generate a password hash if this is not a password-less user
hash := "" hash := ""
if plaintextPassword != "" { if plaintextPassword != "" {
@ -134,13 +170,14 @@ func (d *Database) CreateAccount(
return nil, err return nil, err
} }
} }
if err := d.profiles.insertProfile(ctx, localpart); err != nil { if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) { if common.IsUniqueConstraintViolationErr(err) {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -151,7 +188,7 @@ func (d *Database) CreateAccount(
}`); err != nil { }`); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
} }
// SaveMembership saves the user matching a given localpart as a member of a given // SaveMembership saves the user matching a given localpart as a member of a given
@ -216,6 +253,16 @@ func (d *Database) GetMembershipsByLocalpart(
return d.memberships.selectMembershipsByLocalpart(ctx, localpart) return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
} }
// GetRoomIDsByLocalPart returns an array containing the room ids of all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetRoomIDsByLocalPart(
ctx context.Context, localpart string,
) ([]string, error) {
return d.memberships.selectRoomIDsByLocalPart(ctx, localpart)
}
// newMembership saves a new membership in the database. // newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing. // If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error // If an error occurred, returns the SQL error
@ -258,7 +305,9 @@ func (d *Database) newMembership(
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, localpart, roomID, dataType, content string,
) error { ) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
} }
// GetAccountData returns account data related to a given localpart // GetAccountData returns account data related to a given localpart
@ -288,7 +337,7 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart( func (d *Database) GetNewNumericLocalpart(
ctx context.Context, ctx context.Context,
) (int64, error) { ) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx) return d.accounts.selectNewNumericLocalpart(ctx, nil)
} }
func hashPassword(plaintext string) (hash string, err error) { func hashPassword(plaintext string) (hash string, err error) {

View file

@ -97,7 +97,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
if err != nil { if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck defer common.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed")
threepids = []authtypes.ThreePID{} threepids = []authtypes.ThreePID{}
for rows.Next() { for rows.Next() {

View file

@ -1,41 +1,29 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !wasm
package accounts package accounts
import ( import (
"context"
"errors"
"net/url" "net/url"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/postgres" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/postgres"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type Database interface {
common.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error)
UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error)
}
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) { func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
uri, err := url.Parse(dataSourceName) uri, err := url.Parse(dataSourceName)
if err != nil { if err != nil {
@ -50,7 +38,3 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
return postgres.NewDatabase(dataSourceName, serverName) return postgres.NewDatabase(dataSourceName, serverName)
} }
} }
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")

View file

@ -0,0 +1,38 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
import (
"fmt"
"net/url"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return nil, fmt.Errorf("Cannot use postgres implementation")
}
switch uri.Scheme {
case "postgres":
return nil, fmt.Errorf("Cannot use postgres implementation")
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return nil, fmt.Errorf("Cannot use postgres implementation")
}
}

View file

@ -0,0 +1,32 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"context"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error)
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
RemoveAllDevices(ctx context.Context, localpart string) error
}

View file

@ -226,11 +226,11 @@ func (s *devicesStatements) selectDevicesByLocalpart(
if err != nil { if err != nil {
return devices, err return devices, err
} }
defer rows.Close() // nolint: errcheck defer common.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
for rows.Next() { for rows.Next() {
var dev authtypes.Device var dev authtypes.Device
err = rows.Scan(&dev.ID) err = rows.Scan(&dev.ID, &dev.DisplayName)
if err != nil { if err != nil {
return devices, err return devices, err
} }

View file

@ -231,7 +231,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
for rows.Next() { for rows.Next() {
var dev authtypes.Device var dev authtypes.Device
err = rows.Scan(&dev.ID) err = rows.Scan(&dev.ID, &dev.DisplayName)
if err != nil { if err != nil {
return devices, err return devices, err
} }

View file

@ -40,7 +40,7 @@ type Database struct {
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB var db *sql.DB
var err error var err error
if db, err = sql.Open("sqlite3", dataSourceName); err != nil { if db, err = sql.Open(common.SQLiteDriverName(), dataSourceName); err != nil {
return nil, err return nil, err
} }
d := devicesStatements{} d := devicesStatements{}

View file

@ -1,26 +1,29 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !wasm
package devices package devices
import ( import (
"context"
"net/url" "net/url"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices/postgres" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/postgres"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error)
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
RemoveAllDevices(ctx context.Context, localpart string) error
}
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) { func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
uri, err := url.Parse(dataSourceName) uri, err := url.Parse(dataSourceName)
if err != nil { if err != nil {

View file

@ -0,0 +1,38 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"fmt"
"net/url"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return nil, fmt.Errorf("Cannot use postgres implementation")
}
switch uri.Scheme {
case "postgres":
return nil, fmt.Errorf("Cannot use postgres implementation")
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return nil, fmt.Errorf("Cannot use postgres implementation")
}
}

View file

@ -46,7 +46,7 @@ func SetupClientAPIComponent(
transactionsCache *transactions.Cache, transactionsCache *transactions.Cache,
fedSenderAPI federationSenderAPI.FederationSenderQueryAPI, fedSenderAPI federationSenderAPI.FederationSenderQueryAPI,
) { ) {
roomserverProducer := producers.NewRoomserverProducer(inputAPI) roomserverProducer := producers.NewRoomserverProducer(inputAPI, queryAPI)
typingProducer := producers.NewTypingServerProducer(typingInputAPI) typingProducer := producers.NewTypingServerProducer(typingInputAPI)
userUpdateProducer := &producers.UserUpdateProducer{ userUpdateProducer := &producers.UserUpdateProducer{

View file

@ -91,7 +91,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
"type": ev.Type(), "type": ev.Type(),
}).Info("received event from roomserver") }).Info("received event from roomserver")
events, err := s.lookupStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev) events, err := s.lookupStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev.Event)
if err != nil { if err != nil {
return err return err
} }
@ -138,7 +138,9 @@ func (s *OutputRoomEventConsumer) lookupStateEvents(
return nil, err return nil, err
} }
result = append(result, eventResp.Events...) for _, headeredEvent := range eventResp.Events {
result = append(result, headeredEvent.Event)
}
return result, nil return result, nil
} }

View file

@ -36,11 +36,3 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon
} }
return nil return nil
} }
// LogThenError logs the given error then returns a matrix-compliant 500 internal server error response.
// This should be used to log fatal errors which require investigation. It should not be used
// to log client validation errors, etc.
func LogThenError(req *http.Request, err error) util.JSONResponse {
util.GetLogger(req.Context()).WithError(err).Error("request failed")
return jsonerror.InternalServerError()
}

View file

@ -124,6 +124,12 @@ func GuestAccessForbidden(msg string) *MatrixError {
return &MatrixError{"M_GUEST_ACCESS_FORBIDDEN", msg} return &MatrixError{"M_GUEST_ACCESS_FORBIDDEN", msg}
} }
// UnsupportedRoomVersion is an error which is returned when the client
// requests a room with a version that is unsupported.
func UnsupportedRoomVersion(msg string) *MatrixError {
return &MatrixError{"M_UNSUPPORTED_ROOM_VERSION", msg}
}
// LimitExceededError is a rate-limiting error. // LimitExceededError is a rate-limiting error.
type LimitExceededError struct { type LimitExceededError struct {
MatrixError MatrixError

View file

@ -24,18 +24,20 @@ import (
// RoomserverProducer produces events for the roomserver to consume. // RoomserverProducer produces events for the roomserver to consume.
type RoomserverProducer struct { type RoomserverProducer struct {
InputAPI api.RoomserverInputAPI InputAPI api.RoomserverInputAPI
QueryAPI api.RoomserverQueryAPI
} }
// NewRoomserverProducer creates a new RoomserverProducer // NewRoomserverProducer creates a new RoomserverProducer
func NewRoomserverProducer(inputAPI api.RoomserverInputAPI) *RoomserverProducer { func NewRoomserverProducer(inputAPI api.RoomserverInputAPI, queryAPI api.RoomserverQueryAPI) *RoomserverProducer {
return &RoomserverProducer{ return &RoomserverProducer{
InputAPI: inputAPI, InputAPI: inputAPI,
QueryAPI: queryAPI,
} }
} }
// SendEvents writes the given events to the roomserver input log. The events are written with KindNew. // SendEvents writes the given events to the roomserver input log. The events are written with KindNew.
func (c *RoomserverProducer) SendEvents( func (c *RoomserverProducer) SendEvents(
ctx context.Context, events []gomatrixserverlib.Event, sendAsServer gomatrixserverlib.ServerName, ctx context.Context, events []gomatrixserverlib.HeaderedEvent, sendAsServer gomatrixserverlib.ServerName,
txnID *api.TransactionID, txnID *api.TransactionID,
) (string, error) { ) (string, error) {
ires := make([]api.InputRoomEvent, len(events)) ires := make([]api.InputRoomEvent, len(events))
@ -54,20 +56,20 @@ func (c *RoomserverProducer) SendEvents(
// SendEventWithState writes an event with KindNew to the roomserver input log // SendEventWithState writes an event with KindNew to the roomserver input log
// with the state at the event as KindOutlier before it. // with the state at the event as KindOutlier before it.
func (c *RoomserverProducer) SendEventWithState( func (c *RoomserverProducer) SendEventWithState(
ctx context.Context, state gomatrixserverlib.RespState, event gomatrixserverlib.Event, ctx context.Context, state gomatrixserverlib.RespState, event gomatrixserverlib.HeaderedEvent,
) error { ) error {
outliers, err := state.Events() outliers, err := state.Events()
if err != nil { if err != nil {
return err return err
} }
ires := make([]api.InputRoomEvent, len(outliers)+1) var ires []api.InputRoomEvent
for i, outlier := range outliers { for _, outlier := range outliers {
ires[i] = api.InputRoomEvent{ ires = append(ires, api.InputRoomEvent{
Kind: api.KindOutlier, Kind: api.KindOutlier,
Event: outlier, Event: outlier.Headered(event.RoomVersion),
AuthEventIDs: outlier.AuthEventIDs(), AuthEventIDs: outlier.AuthEventIDs(),
} })
} }
stateEventIDs := make([]string, len(state.StateEvents)) stateEventIDs := make([]string, len(state.StateEvents))
@ -75,13 +77,13 @@ func (c *RoomserverProducer) SendEventWithState(
stateEventIDs[i] = state.StateEvents[i].EventID() stateEventIDs[i] = state.StateEvents[i].EventID()
} }
ires[len(outliers)] = api.InputRoomEvent{ ires = append(ires, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
AuthEventIDs: event.AuthEventIDs(), AuthEventIDs: event.AuthEventIDs(),
HasState: true, HasState: true,
StateEventIDs: stateEventIDs, StateEventIDs: stateEventIDs,
} })
_, err = c.SendInputRoomEvents(ctx, ires) _, err = c.SendInputRoomEvents(ctx, ires)
return err return err
@ -104,8 +106,17 @@ func (c *RoomserverProducer) SendInputRoomEvents(
func (c *RoomserverProducer) SendInvite( func (c *RoomserverProducer) SendInvite(
ctx context.Context, inviteEvent gomatrixserverlib.Event, ctx context.Context, inviteEvent gomatrixserverlib.Event,
) error { ) error {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: inviteEvent.RoomID()}
verRes := api.QueryRoomVersionForRoomResponse{}
err := c.QueryAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes)
if err != nil {
return err
}
request := api.InputRoomEventsRequest{ request := api.InputRoomEventsRequest{
InputInviteEvents: []api.InputInviteEvent{{Event: inviteEvent}}, InputInviteEvents: []api.InputInviteEvent{{
Event: inviteEvent.Headered(verRes.RoomVersion),
}},
} }
var response api.InputRoomEventsResponse var response api.InputRoomEventsResponse
return c.InputAPI.InputRoomEvents(ctx, &request, &response) return c.InputAPI.InputRoomEvents(ctx, &request, &response)

View file

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -43,7 +42,8 @@ func GetAccountData(
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if data, err := accountDB.GetAccountDataByType( if data, err := accountDB.GetAccountDataByType(
@ -75,7 +75,8 @@ func SaveAccountData(
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
defer req.Body.Close() // nolint: errcheck defer req.Body.Close() // nolint: errcheck
@ -89,7 +90,8 @@ func SaveAccountData(
body, err := ioutil.ReadAll(req.Body) body, err := ioutil.ReadAll(req.Body)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("ioutil.ReadAll failed")
return jsonerror.InternalServerError()
} }
if !json.Valid(body) { if !json.Valid(body) {
@ -102,11 +104,13 @@ func SaveAccountData(
if err := accountDB.SaveAccountData( if err := accountDB.SaveAccountData(
req.Context(), localpart, roomID, dataType, string(body), req.Context(), localpart, roomID, dataType, string(body),
); err != nil { ); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed")
return jsonerror.InternalServerError()
} }
if err := syncProducer.SendData(userID, roomID, dataType); err != nil { if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -19,7 +19,6 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -151,7 +150,8 @@ func AuthFallback(
clientIP := req.RemoteAddr clientIP := req.RemoteAddr
err := req.ParseForm() err := req.ParseForm()
if err != nil { if err != nil {
res := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
res := jsonerror.InternalServerError()
return &res return &res
} }
@ -203,7 +203,8 @@ func writeHTTPMessage(
w.WriteHeader(header) w.WriteHeader(header)
_, err := w.Write([]byte(message)) _, err := w.Write([]byte(message))
if err != nil { if err != nil {
res := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
res := jsonerror.InternalServerError()
return &res return &res
} }
return nil return nil

View file

@ -17,7 +17,7 @@ package routing
import ( import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -35,7 +35,8 @@ func GetCapabilities(
&roomVersionsQueryReq, &roomVersionsQueryReq,
&roomVersionsQueryRes, &roomVersionsQueryRes,
); err != nil { ); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryRoomVersionCapabilities failed")
return jsonerror.InternalServerError()
} }
response := map[string]interface{}{ response := map[string]interface{}{

View file

@ -23,6 +23,7 @@ import (
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
roomserverVersion "github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
@ -47,6 +48,7 @@ type createRoomRequest struct {
InitialState []fledglingEvent `json:"initial_state"` InitialState []fledglingEvent `json:"initial_state"`
RoomAliasName string `json:"room_alias_name"` RoomAliasName string `json:"room_alias_name"`
GuestCanJoin bool `json:"guest_can_join"` GuestCanJoin bool `json:"guest_can_join"`
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
} }
const ( const (
@ -180,7 +182,19 @@ func createRoom(
} }
r.CreationContent["creator"] = userID r.CreationContent["creator"] = userID
r.CreationContent["room_version"] = "1" // TODO: We set this to 1 before we support Room versioning roomVersion := roomserverVersion.DefaultRoomVersion()
if r.RoomVersion != "" {
candidateVersion := gomatrixserverlib.RoomVersion(r.RoomVersion)
_, roomVersionError := roomserverVersion.SupportedRoomVersion(candidateVersion)
if roomVersionError != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(roomVersionError.Error()),
}
}
roomVersion = candidateVersion
}
r.CreationContent["room_version"] = roomVersion
// TODO: visibility/presets/raw initial state // TODO: visibility/presets/raw initial state
// TODO: Create room alias association // TODO: Create room alias association
@ -189,11 +203,13 @@ func createRoom(
logger.WithFields(log.Fields{ logger.WithFields(log.Fields{
"userID": userID, "userID": userID,
"roomID": roomID, "roomID": roomID,
"roomVersion": r.CreationContent["room_version"],
}).Info("Creating new room") }).Info("Creating new room")
profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB) profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed")
return jsonerror.InternalServerError()
} }
membershipContent := gomatrixserverlib.MemberContent{ membershipContent := gomatrixserverlib.MemberContent{
@ -221,7 +237,7 @@ func createRoom(
historyVisibility = historyVisibilityShared historyVisibility = historyVisibilityShared
} }
var builtEvents []gomatrixserverlib.Event var builtEvents []gomatrixserverlib.HeaderedEvent
// send events into the room in order of: // send events into the room in order of:
// 1- m.room.create // 1- m.room.create
@ -276,33 +292,38 @@ func createRoom(
} }
err = builder.SetContent(e.Content) err = builder.SetContent(e.Content)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed")
return jsonerror.InternalServerError()
} }
if i > 0 { if i > 0 {
builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()}
} }
var ev *gomatrixserverlib.Event var ev *gomatrixserverlib.Event
ev, err = buildEvent(&builder, &authEvents, cfg, evTime) ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("buildEvent failed")
return jsonerror.InternalServerError()
} }
if err = gomatrixserverlib.Allowed(*ev, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(*ev, &authEvents); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.Allowed failed")
return jsonerror.InternalServerError()
} }
// Add the event to the list of auth events // Add the event to the list of auth events
builtEvents = append(builtEvents, *ev) builtEvents = append(builtEvents, (*ev).Headered(roomVersion))
err = authEvents.AddEvent(ev) err = authEvents.AddEvent(ev)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed")
return jsonerror.InternalServerError()
} }
} }
// send events to the room server // send events to the room server
_, err = producer.SendEvents(req.Context(), builtEvents, cfg.Matrix.ServerName, nil) _, err = producer.SendEvents(req.Context(), builtEvents, cfg.Matrix.ServerName, nil)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed")
return jsonerror.InternalServerError()
} }
// TODO(#269): Reserve room alias while we create the room. This stops us // TODO(#269): Reserve room alias while we create the room. This stops us
@ -321,7 +342,8 @@ func createRoom(
var aliasResp roomserverAPI.SetRoomAliasResponse var aliasResp roomserverAPI.SetRoomAliasResponse
err = aliasAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp) err = aliasAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed")
return jsonerror.InternalServerError()
} }
if aliasResp.AliasExists { if aliasResp.AliasExists {
@ -346,6 +368,7 @@ func buildEvent(
provider gomatrixserverlib.AuthEventProvider, provider gomatrixserverlib.AuthEventProvider,
cfg *config.Dendrite, cfg *config.Dendrite,
evTime time.Time, evTime time.Time,
roomVersion gomatrixserverlib.RoomVersion,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil { if err != nil {
@ -356,10 +379,12 @@ func buildEvent(
return nil, err return nil, err
} }
builder.AuthEvents = refs builder.AuthEvents = refs
eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), cfg.Matrix.ServerName) event, err := builder.Build(
event, err := builder.Build(eventID, evTime, cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey) evTime, cfg.Matrix.ServerName, cfg.Matrix.KeyID,
cfg.Matrix.PrivateKey, roomVersion,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot build event %s : Builder failed to build. %s", builder.Type, err) return nil, fmt.Errorf("cannot build event %s : Builder failed to build. %w", builder.Type, err)
} }
return &event, nil return &event, nil
} }

View file

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -51,7 +50,8 @@ func GetDeviceByID(
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
ctx := req.Context() ctx := req.Context()
@ -62,7 +62,8 @@ func GetDeviceByID(
JSON: jsonerror.NotFound("Unknown device"), JSON: jsonerror.NotFound("Unknown device"),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -80,14 +81,16 @@ func GetDevicesByLocalpart(
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
ctx := req.Context() ctx := req.Context()
deviceList, err := deviceDB.GetDevicesByLocalpart(ctx, localpart) deviceList, err := deviceDB.GetDevicesByLocalpart(ctx, localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDevicesByLocalpart failed")
return jsonerror.InternalServerError()
} }
res := devicesJSON{} res := devicesJSON{}
@ -112,7 +115,8 @@ func UpdateDeviceByID(
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
ctx := req.Context() ctx := req.Context()
@ -123,7 +127,8 @@ func UpdateDeviceByID(
JSON: jsonerror.NotFound("Unknown device"), JSON: jsonerror.NotFound("Unknown device"),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed")
return jsonerror.InternalServerError()
} }
if dev.UserID != device.UserID { if dev.UserID != device.UserID {
@ -138,11 +143,13 @@ func UpdateDeviceByID(
payload := deviceUpdateJSON{} payload := deviceUpdateJSON{}
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("json.NewDecoder.Decode failed")
return jsonerror.InternalServerError()
} }
if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil { if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.UpdateDevice failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -158,14 +165,16 @@ func DeleteDeviceById(
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
ctx := req.Context() ctx := req.Context()
defer req.Body.Close() // nolint: errcheck defer req.Body.Close() // nolint: errcheck
if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil { if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevice failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -180,20 +189,23 @@ func DeleteDevices(
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
ctx := req.Context() ctx := req.Context()
payload := devicesDeleteJSON{} payload := devicesDeleteJSON{}
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("json.NewDecoder.Decode failed")
return jsonerror.InternalServerError()
} }
defer req.Body.Close() // nolint: errcheck defer req.Body.Close() // nolint: errcheck
if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil { if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevices failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -63,7 +63,8 @@ func DirectoryRoom(
queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias} queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias}
var queryRes roomserverAPI.GetRoomIDForAliasResponse var queryRes roomserverAPI.GetRoomIDForAliasResponse
if err = rsAPI.GetRoomIDForAlias(req.Context(), &queryReq, &queryRes); err != nil { if err = rsAPI.GetRoomIDForAlias(req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("rsAPI.GetRoomIDForAlias failed")
return jsonerror.InternalServerError()
} }
res.RoomID = queryRes.RoomID res.RoomID = queryRes.RoomID
@ -76,7 +77,8 @@ func DirectoryRoom(
if fedErr != nil { if fedErr != nil {
// TODO: Return 502 if the remote server errored. // TODO: Return 502 if the remote server errored.
// TODO: Return 504 if the remote server timed out. // TODO: Return 504 if the remote server timed out.
return httputil.LogThenError(req, fedErr) util.GetLogger(req.Context()).WithError(err).Error("federation.LookupRoomAlias failed")
return jsonerror.InternalServerError()
} }
res.RoomID = fedRes.RoomID res.RoomID = fedRes.RoomID
res.fillServers(fedRes.Servers) res.fillServers(fedRes.Servers)
@ -94,7 +96,8 @@ func DirectoryRoom(
joinedHostsReq := federationSenderAPI.QueryJoinedHostServerNamesInRoomRequest{RoomID: res.RoomID} joinedHostsReq := federationSenderAPI.QueryJoinedHostServerNamesInRoomRequest{RoomID: res.RoomID}
var joinedHostsRes federationSenderAPI.QueryJoinedHostServerNamesInRoomResponse var joinedHostsRes federationSenderAPI.QueryJoinedHostServerNamesInRoomResponse
if err = fedSenderAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &joinedHostsReq, &joinedHostsRes); err != nil { if err = fedSenderAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &joinedHostsReq, &joinedHostsRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("fedSenderAPI.QueryJoinedHostServerNamesInRoom failed")
return jsonerror.InternalServerError()
} }
res.fillServers(joinedHostsRes.ServerNames) res.fillServers(joinedHostsRes.ServerNames)
} }
@ -165,7 +168,8 @@ func SetLocalAlias(
} }
var queryRes roomserverAPI.SetRoomAliasResponse var queryRes roomserverAPI.SetRoomAliasResponse
if err := aliasAPI.SetRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { if err := aliasAPI.SetRoomAlias(req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed")
return jsonerror.InternalServerError()
} }
if queryRes.AliasExists { if queryRes.AliasExists {
@ -194,7 +198,8 @@ func RemoveLocalAlias(
} }
var creatorQueryRes roomserverAPI.GetCreatorIDForAliasResponse var creatorQueryRes roomserverAPI.GetCreatorIDForAliasResponse
if err := aliasAPI.GetCreatorIDForAlias(req.Context(), &creatorQueryReq, &creatorQueryRes); err != nil { if err := aliasAPI.GetCreatorIDForAlias(req.Context(), &creatorQueryReq, &creatorQueryRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.GetCreatorIDForAlias failed")
return jsonerror.InternalServerError()
} }
if creatorQueryRes.UserID == "" { if creatorQueryRes.UserID == "" {
@ -218,7 +223,8 @@ func RemoveLocalAlias(
} }
var queryRes roomserverAPI.RemoveRoomAliasResponse var queryRes roomserverAPI.RemoveRoomAliasResponse
if err := aliasAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { if err := aliasAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.RemoveRoomAlias failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -37,7 +37,8 @@ func GetFilter(
} }
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
filter, err := accountDB.GetFilter(req.Context(), localpart, filterID) filter, err := accountDB.GetFilter(req.Context(), localpart, filterID)
@ -74,7 +75,8 @@ func PutFilter(
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
var filter gomatrixserverlib.Filter var filter gomatrixserverlib.Filter
@ -93,7 +95,8 @@ func PutFilter(
filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter) filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.PutFilter failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -18,7 +18,6 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -55,7 +54,8 @@ func GetEvent(
var eventsResp api.QueryEventsByIDResponse var eventsResp api.QueryEventsByIDResponse
err := queryAPI.QueryEventsByID(req.Context(), &eventsReq, &eventsResp) err := queryAPI.QueryEventsByID(req.Context(), &eventsReq, &eventsResp)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryEventsByID failed")
return jsonerror.InternalServerError()
} }
if len(eventsResp.Events) == 0 { if len(eventsResp.Events) == 0 {
@ -66,7 +66,7 @@ func GetEvent(
} }
} }
requestedEvent := eventsResp.Events[0] requestedEvent := eventsResp.Events[0].Event
r := getEventRequest{ r := getEventRequest{
req: req, req: req,
@ -89,7 +89,8 @@ func GetEvent(
} }
var stateResp api.QueryStateAfterEventsResponse var stateResp api.QueryStateAfterEventsResponse
if err := queryAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil { if err := queryAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryStateAfterEvents failed")
return jsonerror.InternalServerError()
} }
if !stateResp.RoomExists { if !stateResp.RoomExists {
@ -109,7 +110,8 @@ func GetEvent(
if stateEvent.StateKeyEquals(r.device.UserID) { if stateEvent.StateKeyEquals(r.device.UserID) {
membership, err := stateEvent.Membership() membership, err := stateEvent.Membership()
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("stateEvent.Membership failed")
return jsonerror.InternalServerError()
} }
if membership == gomatrixserverlib.Join { if membership == gomatrixserverlib.Join {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -62,12 +63,14 @@ func JoinRoomByIDOrAlias(
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed")
return jsonerror.InternalServerError()
} }
content["membership"] = gomatrixserverlib.Join content["membership"] = gomatrixserverlib.Join
@ -119,7 +122,8 @@ func (r joinRoomReq) joinRoomByID(roomID string) util.JSONResponse {
} }
var queryRes roomserverAPI.QueryInvitesForUserResponse var queryRes roomserverAPI.QueryInvitesForUserResponse
if err := r.queryAPI.QueryInvitesForUser(r.req.Context(), &queryReq, &queryRes); err != nil { if err := r.queryAPI.QueryInvitesForUser(r.req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("r.queryAPI.QueryInvitesForUser failed")
return jsonerror.InternalServerError()
} }
servers := []gomatrixserverlib.ServerName{} servers := []gomatrixserverlib.ServerName{}
@ -127,7 +131,8 @@ func (r joinRoomReq) joinRoomByID(roomID string) util.JSONResponse {
for _, userID := range queryRes.InviteSenderUserIDs { for _, userID := range queryRes.InviteSenderUserIDs {
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if !seenInInviterIDs[domain] { if !seenInInviterIDs[domain] {
servers = append(servers, domain) servers = append(servers, domain)
@ -141,7 +146,8 @@ func (r joinRoomReq) joinRoomByID(roomID string) util.JSONResponse {
// Note: It's no guarantee we'll succeed because a room isn't bound to the domain in its ID // Note: It's no guarantee we'll succeed because a room isn't bound to the domain in its ID
_, domain, err := gomatrixserverlib.SplitID('!', roomID) _, domain, err := gomatrixserverlib.SplitID('!', roomID)
if err != nil { if err != nil {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if domain != r.cfg.Matrix.ServerName && !seenInInviterIDs[domain] { if domain != r.cfg.Matrix.ServerName && !seenInInviterIDs[domain] {
servers = append(servers, domain) servers = append(servers, domain)
@ -164,7 +170,8 @@ func (r joinRoomReq) joinRoomByAlias(roomAlias string) util.JSONResponse {
queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias} queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias}
var queryRes roomserverAPI.GetRoomIDForAliasResponse var queryRes roomserverAPI.GetRoomIDForAliasResponse
if err = r.aliasAPI.GetRoomIDForAlias(r.req.Context(), &queryReq, &queryRes); err != nil { if err = r.aliasAPI.GetRoomIDForAlias(r.req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("r.aliasAPI.GetRoomIDForAlias failed")
return jsonerror.InternalServerError()
} }
if len(queryRes.RoomID) > 0 { if len(queryRes.RoomID) > 0 {
@ -194,7 +201,8 @@ func (r joinRoomReq) joinRoomByRemoteAlias(
} }
} }
} }
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("r.federation.LookupRoomAlias failed")
return jsonerror.InternalServerError()
} }
return r.joinRoomUsingServers(resp.RoomID, resp.Servers) return r.joinRoomUsingServers(resp.RoomID, resp.Servers)
@ -227,14 +235,23 @@ func (r joinRoomReq) joinRoomUsingServers(
var eb gomatrixserverlib.EventBuilder var eb gomatrixserverlib.EventBuilder
err := r.writeToBuilder(&eb, roomID) err := r.writeToBuilder(&eb, roomID)
if err != nil { if err != nil {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("r.writeToBuilder failed")
return jsonerror.InternalServerError()
} }
var queryRes roomserverAPI.QueryLatestEventsAndStateResponse queryRes := roomserverAPI.QueryLatestEventsAndStateResponse{}
event, err := common.BuildEvent(r.req.Context(), &eb, r.cfg, r.evTime, r.queryAPI, &queryRes) event, err := common.BuildEvent(r.req.Context(), &eb, r.cfg, r.evTime, r.queryAPI, &queryRes)
if err == nil { if err == nil {
if _, err = r.producer.SendEvents(r.req.Context(), []gomatrixserverlib.Event{*event}, r.cfg.Matrix.ServerName, nil); err != nil { if _, err = r.producer.SendEvents(
return httputil.LogThenError(r.req, err) r.req.Context(),
[]gomatrixserverlib.HeaderedEvent{
(*event).Headered(queryRes.RoomVersion),
},
r.cfg.Matrix.ServerName,
nil,
); err != nil {
util.GetLogger(r.req.Context()).WithError(err).Error("r.producer.SendEvents failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -244,7 +261,8 @@ func (r joinRoomReq) joinRoomUsingServers(
} }
} }
if err != common.ErrRoomNoExists { if err != common.ErrRoomNoExists {
return httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("common.BuildEvent failed")
return jsonerror.InternalServerError()
} }
if len(servers) == 0 { if len(servers) == 0 {
@ -280,7 +298,8 @@ func (r joinRoomReq) joinRoomUsingServers(
// 4) We couldn't fetch the public keys needed to verify the // 4) We couldn't fetch the public keys needed to verify the
// signatures on the state events. // signatures on the state events.
// 5) ... // 5) ...
return httputil.LogThenError(r.req, lastErr) util.GetLogger(r.req.Context()).WithError(lastErr).Error("failed to join through any server")
return jsonerror.InternalServerError()
} }
// joinRoomUsingServer tries to join a remote room using a given matrix server. // joinRoomUsingServer tries to join a remote room using a given matrix server.
@ -288,7 +307,17 @@ func (r joinRoomReq) joinRoomUsingServers(
// server was invalid this returns an error. // server was invalid this returns an error.
// Otherwise this returns a JSONResponse. // Otherwise this returns a JSONResponse.
func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) { func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) {
respMakeJoin, err := r.federation.MakeJoin(r.req.Context(), server, roomID, r.userID) // Ask the room server for information about room versions.
var request api.QueryRoomVersionCapabilitiesRequest
var response api.QueryRoomVersionCapabilitiesResponse
if err := r.queryAPI.QueryRoomVersionCapabilities(r.req.Context(), &request, &response); err != nil {
return nil, err
}
var supportedVersions []gomatrixserverlib.RoomVersion
for version := range response.AvailableRoomVersions {
supportedVersions = append(supportedVersions, version)
}
respMakeJoin, err := r.federation.MakeJoin(r.req.Context(), server, roomID, r.userID, supportedVersions)
if err != nil { if err != nil {
// TODO: Check if the user was not allowed to join the room. // TODO: Check if the user was not allowed to join the room.
return nil, err return nil, err
@ -301,16 +330,29 @@ func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib
return nil, err return nil, err
} }
eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), r.cfg.Matrix.ServerName) if respMakeJoin.RoomVersion == "" {
respMakeJoin.RoomVersion = gomatrixserverlib.RoomVersionV1
}
if _, err = respMakeJoin.RoomVersion.EventFormat(); err != nil {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(
fmt.Sprintf("Room version '%s' is not supported", respMakeJoin.RoomVersion),
),
}, nil
}
event, err := respMakeJoin.JoinEvent.Build( event, err := respMakeJoin.JoinEvent.Build(
eventID, r.evTime, r.cfg.Matrix.ServerName, r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, r.evTime, r.cfg.Matrix.ServerName, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, respMakeJoin.RoomVersion,
) )
if err != nil { if err != nil {
res := httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("respMakeJoin.JoinEvent.Build failed")
res := jsonerror.InternalServerError()
return &res, nil return &res, nil
} }
respSendJoin, err := r.federation.SendJoin(r.req.Context(), server, event) respSendJoin, err := r.federation.SendJoin(r.req.Context(), server, event, respMakeJoin.RoomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -320,9 +362,12 @@ func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib
} }
if err = r.producer.SendEventWithState( if err = r.producer.SendEventWithState(
r.req.Context(), gomatrixserverlib.RespState(respSendJoin.RespState), event, r.req.Context(),
gomatrixserverlib.RespState(respSendJoin.RespState),
event.Headered(respMakeJoin.RoomVersion),
); err != nil { ); err != nil {
res := httputil.LogThenError(r.req, err) util.GetLogger(r.req.Context()).WithError(err).Error("gomatrixserverlib.RespState failed")
res := jsonerror.InternalServerError()
return &res, nil return &res, nil
} }

View file

@ -122,7 +122,8 @@ func Login(
token, err := auth.GenerateAccessToken() token, err := auth.GenerateAccessToken()
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("auth.GenerateAccessToken failed")
return jsonerror.InternalServerError()
} }
dev, err := getDevice(req.Context(), r, deviceDB, acc, token) dev, err := getDevice(req.Context(), r, deviceDB, acc, token)

View file

@ -19,7 +19,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -30,11 +30,13 @@ func Logout(
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil { if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevice failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -49,11 +51,13 @@ func LogoutAll(
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if err := deviceDB.RemoveAllDevices(req.Context(), localpart); err != nil { if err := deviceDB.RemoveAllDevices(req.Context(), localpart); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveAllDevices failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -45,6 +46,15 @@ func SendMembership(
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
) util.JSONResponse { ) util.JSONResponse {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := queryAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(err.Error()),
}
}
var body threepid.MembershipRequest var body threepid.MembershipRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
@ -90,13 +100,18 @@ func SendMembership(
JSON: jsonerror.NotFound(err.Error()), JSON: jsonerror.NotFound(err.Error()),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed")
return jsonerror.InternalServerError()
} }
if _, err := producer.SendEvents( if _, err := producer.SendEvents(
req.Context(), []gomatrixserverlib.Event{*event}, cfg.Matrix.ServerName, nil, req.Context(),
[]gomatrixserverlib.HeaderedEvent{(*event).Headered(verRes.RoomVersion)},
cfg.Matrix.ServerName,
nil,
); err != nil { ); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed")
return jsonerror.InternalServerError()
} }
var returnData interface{} = struct{}{} var returnData interface{} = struct{}{}
@ -242,7 +257,8 @@ func checkAndProcessThreepid(
JSON: jsonerror.NotFound(err.Error()), JSON: jsonerror.NotFound(err.Error()),
} }
} else if err != nil { } else if err != nil {
er := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed")
er := jsonerror.InternalServerError()
return inviteStored, &er return inviteStored, &er
} }
return return

View file

@ -17,8 +17,9 @@ package routing
import ( import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -26,10 +27,14 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type response struct { type getMembershipResponse struct {
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
} }
type getJoinedRoomsResponse struct {
JoinedRooms []string `json:"joined_rooms"`
}
// GetMemberships implements GET /rooms/{roomId}/members // GetMemberships implements GET /rooms/{roomId}/members
func GetMemberships( func GetMemberships(
req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool, req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool,
@ -43,7 +48,8 @@ func GetMemberships(
} }
var queryRes api.QueryMembershipsForRoomResponse var queryRes api.QueryMembershipsForRoomResponse
if err := queryAPI.QueryMembershipsForRoom(req.Context(), &queryReq, &queryRes); err != nil { if err := queryAPI.QueryMembershipsForRoom(req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryMembershipsForRoom failed")
return jsonerror.InternalServerError()
} }
if !queryRes.HasBeenInRoom { if !queryRes.HasBeenInRoom {
@ -55,6 +61,27 @@ func GetMemberships(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: response{queryRes.JoinEvents}, JSON: getMembershipResponse{queryRes.JoinEvents},
}
}
func GetJoinedRooms(
req *http.Request,
device *authtypes.Device,
accountsDB accounts.Database,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
joinedRooms, err := accountsDB.GetRoomIDsByLocalPart(req.Context(), localpart)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountsDB.GetRoomIDsByLocalPart failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: getJoinedRoomsResponse{joinedRooms},
} }
} }

View file

@ -50,7 +50,8 @@ func GetProfile(
} }
} }
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("getProfile failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -77,7 +78,8 @@ func GetAvatarURL(
} }
} }
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("getProfile failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -116,7 +118,8 @@ func SetAvatarURL(
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
evTime, err := httputil.ParseTSParam(req) evTime, err := httputil.ParseTSParam(req)
@ -129,16 +132,19 @@ func SetAvatarURL(
oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed")
return jsonerror.InternalServerError()
} }
if err = accountDB.SetAvatarURL(req.Context(), localpart, r.AvatarURL); err != nil { if err = accountDB.SetAvatarURL(req.Context(), localpart, r.AvatarURL); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.SetAvatarURL failed")
return jsonerror.InternalServerError()
} }
memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart) memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetMembershipsByLocalpart failed")
return jsonerror.InternalServerError()
} }
newProfile := authtypes.Profile{ newProfile := authtypes.Profile{
@ -151,15 +157,18 @@ func SetAvatarURL(
req.Context(), memberships, newProfile, userID, cfg, evTime, queryAPI, req.Context(), memberships, newProfile, userID, cfg, evTime, queryAPI,
) )
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError()
} }
if _, err := rsProducer.SendEvents(req.Context(), events, cfg.Matrix.ServerName, nil); err != nil { if _, err := rsProducer.SendEvents(req.Context(), events, cfg.Matrix.ServerName, nil); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("rsProducer.SendEvents failed")
return jsonerror.InternalServerError()
} }
if err := producer.SendUpdate(userID, changedKey, oldProfile.AvatarURL, r.AvatarURL); err != nil { if err := producer.SendUpdate(userID, changedKey, oldProfile.AvatarURL, r.AvatarURL); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("producer.SendUpdate failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -183,7 +192,8 @@ func GetDisplayName(
} }
} }
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("getProfile failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -222,7 +232,8 @@ func SetDisplayName(
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
evTime, err := httputil.ParseTSParam(req) evTime, err := httputil.ParseTSParam(req)
@ -235,16 +246,19 @@ func SetDisplayName(
oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed")
return jsonerror.InternalServerError()
} }
if err = accountDB.SetDisplayName(req.Context(), localpart, r.DisplayName); err != nil { if err = accountDB.SetDisplayName(req.Context(), localpart, r.DisplayName); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.SetDisplayName failed")
return jsonerror.InternalServerError()
} }
memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart) memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetMembershipsByLocalpart failed")
return jsonerror.InternalServerError()
} }
newProfile := authtypes.Profile{ newProfile := authtypes.Profile{
@ -257,15 +271,18 @@ func SetDisplayName(
req.Context(), memberships, newProfile, userID, cfg, evTime, queryAPI, req.Context(), memberships, newProfile, userID, cfg, evTime, queryAPI,
) )
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError()
} }
if _, err := rsProducer.SendEvents(req.Context(), events, cfg.Matrix.ServerName, nil); err != nil { if _, err := rsProducer.SendEvents(req.Context(), events, cfg.Matrix.ServerName, nil); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("rsProducer.SendEvents failed")
return jsonerror.InternalServerError()
} }
if err := producer.SendUpdate(userID, changedKey, oldProfile.DisplayName, r.DisplayName); err != nil { if err := producer.SendUpdate(userID, changedKey, oldProfile.DisplayName, r.DisplayName); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("producer.SendUpdate failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -321,10 +338,16 @@ func buildMembershipEvents(
memberships []authtypes.Membership, memberships []authtypes.Membership,
newProfile authtypes.Profile, userID string, cfg *config.Dendrite, newProfile authtypes.Profile, userID string, cfg *config.Dendrite,
evTime time.Time, queryAPI api.RoomserverQueryAPI, evTime time.Time, queryAPI api.RoomserverQueryAPI,
) ([]gomatrixserverlib.Event, error) { ) ([]gomatrixserverlib.HeaderedEvent, error) {
evs := []gomatrixserverlib.Event{} evs := []gomatrixserverlib.HeaderedEvent{}
for _, membership := range memberships { for _, membership := range memberships {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: membership.RoomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := queryAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
return []gomatrixserverlib.HeaderedEvent{}, err
}
builder := gomatrixserverlib.EventBuilder{ builder := gomatrixserverlib.EventBuilder{
Sender: userID, Sender: userID,
RoomID: membership.RoomID, RoomID: membership.RoomID,
@ -348,7 +371,7 @@ func buildMembershipEvents(
return nil, err return nil, err
} }
evs = append(evs, *event) evs = append(evs, (*event).Headered(verRes.RoomVersion))
} }
return evs, nil return evs, nil

View file

@ -444,7 +444,6 @@ func Register(
deviceDB devices.Database, deviceDB devices.Database,
cfg *config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
var r registerRequest var r registerRequest
resErr := httputil.UnmarshalJSONRequest(req, &r) resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil { if resErr != nil {
@ -472,7 +471,8 @@ func Register(
if r.Username == "" { if r.Username == "" {
id, err := accountDB.GetNewNumericLocalpart(req.Context()) id, err := accountDB.GetNewNumericLocalpart(req.Context())
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetNewNumericLocalpart failed")
return jsonerror.InternalServerError()
} }
r.Username = strconv.FormatInt(id, 10) r.Username = strconv.FormatInt(id, 10)
@ -516,15 +516,7 @@ func handleGuestRegistration(
accountDB accounts.Database, accountDB accounts.Database,
deviceDB devices.Database, deviceDB devices.Database,
) util.JSONResponse { ) util.JSONResponse {
acc, err := accountDB.CreateGuestAccount(req.Context())
//Generate numeric local part for guest user
id, err := accountDB.GetNewNumericLocalpart(req.Context())
if err != nil {
return httputil.LogThenError(req, err)
}
localpart := strconv.FormatInt(id, 10)
acc, err := accountDB.CreateAccount(req.Context(), localpart, "", "")
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -602,7 +594,8 @@ func handleRegistrationFlow(
valid, err := isValidMacLogin(cfg, r.Username, r.Password, r.Admin, r.Auth.Mac) valid, err := isValidMacLogin(cfg, r.Username, r.Password, r.Admin, r.Auth.Mac)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("isValidMacLogin failed")
return jsonerror.InternalServerError()
} else if !valid { } else if !valid {
return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") return util.MessageResponse(http.StatusForbidden, "HMAC incorrect")
} }
@ -758,7 +751,8 @@ func LegacyRegister(
valid, err := isValidMacLogin(cfg, r.Username, r.Password, r.Admin, r.Mac) valid, err := isValidMacLogin(cfg, r.Username, r.Password, r.Admin, r.Mac)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("isValidMacLogin failed")
return jsonerror.InternalServerError()
} }
if !valid { if !valid {

View file

@ -56,7 +56,8 @@ func GetTags(
_, data, err := obtainSavedTags(req, userID, roomID, accountDB) _, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError()
} }
if data == nil { if data == nil {
@ -99,20 +100,23 @@ func PutTag(
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError()
} }
var tagContent gomatrix.TagContent var tagContent gomatrix.TagContent
if data != nil { if data != nil {
if err = json.Unmarshal(data.Content, &tagContent); err != nil { if err = json.Unmarshal(data.Content, &tagContent); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
} }
} else { } else {
tagContent = newTag() tagContent = newTag()
} }
tagContent.Tags[tag] = properties tagContent.Tags[tag] = properties
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
} }
// Send data to syncProducer in order to inform clients of changes // Send data to syncProducer in order to inform clients of changes
@ -151,7 +155,8 @@ func DeleteTag(
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError()
} }
// If there are no tags in the database, exit // If there are no tags in the database, exit
@ -166,7 +171,8 @@ func DeleteTag(
var tagContent gomatrix.TagContent var tagContent gomatrix.TagContent
err = json.Unmarshal(data.Content, &tagContent) err = json.Unmarshal(data.Content, &tagContent)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
} }
// Check whether the tag to be deleted exists // Check whether the tag to be deleted exists
@ -180,7 +186,8 @@ func DeleteTag(
} }
} }
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
} }
// Send data to syncProducer in order to inform clients of changes // Send data to syncProducer in order to inform clients of changes

View file

@ -105,6 +105,12 @@ func Setup(
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/joined_rooms",
common.MakeAuthAPI("joined_rooms", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return GetJoinedRooms(req, device, accountDB)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}", r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}",
common.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { common.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req)) vars, err := common.URLDecodeMapValues(mux.Vars(req))
@ -390,7 +396,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/thirdparty/protocols", r0mux.Handle("/thirdparty/protocols",
common.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { common.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse {
// TODO: Return the third party protcols // TODO: Return the third party protcols
return util.JSONResponse{ return util.JSONResponse{

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid // http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid
@ -48,6 +49,15 @@ func SendEvent(
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
txnCache *transactions.Cache, txnCache *transactions.Cache,
) util.JSONResponse { ) util.JSONResponse {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := queryAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(err.Error()),
}
}
if txnID != nil { if txnID != nil {
// Try to fetch response from transactionsCache // Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
@ -71,11 +81,22 @@ func SendEvent(
// pass the new event to the roomserver and receive the correct event ID // pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded // event ID in case of duplicate transaction is discarded
eventID, err := producer.SendEvents( eventID, err := producer.SendEvents(
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndSessionID, req.Context(),
[]gomatrixserverlib.HeaderedEvent{
e.Headered(verRes.RoomVersion),
},
cfg.Matrix.ServerName,
txnAndSessionID,
) )
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed")
return jsonerror.InternalServerError()
} }
util.GetLogger(req.Context()).WithFields(logrus.Fields{
"event_id": eventID,
"room_id": roomID,
"room_version": verRes.RoomVersion,
}).Info("Sent event to roomserver")
res := util.JSONResponse{ res := util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -121,7 +142,8 @@ func generateSendEvent(
} }
err = builder.SetContent(r) err = builder.SetContent(r)
if err != nil { if err != nil {
resErr := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed")
resErr := jsonerror.InternalServerError()
return nil, &resErr return nil, &resErr
} }
@ -133,14 +155,15 @@ func generateSendEvent(
JSON: jsonerror.NotFound("Room does not exist"), JSON: jsonerror.NotFound("Room does not exist"),
} }
} else if err != nil { } else if err != nil {
resErr := httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("common.BuildEvent failed")
resErr := jsonerror.InternalServerError()
return nil, &resErr return nil, &resErr
} }
// check to see if this user can perform this operation // check to see if this user can perform this operation
stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents))
for i := range queryRes.StateEvents { for i := range queryRes.StateEvents {
stateEvents[i] = &queryRes.StateEvents[i] stateEvents[i] = &queryRes.StateEvents[i].Event
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(*e, &provider); err != nil { if err = gomatrixserverlib.Allowed(*e, &provider); err != nil {

View file

@ -46,7 +46,8 @@ func SendTyping(
localpart, err := userutil.ParseUsernameParam(userID, nil) localpart, err := userutil.ParseUsernameParam(userID, nil)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("userutil.ParseUsernameParam failed")
return jsonerror.InternalServerError()
} }
// Verify that the user is a member of this room // Verify that the user is a member of this room
@ -57,7 +58,8 @@ func SendTyping(
JSON: jsonerror.Forbidden("User not in this room"), JSON: jsonerror.Forbidden("User not in this room"),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetMembershipInRoomByLocalPart failed")
return jsonerror.InternalServerError()
} }
// parse the incoming http request // parse the incoming http request
@ -70,7 +72,8 @@ func SendTyping(
if err = typingProducer.Send( if err = typingProducer.Send(
req.Context(), userID, roomID, r.Typing, r.Timeout, req.Context(), userID, roomID, r.Typing, r.Timeout,
); err != nil { ); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("typingProducer.Send failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -51,7 +51,8 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
// Check if the 3PID is already in use locally // Check if the 3PID is already in use locally
localpart, err := accountDB.GetLocalpartForThreePID(req.Context(), body.Email, "email") localpart, err := accountDB.GetLocalpartForThreePID(req.Context(), body.Email, "email")
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetLocalpartForThreePID failed")
return jsonerror.InternalServerError()
} }
if len(localpart) > 0 { if len(localpart) > 0 {
@ -71,7 +72,8 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
JSON: jsonerror.NotTrusted(body.IDServer), JSON: jsonerror.NotTrusted(body.IDServer),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("threepid.CreateSession failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -98,7 +100,8 @@ func CheckAndSave3PIDAssociation(
JSON: jsonerror.NotTrusted(body.Creds.IDServer), JSON: jsonerror.NotTrusted(body.Creds.IDServer),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed")
return jsonerror.InternalServerError()
} }
if !verified { if !verified {
@ -120,18 +123,21 @@ func CheckAndSave3PIDAssociation(
JSON: jsonerror.NotTrusted(body.Creds.IDServer), JSON: jsonerror.NotTrusted(body.Creds.IDServer),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("threepid.PublishAssociation failed")
return jsonerror.InternalServerError()
} }
} }
// Save the association in the database // Save the association in the database
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if err = accountDB.SaveThreePIDAssociation(req.Context(), address, localpart, medium); err != nil { if err = accountDB.SaveThreePIDAssociation(req.Context(), address, localpart, medium); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountsDB.SaveThreePIDAssociation failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -146,12 +152,14 @@ func GetAssociated3PIDs(
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
threepids, err := accountDB.GetThreePIDsForLocalpart(req.Context(), localpart) threepids, err := accountDB.GetThreePIDsForLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetThreePIDsForLocalpart failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -168,7 +176,8 @@ func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONRespons
} }
if err := accountDB.RemoveThreePIDAssociation(req.Context(), body.Address, body.Medium); err != nil { if err := accountDB.RemoveThreePIDAssociation(req.Context(), body.Address, body.Medium); err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("accountDB.RemoveThreePIDAssociation failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -23,7 +23,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -56,7 +56,8 @@ func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg *config.
_, err := mac.Write([]byte(resp.Username)) _, err := mac.Write([]byte(resp.Username))
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("mac.Write failed")
return jsonerror.InternalServerError()
} }
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID) resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)

View file

@ -353,12 +353,19 @@ func emit3PIDInviteEvent(
return err return err
} }
var queryRes *api.QueryLatestEventsAndStateResponse queryRes := api.QueryLatestEventsAndStateResponse{}
event, err := common.BuildEvent(ctx, builder, cfg, evTime, queryAPI, queryRes) event, err := common.BuildEvent(ctx, builder, cfg, evTime, queryAPI, &queryRes)
if err != nil { if err != nil {
return err return err
} }
_, err = producer.SendEvents(ctx, []gomatrixserverlib.Event{*event}, cfg.Matrix.ServerName, nil) _, err = producer.SendEvents(
ctx,
[]gomatrixserverlib.HeaderedEvent{
(*event).Headered(queryRes.RoomVersion),
},
cfg.Matrix.ServerName,
nil,
)
return err return err
} }

View file

@ -103,12 +103,14 @@ func main() {
// Build an event and write the event to the output. // Build an event and write the event to the output.
func buildAndOutput() gomatrixserverlib.EventReference { func buildAndOutput() gomatrixserverlib.EventReference {
eventID++ eventID++
id := fmt.Sprintf("$%d:%s", eventID, *serverName)
now = time.Unix(0, 0) now = time.Unix(0, 0)
name := gomatrixserverlib.ServerName(*serverName) name := gomatrixserverlib.ServerName(*serverName)
key := gomatrixserverlib.KeyID(*keyID) key := gomatrixserverlib.KeyID(*keyID)
event, err := b.Build(id, now, name, key, privateKey) event, err := b.Build(
now, name, key, privateKey,
gomatrixserverlib.RoomVersionV1,
)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -125,9 +127,9 @@ func writeEvent(event gomatrixserverlib.Event) {
if *format == "InputRoomEvent" { if *format == "InputRoomEvent" {
var ire api.InputRoomEvent var ire api.InputRoomEvent
ire.Kind = api.KindNew ire.Kind = api.KindNew
ire.Event = event ire.Event = event.Headered(gomatrixserverlib.RoomVersionV1)
authEventIDs := []string{} authEventIDs := []string{}
for _, ref := range b.AuthEvents { for _, ref := range b.AuthEvents.([]gomatrixserverlib.EventReference) {
authEventIDs = append(authEventIDs, ref.EventID) authEventIDs = append(authEventIDs, ref.EventID)
} }
ire.AuthEventIDs = authEventIDs ire.AuthEventIDs = authEventIDs

View file

@ -69,7 +69,7 @@ func main() {
) )
federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI) federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI)
mediaapi.SetupMediaAPIComponent(base, deviceDB) mediaapi.SetupMediaAPIComponent(base, deviceDB)
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, query) publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, query, federation, nil)
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg) syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg)
httpHandler := common.WrapHandlerInCORS(base.APIMux) httpHandler := common.WrapHandlerInCORS(base.APIMux)

View file

@ -28,7 +28,7 @@ func main() {
_, _, query := base.CreateHTTPRoomserverAPIs() _, _, query := base.CreateHTTPRoomserverAPIs()
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, query) publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, query, nil, nil)
base.SetupAndServeHTTP(string(base.Cfg.Bind.PublicRoomsAPI), string(base.Cfg.Listen.PublicRoomsAPI)) base.SetupAndServeHTTP(string(base.Cfg.Bind.PublicRoomsAPI), string(base.Cfg.Listen.PublicRoomsAPI))

104
cmd/dendritejs/jsServer.go Normal file
View file

@ -0,0 +1,104 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build wasm
package main
import (
"bufio"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"syscall/js"
)
// JSServer exposes an HTTP-like server interface which allows JS to 'send' requests to it.
type JSServer struct {
// The router which will service requests
Mux *http.ServeMux
}
// OnRequestFromJS is the function that JS will invoke when there is a new request.
// The JS function signature is:
// function(reqString: string): Promise<{result: string, error: string}>
// Usage is like:
// const res = await global._go_js_server.fetch(reqString);
// if (res.error) {
// // handle error: this is a 'network' error, not a non-2xx error.
// }
// const rawHttpResponse = res.result;
func (h *JSServer) OnRequestFromJS(this js.Value, args []js.Value) interface{} {
// we HAVE to spawn a new goroutine and return immediately or else Go will deadlock
// if this request blocks at all e.g for /sync calls
httpStr := args[0].String()
promise := js.Global().Get("Promise").New(js.FuncOf(func(pthis js.Value, pargs []js.Value) interface{} {
// The initial callback code for new Promise() is also called on the critical path, which is why
// we need to put this in an immediately invoked goroutine.
go func() {
resolve := pargs[0]
fmt.Println("Received request:")
fmt.Printf("%s\n", httpStr)
resStr, err := h.handle(httpStr)
errStr := ""
if err != nil {
errStr = err.Error()
}
fmt.Println("Sending response:")
fmt.Printf("%s\n", resStr)
resolve.Invoke(map[string]interface{}{
"result": resStr,
"error": errStr,
})
}()
return nil
}))
return promise
}
// handle invokes the http.ServeMux for this request and returns the raw HTTP response.
func (h *JSServer) handle(httpStr string) (resStr string, err error) {
req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(httpStr)))
if err != nil {
return
}
w := httptest.NewRecorder()
h.Mux.ServeHTTP(w, req)
res := w.Result()
var resBuffer strings.Builder
err = res.Write(&resBuffer)
return resBuffer.String(), err
}
// ListenAndServe registers a variable in JS-land with the given namespace. This variable is
// a function which JS-land can call to 'send' HTTP requests. The function is attached to
// a global object called "_go_js_server". See OnRequestFromJS for more info.
func (h *JSServer) ListenAndServe(namespace string) {
globalName := "_go_js_server"
// register a hook in JS-land for it to invoke stuff
server := js.Global().Get(globalName)
if !server.Truthy() {
server = js.Global().Get("Object").New()
js.Global().Set(globalName, server)
}
server.Set(namespace, js.FuncOf(h.OnRequestFromJS))
fmt.Printf("Listening for requests from JS on function %s.%s\n", globalName, namespace)
// Block forever to mimic http.ListenAndServe
select {}
}

View file

@ -0,0 +1,84 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build wasm
package main
import (
"context"
"fmt"
"time"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
const libp2pMatrixKeyID = "ed25519:libp2p-dendrite"
type libp2pKeyFetcher struct {
}
// FetchKeys looks up a batch of public keys.
// Takes a map from (server name, key ID) pairs to timestamp.
// The timestamp is when the keys need to be vaild up to.
// Returns a map from (server name, key ID) pairs to server key objects for
// that server name containing that key ID
// The result may have fewer (server name, key ID) pairs than were in the request.
// The result may have more (server name, key ID) pairs than were in the request.
// Returns an error if there was a problem fetching the keys.
func (f *libp2pKeyFetcher) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
res := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult)
for req := range requests {
if req.KeyID != libp2pMatrixKeyID {
return nil, fmt.Errorf("FetchKeys: cannot fetch key with ID %s, should be %s", req.KeyID, libp2pMatrixKeyID)
}
// The server name is a libp2p peer ID
peerIDStr := string(req.ServerName)
peerID, err := peer.Decode(peerIDStr)
if err != nil {
return nil, fmt.Errorf("Failed to decode peer ID from server name '%s': %w", peerIDStr, err)
}
pubKey, err := peerID.ExtractPublicKey()
if err != nil {
return nil, fmt.Errorf("Failed to extract public key from peer ID: %w", err)
}
pubKeyBytes, err := pubKey.Raw()
if err != nil {
return nil, fmt.Errorf("Failed to extract raw bytes from public key: %w", err)
}
util.GetLogger(ctx).Info("libp2pKeyFetcher.FetchKeys: Using public key %v for server name %s", pubKeyBytes, req.ServerName)
b64Key := gomatrixserverlib.Base64String(pubKeyBytes)
res[req] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: gomatrixserverlib.VerifyKey{
Key: b64Key,
},
ExpiredTS: gomatrixserverlib.PublicKeyNotExpired,
ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(24 * time.Hour * 365)),
}
}
return res, nil
}
// FetcherName returns the name of this fetcher, which can then be used for
// logging errors etc.
func (f *libp2pKeyFetcher) FetcherName() string {
return "libp2pKeyFetcher"
}

167
cmd/dendritejs/main.go Normal file
View file

@ -0,0 +1,167 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build wasm
package main
import (
"crypto/ed25519"
"fmt"
"net/http"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/clientapi"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/common/transactions"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationsender"
"github.com/matrix-org/dendrite/mediaapi"
"github.com/matrix-org/dendrite/publicroomsapi"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/syncapi"
"github.com/matrix-org/dendrite/typingserver"
"github.com/matrix-org/dendrite/typingserver/cache"
"github.com/matrix-org/go-http-js-libp2p/go_http_js_libp2p"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
_ "github.com/matrix-org/go-sqlite3-js"
)
func init() {
fmt.Println("dendrite.js starting...")
}
func generateKey() ed25519.PrivateKey {
_, priv, err := ed25519.GenerateKey(nil)
if err != nil {
logrus.Fatalf("Failed to generate ed25519 key: %s", err)
}
return priv
}
func createFederationClient(cfg *config.Dendrite, node *go_http_js_libp2p.P2pLocalNode) *gomatrixserverlib.FederationClient {
fmt.Println("Running in js-libp2p federation mode")
fmt.Println("Warning: Federation with non-libp2p homeservers will not work in this mode yet!")
tr := go_http_js_libp2p.NewP2pTransport(node)
fed := gomatrixserverlib.NewFederationClient(
cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey,
)
fed.Client = *gomatrixserverlib.NewClientWithTransport(tr)
return fed
}
func createP2PNode(privKey ed25519.PrivateKey) (serverName string, node *go_http_js_libp2p.P2pLocalNode) {
hosted := "/dns4/rendezvous.matrix.org/tcp/8443/wss/p2p-websocket-star/"
node = go_http_js_libp2p.NewP2pLocalNode("org.matrix.p2p.experiment", privKey.Seed(), []string{hosted})
serverName = node.Id
fmt.Println("p2p assigned ServerName: ", serverName)
return
}
func main() {
cfg := &config.Dendrite{}
cfg.SetDefaults()
cfg.Kafka.UseNaffka = true
cfg.Database.Account = "file:dendritejs_account.db"
cfg.Database.AppService = "file:dendritejs_appservice.db"
cfg.Database.Device = "file:dendritejs_device.db"
cfg.Database.FederationSender = "file:dendritejs_fedsender.db"
cfg.Database.MediaAPI = "file:dendritejs_mediaapi.db"
cfg.Database.Naffka = "file:dendritejs_naffka.db"
cfg.Database.PublicRoomsAPI = "file:dendritejs_publicrooms.db"
cfg.Database.RoomServer = "file:dendritejs_roomserver.db"
cfg.Database.ServerKey = "file:dendritejs_serverkey.db"
cfg.Database.SyncAPI = "file:dendritejs_syncapi.db"
cfg.Kafka.Topics.UserUpdates = "user_updates"
cfg.Kafka.Topics.OutputTypingEvent = "output_typing_event"
cfg.Kafka.Topics.OutputClientData = "output_client_data"
cfg.Kafka.Topics.OutputRoomEvent = "output_room_event"
cfg.Matrix.TrustedIDServers = []string{
"matrix.org", "vector.im",
}
cfg.Matrix.KeyID = libp2pMatrixKeyID
cfg.Matrix.PrivateKey = generateKey()
serverName, node := createP2PNode(cfg.Matrix.PrivateKey)
cfg.Matrix.ServerName = gomatrixserverlib.ServerName(serverName)
if err := cfg.Derive(); err != nil {
logrus.Fatalf("Failed to derive values from config: %s", err)
}
base := basecomponent.NewBaseDendrite(cfg, "Monolith")
defer base.Close() // nolint: errcheck
accountDB := base.CreateAccountsDB()
deviceDB := base.CreateDeviceDB()
keyDB := base.CreateKeyDB()
federation := createFederationClient(cfg, node)
keyRing := gomatrixserverlib.KeyRing{
KeyFetchers: []gomatrixserverlib.KeyFetcher{
&libp2pKeyFetcher{},
},
KeyDatabase: keyDB,
}
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node)
alias, input, query := roomserver.SetupRoomServerComponent(base)
typingInputAPI := typingserver.SetupTypingServerComponent(base, cache.NewTypingCache())
asQuery := appservice.SetupAppServiceAPIComponent(
base, accountDB, deviceDB, federation, alias, query, transactions.New(),
)
fedSenderAPI := federationsender.SetupFederationSenderComponent(base, federation, query)
clientapi.SetupClientAPIComponent(
base, deviceDB, accountDB,
federation, &keyRing, alias, input, query,
typingInputAPI, asQuery, transactions.New(), fedSenderAPI,
)
federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI)
mediaapi.SetupMediaAPIComponent(base, deviceDB)
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, query, federation, p2pPublicRoomProvider)
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg)
httpHandler := common.WrapHandlerInCORS(base.APIMux)
http.Handle("/", httpHandler)
// Expose the matrix APIs via libp2p-js - for federation traffic
if node != nil {
go func() {
logrus.Info("Listening on libp2p-js host ID ", node.Id)
s := JSServer{
Mux: http.DefaultServeMux,
}
s.ListenAndServe("p2p")
}()
}
// Expose the matrix APIs via fetch - for local traffic
go func() {
logrus.Info("Listening for service-worker fetch traffic")
s := JSServer{
Mux: http.DefaultServeMux,
}
s.ListenAndServe("fetch")
}()
// We want to block forever to let the fetch and libp2p handler serve the APIs
select {}
}

View file

@ -0,0 +1,23 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !wasm
package main
import "fmt"
func main() {
fmt.Println("dendritejs: no-op when not compiling for WebAssembly")
}

View file

@ -0,0 +1,46 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build wasm
package main
import (
"github.com/matrix-org/go-http-js-libp2p/go_http_js_libp2p"
)
type libp2pPublicRoomsProvider struct {
node *go_http_js_libp2p.P2pLocalNode
providers []go_http_js_libp2p.PeerInfo
}
func NewLibP2PPublicRoomsProvider(node *go_http_js_libp2p.P2pLocalNode) *libp2pPublicRoomsProvider {
p := &libp2pPublicRoomsProvider{
node: node,
}
node.RegisterFoundProviders(p.foundProviders)
return p
}
func (p *libp2pPublicRoomsProvider) foundProviders(peerInfos []go_http_js_libp2p.PeerInfo) {
p.providers = peerInfos
}
func (p *libp2pPublicRoomsProvider) Homeservers() []string {
result := make([]string, len(p.providers))
for i := range p.providers {
result[i] = p.providers[i].Id
}
return result
}

View file

@ -104,7 +104,7 @@ func clientEventJSONForOutputRoomEvent(outputRoomEvent string) string {
panic("failed to unmarshal output room event: " + err.Error()) panic("failed to unmarshal output room event: " + err.Error())
} }
clientEvs := gomatrixserverlib.ToClientEvents([]gomatrixserverlib.Event{ clientEvs := gomatrixserverlib.ToClientEvents([]gomatrixserverlib.Event{
out.NewRoomEvent.Event, out.NewRoomEvent.Event.Event,
}, gomatrixserverlib.FormatSync) }, gomatrixserverlib.FormatSync)
b, err := json.Marshal(clientEvs[0]) b, err := json.Marshal(clientEvs[0])
if err != nil { if err != nil {

View file

@ -216,7 +216,7 @@ func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
uri, err := url.Parse(string(cfg.Database.Naffka)) uri, err := url.Parse(string(cfg.Database.Naffka))
if err != nil || uri.Scheme == "file" { if err != nil || uri.Scheme == "file" {
db, err = sql.Open("sqlite3", string(cfg.Database.Naffka)) db, err = sql.Open(common.SQLiteDriverName(), string(cfg.Database.Naffka))
if err != nil { if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database") logrus.WithError(err).Panic("Failed to open naffka database")
} }

View file

@ -224,6 +224,8 @@ type Dendrite struct {
// The config for tracing the dendrite servers. // The config for tracing the dendrite servers.
Tracing struct { Tracing struct {
// Set to true to enable tracer hooks. If false, no tracing is set up.
Enabled bool `yaml:"enabled"`
// The config for the jaeger opentracing reporter. // The config for the jaeger opentracing reporter.
Jaeger jaegerconfig.Configuration `yaml:"jaeger"` Jaeger jaegerconfig.Configuration `yaml:"jaeger"`
} `yaml:"tracing"` } `yaml:"tracing"`
@ -365,7 +367,7 @@ func loadConfig(
return nil, err return nil, err
} }
config.setDefaults() config.SetDefaults()
if err = config.check(monolithic); err != nil { if err = config.check(monolithic); err != nil {
return nil, err return nil, err
@ -398,7 +400,7 @@ func loadConfig(
config.Media.AbsBasePath = Path(absPath(basePath, config.Media.BasePath)) config.Media.AbsBasePath = Path(absPath(basePath, config.Media.BasePath))
// Generate data from config options // Generate data from config options
err = config.derive() err = config.Derive()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -406,9 +408,9 @@ func loadConfig(
return &config, nil return &config, nil
} }
// derive generates data that is derived from various values provided in // Derive generates data that is derived from various values provided in
// the config file. // the config file.
func (config *Dendrite) derive() error { func (config *Dendrite) Derive() error {
// Determine registrations flows based off config values // Determine registrations flows based off config values
config.Derived.Registration.Params = make(map[string]interface{}) config.Derived.Registration.Params = make(map[string]interface{})
@ -433,8 +435,8 @@ func (config *Dendrite) derive() error {
return nil return nil
} }
// setDefaults sets default config values if they are not explicitly set. // SetDefaults sets default config values if they are not explicitly set.
func (config *Dendrite) setDefaults() { func (config *Dendrite) SetDefaults() {
if config.Matrix.KeyValidityPeriod == 0 { if config.Matrix.KeyValidityPeriod == 0 {
config.Matrix.KeyValidityPeriod = 24 * time.Hour config.Matrix.KeyValidityPeriod = 24 * time.Hour
} }
@ -703,6 +705,9 @@ func (config *Dendrite) FederationSenderURL() string {
// SetupTracing configures the opentracing using the supplied configuration. // SetupTracing configures the opentracing using the supplied configuration.
func (config *Dendrite) SetupTracing(serviceName string) (closer io.Closer, err error) { func (config *Dendrite) SetupTracing(serviceName string) (closer io.Closer, err error) {
if !config.Tracing.Enabled {
return ioutil.NopCloser(bytes.NewReader([]byte{})), nil
}
return config.Tracing.Jaeger.InitGlobalTracer( return config.Tracing.Jaeger.InitGlobalTracer(
serviceName, serviceName,
jaegerconfig.Logger(logrusLogger{logrus.StandardLogger()}), jaegerconfig.Logger(logrusLogger{logrus.StandardLogger()}),

View file

@ -112,7 +112,7 @@ func (c *ContinualConsumer) consumePartition(pc sarama.PartitionConsumer) {
msgErr := c.ProcessMessage(message) msgErr := c.ProcessMessage(message)
// Advance our position in the stream so that we will start at the right position after a restart. // Advance our position in the stream so that we will start at the right position after a restart.
if err := c.PartitionStore.SetPartitionOffset(context.TODO(), c.Topic, message.Partition, message.Offset); err != nil { if err := c.PartitionStore.SetPartitionOffset(context.TODO(), c.Topic, message.Partition, message.Offset); err != nil {
panic(fmt.Errorf("the ContinualConsumer failed to SetPartitionOffset: %s", err)) panic(fmt.Errorf("the ContinualConsumer failed to SetPartitionOffset: %w", err))
} }
// Shutdown if we were told to do so. // Shutdown if we were told to do so.
if msgErr == ErrShutdown { if msgErr == ErrShutdown {

View file

@ -17,14 +17,12 @@ package common
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"time" "time"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
// ErrRoomNoExists is returned when trying to lookup the state of a room that // ErrRoomNoExists is returned when trying to lookup the state of a room that
@ -42,13 +40,19 @@ func BuildEvent(
builder *gomatrixserverlib.EventBuilder, cfg *config.Dendrite, evTime time.Time, builder *gomatrixserverlib.EventBuilder, cfg *config.Dendrite, evTime time.Time,
queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse, queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
if queryRes == nil {
queryRes = &api.QueryLatestEventsAndStateResponse{}
}
err := AddPrevEventsToEvent(ctx, builder, queryAPI, queryRes) err := AddPrevEventsToEvent(ctx, builder, queryAPI, queryRes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), cfg.Matrix.ServerName) event, err := builder.Build(
event, err := builder.Build(eventID, evTime, cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey) evTime, cfg.Matrix.ServerName, cfg.Matrix.KeyID,
cfg.Matrix.PrivateKey, queryRes.RoomVersion,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -72,9 +76,6 @@ func AddPrevEventsToEvent(
RoomID: builder.RoomID, RoomID: builder.RoomID,
StateToFetch: eventsNeeded.Tuples(), StateToFetch: eventsNeeded.Tuples(),
} }
if queryRes == nil {
queryRes = &api.QueryLatestEventsAndStateResponse{}
}
if err = queryAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes); err != nil { if err = queryAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes); err != nil {
return err return err
} }
@ -83,13 +84,17 @@ func AddPrevEventsToEvent(
return ErrRoomNoExists return ErrRoomNoExists
} }
eventFormat, err := queryRes.RoomVersion.EventFormat()
if err != nil {
return err
}
builder.Depth = queryRes.Depth builder.Depth = queryRes.Depth
builder.PrevEvents = queryRes.LatestEvents
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
for i := range queryRes.StateEvents { for i := range queryRes.StateEvents {
err = authEvents.AddEvent(&queryRes.StateEvents[i]) err = authEvents.AddEvent(&queryRes.StateEvents[i].Event)
if err != nil { if err != nil {
return err return err
} }
@ -99,7 +104,23 @@ func AddPrevEventsToEvent(
if err != nil { if err != nil {
return err return err
} }
switch eventFormat {
case gomatrixserverlib.EventFormatV1:
builder.AuthEvents = refs builder.AuthEvents = refs
builder.PrevEvents = queryRes.LatestEvents
case gomatrixserverlib.EventFormatV2:
v2AuthRefs := []string{}
v2PrevRefs := []string{}
for _, ref := range refs {
v2AuthRefs = append(v2AuthRefs, ref.EventID)
}
for _, ref := range queryRes.LatestEvents {
v2PrevRefs = append(v2PrevRefs, ref.EventID)
}
builder.AuthEvents = v2AuthRefs
builder.PrevEvents = v2PrevRefs
}
return nil return nil
} }

View file

@ -25,6 +25,10 @@ func MakeAuthAPI(
if err != nil { if err != nil {
return *err return *err
} }
// add the user ID to the logger
logger := util.GetLogger((req.Context()))
logger = logger.WithField("user_id", device.UserID)
req = req.WithContext(util.ContextWithLogger(req.Context(), logger))
return f(req, device) return f(req, device)
} }

13
common/keydb/interface.go Normal file
View file

@ -0,0 +1,13 @@
package keydb
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
FetcherName() string
FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error)
StoreKeys(ctx context.Context, keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error
}

View file

@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// +build !wasm
package keydb package keydb
import ( import (
"context"
"net/url" "net/url"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@ -25,12 +26,6 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type Database interface {
FetcherName() string
FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error)
StoreKeys(ctx context.Context, keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error
}
// NewDatabase opens a database connection. // NewDatabase opens a database connection.
func NewDatabase( func NewDatabase(
dataSourceName string, dataSourceName string,

View file

@ -0,0 +1,46 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package keydb
import (
"fmt"
"net/url"
"golang.org/x/crypto/ed25519"
"github.com/matrix-org/dendrite/common/keydb/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a database connection.
func NewDatabase(
dataSourceName string,
serverName gomatrixserverlib.ServerName,
serverKey ed25519.PublicKey,
serverKeyID gomatrixserverlib.KeyID,
) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return nil, err
}
switch uri.Scheme {
case "postgres":
return nil, fmt.Errorf("Cannot use postgres implementation")
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
default:
return nil, fmt.Errorf("Cannot use postgres implementation")
}
}

View file

@ -19,6 +19,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -91,7 +93,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck defer common.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() { for rows.Next() {
var serverName string var serverName string

View file

@ -22,6 +22,7 @@ import (
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -43,7 +44,7 @@ func NewDatabase(
serverKey ed25519.PublicKey, serverKey ed25519.PublicKey,
serverKeyID gomatrixserverlib.KeyID, serverKeyID gomatrixserverlib.KeyID,
) (*Database, error) { ) (*Database, error) {
db, err := sql.Open("sqlite3", dataSourceName) db, err := sql.Open(common.SQLiteDriverName(), dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -18,9 +18,12 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"github.com/lib/pq" lru "github.com/hashicorp/golang-lru"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
const serverKeysSchema = ` const serverKeysSchema = `
@ -60,11 +63,19 @@ const upsertServerKeysSQL = "" +
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6" " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
type serverKeyStatements struct { type serverKeyStatements struct {
db *sql.DB
bulkSelectServerKeysStmt *sql.Stmt bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt
cache *lru.Cache // nameAndKeyID => gomatrixserverlib.PublicKeyLookupResult
} }
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) { func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.cache, err = lru.New(64)
if err != nil {
return
}
_, err = db.Exec(serverKeysSchema) _, err = db.Exec(serverKeysSchema)
if err != nil { if err != nil {
return return
@ -86,12 +97,34 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
for request := range requests { for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
} }
stmt := s.bulkSelectServerKeysStmt
rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs)) // If we can satisfy all of the requests from the cache, do so. TODO: Allow partial matches with merges.
cacheResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for request := range requests {
r, ok := s.cache.Get(nameAndKeyID(request))
if !ok {
break
}
cacheResult := r.(gomatrixserverlib.PublicKeyLookupResult)
cacheResults[request] = cacheResult
}
if len(cacheResults) == len(requests) {
util.GetLogger(ctx).Infof("KeyDB cache hit for %d keys", len(cacheResults))
return cacheResults, nil
}
query := strings.Replace(bulkSelectServerKeysSQL, "($1)", common.QueryVariadic(len(nameAndKeyIDs)), 1)
iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
for i, v := range nameAndKeyIDs {
iKeyIDs[i] = v
}
rows, err := s.db.QueryContext(ctx, query, iKeyIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck defer common.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() { for rows.Next() {
var serverName string var serverName string
@ -125,6 +158,7 @@ func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyLookupRequest, request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult, key gomatrixserverlib.PublicKeyLookupResult,
) error { ) error {
s.cache.Add(nameAndKeyID(request), key)
_, err := s.upsertServerKeysStmt.ExecContext( _, err := s.upsertServerKeysStmt.ExecContext(
ctx, ctx,
string(request.ServerName), string(request.ServerName),

View file

@ -15,13 +15,17 @@
package common package common
import ( import (
"context"
"fmt" "fmt"
"io"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dugong" "github.com/matrix-org/dugong"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -156,3 +160,17 @@ func setupFileHook(hook config.LogrusHook, level logrus.Level, componentName str
), ),
}) })
} }
//CloseAndLogIfError Closes io.Closer and logs the error if any
func CloseAndLogIfError(ctx context.Context, closer io.Closer, message string) {
if closer == nil {
return
}
err := closer.Close()
if ctx == nil {
ctx = context.TODO()
}
if err != nil {
util.GetLogger(ctx).WithError(err).Error(message)
}
}

View file

@ -90,7 +90,7 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck defer CloseAndLogIfError(ctx, rows, "selectPartitionOffsets: rows.close() failed")
var results []PartitionOffset var results []PartitionOffset
for rows.Next() { for rows.Next() {
var offset PartitionOffset var offset PartitionOffset

25
common/postgres.go Normal file
View file

@ -0,0 +1,25 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !wasm
package common
import "github.com/lib/pq"
// IsUniqueConstraintViolationErr returns true if the error is a postgresql unique_violation error
func IsUniqueConstraintViolationErr(err error) bool {
pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505"
}

22
common/postgres_wasm.go Normal file
View file

@ -0,0 +1,22 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build wasm
package common
// IsUniqueConstraintViolationErr no-ops for this architecture
func IsUniqueConstraintViolationErr(err error) bool {
return false
}

View file

@ -17,8 +17,7 @@ package common
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"runtime"
"github.com/lib/pq"
) )
// A Transaction is something that can be committed or rolledback. // A Transaction is something that can be committed or rolledback.
@ -77,12 +76,6 @@ func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
return statement return statement
} }
// IsUniqueConstraintViolationErr returns true if the error is a postgresql unique_violation error
func IsUniqueConstraintViolationErr(err error) bool {
pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505"
}
// Hack of the century // Hack of the century
func QueryVariadic(count int) string { func QueryVariadic(count int) string {
return QueryVariadicOffset(count, 0) return QueryVariadicOffset(count, 0)
@ -99,3 +92,10 @@ func QueryVariadicOffset(count, offset int) string {
str += ")" str += ")"
return str return str
} }
func SQLiteDriverName() string {
if runtime.GOOS == "js" {
return "sqlite3_js"
}
return "sqlite3"
}

View file

@ -11,10 +11,10 @@ and start working on dendrite.
### Configuration ### Configuration
Copy the `dendrite-docker.yaml` file to the root of the project and rename it to Create a directory named `cfg` in the root of the project. Copy the
`dendrite.yaml`. It already contains the defaults used in `docker-compose` for `dendrite-docker.yaml` file into that directory and rename it to `dendrite.yaml`.
networking so you will only have to change things like the `server_name` or to It already contains the defaults used in `docker-compose` for networking so you will
toggle `naffka`. only have to change things like the `server_name` or to toggle `naffka`.
You can run the following `docker-compose` commands either from the top directory You can run the following `docker-compose` commands either from the top directory
specifying the `docker-compose` file specifying the `docker-compose` file

View file

@ -42,7 +42,7 @@ func SetupFederationAPIComponent(
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federationSenderAPI federationSenderAPI.FederationSenderQueryAPI, federationSenderAPI federationSenderAPI.FederationSenderQueryAPI,
) { ) {
roomserverProducer := producers.NewRoomserverProducer(inputAPI) roomserverProducer := producers.NewRoomserverProducer(inputAPI, queryAPI)
routing.Setup( routing.Setup(
base.APIMux, base.Cfg, queryAPI, aliasAPI, asAPI, base.APIMux, base.Cfg, queryAPI, aliasAPI, asAPI,

View file

@ -15,11 +15,11 @@
package routing package routing
import ( import (
"encoding/json"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -72,27 +72,34 @@ func Backfill(
ServerName: request.Origin(), ServerName: request.Origin(),
} }
if req.Limit, err = strconv.Atoi(limit); err != nil { if req.Limit, err = strconv.Atoi(limit); err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed")
return jsonerror.InternalServerError()
} }
// Query the roomserver. // Query the roomserver.
if err = query.QueryBackfill(httpReq.Context(), &req, &res); err != nil { if err = query.QueryBackfill(httpReq.Context(), &req, &res); err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryBackfill failed")
return jsonerror.InternalServerError()
} }
// Filter any event that's not from the requested room out. // Filter any event that's not from the requested room out.
evs := make([]gomatrixserverlib.Event, 0) evs := make([]gomatrixserverlib.Event, 0)
var ev gomatrixserverlib.Event var ev gomatrixserverlib.HeaderedEvent
for _, ev = range res.Events { for _, ev = range res.Events {
if ev.RoomID() == roomID { if ev.RoomID() == roomID {
evs = append(evs, ev) evs = append(evs, ev.Event)
} }
} }
var eventJSONs []json.RawMessage
for _, e := range evs {
eventJSONs = append(eventJSONs, e.JSON())
}
txn := gomatrixserverlib.Transaction{ txn := gomatrixserverlib.Transaction{
Origin: cfg.Matrix.ServerName, Origin: cfg.Matrix.ServerName,
PDUs: evs, PDUs: eventJSONs,
OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()),
} }

View file

@ -17,7 +17,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -43,7 +42,8 @@ func GetUserDevices(
devs, err := deviceDB.GetDevicesByLocalpart(req.Context(), localpart) devs, err := deviceDB.GetDevicesByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDevicesByLocalPart failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -0,0 +1,43 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"net/http"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// GetEventAuth returns event auth for the roomID and eventID
func GetEventAuth(
ctx context.Context,
request *gomatrixserverlib.FederationRequest,
query api.RoomserverQueryAPI,
roomID string,
eventID string,
) util.JSONResponse {
// TODO: Optimisation: we shouldn't be querying all the room state
// that is in state.StateEvents - we just ignore it.
state, err := getState(ctx, request, query, roomID, eventID)
if err != nil {
return *err
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: gomatrixserverlib.RespEventAuth{AuthEvents: state.AuthEvents},
}
}

View file

@ -80,5 +80,5 @@ func getEvent(
return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil} return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil}
} }
return &eventsResponse.Events[0], nil return &eventsResponse.Events[0].Event, nil
} }

View file

@ -15,13 +15,13 @@
package routing package routing
import ( import (
"encoding/json" "context"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -36,10 +36,19 @@ func Invite(
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing, keys gomatrixserverlib.KeyRing,
) util.JSONResponse { ) util.JSONResponse {
// Look up the room version for the room.
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := producer.QueryAPI.QueryRoomVersionForRoom(context.Background(), &verReq, &verRes); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(err.Error()),
}
}
// Decode the event JSON from the request. // Decode the event JSON from the request.
var event gomatrixserverlib.Event event, err := gomatrixserverlib.NewEventFromUntrustedJSON(request.Content(), verRes.RoomVersion)
if err := json.Unmarshal(request.Content(), &event); err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()),
@ -71,14 +80,16 @@ func Invite(
} }
// Check that the event is signed by the server sending the request. // Check that the event is signed by the server sending the request.
redacted := event.Redact()
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: event.Origin(), ServerName: event.Origin(),
Message: event.Redact().JSON(), Message: redacted.JSON(),
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
}} }}
verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests)
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed")
return jsonerror.InternalServerError()
} }
if verifyResults[0].Error != nil { if verifyResults[0].Error != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -94,7 +105,8 @@ func Invite(
// Add the invite event to the roomserver. // Add the invite event to the roomserver.
if err = producer.SendInvite(httpReq.Context(), signedEvent); err != nil { if err = producer.SendInvite(httpReq.Context(), signedEvent); err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendInvite failed")
return jsonerror.InternalServerError()
} }
// Return the signed event to the originating server, it should then tell // Return the signed event to the originating server, it should then tell

View file

@ -15,11 +15,9 @@
package routing package routing
import ( import (
"encoding/json"
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -37,6 +35,15 @@ func MakeJoin(
query api.RoomserverQueryAPI, query api.RoomserverQueryAPI,
roomID, userID string, roomID, userID string,
) util.JSONResponse { ) util.JSONResponse {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := query.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(),
}
}
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -60,10 +67,13 @@ func MakeJoin(
} }
err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Join}) err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Join})
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("builder.SetContent failed")
return jsonerror.InternalServerError()
} }
var queryRes api.QueryLatestEventsAndStateResponse queryRes := api.QueryLatestEventsAndStateResponse{
RoomVersion: verRes.RoomVersion,
}
event, err := common.BuildEvent(httpReq.Context(), &builder, cfg, time.Now(), query, &queryRes) event, err := common.BuildEvent(httpReq.Context(), &builder, cfg, time.Now(), query, &queryRes)
if err == common.ErrRoomNoExists { if err == common.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
@ -71,14 +81,16 @@ func MakeJoin(
JSON: jsonerror.NotFound("Room does not exist"), JSON: jsonerror.NotFound("Room does not exist"),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("common.BuildEvent failed")
return jsonerror.InternalServerError()
} }
// Check that the join is allowed or not // Check that the join is allowed or not
stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents))
for i := range queryRes.StateEvents { for i := range queryRes.StateEvents {
stateEvents[i] = &queryRes.StateEvents[i] stateEvents[i] = &queryRes.StateEvents[i].Event
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(*event, &provider); err != nil { if err = gomatrixserverlib.Allowed(*event, &provider); err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -89,7 +101,10 @@ func MakeJoin(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: map[string]interface{}{"event": builder}, JSON: map[string]interface{}{
"event": builder,
"room_version": verRes.RoomVersion,
},
} }
} }
@ -103,8 +118,18 @@ func SendJoin(
keys gomatrixserverlib.KeyRing, keys gomatrixserverlib.KeyRing,
roomID, eventID string, roomID, eventID string,
) util.JSONResponse { ) util.JSONResponse {
var event gomatrixserverlib.Event verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
if err := json.Unmarshal(request.Content(), &event); err != nil { verRes := api.QueryRoomVersionForRoomResponse{}
if err := query.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryRoomVersionForRoom failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(),
}
}
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(request.Content(), verRes.RoomVersion)
if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()),
@ -136,19 +161,21 @@ func SendJoin(
} }
// Check that the event is signed by the server sending the request. // Check that the event is signed by the server sending the request.
redacted := event.Redact()
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: event.Origin(), ServerName: event.Origin(),
Message: event.Redact().JSON(), Message: redacted.JSON(),
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
}} }}
verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests)
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed")
return jsonerror.InternalServerError()
} }
if verifyResults[0].Error != nil { if verifyResults[0].Error != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The join must be signed by the server it originated on"), JSON: jsonerror.Forbidden("Signature check failed: " + verifyResults[0].Error.Error()),
} }
} }
@ -161,7 +188,8 @@ func SendJoin(
RoomID: roomID, RoomID: roomID,
}, &stateAndAuthChainResponse) }, &stateAndAuthChainResponse)
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryStateAndAuthChain failed")
return jsonerror.InternalServerError()
} }
if !stateAndAuthChainResponse.RoomExists { if !stateAndAuthChainResponse.RoomExists {
@ -175,17 +203,23 @@ func SendJoin(
// We are responsible for notifying other servers that the user has joined // We are responsible for notifying other servers that the user has joined
// the room, so set SendAsServer to cfg.Matrix.ServerName // the room, so set SendAsServer to cfg.Matrix.ServerName
_, err = producer.SendEvents( _, err = producer.SendEvents(
httpReq.Context(), []gomatrixserverlib.Event{event}, cfg.Matrix.ServerName, nil, httpReq.Context(),
[]gomatrixserverlib.HeaderedEvent{
event.Headered(stateAndAuthChainResponse.RoomVersion),
},
cfg.Matrix.ServerName,
nil,
) )
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendEvents failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: map[string]interface{}{ JSON: map[string]interface{}{
"state": stateAndAuthChainResponse.StateEvents, "state": gomatrixserverlib.UnwrapEventHeaders(stateAndAuthChainResponse.StateEvents),
"auth_chain": stateAndAuthChainResponse.AuthChainEvents, "auth_chain": gomatrixserverlib.UnwrapEventHeaders(stateAndAuthChainResponse.AuthChainEvents),
}, },
} }
} }

View file

@ -13,11 +13,9 @@
package routing package routing
import ( import (
"encoding/json"
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -58,7 +56,8 @@ func MakeLeave(
} }
err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Leave}) err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Leave})
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("builder.SetContent failed")
return jsonerror.InternalServerError()
} }
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
@ -69,13 +68,14 @@ func MakeLeave(
JSON: jsonerror.NotFound("Room does not exist"), JSON: jsonerror.NotFound("Room does not exist"),
} }
} else if err != nil { } else if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("common.BuildEvent failed")
return jsonerror.InternalServerError()
} }
// Check that the leave is allowed or not // Check that the leave is allowed or not
stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents))
for i := range queryRes.StateEvents { for i := range queryRes.StateEvents {
stateEvents[i] = &queryRes.StateEvents[i] stateEvents[i] = &queryRes.StateEvents[i].Event
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(*event, &provider); err != nil { if err = gomatrixserverlib.Allowed(*event, &provider); err != nil {
@ -100,8 +100,18 @@ func SendLeave(
keys gomatrixserverlib.KeyRing, keys gomatrixserverlib.KeyRing,
roomID, eventID string, roomID, eventID string,
) util.JSONResponse { ) util.JSONResponse {
var event gomatrixserverlib.Event verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
if err := json.Unmarshal(request.Content(), &event); err != nil { verRes := api.QueryRoomVersionForRoomResponse{}
if err := producer.QueryAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(err.Error()),
}
}
// Decode the event JSON from the request.
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(request.Content(), verRes.RoomVersion)
if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()),
@ -133,14 +143,16 @@ func SendLeave(
} }
// Check that the event is signed by the server sending the request. // Check that the event is signed by the server sending the request.
redacted := event.Redact()
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: event.Origin(), ServerName: event.Origin(),
Message: event.Redact().JSON(), Message: redacted.JSON(),
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
}} }}
verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests)
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed")
return jsonerror.InternalServerError()
} }
if verifyResults[0].Error != nil { if verifyResults[0].Error != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -152,7 +164,8 @@ func SendLeave(
// check membership is set to leave // check membership is set to leave
mem, err := event.Membership() mem, err := event.Membership()
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("event.Membership failed")
return jsonerror.InternalServerError()
} else if mem != gomatrixserverlib.Leave { } else if mem != gomatrixserverlib.Leave {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -163,9 +176,17 @@ func SendLeave(
// Send the events to the room server. // Send the events to the room server.
// We are responsible for notifying other servers that the user has left // We are responsible for notifying other servers that the user has left
// the room, so set SendAsServer to cfg.Matrix.ServerName // the room, so set SendAsServer to cfg.Matrix.ServerName
_, err = producer.SendEvents(httpReq.Context(), []gomatrixserverlib.Event{event}, cfg.Matrix.ServerName, nil) _, err = producer.SendEvents(
httpReq.Context(),
[]gomatrixserverlib.HeaderedEvent{
event.Headered(verRes.RoomVersion),
},
cfg.Matrix.ServerName,
nil,
)
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendEvents failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -16,7 +16,6 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -56,7 +55,8 @@ func GetMissingEvents(
}, },
&eventsResponse, &eventsResponse,
); err != nil { ); err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryMissingEvents failed")
return jsonerror.InternalServerError()
} }
eventsResponse.Events = filterEvents(eventsResponse.Events, gme.MinDepth, roomID) eventsResponse.Events = filterEvents(eventsResponse.Events, gme.MinDepth, roomID)
@ -68,8 +68,8 @@ func GetMissingEvents(
// filterEvents returns only those events with matching roomID and having depth greater than minDepth // filterEvents returns only those events with matching roomID and having depth greater than minDepth
func filterEvents( func filterEvents(
events []gomatrixserverlib.Event, minDepth int64, roomID string, events []gomatrixserverlib.HeaderedEvent, minDepth int64, roomID string,
) []gomatrixserverlib.Event { ) []gomatrixserverlib.HeaderedEvent {
ref := events[:0] ref := events[:0]
for _, ev := range events { for _, ev := range events {
if ev.Depth() >= minDepth && ev.RoomID() == roomID { if ev.Depth() >= minDepth && ev.RoomID() == roomID {

View file

@ -19,7 +19,6 @@ import (
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
@ -46,16 +45,19 @@ func GetProfile(
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
} }
if domain != cfg.Matrix.ServerName { if domain != cfg.Matrix.ServerName {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("domain != cfg.Matrix.ServerName failed")
return jsonerror.InternalServerError()
} }
profile, err := appserviceAPI.RetrieveUserProfile(httpReq.Context(), userID, asAPI, accountDB) profile, err := appserviceAPI.RetrieveUserProfile(httpReq.Context(), userID, asAPI, accountDB)
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed")
return jsonerror.InternalServerError()
} }
var res interface{} var res interface{}

View file

@ -18,7 +18,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
@ -57,14 +56,16 @@ func RoomAliasToID(
queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias} queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias}
var queryRes roomserverAPI.GetRoomIDForAliasResponse var queryRes roomserverAPI.GetRoomIDForAliasResponse
if err = aliasAPI.GetRoomIDForAlias(httpReq.Context(), &queryReq, &queryRes); err != nil { if err = aliasAPI.GetRoomIDForAlias(httpReq.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed")
return jsonerror.InternalServerError()
} }
if queryRes.RoomID != "" { if queryRes.RoomID != "" {
serverQueryReq := federationSenderAPI.QueryJoinedHostServerNamesInRoomRequest{RoomID: queryRes.RoomID} serverQueryReq := federationSenderAPI.QueryJoinedHostServerNamesInRoomRequest{RoomID: queryRes.RoomID}
var serverQueryRes federationSenderAPI.QueryJoinedHostServerNamesInRoomResponse var serverQueryRes federationSenderAPI.QueryJoinedHostServerNamesInRoomResponse
if err = senderAPI.QueryJoinedHostServerNamesInRoom(httpReq.Context(), &serverQueryReq, &serverQueryRes); err != nil { if err = senderAPI.QueryJoinedHostServerNamesInRoom(httpReq.Context(), &serverQueryReq, &serverQueryRes); err != nil {
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("senderAPI.QueryJoinedHostServerNamesInRoom failed")
return jsonerror.InternalServerError()
} }
resp = gomatrixserverlib.RespDirectory{ resp = gomatrixserverlib.RespDirectory{
@ -92,7 +93,8 @@ func RoomAliasToID(
} }
// TODO: Return 502 if the remote server errored. // TODO: Return 502 if the remote server errored.
// TODO: Return 504 if the remote server timed out. // TODO: Return 504 if the remote server timed out.
return httputil.LogThenError(httpReq, err) util.GetLogger(httpReq.Context()).WithError(err).Error("federation.LookupRoomAlias failed")
return jsonerror.InternalServerError()
} }
} }

View file

@ -131,7 +131,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/state/{roomID}", common.MakeFedAPI( v1fedmux.Handle("/state/{roomID}", common.MakeFedAPI(
"federation_get_event_auth", cfg.Matrix.ServerName, keys, "federation_get_state", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) vars, err := common.URLDecodeMapValues(mux.Vars(httpReq))
if err != nil { if err != nil {
@ -144,7 +144,7 @@ func Setup(
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/state_ids/{roomID}", common.MakeFedAPI( v1fedmux.Handle("/state_ids/{roomID}", common.MakeFedAPI(
"federation_get_event_auth", cfg.Matrix.ServerName, keys, "federation_get_state_ids", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) vars, err := common.URLDecodeMapValues(mux.Vars(httpReq))
if err != nil { if err != nil {
@ -156,6 +156,16 @@ func Setup(
}, },
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/event_auth/{roomID}/{eventID}", common.MakeFedAPI(
"federation_get_event_auth", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return GetEventAuth(
httpReq.Context(), request, query, vars["roomID"], vars["eventID"],
)
},
)).Methods(http.MethodGet)
v1fedmux.Handle("/query/directory", common.MakeFedAPI( v1fedmux.Handle("/query/directory", common.MakeFedAPI(
"federation_query_room_alias", cfg.Matrix.ServerName, keys, "federation_query_room_alias", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {

Some files were not shown because too many files have changed in this diff Show more