mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-07 06:53:09 -06:00
Merge branch 'master' into kegan/http-auth
This commit is contained in:
commit
2a86ae1833
15
.travis.yml
15
.travis.yml
|
|
@ -1,13 +1,26 @@
|
||||||
language: go
|
language: go
|
||||||
go:
|
go:
|
||||||
- 1.7
|
- 1.7
|
||||||
|
|
||||||
|
sudo: false
|
||||||
|
|
||||||
|
# Use trusty for postgres 9.5 support
|
||||||
|
dist: trusty
|
||||||
|
|
||||||
|
addons:
|
||||||
|
postgresql: "9.5"
|
||||||
|
|
||||||
|
services:
|
||||||
|
- postgresql
|
||||||
|
|
||||||
install:
|
install:
|
||||||
- go get github.com/constabulary/gb/...
|
- go get github.com/constabulary/gb/...
|
||||||
- go get github.com/golang/lint/golint
|
- go get github.com/golang/lint/golint
|
||||||
- go get github.com/fzipp/gocyclo
|
- go get github.com/fzipp/gocyclo
|
||||||
|
- ./travis-install-kafka.sh
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- gb build github.com/matrix-org/dendrite/roomserver/roomserver && ./hooks/pre-commit
|
- ./travis-test.sh
|
||||||
|
|
||||||
notifications:
|
notifications:
|
||||||
webhooks:
|
webhooks:
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ complications for this model.
|
||||||
3) A client can query the current state of the room from a reader.
|
3) A client can query the current state of the room from a reader.
|
||||||
|
|
||||||
The writers and readers cannot extract the necessary information directly from
|
The writers and readers cannot extract the necessary information directly from
|
||||||
the event logs because it would take to long to extract the information as the
|
the event logs because it would take too long to extract the information as the
|
||||||
state is built up by collecting individual state events from the event history.
|
state is built up by collecting individual state events from the event history.
|
||||||
|
|
||||||
The writers and readers therefore need access to something that stores copies
|
The writers and readers therefore need access to something that stores copies
|
||||||
|
|
|
||||||
85
src/github.com/matrix-org/dendrite/roomserver/api/output.go
Normal file
85
src/github.com/matrix-org/dendrite/roomserver/api/output.go
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// An OutputRoomEvent is written when the roomserver receives a new event.
|
||||||
|
type OutputRoomEvent struct {
|
||||||
|
// The JSON bytes of the event.
|
||||||
|
Event []byte
|
||||||
|
// The state event IDs needed to determine who can see this event.
|
||||||
|
// This can be used to tell which users to send the event to.
|
||||||
|
VisibilityEventIDs []string
|
||||||
|
// The latest events in the room after this event.
|
||||||
|
// This can be used to set the prev events for new events in the room.
|
||||||
|
// This also can be used to get the full current state after this event.
|
||||||
|
LatestEventIDs []string
|
||||||
|
// The state event IDs that were added to the state of the room by this event.
|
||||||
|
// Together with RemovesStateEventIDs this allows the receiver to keep an up to date
|
||||||
|
// view of the current state of the room.
|
||||||
|
AddsStateEventIDs []string
|
||||||
|
// The state event IDs that were removed from the state of the room by this event.
|
||||||
|
RemovesStateEventIDs []string
|
||||||
|
// The ID of the event that was output before this event.
|
||||||
|
// Or the empty string if this is the first event output for this room.
|
||||||
|
// This is used by consumers to check if they can safely update their
|
||||||
|
// current state using the delta supplied in AddsStateEventIDs and
|
||||||
|
// RemovesStateEventIDs.
|
||||||
|
// If the LastSentEventID doesn't match what they were expecting it to be
|
||||||
|
// they can use the LatestEventIDs to request the full current state.
|
||||||
|
LastSentEventID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaller
|
||||||
|
func (ore *OutputRoomEvent) UnmarshalJSON(data []byte) error {
|
||||||
|
// Create a struct rather than unmarshalling directly into the OutputRoomEvent
|
||||||
|
// so that we can use json.RawMessage.
|
||||||
|
// We use json.RawMessage so that the event JSON is sent as JSON rather than
|
||||||
|
// being base64 encoded which is the default for []byte.
|
||||||
|
var content struct {
|
||||||
|
Event *json.RawMessage
|
||||||
|
VisibilityEventIDs []string
|
||||||
|
LatestEventIDs []string
|
||||||
|
AddsStateEventIDs []string
|
||||||
|
RemovesStateEventIDs []string
|
||||||
|
LastSentEventID string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &content); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if content.Event != nil {
|
||||||
|
ore.Event = []byte(*content.Event)
|
||||||
|
}
|
||||||
|
ore.VisibilityEventIDs = content.VisibilityEventIDs
|
||||||
|
ore.LatestEventIDs = content.LatestEventIDs
|
||||||
|
ore.AddsStateEventIDs = content.AddsStateEventIDs
|
||||||
|
ore.RemovesStateEventIDs = content.RemovesStateEventIDs
|
||||||
|
ore.LastSentEventID = content.LastSentEventID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaller
|
||||||
|
func (ore OutputRoomEvent) MarshalJSON() ([]byte, error) {
|
||||||
|
// Create a struct rather than marshalling directly from the OutputRoomEvent
|
||||||
|
// so that we can use json.RawMessage.
|
||||||
|
// We use json.RawMessage so that the event JSON is sent as JSON rather than
|
||||||
|
// being base64 encoded which is the default for []byte.
|
||||||
|
event := json.RawMessage(ore.Event)
|
||||||
|
content := struct {
|
||||||
|
Event *json.RawMessage
|
||||||
|
VisibilityEventIDs []string
|
||||||
|
LatestEventIDs []string
|
||||||
|
AddsStateEventIDs []string
|
||||||
|
RemovesStateEventIDs []string
|
||||||
|
LastSentEventID string
|
||||||
|
}{
|
||||||
|
Event: &event,
|
||||||
|
VisibilityEventIDs: ore.VisibilityEventIDs,
|
||||||
|
LatestEventIDs: ore.LatestEventIDs,
|
||||||
|
AddsStateEventIDs: ore.AddsStateEventIDs,
|
||||||
|
RemovesStateEventIDs: ore.RemovesStateEventIDs,
|
||||||
|
LastSentEventID: ore.LastSentEventID,
|
||||||
|
}
|
||||||
|
return json.Marshal(&content)
|
||||||
|
}
|
||||||
102
src/github.com/matrix-org/dendrite/roomserver/api/query.go
Normal file
102
src/github.com/matrix-org/dendrite/roomserver/api/query.go
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StateKeyTuple is a pair of an event type and state_key.
|
||||||
|
// This is used when requesting parts of the state of a room.
|
||||||
|
type StateKeyTuple struct {
|
||||||
|
// The "type" key
|
||||||
|
EventType string
|
||||||
|
// The "state_key" of a matrix event.
|
||||||
|
// The empty string is a legitimate value for the "state_key" in matrix
|
||||||
|
// so take care to initialise this field lest you accidentally request a
|
||||||
|
// "state_key" with the go default of the empty string.
|
||||||
|
EventStateKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState
|
||||||
|
type QueryLatestEventsAndStateRequest struct {
|
||||||
|
// The roomID to query the latest events for.
|
||||||
|
RoomID string
|
||||||
|
// The state key tuples to fetch from the room current state.
|
||||||
|
// If this list is empty or nil then no state events are returned.
|
||||||
|
StateToFetch []StateKeyTuple
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryLatestEventsAndStateResponse is a response to QueryLatestEventsAndState
|
||||||
|
type QueryLatestEventsAndStateResponse struct {
|
||||||
|
// Copy of the request for debugging.
|
||||||
|
QueryLatestEventsAndStateRequest
|
||||||
|
// Does the room exist?
|
||||||
|
// If the room doesn't exist this will be false and LatestEvents will be empty.
|
||||||
|
RoomExists bool
|
||||||
|
// The latest events in the room.
|
||||||
|
LatestEvents []gomatrixserverlib.EventReference
|
||||||
|
// The state events requested.
|
||||||
|
StateEvents []gomatrixserverlib.Event
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomserverQueryAPI is used to query information from the room server.
|
||||||
|
type RoomserverQueryAPI interface {
|
||||||
|
// Query the latest events and state for a room from the room server.
|
||||||
|
QueryLatestEventsAndState(
|
||||||
|
request *QueryLatestEventsAndStateRequest,
|
||||||
|
response *QueryLatestEventsAndStateResponse,
|
||||||
|
) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API.
|
||||||
|
const RoomserverQueryLatestEventsAndStatePath = "/api/roomserver/QueryLatestEventsAndState"
|
||||||
|
|
||||||
|
// NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API.
|
||||||
|
// If httpClient is nil then it uses the http.DefaultClient
|
||||||
|
func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI {
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = http.DefaultClient
|
||||||
|
}
|
||||||
|
return &httpRoomserverQueryAPI{roomserverURL, *httpClient}
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpRoomserverQueryAPI struct {
|
||||||
|
roomserverURL string
|
||||||
|
httpClient http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryLatestEventsAndState implements RoomserverQueryAPI
|
||||||
|
func (h *httpRoomserverQueryAPI) QueryLatestEventsAndState(
|
||||||
|
request *QueryLatestEventsAndStateRequest,
|
||||||
|
response *QueryLatestEventsAndStateResponse,
|
||||||
|
) error {
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath
|
||||||
|
return postJSON(h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func postJSON(httpClient http.Client, apiURL string, request, response interface{}) error {
|
||||||
|
jsonBytes, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res, err := httpClient.Post(apiURL, "application/json", bytes.NewReader(jsonBytes))
|
||||||
|
if res != nil {
|
||||||
|
defer res.Body.Close()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if res.StatusCode != 200 {
|
||||||
|
var errorBody struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
if err = json.NewDecoder(res.Body).Decode(&errorBody); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fmt.Errorf("api: %d: %s", res.StatusCode, errorBody.Message)
|
||||||
|
}
|
||||||
|
return json.NewDecoder(res.Body).Decode(response)
|
||||||
|
}
|
||||||
|
|
@ -34,15 +34,33 @@ type Consumer struct {
|
||||||
// But any equivalent event streaming protocol could be made to implement the same interface.
|
// But any equivalent event streaming protocol could be made to implement the same interface.
|
||||||
Consumer sarama.Consumer
|
Consumer sarama.Consumer
|
||||||
// The database used to store the room events.
|
// The database used to store the room events.
|
||||||
DB ConsumerDatabase
|
DB ConsumerDatabase
|
||||||
|
Producer sarama.SyncProducer
|
||||||
// The kafkaesque topic to consume room events from.
|
// The kafkaesque topic to consume room events from.
|
||||||
// This is the name used in kafka to identify the stream to consume events from.
|
// This is the name used in kafka to identify the stream to consume events from.
|
||||||
RoomEventTopic string
|
InputRoomEventTopic string
|
||||||
|
// The kafkaesque topic to output new room events to.
|
||||||
|
// This is the name used in kafka to identify the stream to write events to.
|
||||||
|
OutputRoomEventTopic string
|
||||||
// The ErrorLogger for this consumer.
|
// The ErrorLogger for this consumer.
|
||||||
// If left as nil then the consumer will panic when it encounters an error
|
// If left as nil then the consumer will panic when it encounters an error
|
||||||
ErrorLogger ErrorLogger
|
ErrorLogger ErrorLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteOutputRoomEvent implements OutputRoomEventWriter
|
||||||
|
func (c *Consumer) WriteOutputRoomEvent(output api.OutputRoomEvent) error {
|
||||||
|
var m sarama.ProducerMessage
|
||||||
|
value, err := json.Marshal(output)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.Topic = c.OutputRoomEventTopic
|
||||||
|
m.Key = sarama.StringEncoder("")
|
||||||
|
m.Value = sarama.ByteEncoder(value)
|
||||||
|
_, _, err = c.Producer.SendMessage(&m)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Start starts the consumer consuming.
|
// Start starts the consumer consuming.
|
||||||
// Starts up a goroutine for each partition in the kafka stream.
|
// Starts up a goroutine for each partition in the kafka stream.
|
||||||
// Returns nil once all the goroutines are started.
|
// Returns nil once all the goroutines are started.
|
||||||
|
|
@ -50,7 +68,7 @@ type Consumer struct {
|
||||||
func (c *Consumer) Start() error {
|
func (c *Consumer) Start() error {
|
||||||
offsets := map[int32]int64{}
|
offsets := map[int32]int64{}
|
||||||
|
|
||||||
partitions, err := c.Consumer.Partitions(c.RoomEventTopic)
|
partitions, err := c.Consumer.Partitions(c.InputRoomEventTopic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -59,7 +77,7 @@ func (c *Consumer) Start() error {
|
||||||
offsets[partition] = sarama.OffsetOldest
|
offsets[partition] = sarama.OffsetOldest
|
||||||
}
|
}
|
||||||
|
|
||||||
storedOffsets, err := c.DB.PartitionOffsets(c.RoomEventTopic)
|
storedOffsets, err := c.DB.PartitionOffsets(c.InputRoomEventTopic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -70,7 +88,7 @@ func (c *Consumer) Start() error {
|
||||||
|
|
||||||
var partitionConsumers []sarama.PartitionConsumer
|
var partitionConsumers []sarama.PartitionConsumer
|
||||||
for partition, offset := range offsets {
|
for partition, offset := range offsets {
|
||||||
pc, err := c.Consumer.ConsumePartition(c.RoomEventTopic, partition, offset)
|
pc, err := c.Consumer.ConsumePartition(c.InputRoomEventTopic, partition, offset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
for _, p := range partitionConsumers {
|
for _, p := range partitionConsumers {
|
||||||
p.Close()
|
p.Close()
|
||||||
|
|
@ -95,7 +113,7 @@ func (c *Consumer) consumePartition(pc sarama.PartitionConsumer) {
|
||||||
// If the message is invalid then log it and move onto the next message in the stream.
|
// If the message is invalid then log it and move onto the next message in the stream.
|
||||||
c.logError(message, err)
|
c.logError(message, err)
|
||||||
} else {
|
} else {
|
||||||
if err := processRoomEvent(c.DB, input); err != nil {
|
if err := processRoomEvent(c.DB, c, input); err != nil {
|
||||||
// If there was an error processing the message then log it and
|
// If there was an error processing the message then log it and
|
||||||
// move onto the next message in the stream.
|
// move onto the next message in the stream.
|
||||||
// TODO: If the error was due to a problem talking to the database
|
// TODO: If the error was due to a problem talking to the database
|
||||||
|
|
@ -105,7 +123,7 @@ func (c *Consumer) consumePartition(pc sarama.PartitionConsumer) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Advance our position in the stream so that we will start at the right position after a restart.
|
// Advance our position in the stream so that we will start at the right position after a restart.
|
||||||
if err := c.DB.SetPartitionOffset(c.RoomEventTopic, message.Partition, message.Offset); err != nil {
|
if err := c.DB.SetPartitionOffset(c.InputRoomEventTopic, message.Partition, message.Offset); err != nil {
|
||||||
c.logError(message, err)
|
c.logError(message, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,11 +36,20 @@ type RoomEventDatabase interface {
|
||||||
SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error
|
SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error
|
||||||
// Lookup the latest events in a room in preparation for an update.
|
// Lookup the latest events in a room in preparation for an update.
|
||||||
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
||||||
|
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
||||||
// If this returns an error then no further action is required.
|
// If this returns an error then no further action is required.
|
||||||
GetLatestEventsForUpdate(roomNID types.RoomNID) ([]types.StateAtEventAndReference, types.RoomRecentEventsUpdater, error)
|
GetLatestEventsForUpdate(roomNID types.RoomNID) (updater types.RoomRecentEventsUpdater, err error)
|
||||||
|
// Lookup the string event IDs for a list of numeric event IDs
|
||||||
|
EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
// OutputRoomEventWriter has the APIs needed to write an event to the output logs.
|
||||||
|
type OutputRoomEventWriter interface {
|
||||||
|
// Write an event.
|
||||||
|
WriteOutputRoomEvent(output api.OutputRoomEvent) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputRoomEvent) error {
|
||||||
// Parse and validate the event JSON
|
// Parse and validate the event JSON
|
||||||
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(input.Event)
|
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(input.Event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -82,7 +91,7 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
|
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
|
||||||
if stateAtEvent.BeforeStateSnapshotNID, err = calculateAndStoreState(db, event, roomNID); err != nil {
|
if stateAtEvent.BeforeStateSnapshotNID, err = calculateAndStoreStateBeforeEvent(db, event, roomNID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -95,7 +104,7 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the extremities of the event graph for the room
|
// Update the extremities of the event graph for the room
|
||||||
if err := updateLatestEvents(db, roomNID, stateAtEvent, event); err != nil {
|
if err := updateLatestEvents(db, ow, roomNID, stateAtEvent, event); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,13 @@ package input
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// updateLatestEvents updates the list of latest events for this room.
|
// updateLatestEvents updates the list of latest events for this room in the database and writes the
|
||||||
|
// event to the output log.
|
||||||
// The latest events are the events that aren't referenced by another event in the database:
|
// The latest events are the events that aren't referenced by another event in the database:
|
||||||
//
|
//
|
||||||
// Time goes down the page. 1 is the m.room.create event (root).
|
// Time goes down the page. 1 is the m.room.create event (root).
|
||||||
|
|
@ -22,9 +24,9 @@ import (
|
||||||
// 7 <----- latest
|
// 7 <----- latest
|
||||||
//
|
//
|
||||||
func updateLatestEvents(
|
func updateLatestEvents(
|
||||||
db RoomEventDatabase, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event,
|
db RoomEventDatabase, ow OutputRoomEventWriter, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
oldLatest, updater, err := db.GetLatestEventsForUpdate(roomNID)
|
updater, err := db.GetLatestEventsForUpdate(roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -42,22 +44,81 @@ func updateLatestEvents(
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = doUpdateLatestEvents(updater, oldLatest, roomNID, stateAtEvent, event)
|
err = doUpdateLatestEvents(db, updater, ow, roomNID, stateAtEvent, event)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func doUpdateLatestEvents(
|
func doUpdateLatestEvents(
|
||||||
updater types.RoomRecentEventsUpdater, oldLatest []types.StateAtEventAndReference, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event,
|
db RoomEventDatabase, updater types.RoomRecentEventsUpdater, ow OutputRoomEventWriter, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event,
|
||||||
) error {
|
) error {
|
||||||
var err error
|
var err error
|
||||||
var prevEvents []gomatrixserverlib.EventReference
|
var prevEvents []gomatrixserverlib.EventReference
|
||||||
prevEvents = event.PrevEvents()
|
prevEvents = event.PrevEvents()
|
||||||
|
oldLatest := updater.LatestEvents()
|
||||||
|
lastEventIDSent := updater.LastEventIDSent()
|
||||||
|
oldStateNID := updater.CurrentStateSnapshotNID()
|
||||||
|
|
||||||
|
if hasBeenSent, err := updater.HasEventBeenSent(stateAtEvent.EventNID); err != nil {
|
||||||
|
return err
|
||||||
|
} else if hasBeenSent {
|
||||||
|
// Already sent this event so we can stop processing
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if err = updater.StorePreviousEvents(stateAtEvent.EventNID, prevEvents); err != nil {
|
if err = updater.StorePreviousEvents(stateAtEvent.EventNID, prevEvents); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this event references any of the latest events in the room.
|
eventReference := event.EventReference()
|
||||||
|
// Check if this event is already referenced by another event in the room.
|
||||||
|
var alreadyReferenced bool
|
||||||
|
if alreadyReferenced, err = updater.IsReferenced(eventReference); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newLatest := calculateLatest(oldLatest, alreadyReferenced, prevEvents, types.StateAtEventAndReference{
|
||||||
|
EventReference: eventReference,
|
||||||
|
StateAtEvent: stateAtEvent,
|
||||||
|
})
|
||||||
|
|
||||||
|
latestStateAtEvents := make([]types.StateAtEvent, len(newLatest))
|
||||||
|
for i := range newLatest {
|
||||||
|
latestStateAtEvents[i] = newLatest[i].StateAtEvent
|
||||||
|
}
|
||||||
|
newStateNID, err := calculateAndStoreStateAfterEvents(db, roomNID, latestStateAtEvents)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
removed, added, err := differenceBetweeenStateSnapshots(db, oldStateNID, newStateNID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the event to the output logs.
|
||||||
|
// We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
|
||||||
|
// (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
|
||||||
|
// the write to the output log succeeds)
|
||||||
|
// TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
|
||||||
|
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in
|
||||||
|
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
|
||||||
|
// necessary bookkeeping we'll keep the event sending synchronous for now.
|
||||||
|
if err = writeEvent(db, ow, lastEventIDSent, event, newLatest, removed, added); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = updater.SetLatestEvents(roomNID, newLatest, stateAtEvent.EventNID, newStateNID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = updater.MarkEventAsSent(stateAtEvent.EventNID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateLatest(oldLatest []types.StateAtEventAndReference, alreadyReferenced bool, prevEvents []gomatrixserverlib.EventReference, newEvent types.StateAtEventAndReference) []types.StateAtEventAndReference {
|
||||||
var alreadyInLatest bool
|
var alreadyInLatest bool
|
||||||
var newLatest []types.StateAtEventAndReference
|
var newLatest []types.StateAtEventAndReference
|
||||||
for _, l := range oldLatest {
|
for _, l := range oldLatest {
|
||||||
|
|
@ -71,7 +132,7 @@ func doUpdateLatestEvents(
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if l.EventNID == stateAtEvent.EventNID {
|
if l.EventNID == newEvent.EventNID {
|
||||||
alreadyInLatest = true
|
alreadyInLatest = true
|
||||||
}
|
}
|
||||||
if keep {
|
if keep {
|
||||||
|
|
@ -80,26 +141,51 @@ func doUpdateLatestEvents(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eventReference := event.EventReference()
|
|
||||||
// Check if this event is already referenced by another event in the room.
|
|
||||||
var alreadyReferenced bool
|
|
||||||
if alreadyReferenced, err = updater.IsReferenced(eventReference); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !alreadyReferenced && !alreadyInLatest {
|
if !alreadyReferenced && !alreadyInLatest {
|
||||||
// This event is not referenced by any of the events in the room
|
// This event is not referenced by any of the events in the room
|
||||||
// and the event is not already in the latest events.
|
// and the event is not already in the latest events.
|
||||||
// Add it to the latest events
|
// Add it to the latest events
|
||||||
newLatest = append(newLatest, types.StateAtEventAndReference{
|
newLatest = append(newLatest, newEvent)
|
||||||
StateAtEvent: stateAtEvent,
|
|
||||||
EventReference: eventReference,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = updater.SetLatestEvents(roomNID, newLatest); err != nil {
|
return newLatest
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeEvent(
|
||||||
|
db RoomEventDatabase, ow OutputRoomEventWriter, lastEventIDSent string,
|
||||||
|
event gomatrixserverlib.Event, latest []types.StateAtEventAndReference,
|
||||||
|
removed, added []types.StateEntry,
|
||||||
|
) error {
|
||||||
|
|
||||||
|
latestEventIDs := make([]string, len(latest))
|
||||||
|
for i := range latest {
|
||||||
|
latestEventIDs[i] = latest[i].EventID
|
||||||
|
}
|
||||||
|
|
||||||
|
ore := api.OutputRoomEvent{
|
||||||
|
Event: event.JSON(),
|
||||||
|
LastSentEventID: lastEventIDSent,
|
||||||
|
LatestEventIDs: latestEventIDs,
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateEventNIDs []types.EventNID
|
||||||
|
for _, entry := range added {
|
||||||
|
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
||||||
|
}
|
||||||
|
for _, entry := range removed {
|
||||||
|
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
||||||
|
}
|
||||||
|
eventIDMap, err := db.EventIDs(stateEventNIDs)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
for _, entry := range added {
|
||||||
|
ore.AddsStateEventIDs = append(ore.AddsStateEventIDs, eventIDMap[entry.EventNID])
|
||||||
|
}
|
||||||
|
for _, entry := range removed {
|
||||||
|
ore.RemovesStateEventIDs = append(ore.RemovesStateEventIDs, eventIDMap[entry.EventNID])
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
// TODO: Fill out VisibilityStateIDs
|
||||||
|
return ow.WriteOutputRoomEvent(ore)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ import (
|
||||||
|
|
||||||
// calculateAndStoreState calculates a snapshot of the state of a room before an event.
|
// calculateAndStoreState calculates a snapshot of the state of a room before an event.
|
||||||
// Stores the snapshot of the state in the database.
|
// Stores the snapshot of the state in the database.
|
||||||
// Returns a numeric ID for that snapshot.
|
// Returns a numeric ID for the snapshot of the state before the event.
|
||||||
func calculateAndStoreState(
|
func calculateAndStoreStateBeforeEvent(
|
||||||
db RoomEventDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID,
|
db RoomEventDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID,
|
||||||
) (types.StateSnapshotNID, error) {
|
) (types.StateSnapshotNID, error) {
|
||||||
// Load the state at the prev events.
|
// Load the state at the prev events.
|
||||||
|
|
@ -25,6 +25,13 @@ func calculateAndStoreState(
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The state before this event will be the state after the events that came before it.
|
||||||
|
return calculateAndStoreStateAfterEvents(db, roomNID, prevStates)
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateAndStoreStateAfterEvents finds the room state after the given events.
|
||||||
|
// Stores the resulting state in the database and returns a numeric ID for that snapshot.
|
||||||
|
func calculateAndStoreStateAfterEvents(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) {
|
||||||
if len(prevStates) == 0 {
|
if len(prevStates) == 0 {
|
||||||
// 2) There weren't any prev_events for this event so the state is
|
// 2) There weren't any prev_events for this event so the state is
|
||||||
// empty.
|
// empty.
|
||||||
|
|
@ -55,9 +62,9 @@ func calculateAndStoreState(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
// If there are too many deltas then we need to calculate the full state
|
// If there are too many deltas then we need to calculate the full state
|
||||||
// So fall through to calculateAndStoreStateMany
|
// So fall through to calculateAndStoreStateAfterManyEvents
|
||||||
}
|
}
|
||||||
return calculateAndStoreStateMany(db, roomNID, prevStates)
|
return calculateAndStoreStateAfterManyEvents(db, roomNID, prevStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
|
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
|
||||||
|
|
@ -67,10 +74,10 @@ func calculateAndStoreState(
|
||||||
// TODO: Tune this to get the right balance between size and lookup performance.
|
// TODO: Tune this to get the right balance between size and lookup performance.
|
||||||
const maxStateBlockNIDs = 64
|
const maxStateBlockNIDs = 64
|
||||||
|
|
||||||
// calculateAndStoreStateMany calculates the state of the room before an event
|
// calculateAndStoreStateAfterManyEvents finds the room state after the given events.
|
||||||
// using the states at each of the event's prev events.
|
// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event.
|
||||||
// Stores the resulting state and returns a numeric ID for the snapshot.
|
// Stores the resulting state and returns a numeric ID for the snapshot.
|
||||||
func calculateAndStoreStateMany(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) {
|
func calculateAndStoreStateAfterManyEvents(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) {
|
||||||
// Conflict resolution.
|
// Conflict resolution.
|
||||||
// First stage: load the state after each of the prev events.
|
// First stage: load the state after each of the prev events.
|
||||||
combined, err := loadCombinedStateAfterEvents(db, prevStates)
|
combined, err := loadCombinedStateAfterEvents(db, prevStates)
|
||||||
|
|
@ -107,6 +114,98 @@ func calculateAndStoreStateMany(db RoomEventDatabase, roomNID types.RoomNID, pre
|
||||||
return db.AddState(roomNID, nil, state)
|
return db.AddState(roomNID, nil, state)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// differenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots.
|
||||||
|
func differenceBetweeenStateSnapshots(db RoomEventDatabase, oldStateNID, newStateNID types.StateSnapshotNID) (
|
||||||
|
removed, added []types.StateEntry, err error,
|
||||||
|
) {
|
||||||
|
if oldStateNID == newStateNID {
|
||||||
|
// If the snapshot NIDs are the same then nothing has changed
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldEntries []types.StateEntry
|
||||||
|
var newEntries []types.StateEntry
|
||||||
|
if oldStateNID != 0 {
|
||||||
|
oldEntries, err = loadStateAtSnapshot(db, oldStateNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newStateNID != 0 {
|
||||||
|
newEntries, err = loadStateAtSnapshot(db, newStateNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldI int
|
||||||
|
var newI int
|
||||||
|
for {
|
||||||
|
switch {
|
||||||
|
case oldI == len(oldEntries):
|
||||||
|
// We've reached the end of the old entries.
|
||||||
|
// The rest of the new list must have been newly added.
|
||||||
|
added = append(added, newEntries[newI:]...)
|
||||||
|
return
|
||||||
|
case newI == len(newEntries):
|
||||||
|
// We've reached the end of the new entries.
|
||||||
|
// The rest of the old list must be have been removed.
|
||||||
|
removed = append(removed, oldEntries[oldI:]...)
|
||||||
|
return
|
||||||
|
case oldEntries[oldI] == newEntries[newI]:
|
||||||
|
// The entry is in both lists so skip over it.
|
||||||
|
oldI++
|
||||||
|
newI++
|
||||||
|
case oldEntries[oldI].LessThan(newEntries[newI]):
|
||||||
|
// The lists are sorted so the old entry being less than the new entry means that it only appears in the old list.
|
||||||
|
removed = append(removed, oldEntries[oldI])
|
||||||
|
oldI++
|
||||||
|
default:
|
||||||
|
// Reaching the default case implies that the new entry is less than the old entry.
|
||||||
|
// Since the lists are sorted this means that it only appears in the new list.
|
||||||
|
added = append(added, newEntries[newI])
|
||||||
|
newI++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadStateAtSnapshot loads the full state of a room at a particular snapshot.
|
||||||
|
// This is typically the state before an event or the current state of a room.
|
||||||
|
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||||
|
func loadStateAtSnapshot(db RoomEventDatabase, stateNID types.StateSnapshotNID) ([]types.StateEntry, error) {
|
||||||
|
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateBlockNIDList := stateBlockNIDLists[0]
|
||||||
|
|
||||||
|
stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
||||||
|
|
||||||
|
// Combined all the state entries for this snapshot.
|
||||||
|
// The order of state block NIDs in the list tells us the order to combine them in.
|
||||||
|
var fullState []types.StateEntry
|
||||||
|
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
|
||||||
|
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
||||||
|
if !ok {
|
||||||
|
// This should only get hit if the database is corrupt.
|
||||||
|
// It should be impossible for an event to reference a NID that doesn't exist
|
||||||
|
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
|
||||||
|
}
|
||||||
|
fullState = append(fullState, entries...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stable sort so that the most recent entry for each state key stays
|
||||||
|
// remains later in the list than the older entries for the same state key.
|
||||||
|
sort.Stable(stateEntryByStateKeySorter(fullState))
|
||||||
|
// Unique returns the last entry and hence the most recent entry for each state key.
|
||||||
|
fullState = fullState[:unique(stateEntryByStateKeySorter(fullState))]
|
||||||
|
return fullState, nil
|
||||||
|
}
|
||||||
|
|
||||||
// loadCombinedStateAfterEvents loads a snapshot of the state after each of the events
|
// loadCombinedStateAfterEvents loads a snapshot of the state after each of the events
|
||||||
// and combines those snapshots together into a single list.
|
// and combines those snapshots together into a single list.
|
||||||
func loadCombinedStateAfterEvents(db RoomEventDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) {
|
func loadCombinedStateAfterEvents(db RoomEventDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) {
|
||||||
|
|
@ -146,18 +245,18 @@ func loadCombinedStateAfterEvents(db RoomEventDatabase, prevStates []types.State
|
||||||
if !ok {
|
if !ok {
|
||||||
// This should only get hit if the database is corrupt.
|
// This should only get hit if the database is corrupt.
|
||||||
// It should be impossible for an event to reference a NID that doesn't exist
|
// It should be impossible for an event to reference a NID that doesn't exist
|
||||||
panic(fmt.Errorf("Corrupt DB: Missing state numeric ID %d", prevState.BeforeStateSnapshotNID))
|
panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combined all the state entries for this snapshot.
|
// Combined all the state entries for this snapshot.
|
||||||
// The order of state data NIDs in the list tells us the order to combine them in.
|
// The order of state block NIDs in the list tells us the order to combine them in.
|
||||||
var fullState []types.StateEntry
|
var fullState []types.StateEntry
|
||||||
for _, stateBlockNID := range stateBlockNIDs {
|
for _, stateBlockNID := range stateBlockNIDs {
|
||||||
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
||||||
if !ok {
|
if !ok {
|
||||||
// This should only get hit if the database is corrupt.
|
// This should only get hit if the database is corrupt.
|
||||||
// It should be impossible for an event to reference a NID that doesn't exist
|
// It should be impossible for an event to reference a NID that doesn't exist
|
||||||
panic(fmt.Errorf("Corrupt DB: Missing state numeric ID %d", prevState.BeforeStateSnapshotNID))
|
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
|
||||||
}
|
}
|
||||||
fullState = append(fullState, entries...)
|
fullState = append(fullState, entries...)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
68
src/github.com/matrix-org/dendrite/roomserver/query/query.go
Normal file
68
src/github.com/matrix-org/dendrite/roomserver/query/query.go
Normal file
|
|
@ -0,0 +1,68 @@
|
||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API.
|
||||||
|
type RoomserverQueryAPIDatabase interface {
|
||||||
|
// Lookup the numeric ID for the room.
|
||||||
|
// Returns 0 if the room doesn't exists.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
RoomNID(roomID string) (types.RoomNID, error)
|
||||||
|
// Lookup event references for the latest events in the room.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomserverQueryAPI is an implementation of RoomserverQueryAPI
|
||||||
|
type RoomserverQueryAPI struct {
|
||||||
|
DB RoomserverQueryAPIDatabase
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryLatestEventsAndState implements api.RoomserverQueryAPI
|
||||||
|
func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
|
||||||
|
request *api.QueryLatestEventsAndStateRequest,
|
||||||
|
response *api.QueryLatestEventsAndStateResponse,
|
||||||
|
) (err error) {
|
||||||
|
response.QueryLatestEventsAndStateRequest = *request
|
||||||
|
roomNID, err := r.DB.RoomNID(request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if roomNID == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
response.RoomExists = true
|
||||||
|
response.LatestEvents, err = r.DB.LatestEventIDs(roomNID)
|
||||||
|
// TODO: look up the current state.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux.
|
||||||
|
func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
|
||||||
|
servMux.Handle(
|
||||||
|
api.RoomserverQueryLatestEventsAndStatePath,
|
||||||
|
makeAPI("query_latest_events_and_state", func(req *http.Request) util.JSONResponse {
|
||||||
|
var request api.QueryLatestEventsAndStateRequest
|
||||||
|
var response api.QueryLatestEventsAndStateResponse
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
if err := r.QueryLatestEventsAndState(&request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: 200, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeAPI(metric string, apiFunc func(req *http.Request) util.JSONResponse) http.Handler {
|
||||||
|
return prometheus.InstrumentHandler(metric, util.MakeJSONAPI(util.NewJSONRequestHandler(apiFunc)))
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,382 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Path to where kafka is installed.
|
||||||
|
kafkaDir = defaulting(os.Getenv("KAFKA_DIR"), "kafka")
|
||||||
|
// The URI the kafka zookeeper is listening on.
|
||||||
|
zookeeperURI = defaulting(os.Getenv("ZOOKEEPER_URI"), "localhost:2181")
|
||||||
|
// The URI the kafka server is listening on.
|
||||||
|
kafkaURI = defaulting(os.Getenv("KAFKA_URIS"), "localhost:9092")
|
||||||
|
// The address the roomserver should listen on.
|
||||||
|
roomserverAddr = defaulting(os.Getenv("ROOMSERVER_URI"), "localhost:9876")
|
||||||
|
// How long to wait for the roomserver to write the expected output messages.
|
||||||
|
// This needs to be high enough to account for the time it takes to create
|
||||||
|
// the postgres database tables which can take a while on travis.
|
||||||
|
timeoutString = defaulting(os.Getenv("TIMEOUT"), "60s")
|
||||||
|
// The name of maintenance database to connect to in order to create the test database.
|
||||||
|
postgresDatabase = defaulting(os.Getenv("POSTGRES_DATABASE"), "postgres")
|
||||||
|
// The name of the test database to create.
|
||||||
|
testDatabaseName = defaulting(os.Getenv("DATABASE_NAME"), "roomserver_test")
|
||||||
|
// The postgres connection config for connecting to the test database.
|
||||||
|
testDatabase = defaulting(os.Getenv("DATABASE"), fmt.Sprintf("dbname=%s binary_parameters=yes", testDatabaseName))
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaulting(value, defaultValue string) string {
|
||||||
|
if value == "" {
|
||||||
|
value = defaultValue
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
var timeout time.Duration
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
timeout, err = time.ParseDuration(timeoutString)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createDatabase(database string) error {
|
||||||
|
cmd := exec.Command("psql", postgresDatabase)
|
||||||
|
cmd.Stdin = strings.NewReader(
|
||||||
|
fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", database, database),
|
||||||
|
)
|
||||||
|
// Send stdout and stderr to our stderr so that we see error messages from
|
||||||
|
// the psql process
|
||||||
|
cmd.Stdout = os.Stderr
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTopic(topic string) error {
|
||||||
|
cmd := exec.Command(
|
||||||
|
filepath.Join(kafkaDir, "bin", "kafka-topics.sh"),
|
||||||
|
"--create",
|
||||||
|
"--zookeeper", zookeeperURI,
|
||||||
|
"--replication-factor", "1",
|
||||||
|
"--partitions", "1",
|
||||||
|
"--topic", topic,
|
||||||
|
)
|
||||||
|
// Send stdout and stderr to our stderr so that we see error messages from
|
||||||
|
// the kafka process.
|
||||||
|
cmd.Stdout = os.Stderr
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeToTopic(topic string, data []string) error {
|
||||||
|
cmd := exec.Command(
|
||||||
|
filepath.Join(kafkaDir, "bin", "kafka-console-producer.sh"),
|
||||||
|
"--broker-list", kafkaURI,
|
||||||
|
"--topic", topic,
|
||||||
|
)
|
||||||
|
// Send stdout and stderr to our stderr so that we see error messages from
|
||||||
|
// the kafka process.
|
||||||
|
cmd.Stdout = os.Stderr
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
cmd.Stdin = strings.NewReader(strings.Join(data, "\n"))
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// runAndReadFromTopic runs a command and waits for a number of messages to be
|
||||||
|
// written to a kafka topic. It returns if the command exits, the number of
|
||||||
|
// messages is reached or after a timeout. It kills the command before it returns.
|
||||||
|
// It returns a list of the messages read from the command on success or an error
|
||||||
|
// on failure.
|
||||||
|
func runAndReadFromTopic(runCmd *exec.Cmd, topic string, count int, checkQueryAPI func()) ([]string, error) {
|
||||||
|
type result struct {
|
||||||
|
// data holds all of stdout on success.
|
||||||
|
data []byte
|
||||||
|
// err is set on failure.
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
done := make(chan result)
|
||||||
|
readCmd := exec.Command(
|
||||||
|
filepath.Join(kafkaDir, "bin", "kafka-console-consumer.sh"),
|
||||||
|
"--bootstrap-server", kafkaURI,
|
||||||
|
"--topic", topic,
|
||||||
|
"--from-beginning",
|
||||||
|
"--max-messages", fmt.Sprintf("%d", count),
|
||||||
|
)
|
||||||
|
// Send stderr to our stderr so the user can see any error messages.
|
||||||
|
readCmd.Stderr = os.Stderr
|
||||||
|
// Run the command, read the messages and wait for a timeout in parallel.
|
||||||
|
go func() {
|
||||||
|
// Read all of stdout.
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
if errv, ok := err.(error); ok {
|
||||||
|
done <- result{nil, errv}
|
||||||
|
} else {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
data, err := readCmd.Output()
|
||||||
|
checkQueryAPI()
|
||||||
|
done <- result{data, err}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
err := runCmd.Run()
|
||||||
|
done <- result{nil, err}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
time.Sleep(timeout)
|
||||||
|
done <- result{nil, fmt.Errorf("Timeout reading %d messages from topic %q", count, topic)}
|
||||||
|
}()
|
||||||
|
// Wait for one of the tasks to finsh.
|
||||||
|
r := <-done
|
||||||
|
|
||||||
|
// Kill both processes. We don't check if the processes are running and
|
||||||
|
// we ignore failures since we are just trying to clean up before returning.
|
||||||
|
runCmd.Process.Kill()
|
||||||
|
readCmd.Process.Kill()
|
||||||
|
|
||||||
|
if r.err != nil {
|
||||||
|
return nil, r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// The kafka console consumer writes a newline character after each message.
|
||||||
|
// So we split on newline characters
|
||||||
|
lines := strings.Split(string(r.data), "\n")
|
||||||
|
if len(lines) > 0 {
|
||||||
|
// Remove the blank line at the end of the data.
|
||||||
|
lines = lines[:len(lines)-1]
|
||||||
|
}
|
||||||
|
return lines, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteTopic(topic string) error {
|
||||||
|
cmd := exec.Command(
|
||||||
|
filepath.Join(kafkaDir, "bin", "kafka-topics.sh"),
|
||||||
|
"--delete",
|
||||||
|
"--if-exists",
|
||||||
|
"--zookeeper", zookeeperURI,
|
||||||
|
"--topic", topic,
|
||||||
|
)
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
cmd.Stdout = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// testRoomserver is used to run integration tests against a single roomserver.
|
||||||
|
// It creates new kafka topics for the input and output of the roomserver.
|
||||||
|
// It writes the input messages to the input kafka topic, formatting each message
|
||||||
|
// as canonical JSON so that it fits on a single line.
|
||||||
|
// It then runs the roomserver and waits for a number of messages to be written
|
||||||
|
// to the output topic.
|
||||||
|
// Once those messages have been written it runs the checkQueries function passing
|
||||||
|
// a api.RoomserverQueryAPI client. The caller can use this function to check the
|
||||||
|
// behaviour of the query API.
|
||||||
|
func testRoomserver(input []string, wantOutput []string, checkQueries func(api.RoomserverQueryAPI)) {
|
||||||
|
const (
|
||||||
|
inputTopic = "roomserverInput"
|
||||||
|
outputTopic = "roomserverOutput"
|
||||||
|
)
|
||||||
|
deleteTopic(inputTopic)
|
||||||
|
if err := createTopic(inputTopic); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
deleteTopic(outputTopic)
|
||||||
|
if err := createTopic(outputTopic); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeToTopic(inputTopic, canonicalJSONInput(input)); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := createDatabase(testDatabaseName); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(filepath.Join(filepath.Dir(os.Args[0]), "roomserver"))
|
||||||
|
|
||||||
|
// Append the roomserver config to the existing environment.
|
||||||
|
// We append to the environment rather than replacing so that any additional
|
||||||
|
// postgres and golang environment variables such as PGHOST are passed to
|
||||||
|
// the roomserver process.
|
||||||
|
cmd.Env = append(
|
||||||
|
os.Environ(),
|
||||||
|
fmt.Sprintf("DATABASE=%s", testDatabase),
|
||||||
|
fmt.Sprintf("KAFKA_URIS=%s", kafkaURI),
|
||||||
|
fmt.Sprintf("TOPIC_INPUT_ROOM_EVENT=%s", inputTopic),
|
||||||
|
fmt.Sprintf("TOPIC_OUTPUT_ROOM_EVENT=%s", outputTopic),
|
||||||
|
fmt.Sprintf("BIND_ADDRESS=%s", roomserverAddr),
|
||||||
|
)
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
gotOutput, err := runAndReadFromTopic(cmd, outputTopic, len(wantOutput), func() {
|
||||||
|
queryAPI := api.NewRoomserverQueryAPIHTTP("http://"+roomserverAddr, nil)
|
||||||
|
checkQueries(queryAPI)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(wantOutput) != len(gotOutput) {
|
||||||
|
panic(fmt.Errorf("Wanted %d lines of output got %d lines", len(wantOutput), len(gotOutput)))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range wantOutput {
|
||||||
|
if !equalJSON(wantOutput[i], gotOutput[i]) {
|
||||||
|
panic(fmt.Errorf("Wanted %q at index %d got %q", wantOutput[i], i, gotOutput[i]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func canonicalJSONInput(jsonData []string) []string {
|
||||||
|
for i := range jsonData {
|
||||||
|
jsonBytes, err := gomatrixserverlib.CanonicalJSON([]byte(jsonData[i]))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
jsonData[i] = string(jsonBytes)
|
||||||
|
}
|
||||||
|
return jsonData
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalJSON(a, b string) bool {
|
||||||
|
canonicalA, err := gomatrixserverlib.CanonicalJSON([]byte(a))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
canonicalB, err := gomatrixserverlib.CanonicalJSON([]byte(b))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return string(canonicalA) == string(canonicalB)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("==TESTING==", os.Args[0])
|
||||||
|
|
||||||
|
input := []string{
|
||||||
|
`{
|
||||||
|
"AuthEventIDs": [],
|
||||||
|
"Kind": 1,
|
||||||
|
"Event": {
|
||||||
|
"origin": "matrix.org",
|
||||||
|
"signatures": {
|
||||||
|
"matrix.org": {
|
||||||
|
"ed25519:auto": "3kXGwNtdj+zqEXlI8PWLiB76xtrQ7SxcvPuXAEVCTo+QPoBoUvLi1RkHs6O5mDz7UzIowK5bi1seAN4vOh0OBA"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"origin_server_ts": 1463671337837,
|
||||||
|
"sender": "@richvdh:matrix.org",
|
||||||
|
"event_id": "$1463671337126266wrSBX:matrix.org",
|
||||||
|
"prev_events": [],
|
||||||
|
"state_key": "",
|
||||||
|
"content": {"creator": "@richvdh:matrix.org"},
|
||||||
|
"depth": 1,
|
||||||
|
"prev_state": [],
|
||||||
|
"room_id": "!HCXfdvrfksxuYnIFiJ:matrix.org",
|
||||||
|
"auth_events": [],
|
||||||
|
"hashes": {"sha256": "Q05VLC8nztN2tguy+KnHxxhitI95wK9NelnsDaXRqeo"},
|
||||||
|
"type": "m.room.create"}
|
||||||
|
}`, `{
|
||||||
|
"AuthEventIDs": ["$1463671337126266wrSBX:matrix.org"],
|
||||||
|
"Kind": 2,
|
||||||
|
"StateEventIDs": ["$1463671337126266wrSBX:matrix.org"],
|
||||||
|
"Event": {
|
||||||
|
"origin": "matrix.org",
|
||||||
|
"signatures": {
|
||||||
|
"matrix.org": {
|
||||||
|
"ed25519:auto": "a2b3xXYVPPFeG1sHCU3hmZnAaKqZFgzGZozijRGblG5Y//ewRPAn1A2mCrI2UM5I+0zqr70cNpHgF8bmNFu4BA"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"origin_server_ts": 1463671339844,
|
||||||
|
"sender": "@richvdh:matrix.org",
|
||||||
|
"event_id": "$1463671339126270PnVwC:matrix.org",
|
||||||
|
"prev_events": [[
|
||||||
|
"$1463671337126266wrSBX:matrix.org", {"sha256": "h/VS07u8KlMwT3Ee8JhpkC7sa1WUs0Srgs+l3iBv6c0"}
|
||||||
|
]],
|
||||||
|
"membership": "join",
|
||||||
|
"state_key": "@richvdh:matrix.org",
|
||||||
|
"content": {
|
||||||
|
"membership": "join",
|
||||||
|
"avatar_url": "mxc://matrix.org/ZafPzsxMJtLaSaJXloBEKiws",
|
||||||
|
"displayname": "richvdh"
|
||||||
|
},
|
||||||
|
"depth": 2,
|
||||||
|
"prev_state": [],
|
||||||
|
"room_id": "!HCXfdvrfksxuYnIFiJ:matrix.org",
|
||||||
|
"auth_events": [[
|
||||||
|
"$1463671337126266wrSBX:matrix.org", {"sha256": "h/VS07u8KlMwT3Ee8JhpkC7sa1WUs0Srgs+l3iBv6c0"}
|
||||||
|
]],
|
||||||
|
"hashes": {"sha256": "t9t3sZV1Eu0P9Jyrs7pge6UTa1zuTbRdVxeUHnrQVH0"},
|
||||||
|
"type": "m.room.member"},
|
||||||
|
"HasState": true
|
||||||
|
}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []string{
|
||||||
|
`{
|
||||||
|
"Event":{
|
||||||
|
"auth_events":[[
|
||||||
|
"$1463671337126266wrSBX:matrix.org",{"sha256":"h/VS07u8KlMwT3Ee8JhpkC7sa1WUs0Srgs+l3iBv6c0"}
|
||||||
|
]],
|
||||||
|
"content":{
|
||||||
|
"avatar_url":"mxc://matrix.org/ZafPzsxMJtLaSaJXloBEKiws",
|
||||||
|
"displayname":"richvdh",
|
||||||
|
"membership":"join"
|
||||||
|
},
|
||||||
|
"depth": 2,
|
||||||
|
"event_id": "$1463671339126270PnVwC:matrix.org",
|
||||||
|
"hashes": {"sha256":"t9t3sZV1Eu0P9Jyrs7pge6UTa1zuTbRdVxeUHnrQVH0"},
|
||||||
|
"membership": "join",
|
||||||
|
"origin": "matrix.org",
|
||||||
|
"origin_server_ts": 1463671339844,
|
||||||
|
"prev_events": [[
|
||||||
|
"$1463671337126266wrSBX:matrix.org",{"sha256":"h/VS07u8KlMwT3Ee8JhpkC7sa1WUs0Srgs+l3iBv6c0"}
|
||||||
|
]],
|
||||||
|
"prev_state":[],
|
||||||
|
"room_id":"!HCXfdvrfksxuYnIFiJ:matrix.org",
|
||||||
|
"sender":"@richvdh:matrix.org",
|
||||||
|
"signatures":{
|
||||||
|
"matrix.org":{
|
||||||
|
"ed25519:auto":"a2b3xXYVPPFeG1sHCU3hmZnAaKqZFgzGZozijRGblG5Y//ewRPAn1A2mCrI2UM5I+0zqr70cNpHgF8bmNFu4BA"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"state_key":"@richvdh:matrix.org",
|
||||||
|
"type":"m.room.member"
|
||||||
|
},
|
||||||
|
"VisibilityEventIDs":null,
|
||||||
|
"LatestEventIDs":["$1463671339126270PnVwC:matrix.org"],
|
||||||
|
"AddsStateEventIDs":["$1463671337126266wrSBX:matrix.org", "$1463671339126270PnVwC:matrix.org"],
|
||||||
|
"RemovesStateEventIDs":null,
|
||||||
|
"LastSentEventID":""
|
||||||
|
}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRoomserver(input, want, func(q api.RoomserverQueryAPI) {
|
||||||
|
var response api.QueryLatestEventsAndStateResponse
|
||||||
|
if err := q.QueryLatestEventsAndState(
|
||||||
|
&api.QueryLatestEventsAndStateRequest{RoomID: "!HCXfdvrfksxuYnIFiJ:matrix.org"},
|
||||||
|
&response,
|
||||||
|
); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if !response.RoomExists {
|
||||||
|
panic(fmt.Errorf(`Wanted room "!HCXfdvrfksxuYnIFiJ:matrix.org" to exist`))
|
||||||
|
}
|
||||||
|
if len(response.LatestEvents) != 1 || response.LatestEvents[0].EventID != "$1463671339126270PnVwC:matrix.org" {
|
||||||
|
panic(fmt.Errorf(`Wanted "$1463671339126270PnVwC:matrix.org" to be the latest event got %#v`, response.LatestEvents))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Println("==PASSED==", os.Args[0])
|
||||||
|
}
|
||||||
|
|
@ -3,16 +3,20 @@ package main
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/matrix-org/dendrite/roomserver/input"
|
"github.com/matrix-org/dendrite/roomserver/input"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/query"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
sarama "gopkg.in/Shopify/sarama.v1"
|
sarama "gopkg.in/Shopify/sarama.v1"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
database = os.Getenv("DATABASE")
|
database = os.Getenv("DATABASE")
|
||||||
kafkaURIs = strings.Split(os.Getenv("KAFKA_URIS"), ",")
|
kafkaURIs = strings.Split(os.Getenv("KAFKA_URIS"), ",")
|
||||||
roomEventTopic = os.Getenv("TOPIC_ROOM_EVENT")
|
inputRoomEventTopic = os.Getenv("TOPIC_INPUT_ROOM_EVENT")
|
||||||
|
outputRoomEventTopic = os.Getenv("TOPIC_OUTPUT_ROOM_EVENT")
|
||||||
|
bindAddr = os.Getenv("BIND_ADDRESS")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
@ -26,19 +30,31 @@ func main() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kafkaProducer, err := sarama.NewSyncProducer(kafkaURIs, nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
consumer := input.Consumer{
|
consumer := input.Consumer{
|
||||||
Consumer: kafkaConsumer,
|
Consumer: kafkaConsumer,
|
||||||
DB: db,
|
DB: db,
|
||||||
RoomEventTopic: roomEventTopic,
|
Producer: kafkaProducer,
|
||||||
|
InputRoomEventTopic: inputRoomEventTopic,
|
||||||
|
OutputRoomEventTopic: outputRoomEventTopic,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = consumer.Start(); err != nil {
|
if err = consumer.Start(); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
queryAPI := query.RoomserverQueryAPI{
|
||||||
|
DB: db,
|
||||||
|
}
|
||||||
|
|
||||||
|
queryAPI.SetupHTTP(http.DefaultServeMux)
|
||||||
|
|
||||||
fmt.Println("Started roomserver")
|
fmt.Println("Started roomserver")
|
||||||
|
|
||||||
// Wait forever.
|
|
||||||
// TODO: Implement clean shutdown.
|
// TODO: Implement clean shutdown.
|
||||||
select {}
|
http.ListenAndServe(bindAddr, nil)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -45,13 +44,10 @@ func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.insertEventJSONStmt, err = db.Prepare(insertEventJSONSQL); err != nil {
|
return statementList{
|
||||||
return
|
{&s.insertEventJSONStmt, insertEventJSONSQL},
|
||||||
}
|
{&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL},
|
||||||
if s.bulkSelectEventJSONStmt, err = db.Prepare(bulkSelectEventJSONSQL); err != nil {
|
}.prepare(db)
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventJSONStatements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error {
|
func (s *eventJSONStatements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error {
|
||||||
|
|
@ -65,11 +61,7 @@ type eventJSONPair struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventJSONStatements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) {
|
func (s *eventJSONStatements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) {
|
||||||
nids := make([]int64, len(eventNIDs))
|
rows, err := s.bulkSelectEventJSONStmt.Query(eventNIDsAsArray(eventNIDs))
|
||||||
for i := range eventNIDs {
|
|
||||||
nids[i] = int64(eventNIDs[i])
|
|
||||||
}
|
|
||||||
rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(nids))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -53,16 +53,11 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.insertEventStateKeyNIDStmt, err = db.Prepare(insertEventStateKeyNIDSQL); err != nil {
|
return statementList{
|
||||||
return
|
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
|
||||||
}
|
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
|
||||||
if s.selectEventStateKeyNIDStmt, err = db.Prepare(selectEventStateKeyNIDSQL); err != nil {
|
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
|
||||||
return
|
}.prepare(db)
|
||||||
}
|
|
||||||
if s.bulkSelectEventStateKeyNIDStmt, err = db.Prepare(bulkSelectEventStateKeyNIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) insertEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) {
|
func (s *eventStateKeyStatements) insertEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) {
|
||||||
|
|
|
||||||
|
|
@ -76,13 +76,11 @@ func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.insertEventTypeNIDStmt, err = db.Prepare(insertEventTypeNIDSQL); err != nil {
|
|
||||||
return
|
return statementList{
|
||||||
}
|
{&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
|
||||||
if s.selectEventTypeNIDStmt, err = db.Prepare(selectEventTypeNIDSQL); err != nil {
|
{&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
|
||||||
return
|
}.prepare(db)
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventTypeStatements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
func (s *eventTypeStatements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const eventsSchema = `
|
const eventsSchema = `
|
||||||
|
|
@ -23,6 +24,8 @@ CREATE TABLE IF NOT EXISTS events (
|
||||||
-- Local numeric ID for the state_key of the event
|
-- Local numeric ID for the state_key of the event
|
||||||
-- This is 0 if the event is not a state event.
|
-- This is 0 if the event is not a state event.
|
||||||
event_state_key_nid BIGINT NOT NULL,
|
event_state_key_nid BIGINT NOT NULL,
|
||||||
|
-- Whether the event has been written to the output log.
|
||||||
|
sent_to_output BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
-- Local numeric ID for the state at the event.
|
-- Local numeric ID for the state at the event.
|
||||||
-- This is 0 if we don't know the state at the event.
|
-- This is 0 if we don't know the state at the event.
|
||||||
-- If the state is not 0 then this event is part of the contiguous
|
-- If the state is not 0 then this event is part of the contiguous
|
||||||
|
|
@ -68,17 +71,37 @@ const bulkSelectStateAtEventByIDSQL = "" +
|
||||||
const updateEventStateSQL = "" +
|
const updateEventStateSQL = "" +
|
||||||
"UPDATE events SET state_snapshot_nid = $2 WHERE event_nid = $1"
|
"UPDATE events SET state_snapshot_nid = $2 WHERE event_nid = $1"
|
||||||
|
|
||||||
|
const selectEventSentToOutputSQL = "" +
|
||||||
|
"SELECT sent_to_output FROM events WHERE event_nid = $1"
|
||||||
|
|
||||||
|
const updateEventSentToOutputSQL = "" +
|
||||||
|
"UPDATE events SET sent_to_output = TRUE WHERE event_nid = $1"
|
||||||
|
|
||||||
|
const selectEventIDSQL = "" +
|
||||||
|
"SELECT event_id FROM events WHERE event_nid = $1"
|
||||||
|
|
||||||
const bulkSelectStateAtEventAndReferenceSQL = "" +
|
const bulkSelectStateAtEventAndReferenceSQL = "" +
|
||||||
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
|
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
|
||||||
" FROM events WHERE event_nid = ANY($1)"
|
" FROM events WHERE event_nid = ANY($1)"
|
||||||
|
|
||||||
|
const bulkSelectEventReferenceSQL = "" +
|
||||||
|
"SELECT event_id, reference_sha256 FROM events WHERE event_nid = ANY($1)"
|
||||||
|
|
||||||
|
const bulkSelectEventIDSQL = "" +
|
||||||
|
"SELECT event_nid, event_id FROM events WHERE event_nid = ANY($1)"
|
||||||
|
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||||
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
||||||
updateEventStateStmt *sql.Stmt
|
updateEventStateStmt *sql.Stmt
|
||||||
|
selectEventSentToOutputStmt *sql.Stmt
|
||||||
|
updateEventSentToOutputStmt *sql.Stmt
|
||||||
|
selectEventIDStmt *sql.Stmt
|
||||||
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
||||||
|
bulkSelectEventReferenceStmt *sql.Stmt
|
||||||
|
bulkSelectEventIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
@ -86,25 +109,20 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
|
|
||||||
return
|
return statementList{
|
||||||
}
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
if s.selectEventStmt, err = db.Prepare(selectEventSQL); err != nil {
|
{&s.selectEventStmt, selectEventSQL},
|
||||||
return
|
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||||
}
|
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
||||||
if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil {
|
{&s.updateEventStateStmt, updateEventStateSQL},
|
||||||
return
|
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
|
||||||
}
|
{&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL},
|
||||||
if s.bulkSelectStateAtEventByIDStmt, err = db.Prepare(bulkSelectStateAtEventByIDSQL); err != nil {
|
{&s.selectEventIDStmt, selectEventIDSQL},
|
||||||
return
|
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
|
||||||
}
|
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
|
||||||
if s.updateEventStateStmt, err = db.Prepare(updateEventStateSQL); err != nil {
|
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
|
||||||
return
|
}.prepare(db)
|
||||||
}
|
|
||||||
if s.bulkSelectStateAtEventAndReferenceStmt, err = db.Prepare(bulkSelectStateAtEventAndReferenceSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) insertEvent(
|
func (s *eventStatements) insertEvent(
|
||||||
|
|
@ -113,15 +131,11 @@ func (s *eventStatements) insertEvent(
|
||||||
referenceSHA256 []byte,
|
referenceSHA256 []byte,
|
||||||
authEventNIDs []types.EventNID,
|
authEventNIDs []types.EventNID,
|
||||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||||
nids := make([]int64, len(authEventNIDs))
|
|
||||||
for i := range authEventNIDs {
|
|
||||||
nids[i] = int64(authEventNIDs[i])
|
|
||||||
}
|
|
||||||
var eventNID int64
|
var eventNID int64
|
||||||
var stateNID int64
|
var stateNID int64
|
||||||
err := s.insertEventStmt.QueryRow(
|
err := s.insertEventStmt.QueryRow(
|
||||||
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256,
|
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256,
|
||||||
pq.Int64Array(nids),
|
eventNIDsAsArray(authEventNIDs),
|
||||||
).Scan(&eventNID, &stateNID)
|
).Scan(&eventNID, &stateNID)
|
||||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||||
}
|
}
|
||||||
|
|
@ -199,12 +213,23 @@ func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID typ
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) {
|
||||||
|
err = txn.Stmt(s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error {
|
||||||
|
_, err := txn.Stmt(s.updateEventSentToOutputStmt).Exec(int64(eventNID))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) {
|
||||||
|
err = txn.Stmt(s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
|
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
|
||||||
nids := make([]int64, len(eventNIDs))
|
rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs))
|
||||||
for i := range eventNIDs {
|
|
||||||
nids[i] = int64(eventNIDs[i])
|
|
||||||
}
|
|
||||||
rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(pq.Int64Array(nids))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -238,3 +263,54 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventN
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) {
|
||||||
|
rows, err := s.bulkSelectEventReferenceStmt.Query(eventNIDsAsArray(eventNIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
results := make([]gomatrixserverlib.EventReference, len(eventNIDs))
|
||||||
|
i := 0
|
||||||
|
for ; rows.Next(); i++ {
|
||||||
|
result := &results[i]
|
||||||
|
if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i != len(eventNIDs) {
|
||||||
|
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||||
|
func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||||
|
rows, err := s.bulkSelectEventIDStmt.Query(eventNIDsAsArray(eventNIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
results := make(map[types.EventNID]string, len(eventNIDs))
|
||||||
|
i := 0
|
||||||
|
for ; rows.Next(); i++ {
|
||||||
|
var eventNID int64
|
||||||
|
var eventID string
|
||||||
|
if err = rows.Scan(&eventNID, &eventID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results[types.EventNID(eventNID)] = eventID
|
||||||
|
}
|
||||||
|
if i != len(eventNIDs) {
|
||||||
|
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
|
||||||
|
nids := make([]int64, len(eventNIDs))
|
||||||
|
for i := range eventNIDs {
|
||||||
|
nids[i] = int64(eventNIDs[i])
|
||||||
|
}
|
||||||
|
return nids
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,13 +36,11 @@ func (s *partitionOffsetStatements) prepare(db *sql.DB) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.selectPartitionOffsetsStmt, err = db.Prepare(selectPartitionOffsetsSQL); err != nil {
|
|
||||||
return
|
return statementList{
|
||||||
}
|
{&s.selectPartitionOffsetsStmt, selectPartitionOffsetsSQL},
|
||||||
if s.upsertPartitionOffsetStmt, err = db.Prepare(upsertPartitionOffsetsSQL); err != nil {
|
{&s.upsertPartitionOffsetStmt, upsertPartitionOffsetsSQL},
|
||||||
return
|
}.prepare(db)
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *partitionOffsetStatements) selectPartitionOffsets(topic string) ([]types.PartitionOffset, error) {
|
func (s *partitionOffsetStatements) selectPartitionOffsets(topic string) ([]types.PartitionOffset, error) {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
|
||||||
|
type statementList []struct {
|
||||||
|
statement **sql.Stmt
|
||||||
|
sql string
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
|
||||||
|
func (s statementList) prepare(db *sql.DB) (err error) {
|
||||||
|
for _, statement := range s {
|
||||||
|
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -50,13 +50,11 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.insertPreviousEventStmt, err = db.Prepare(insertPreviousEventSQL); err != nil {
|
|
||||||
return
|
return statementList{
|
||||||
}
|
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
|
||||||
if s.selectPreviousEventExistsStmt, err = db.Prepare(selectPreviousEventExistsSQL); err != nil {
|
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
|
||||||
return
|
}.prepare(db)
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error {
|
func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error {
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,12 @@ CREATE TABLE IF NOT EXISTS rooms (
|
||||||
-- The most recent events in the room that aren't referenced by another event.
|
-- The most recent events in the room that aren't referenced by another event.
|
||||||
-- This list may empty if the server hasn't joined the room yet.
|
-- This list may empty if the server hasn't joined the room yet.
|
||||||
-- (The server will be in that state while it stores the events for the initial state of the room)
|
-- (The server will be in that state while it stores the events for the initial state of the room)
|
||||||
latest_event_nids BIGINT[] NOT NULL DEFAULT '{}'::BIGINT[]
|
latest_event_nids BIGINT[] NOT NULL DEFAULT '{}'::BIGINT[],
|
||||||
|
-- The last event written to the output log for this room.
|
||||||
|
last_event_sent_nid BIGINT NOT NULL DEFAULT 0,
|
||||||
|
-- The state of the room after the current set of latest events.
|
||||||
|
-- This will be 0 if there are no latest events in the room.
|
||||||
|
state_snapshot_nid BIGINT NOT NULL DEFAULT 0
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|
@ -30,16 +35,20 @@ const selectRoomNIDSQL = "" +
|
||||||
"SELECT room_nid FROM rooms WHERE room_id = $1"
|
"SELECT room_nid FROM rooms WHERE room_id = $1"
|
||||||
|
|
||||||
const selectLatestEventNIDsSQL = "" +
|
const selectLatestEventNIDsSQL = "" +
|
||||||
"SELECT latest_event_nids FROM rooms WHERE room_nid = $1 FOR UPDATE"
|
"SELECT latest_event_nids FROM rooms WHERE room_nid = $1"
|
||||||
|
|
||||||
|
const selectLatestEventNIDsForUpdateSQL = "" +
|
||||||
|
"SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM rooms WHERE room_nid = $1 FOR UPDATE"
|
||||||
|
|
||||||
const updateLatestEventNIDsSQL = "" +
|
const updateLatestEventNIDsSQL = "" +
|
||||||
"UPDATE rooms SET latest_event_nids = $2 WHERE room_nid = $1"
|
"UPDATE rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1"
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
insertRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
selectRoomNIDStmt *sql.Stmt
|
selectRoomNIDStmt *sql.Stmt
|
||||||
selectLatestEventNIDsStmt *sql.Stmt
|
selectLatestEventNIDsStmt *sql.Stmt
|
||||||
updateLatestEventNIDsStmt *sql.Stmt
|
selectLatestEventNIDsForUpdateStmt *sql.Stmt
|
||||||
|
updateLatestEventNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
@ -47,19 +56,13 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.insertRoomNIDStmt, err = db.Prepare(insertRoomNIDSQL); err != nil {
|
return statementList{
|
||||||
return
|
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
|
||||||
}
|
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
|
||||||
if s.selectRoomNIDStmt, err = db.Prepare(selectRoomNIDSQL); err != nil {
|
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
|
||||||
return
|
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
|
||||||
}
|
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||||
if s.selectLatestEventNIDsStmt, err = db.Prepare(selectLatestEventNIDsSQL); err != nil {
|
}.prepare(db)
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.updateLatestEventNIDsStmt, err = db.Prepare(updateLatestEventNIDsSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) insertRoomNID(roomID string) (types.RoomNID, error) {
|
func (s *roomStatements) insertRoomNID(roomID string) (types.RoomNID, error) {
|
||||||
|
|
@ -74,9 +77,9 @@ func (s *roomStatements) selectRoomNID(roomID string) (types.RoomNID, error) {
|
||||||
return types.RoomNID(roomNID), err
|
return types.RoomNID(roomNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, error) {
|
func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, error) {
|
||||||
var nids pq.Int64Array
|
var nids pq.Int64Array
|
||||||
err := txn.Stmt(s.selectLatestEventNIDsStmt).QueryRow(int64(roomNID)).Scan(&nids)
|
err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -87,11 +90,29 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty
|
||||||
return eventNIDs, nil
|
return eventNIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) updateLatestEventNIDs(txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID) error {
|
func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) (
|
||||||
nids := make([]int64, len(eventNIDs))
|
[]types.EventNID, types.EventNID, types.StateSnapshotNID, error,
|
||||||
for i := range eventNIDs {
|
) {
|
||||||
nids[i] = int64(eventNIDs[i])
|
var nids pq.Int64Array
|
||||||
|
var lastEventSentNID int64
|
||||||
|
var stateSnapshotNID int64
|
||||||
|
err := txn.Stmt(s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
_, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec(roomNID, pq.Int64Array(nids))
|
eventNIDs := make([]types.EventNID, len(nids))
|
||||||
|
for i := range nids {
|
||||||
|
eventNIDs[i] = types.EventNID(nids[i])
|
||||||
|
}
|
||||||
|
return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) updateLatestEventNIDs(
|
||||||
|
txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID,
|
||||||
|
stateSnapshotNID types.StateSnapshotNID,
|
||||||
|
) error {
|
||||||
|
_, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec(
|
||||||
|
roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID),
|
||||||
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -57,16 +57,12 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.insertStateDataStmt, err = db.Prepare(insertStateDataSQL); err != nil {
|
|
||||||
return
|
return statementList{
|
||||||
}
|
{&s.insertStateDataStmt, insertStateDataSQL},
|
||||||
if s.selectNextStateBlockNIDStmt, err = db.Prepare(selectNextStateBlockNIDSQL); err != nil {
|
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
|
||||||
return
|
{&s.bulkSelectStateDataEntriesStmt, bulkSelectStateDataEntriesSQL},
|
||||||
}
|
}.prepare(db)
|
||||||
if s.bulkSelectStateDataEntriesStmt, err = db.Prepare(bulkSelectStateDataEntriesSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBlockNID, entries []types.StateEntry) error {
|
func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBlockNID, entries []types.StateEntry) error {
|
||||||
|
|
|
||||||
|
|
@ -52,13 +52,11 @@ func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.insertStateStmt, err = db.Prepare(insertStateSQL); err != nil {
|
|
||||||
return
|
return statementList{
|
||||||
}
|
{&s.insertStateStmt, insertStateSQL},
|
||||||
if s.bulkSelectStateBlockNIDsStmt, err = db.Prepare(bulkSelectStateBlockNIDsSQL); err != nil {
|
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
|
||||||
return
|
}.prepare(db)
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) insertState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) {
|
func (s *stateSnapshotStatements) insertState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) {
|
||||||
|
|
|
||||||
|
|
@ -205,30 +205,62 @@ func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.S
|
||||||
return d.statements.bulkSelectStateDataEntries(stateBlockNIDs)
|
return d.statements.bulkSelectStateDataEntries(stateBlockNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EventIDs implements input.RoomEventDatabase
|
||||||
|
func (d *Database) EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||||
|
return d.statements.bulkSelectEventID(eventNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
// GetLatestEventsForUpdate implements input.EventDatabase
|
// GetLatestEventsForUpdate implements input.EventDatabase
|
||||||
func (d *Database) GetLatestEventsForUpdate(roomNID types.RoomNID) ([]types.StateAtEventAndReference, types.RoomRecentEventsUpdater, error) {
|
func (d *Database) GetLatestEventsForUpdate(roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) {
|
||||||
txn, err := d.db.Begin()
|
txn, err := d.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventNIDs, err := d.statements.selectLatestEventsNIDsForUpdate(txn, roomNID)
|
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := d.statements.selectLatestEventsNIDsForUpdate(txn, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
txn.Rollback()
|
txn.Rollback()
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(txn, eventNIDs)
|
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(txn, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
txn.Rollback()
|
txn.Rollback()
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return stateAndRefs, &roomRecentEventsUpdater{txn, d}, nil
|
var lastEventIDSent string
|
||||||
|
if lastEventNIDSent != 0 {
|
||||||
|
lastEventIDSent, err = d.statements.selectEventID(txn, lastEventNIDSent)
|
||||||
|
if err != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &roomRecentEventsUpdater{txn, d, stateAndRefs, lastEventIDSent, currentStateSnapshotNID}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type roomRecentEventsUpdater struct {
|
type roomRecentEventsUpdater struct {
|
||||||
txn *sql.Tx
|
txn *sql.Tx
|
||||||
d *Database
|
d *Database
|
||||||
|
latestEvents []types.StateAtEventAndReference
|
||||||
|
lastEventIDSent string
|
||||||
|
currentStateSnapshotNID types.StateSnapshotNID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LatestEvents implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
|
||||||
|
return u.latestEvents
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastEventIDSent implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *roomRecentEventsUpdater) LastEventIDSent() string {
|
||||||
|
return u.lastEventIDSent
|
||||||
|
}
|
||||||
|
|
||||||
|
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
||||||
|
return u.currentStateSnapshotNID
|
||||||
|
}
|
||||||
|
|
||||||
|
// StorePreviousEvents implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||||
for _, ref := range previousEventReferences {
|
for _, ref := range previousEventReferences {
|
||||||
if err := u.d.statements.insertPreviousEvent(u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
if err := u.d.statements.insertPreviousEvent(u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||||
|
|
@ -238,6 +270,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
||||||
err := u.d.statements.selectPreviousEventExists(u.txn, eventReference.EventID, eventReference.EventSHA256)
|
err := u.d.statements.selectPreviousEventExists(u.txn, eventReference.EventID, eventReference.EventSHA256)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
@ -249,18 +282,52 @@ func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *roomRecentEventsUpdater) SetLatestEvents(roomNID types.RoomNID, latest []types.StateAtEventAndReference) error {
|
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *roomRecentEventsUpdater) SetLatestEvents(
|
||||||
|
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
|
||||||
|
currentStateSnapshotNID types.StateSnapshotNID,
|
||||||
|
) error {
|
||||||
eventNIDs := make([]types.EventNID, len(latest))
|
eventNIDs := make([]types.EventNID, len(latest))
|
||||||
for i := range latest {
|
for i := range latest {
|
||||||
eventNIDs[i] = latest[i].EventNID
|
eventNIDs[i] = latest[i].EventNID
|
||||||
}
|
}
|
||||||
return u.d.statements.updateLatestEventNIDs(u.txn, roomNID, eventNIDs)
|
return u.d.statements.updateLatestEventNIDs(u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
||||||
|
return u.d.statements.selectEventSentToOutput(u.txn, eventNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||||
|
return u.d.statements.updateEventSentToOutput(u.txn, eventNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) Commit() error {
|
func (u *roomRecentEventsUpdater) Commit() error {
|
||||||
return u.txn.Commit()
|
return u.txn.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rollback implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) Rollback() error {
|
func (u *roomRecentEventsUpdater) Rollback() error {
|
||||||
return u.txn.Rollback()
|
return u.txn.Rollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoomNID implements query.RoomserverQueryAPIDB
|
||||||
|
func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
|
||||||
|
roomNID, err := d.statements.selectRoomNID(roomID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return roomNID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// LatestEventIDs implements query.RoomserverQueryAPIDB
|
||||||
|
func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, error) {
|
||||||
|
eventNIDs, err := d.statements.selectLatestEventNIDs(roomNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.statements.bulkSelectEventReference(eventNIDs)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -133,6 +133,12 @@ type StateEntryList struct {
|
||||||
// (On postgresql this wraps a database transaction that holds a "FOR UPDATE"
|
// (On postgresql this wraps a database transaction that holds a "FOR UPDATE"
|
||||||
// lock on the row holding the latest events for the room.)
|
// lock on the row holding the latest events for the room.)
|
||||||
type RoomRecentEventsUpdater interface {
|
type RoomRecentEventsUpdater interface {
|
||||||
|
// The latest event IDs and state in the room.
|
||||||
|
LatestEvents() []StateAtEventAndReference
|
||||||
|
// The event ID of the latest event written to the output log in the room.
|
||||||
|
LastEventIDSent() string
|
||||||
|
// The current state of the room.
|
||||||
|
CurrentStateSnapshotNID() StateSnapshotNID
|
||||||
// Store the previous events referenced by an event.
|
// Store the previous events referenced by an event.
|
||||||
// This adds the event NID to an entry in the database for each of the previous events.
|
// This adds the event NID to an entry in the database for each of the previous events.
|
||||||
// If there isn't an entry for one of previous events then an entry is created.
|
// If there isn't an entry for one of previous events then an entry is created.
|
||||||
|
|
@ -143,7 +149,14 @@ type RoomRecentEventsUpdater interface {
|
||||||
IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error)
|
IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error)
|
||||||
// Set the list of latest events for the room.
|
// Set the list of latest events for the room.
|
||||||
// This replaces the current list stored in the database with the given list
|
// This replaces the current list stored in the database with the given list
|
||||||
SetLatestEvents(roomNID RoomNID, latest []StateAtEventAndReference) error
|
SetLatestEvents(
|
||||||
|
roomNID RoomNID, latest []StateAtEventAndReference, lastEventNIDSent EventNID,
|
||||||
|
currentStateSnapshotNID StateSnapshotNID,
|
||||||
|
) error
|
||||||
|
// Check if the event has already be written to the output logs.
|
||||||
|
HasEventBeenSent(eventNID EventNID) (bool, error)
|
||||||
|
// Mark the event as having been sent to the output logs.
|
||||||
|
MarkEventAsSent(eventNID EventNID) error
|
||||||
// Commit the transaction
|
// Commit the transaction
|
||||||
Commit() error
|
Commit() error
|
||||||
// Rollback the transaction.
|
// Rollback the transaction.
|
||||||
|
|
|
||||||
22
travis-install-kafka.sh
Executable file
22
travis-install-kafka.sh
Executable file
|
|
@ -0,0 +1,22 @@
|
||||||
|
# /bin/bash
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
# The mirror to download kafka from is picked from the list of mirrors at
|
||||||
|
# https://www.apache.org/dyn/closer.cgi?path=/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz
|
||||||
|
# TODO: Check the signature since we are downloading over HTTP.
|
||||||
|
MIRROR=http://mirror.ox.ac.uk/sites/rsync.apache.org/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz
|
||||||
|
|
||||||
|
# Only download the kafka if it isn't already downloaded.
|
||||||
|
test -f kafka.tgz || wget $MIRROR -O kafka.tgz
|
||||||
|
# Unpack the kafka over the top of any existing installation
|
||||||
|
mkdir -p kafka && tar xzf kafka.tgz -C kafka --strip-components 1
|
||||||
|
# Start the zookeeper running in the background.
|
||||||
|
# By default the zookeeper listens on localhost:2181
|
||||||
|
kafka/bin/zookeeper-server-start.sh -daemon kafka/config/zookeeper.properties
|
||||||
|
# Enable topic deletion so that the integration tests can create a fresh topic
|
||||||
|
# for each test run.
|
||||||
|
echo "delete.topic.enable=true" >> kafka/config/server.properties
|
||||||
|
# Start the kafka server running in the background.
|
||||||
|
# By default the kafka listens on localhost:9092
|
||||||
|
kafka/bin/kafka-server-start.sh -daemon kafka/config/server.properties
|
||||||
13
travis-test.sh
Executable file
13
travis-test.sh
Executable file
|
|
@ -0,0 +1,13 @@
|
||||||
|
#! /bin/bash
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
# Check that the servers build
|
||||||
|
gb build github.com/matrix-org/dendrite/roomserver/roomserver
|
||||||
|
gb build github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests
|
||||||
|
|
||||||
|
# Run the pre commit hooks
|
||||||
|
./hooks/pre-commit
|
||||||
|
|
||||||
|
# Run the integration tests
|
||||||
|
bin/roomserver-integration-tests
|
||||||
4
vendor/manifest
vendored
4
vendor/manifest
vendored
|
|
@ -98,7 +98,7 @@
|
||||||
{
|
{
|
||||||
"importpath": "github.com/matrix-org/util",
|
"importpath": "github.com/matrix-org/util",
|
||||||
"repository": "https://github.com/matrix-org/util",
|
"repository": "https://github.com/matrix-org/util",
|
||||||
"revision": "ccef6dc7c24a7c896d96b433a9107b7c47ecf828",
|
"revision": "28bd7491c8aafbf346ca23821664f0f9911ef52b",
|
||||||
"branch": "master"
|
"branch": "master"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -206,4 +206,4 @@
|
||||||
"branch": "master"
|
"branch": "master"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -25,11 +25,13 @@ func GetRequestID(ctx context.Context) string {
|
||||||
// ctxValueLogger is the key to extract the logrus Logger.
|
// ctxValueLogger is the key to extract the logrus Logger.
|
||||||
const ctxValueLogger = contextKeys("logger")
|
const ctxValueLogger = contextKeys("logger")
|
||||||
|
|
||||||
// GetLogger retrieves the logrus logger from the supplied context. Returns nil if there is no logger.
|
// GetLogger retrieves the logrus logger from the supplied context. Always returns a logger,
|
||||||
|
// even if there wasn't one originally supplied.
|
||||||
func GetLogger(ctx context.Context) *log.Entry {
|
func GetLogger(ctx context.Context) *log.Entry {
|
||||||
l := ctx.Value(ctxValueLogger)
|
l := ctx.Value(ctxValueLogger)
|
||||||
if l == nil {
|
if l == nil {
|
||||||
return nil
|
// Always return a logger so callers don't need to constantly nil check.
|
||||||
|
return log.WithField("context", "missing")
|
||||||
}
|
}
|
||||||
return l.(*log.Entry)
|
return l.(*log.Entry)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
15
vendor/src/github.com/matrix-org/util/json.go
vendored
15
vendor/src/github.com/matrix-org/util/json.go
vendored
|
|
@ -58,6 +58,21 @@ type JSONRequestHandler interface {
|
||||||
OnIncomingRequest(req *http.Request) JSONResponse
|
OnIncomingRequest(req *http.Request) JSONResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// jsonRequestHandlerWrapper is a wrapper to allow in-line functions to conform to util.JSONRequestHandler
|
||||||
|
type jsonRequestHandlerWrapper struct {
|
||||||
|
function func(req *http.Request) JSONResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnIncomingRequest implements util.JSONRequestHandler
|
||||||
|
func (r *jsonRequestHandlerWrapper) OnIncomingRequest(req *http.Request) JSONResponse {
|
||||||
|
return r.function(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJSONRequestHandler converts the given OnIncomingRequest function into a JSONRequestHandler
|
||||||
|
func NewJSONRequestHandler(f func(req *http.Request) JSONResponse) JSONRequestHandler {
|
||||||
|
return &jsonRequestHandlerWrapper{f}
|
||||||
|
}
|
||||||
|
|
||||||
// Protect panicking HTTP requests from taking down the entire process, and log them using
|
// Protect panicking HTTP requests from taking down the entire process, and log them using
|
||||||
// the correct logger, returning a 500 with a JSON response rather than abruptly closing the
|
// the correct logger, returning a 500 with a JSON response rather than abruptly closing the
|
||||||
// connection. The http.Request MUST have a ctxValueLogger.
|
// connection. The http.Request MUST have a ctxValueLogger.
|
||||||
|
|
|
||||||
|
|
@ -164,8 +164,8 @@ func TestGetLogger(t *testing.T) {
|
||||||
|
|
||||||
noLoggerInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
|
noLoggerInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
|
||||||
ctxLogger = GetLogger(noLoggerInReq.Context())
|
ctxLogger = GetLogger(noLoggerInReq.Context())
|
||||||
if ctxLogger != nil {
|
if ctxLogger == nil {
|
||||||
t.Errorf("TestGetLogger wanted nil logger, got '%v'", ctxLogger)
|
t.Errorf("TestGetLogger wanted logger, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue