Merged latest changes from master

This commit is contained in:
Brendan Abolivier 2017-07-17 17:45:13 +01:00
commit 81fec3bd8d
14 changed files with 241 additions and 160 deletions

View file

@ -24,7 +24,6 @@ import (
"strings" "strings"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -40,10 +39,16 @@ var UnknownDeviceID = "unknown-device"
// 32 bytes => 256 bits // 32 bytes => 256 bits
var tokenByteLength = 32 var tokenByteLength = 32
// DeviceDatabase represents a device database.
type DeviceDatabase interface {
// Lookup the device matching the given access token.
GetDeviceByAccessToken(token string) (*authtypes.Device, error)
}
// VerifyAccessToken verifies that an access token was supplied in the given HTTP request // VerifyAccessToken verifies that an access token was supplied in the given HTTP request
// and returns the device it corresponds to. Returns resErr (an error response which can be // and returns the device it corresponds to. Returns resErr (an error response which can be
// sent to the client) if the token is invalid or there was a problem querying the database. // sent to the client) if the token is invalid or there was a problem querying the database.
func VerifyAccessToken(req *http.Request, deviceDB *devices.Database) (device *authtypes.Device, resErr *util.JSONResponse) { func VerifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *authtypes.Device, resErr *util.JSONResponse) {
token, err := extractAccessToken(req) token, err := extractAccessToken(req)
if err != nil { if err != nil {
resErr = &util.JSONResponse{ resErr = &util.JSONResponse{

View file

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -53,7 +54,7 @@ func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, erro
// an error will be returned. // an error will be returned.
// Returns the device on success. // Returns the device on success.
func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) { func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) {
returnErr = runTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
// Revoke existing token for this device // Revoke existing token for this device
if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil { if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil {
@ -74,30 +75,10 @@ func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *a
// If the device doesn't exist, it will not return an error // If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error // If something went wrong during the deletion, it will return the SQL error
func (d *Database) RemoveDevice(deviceID string, localpart string) error { func (d *Database) RemoveDevice(deviceID string, localpart string) error {
return runTransaction(d.db, func(txn *sql.Tx) error { return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows {
return err return err
} }
return nil return nil
}) })
} }
// TODO: factor out to common
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}

View file

@ -48,7 +48,7 @@ func (c *RoomserverProducer) SendEvents(events []gomatrixserverlib.Event, sendAs
for i, event := range events { for i, event := range events {
ires[i] = api.InputRoomEvent{ ires[i] = api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: event.JSON(), Event: event,
AuthEventIDs: event.AuthEventIDs(), AuthEventIDs: event.AuthEventIDs(),
SendAsServer: string(sendAsServer), SendAsServer: string(sendAsServer),
} }
@ -70,7 +70,7 @@ func (c *RoomserverProducer) SendEventWithState(state gomatrixserverlib.RespStat
for i, outlier := range outliers { for i, outlier := range outliers {
ires[i] = api.InputRoomEvent{ ires[i] = api.InputRoomEvent{
Kind: api.KindOutlier, Kind: api.KindOutlier,
Event: outlier.JSON(), Event: outlier,
AuthEventIDs: outlier.AuthEventIDs(), AuthEventIDs: outlier.AuthEventIDs(),
} }
eventIDs[i] = outlier.EventID() eventIDs[i] = outlier.EventID()
@ -83,7 +83,7 @@ func (c *RoomserverProducer) SendEventWithState(state gomatrixserverlib.RespStat
ires[len(outliers)] = api.InputRoomEvent{ ires[len(outliers)] = api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: event.JSON(), Event: event,
AuthEventIDs: event.AuthEventIDs(), AuthEventIDs: event.AuthEventIDs(),
HasState: true, HasState: true,
StateEventIDs: stateEventIDs, StateEventIDs: stateEventIDs,

View file

@ -21,12 +21,13 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/ed25519"
"os" "os"
"strings" "strings"
"time" "time"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/ed25519"
) )
const usage = `Usage: %s const usage = `Usage: %s
@ -131,7 +132,7 @@ func writeEvent(event gomatrixserverlib.Event) {
if *format == "InputRoomEvent" { if *format == "InputRoomEvent" {
var ire api.InputRoomEvent var ire api.InputRoomEvent
ire.Kind = api.KindNew ire.Kind = api.KindNew
ire.Event = event.JSON() ire.Event = event
authEventIDs := []string{} authEventIDs := []string{}
for _, ref := range b.AuthEvents { for _, ref := range b.AuthEvents {
authEventIDs = append(authEventIDs, ref.EventID) authEventIDs = append(authEventIDs, ref.EventID)

View file

@ -252,9 +252,9 @@ func main() {
input := []string{ input := []string{
`{ `{
"AuthEventIDs": [], "auth_event_ids": [],
"Kind": 1, "kind": 1,
"Event": { "event": {
"origin": "matrix.org", "origin": "matrix.org",
"signatures": { "signatures": {
"matrix.org": { "matrix.org": {
@ -274,10 +274,10 @@ func main() {
"hashes": {"sha256": "Q05VLC8nztN2tguy+KnHxxhitI95wK9NelnsDaXRqeo"}, "hashes": {"sha256": "Q05VLC8nztN2tguy+KnHxxhitI95wK9NelnsDaXRqeo"},
"type": "m.room.create"} "type": "m.room.create"}
}`, `{ }`, `{
"AuthEventIDs": ["$1463671337126266wrSBX:matrix.org"], "auth_event_ids": ["$1463671337126266wrSBX:matrix.org"],
"Kind": 2, "kind": 2,
"StateEventIDs": ["$1463671337126266wrSBX:matrix.org"], "state_event_ids": ["$1463671337126266wrSBX:matrix.org"],
"Event": { "event": {
"origin": "matrix.org", "origin": "matrix.org",
"signatures": { "signatures": {
"matrix.org": { "matrix.org": {
@ -305,7 +305,7 @@ func main() {
]], ]],
"hashes": {"sha256": "t9t3sZV1Eu0P9Jyrs7pge6UTa1zuTbRdVxeUHnrQVH0"}, "hashes": {"sha256": "t9t3sZV1Eu0P9Jyrs7pge6UTa1zuTbRdVxeUHnrQVH0"},
"type": "m.room.member"}, "type": "m.room.member"},
"HasState": true "has_state": true
}`, }`,
} }

View file

@ -1,16 +1,16 @@
package common package common
import ( import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"net/http"
) )
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which checks the access token in the request. // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which checks the access token in the request.
func MakeAuthAPI(metricsName string, deviceDB *devices.Database, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler { func MakeAuthAPI(metricsName string, deviceDB auth.DeviceDatabase, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler {
h := util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { h := util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
device, resErr := auth.VerifyAccessToken(req, deviceDB) device, resErr := auth.VerifyAccessToken(req, deviceDB)
if resErr != nil { if resErr != nil {

View file

@ -0,0 +1,41 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import (
"database/sql"
)
// WithTransaction runs a block of code passing in an SQL transaction
// If the code returns an error or panics then the transactions is rolledback
// Otherwise the transaction is committed.
func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}

View file

@ -77,7 +77,7 @@ func (d *Database) UpdateRoom(
addHosts []types.JoinedHost, addHosts []types.JoinedHost,
removeHosts []string, removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) { ) (joinedHosts []types.JoinedHost, err error) {
err = runTransaction(d.db, func(txn *sql.Tx) error { err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err = d.insertRoom(txn, roomID); err != nil { if err = d.insertRoom(txn, roomID); err != nil {
return err return err
} }
@ -105,22 +105,3 @@ func (d *Database) UpdateRoom(
}) })
return return
} }
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}

View file

@ -16,7 +16,9 @@
package api package api
import ( import (
"encoding/json" "net/http"
"github.com/matrix-org/gomatrixserverlib"
) )
const ( const (
@ -25,19 +27,14 @@ const (
// These events are state events used to authenticate other events. // These events are state events used to authenticate other events.
// They can become part of the contiguous event graph via backfill. // They can become part of the contiguous event graph via backfill.
KindOutlier = 1 KindOutlier = 1
// KindJoin event start a new contiguous event graph. The event must be a
// m.room.member event joining this server to the room. This must come with
// the state at the event. If the event is contiguous with the existing
// graph for the room then it is treated as a normal new event.
KindJoin = 2
// KindNew event extend the contiguous graph going forwards. // KindNew event extend the contiguous graph going forwards.
// They usually don't need state, but may include state if the // They usually don't need state, but may include state if the
// there was a new event that references an event that we don't // there was a new event that references an event that we don't
// have a copy of. // have a copy of.
KindNew = 3 KindNew = 2
// KindBackfill event extend the contiguous graph going backwards. // KindBackfill event extend the contiguous graph going backwards.
// They always have state. // They always have state.
KindBackfill = 4 KindBackfill = 3
) )
// DoNotSendToOtherServers tells us not to send the event to other matrix // DoNotSendToOtherServers tells us not to send the event to other matrix
@ -49,77 +46,66 @@ const DoNotSendToOtherServers = ""
type InputRoomEvent struct { type InputRoomEvent struct {
// Whether this event is new, backfilled or an outlier. // Whether this event is new, backfilled or an outlier.
// This controls how the event is processed. // This controls how the event is processed.
Kind int Kind int `json:"kind"`
// The event JSON for the event to add. // The event JSON for the event to add.
Event []byte Event gomatrixserverlib.Event `json:"event"`
// List of state event IDs that authenticate this event. // List of state event IDs that authenticate this event.
// These are likely derived from the "auth_events" JSON key of the event. // These are likely derived from the "auth_events" JSON key of the event.
// But can be different because the "auth_events" key can be incomplete or wrong. // But can be different because the "auth_events" key can be incomplete or wrong.
// For example many matrix events forget to reference the m.room.create event even though it is needed for auth. // For example many matrix events forget to reference the m.room.create event even though it is needed for auth.
// (since synapse allows this to happen we have to allow it as well.) // (since synapse allows this to happen we have to allow it as well.)
AuthEventIDs []string AuthEventIDs []string `json:"auth_event_ids"`
// Whether the state is supplied as a list of event IDs or whether it // Whether the state is supplied as a list of event IDs or whether it
// should be derived from the state at the previous events. // should be derived from the state at the previous events.
HasState bool HasState bool `json:"has_state"`
// Optional list of state event IDs forming the state before this event. // Optional list of state event IDs forming the state before this event.
// These state events must have already been persisted. // These state events must have already been persisted.
// These are only used if HasState is true. // These are only used if HasState is true.
// The list can be empty, for example when storing the first event in a room. // The list can be empty, for example when storing the first event in a room.
StateEventIDs []string StateEventIDs []string `json:"state_event_ids"`
// The server name to use to push this event to other servers. // The server name to use to push this event to other servers.
// Or empty if this event shouldn't be pushed to other servers. // Or empty if this event shouldn't be pushed to other servers.
SendAsServer string SendAsServer string `json:"send_as_server"`
} }
// UnmarshalJSON implements json.Unmarshaller // InputRoomEventsRequest is a request to InputRoomEvents
func (ire *InputRoomEvent) UnmarshalJSON(data []byte) error { type InputRoomEventsRequest struct {
// Create a struct rather than unmarshalling directly into the InputRoomEvent InputRoomEvents []InputRoomEvent `json:"input_room_events"`
// so that we can use json.RawMessage.
// We use json.RawMessage so that the event JSON is sent as JSON rather than
// being base64 encoded which is the default for []byte.
var content struct {
Kind int
Event *json.RawMessage
AuthEventIDs []string
StateEventIDs []string
HasState bool
SendAsServer string
}
if err := json.Unmarshal(data, &content); err != nil {
return err
}
ire.Kind = content.Kind
ire.AuthEventIDs = content.AuthEventIDs
ire.StateEventIDs = content.StateEventIDs
ire.HasState = content.HasState
ire.SendAsServer = content.SendAsServer
if content.Event != nil {
ire.Event = []byte(*content.Event)
}
return nil
} }
// MarshalJSON implements json.Marshaller // InputRoomEventsResponse is a response to InputRoomEvents
func (ire InputRoomEvent) MarshalJSON() ([]byte, error) { type InputRoomEventsResponse struct{}
// Create a struct rather than marshalling directly from the InputRoomEvent
// so that we can use json.RawMessage. // RoomserverInputAPI is used to write events to the room server.
// We use json.RawMessage so that the event JSON is sent as JSON rather than type RoomserverInputAPI interface {
// being base64 encoded which is the default for []byte. InputRoomEvents(
event := json.RawMessage(ire.Event) request *InputRoomEventsRequest,
content := struct { response *InputRoomEventsResponse,
Kind int ) error
Event *json.RawMessage }
AuthEventIDs []string
StateEventIDs []string // RoomserverInputRoomEventsPath is the HTTP path for the InputRoomEvents API.
HasState bool const RoomserverInputRoomEventsPath = "/api/roomserver/inputRoomEvents"
SendAsServer string
}{ // NewRoomserverInputAPIHTTP creates a RoomserverInputAPI implemented by talking to a HTTP POST API.
Kind: ire.Kind, // If httpClient is nil then it uses the http.DefaultClient
AuthEventIDs: ire.AuthEventIDs, func NewRoomserverInputAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverInputAPI {
StateEventIDs: ire.StateEventIDs, if httpClient == nil {
Event: &event, httpClient = http.DefaultClient
HasState: ire.HasState, }
SendAsServer: ire.SendAsServer, return &httpRoomserverInputAPI{roomserverURL, httpClient}
} }
return json.Marshal(&content)
type httpRoomserverInputAPI struct {
roomserverURL string
httpClient *http.Client
}
// InputRoomEvents implements RoomserverInputAPI
func (h *httpRoomserverInputAPI) InputRoomEvents(
request *InputRoomEventsRequest,
response *InputRoomEventsResponse,
) error {
apiURL := h.roomserverURL + RoomserverInputRoomEventsPath
return postJSON(h.httpClient, apiURL, request, response)
} }

View file

@ -136,12 +136,12 @@ func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) Ro
if httpClient == nil { if httpClient == nil {
httpClient = http.DefaultClient httpClient = http.DefaultClient
} }
return &httpRoomserverQueryAPI{roomserverURL, *httpClient} return &httpRoomserverQueryAPI{roomserverURL, httpClient}
} }
type httpRoomserverQueryAPI struct { type httpRoomserverQueryAPI struct {
roomserverURL string roomserverURL string
httpClient http.Client httpClient *http.Client
} }
// QueryLatestEventsAndState implements RoomserverQueryAPI // QueryLatestEventsAndState implements RoomserverQueryAPI
@ -171,7 +171,7 @@ func (h *httpRoomserverQueryAPI) QueryEventsByID(
return postJSON(h.httpClient, apiURL, request, response) return postJSON(h.httpClient, apiURL, request, response)
} }
func postJSON(httpClient http.Client, apiURL string, request, response interface{}) error { func postJSON(httpClient *http.Client, apiURL string, request, response interface{}) error {
jsonBytes, err := json.Marshal(request) jsonBytes, err := json.Marshal(request)
if err != nil { if err != nil {
return err return err

View file

@ -49,10 +49,7 @@ type OutputRoomEventWriter interface {
func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputRoomEvent) error { func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputRoomEvent) error {
// Parse and validate the event JSON // Parse and validate the event JSON
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(input.Event) event := input.Event
if err != nil {
return err
}
// Check that the event passes authentication checks and work out the numeric IDs for the auth events. // Check that the event passes authentication checks and work out the numeric IDs for the auth events.
authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs) authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs)
@ -79,8 +76,8 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.
if input.HasState { if input.HasState {
// We've been told what the state at the event is so we don't need to calculate it. // We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.
entries, err := db.StateEntriesForEventIDs(input.StateEventIDs) var entries []types.StateEntry
if err != nil { if entries, err = db.StateEntriesForEventIDs(input.StateEventIDs); err != nil {
return err return err
} }

View file

@ -0,0 +1,107 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package input contains the code processes new room events
package input
import (
"encoding/json"
"fmt"
"sync/atomic"
"net/http"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/util"
sarama "gopkg.in/Shopify/sarama.v1"
)
// RoomserverInputAPI implements api.RoomserverInputAPI
type RoomserverInputAPI struct {
DB RoomEventDatabase
Producer sarama.SyncProducer
// The kafkaesque topic to output new room events to.
// This is the name used in kafka to identify the stream to write events to.
OutputRoomEventTopic string
// If non-nil then the API will stop processing messages after this
// many messages and will shutdown. Malformed messages are not in the count.
StopProcessingAfter *int64
// If not-nil then the API will call this to shutdown the server.
// If this is nil then the API will continue to process messsages even
// though StopProcessingAfter has been reached.
ShutdownCallback func(reason string)
// How many messages the consumer has processed.
processed int64
}
// WriteOutputRoomEvent implements OutputRoomEventWriter
func (r *RoomserverInputAPI) WriteOutputRoomEvent(output api.OutputNewRoomEvent) error {
var m sarama.ProducerMessage
oe := api.OutputEvent{
Type: api.OutputTypeNewRoomEvent,
NewRoomEvent: &output,
}
value, err := json.Marshal(oe)
if err != nil {
return err
}
m.Topic = r.OutputRoomEventTopic
m.Key = sarama.StringEncoder("")
m.Value = sarama.ByteEncoder(value)
_, _, err = r.Producer.SendMessage(&m)
return err
}
// InputRoomEvents implements api.RoomserverInputAPI
func (r *RoomserverInputAPI) InputRoomEvents(
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) error {
for i := range request.InputRoomEvents {
if err := processRoomEvent(r.DB, r, request.InputRoomEvents[i]); err != nil {
return err
}
// Update the number of processed messages using atomic addition because it is accessed from multiple goroutines.
processed := atomic.AddInt64(&r.processed, 1)
// Check if we should stop processing.
// Note that since we have multiple goroutines it's quite likely that we'll overshoot by a few messages.
// If we try to stop processing after M message and we have N goroutines then we will process somewhere
// between M and (N + M) messages because the N goroutines could all try to process what they think will be the
// last message. We could be more careful here but this is good enough for getting rough benchmarks.
if r.StopProcessingAfter != nil && processed >= int64(*r.StopProcessingAfter) {
if r.ShutdownCallback != nil {
r.ShutdownCallback(fmt.Sprintf("Stopping processing after %d messages", r.processed))
}
}
}
return nil
}
// SetupHTTP adds the RoomserverInputAPI handlers to the http.ServeMux.
func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux) {
servMux.Handle(api.RoomserverInputRoomEventsPath,
common.MakeAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse {
var request api.InputRoomEventsRequest
var response api.InputRoomEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(400, err.Error())
}
if err := r.InputRoomEvents(&request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: 200, JSON: &response}
}),
)
}

View file

@ -16,13 +16,14 @@ package query
import ( import (
"encoding/json" "encoding/json"
"net/http"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"net/http"
) )
// RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API. // RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API.
@ -173,7 +174,7 @@ func (r *RoomserverQueryAPI) loadEvents(eventNIDs []types.EventNID) ([]gomatrixs
func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
servMux.Handle( servMux.Handle(
api.RoomserverQueryLatestEventsAndStatePath, api.RoomserverQueryLatestEventsAndStatePath,
common.MakeAPI("query_latest_events_and_state", func(req *http.Request) util.JSONResponse { common.MakeAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse {
var request api.QueryLatestEventsAndStateRequest var request api.QueryLatestEventsAndStateRequest
var response api.QueryLatestEventsAndStateResponse var response api.QueryLatestEventsAndStateResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -187,7 +188,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
) )
servMux.Handle( servMux.Handle(
api.RoomserverQueryStateAfterEventsPath, api.RoomserverQueryStateAfterEventsPath,
common.MakeAPI("query_state_after_events", func(req *http.Request) util.JSONResponse { common.MakeAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAfterEventsRequest var request api.QueryStateAfterEventsRequest
var response api.QueryStateAfterEventsResponse var response api.QueryStateAfterEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -201,7 +202,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
) )
servMux.Handle( servMux.Handle(
api.RoomserverQueryEventsByIDPath, api.RoomserverQueryEventsByIDPath,
common.MakeAPI("query_events_by_id", func(req *http.Request) util.JSONResponse { common.MakeAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {
var request api.QueryEventsByIDRequest var request api.QueryEventsByIDRequest
var response api.QueryEventsByIDResponse var response api.QueryEventsByIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { if err := json.NewDecoder(req.Body).Decode(&request); err != nil {

View file

@ -92,7 +92,7 @@ func (d *SyncServerDatabase) Events(eventIDs []string) ([]gomatrixserverlib.Even
func (d *SyncServerDatabase) WriteEvent( func (d *SyncServerDatabase) WriteEvent(
ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string,
) (streamPos types.StreamPosition, returnErr error) { ) (streamPos types.StreamPosition, returnErr error) {
returnErr = runTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs) pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs)
if err != nil { if err != nil {
@ -162,7 +162,7 @@ func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error)
// IncrementalSync returns all the data needed in order to create an incremental sync response. // IncrementalSync returns all the data needed in order to create an incremental sync response.
func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int) (res *types.Response, returnErr error) { func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int) (res *types.Response, returnErr error) {
returnErr = runTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
// Work out which rooms to return in the response. This is done by getting not only the currently // Work out which rooms to return in the response. This is done by getting not only the currently
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
// This works out what the 'state' key should be for each room as well as which membership block // This works out what the 'state' key should be for each room as well as which membership block
@ -223,7 +223,7 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom
// a consistent view of the database throughout. This includes extracting the sync stream position. // a consistent view of the database throughout. This includes extracting the sync stream position.
// This does have the unfortunate side-effect that all the matrixy logic resides in this function, // This does have the unfortunate side-effect that all the matrixy logic resides in this function,
// but it's better to not hide the fact that this is being done in a transaction. // but it's better to not hide the fact that this is being done in a transaction.
returnErr = runTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
// Get the current stream position which we will base the sync response on. // Get the current stream position which we will base the sync response on.
id, err := d.events.selectMaxID(txn) id, err := d.events.selectMaxID(txn)
if err != nil { if err != nil {
@ -479,22 +479,3 @@ func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string {
} }
return "" return ""
} }
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}