Add a component for sending event to remote matrix servers using federation

This commit is contained in:
Mark Haines 2017-06-27 15:52:56 +01:00
parent 54e7e3041b
commit 104551ecd0
8 changed files with 916 additions and 5 deletions

View file

@ -0,0 +1,74 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"flag"
"net/http"
"os"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/federationsender/consumers"
"github.com/matrix-org/dendrite/federationsender/queue"
"github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
log "github.com/Sirupsen/logrus"
)
var configPath = flag.String("config", "dendrite.yaml", "The path to the config file. For more information, see the config file in this repository.")
func main() {
common.SetupLogging(os.Getenv("LOG_DIR"))
flag.Parse()
if *configPath == "" {
log.Fatal("--config must be supplied")
}
cfg, err := config.Load(*configPath)
if err != nil {
log.Fatalf("Invalid config file: %s", err)
}
log.Info("config: ", cfg)
db, err := storage.NewDatabase(string(cfg.Database.FederationSender))
if err != nil {
log.Panicf("startup: failed to create federation sender database with data source %s : %s", cfg.Database.FederationSender, err)
}
federation := gomatrixserverlib.NewFederationClient(
cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey,
)
queues := queue.NewOutgoingQueues(cfg.Matrix.ServerName, federation)
consumer, err := consumers.NewOutputRoomEvent(cfg, queues, db)
if err != nil {
log.WithError(err).Panicf("startup: failed to create room server consumer")
}
if err = consumer.Start(); err != nil {
log.WithError(err).Panicf("startup: failed to start room server consumer")
}
http.DefaultServeMux.Handle("/metrics", prometheus.Handler())
if err := http.ListenAndServe(string(cfg.Listen.FederationSender), nil); err != nil {
panic(err)
}
}

View file

@ -122,16 +122,20 @@ type Dendrite struct {
// The RoomServer database stores information about matrix rooms.
// It is only accessed by the RoomServer.
RoomServer DataSource `yaml:"room_server"`
// The FederationSender database stores information used by the FederationSender
// It is only accessed by the FederationSender.
FederationSender DataSource `yaml:"federation_sender"`
} `yaml:"database"`
// The internal addresses the components will listen on.
// These should not be exposed externally as they expose metrics and debugging APIs.
Listen struct {
MediaAPI Address `yaml:"media_api"`
ClientAPI Address `yaml:"client_api"`
FederationAPI Address `yaml:"federation_api"`
SyncAPI Address `yaml:"sync_api"`
RoomServer Address `yaml:"room_server"`
MediaAPI Address `yaml:"media_api"`
ClientAPI Address `yaml:"client_api"`
FederationAPI Address `yaml:"federation_api"`
SyncAPI Address `yaml:"sync_api"`
RoomServer Address `yaml:"room_server"`
FederationSender Address `yaml:"federation_sender"`
} `yaml:"listen"`
}

View file

@ -0,0 +1,355 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"encoding/json"
"fmt"
"strings"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/federationsender/queue"
"github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
sarama "gopkg.in/Shopify/sarama.v1"
)
// OutputRoomEvent consumes events that originated in the room server.
type OutputRoomEvent struct {
roomServerConsumer *common.ContinualConsumer
db *storage.Database
queues *queue.OutgoingQueues
query api.RoomserverQueryAPI
}
// NewOutputRoomEvent creates a new OutputRoomEvent consumer. Call Start() to begin consuming from room servers.
func NewOutputRoomEvent(cfg *config.Dendrite, queues *queue.OutgoingQueues, store *storage.Database) (*OutputRoomEvent, error) {
kafkaConsumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
if err != nil {
return nil, err
}
roomServerURL := cfg.RoomServerURL()
consumer := common.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputRoomEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
}
s := &OutputRoomEvent{
roomServerConsumer: &consumer,
db: store,
queues: queues,
query: api.NewRoomserverQueryAPIHTTP(roomServerURL, nil),
}
consumer.ProcessMessage = s.onMessage
return s, nil
}
// Start consuming from room servers
func (s *OutputRoomEvent) Start() error {
return s.roomServerConsumer.Start()
}
// onMessage is called when the sync server receives a new event from the room server output log.
// It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated.
func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
// Parse out the event JSON
var output api.OutputRoomEvent
if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure")
return nil
}
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(output.Event, false)
if err != nil {
log.WithError(err).Errorf("roomserver output log: event parse failure")
return nil
}
log.WithFields(log.Fields{
"event_id": ev.EventID(),
"room_id": ev.RoomID(),
"send_as_server": output.SendAsServer,
}).Info("received event from roomserver")
err = s.processMessage(output, ev)
if err != nil {
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
"event": string(ev.JSON()),
log.ErrorKey: err,
"add": output.AddsStateEventIDs,
"del": output.RemovesStateEventIDs,
}).Panicf("roomserver output log: write event failure")
return nil
}
return nil
}
func (s *OutputRoomEvent) processMessage(ore api.OutputRoomEvent, ev gomatrixserverlib.Event) error {
addsStateEvents, err := s.lookupStateEvents(ore.AddsStateEventIDs, ev)
if err != nil {
return err
}
addsJoinedHosts, err := joinedHostsFromEvents(addsStateEvents)
if err != nil {
return err
}
// Update our copy of the current state.
// We keep a copy of the current state because the state at each event is
// expressed as a delta against the current state.
// TODO: handle EventIDMismatchError and recover the current state by talking
// to the roomserver
oldJoinedHosts, err := s.db.UpdateRoom(
ev.RoomID(), ore.LastSentEventID, ev.EventID(),
addsJoinedHosts, ore.RemovesStateEventIDs,
)
if err != nil {
return err
}
if ore.SendAsServer == "" {
// Ignore event that we don't need to send anywhere.
return nil
}
joinedHosts, err := s.joinedHostsAtEvent(ore, ev, oldJoinedHosts)
if err != nil {
return err
}
if err = s.queues.SendEvent(
&ev, gomatrixserverlib.ServerName(ore.SendAsServer), joinedHosts,
); err != nil {
return err
}
// TODO: Add the event to the transaction queue
return nil
}
// joinedHostsAtEvent works out a list of matrix servers that were joined to
// the room at the event.
// It is important to use the state at the event for sending messages because:
// 1) We shouldn't send messages to servers that weren't in the room.
// 2) If a server is kicked from the rooms it should still be told about the
// kick event,
// Usually the list can be calculated locally, but sometimes it will need fetch
// events from the room server.
// Returns an error if there was a problem talking to the room server.
func (s *OutputRoomEvent) joinedHostsAtEvent(
ore api.OutputRoomEvent, ev gomatrixserverlib.Event, oldJoinedHosts []types.JoinedHost,
) ([]gomatrixserverlib.ServerName, error) {
combinedAdds, combinedRemoves := combineDeltas(
ore.AddsStateEventIDs, ore.RemovesStateEventIDs,
ore.StateBeforeAddsEventIDs, ore.StateBeforeRemovesEventIDs,
)
combinedAddsEvents, err := s.lookupStateEvents(combinedAdds, ev)
if err != nil {
return nil, err
}
combinedAddsJoinedHosts, err := joinedHostsFromEvents(combinedAddsEvents)
if err != nil {
return nil, err
}
removed := map[string]bool{}
for _, eventID := range combinedRemoves {
removed[eventID] = true
}
joined := map[gomatrixserverlib.ServerName]bool{}
for _, joinedHost := range oldJoinedHosts {
if removed[joinedHost.EventID] {
// This m.room.member event is part of the current state of the
// room, but not part of the state at the event we are processing
// Therefore we can't use it to tell whether the server was in
// the room at the event.
continue
}
joined[joinedHost.ServerName] = true
}
for _, joinedHost := range combinedAddsJoinedHosts {
// This m.room.member event was part of the state of the room at the
// event, but isn't part of the current state of the room now.
joined[joinedHost.ServerName] = true
}
var result []gomatrixserverlib.ServerName
for serverName, include := range joined {
if include {
result = append(result, serverName)
}
}
return result, nil
}
// joinedHostsFromEvents turns a list of state events into a list of joined hosts.
// This errors if one of the events was invalid.
// It should be impossible for an invalid event to get this far in the pipeline.
func joinedHostsFromEvents(evs []gomatrixserverlib.Event) ([]types.JoinedHost, error) {
var joinedHosts []types.JoinedHost
for _, ev := range evs {
if ev.Type() != "m.room.member" || ev.StateKey() == nil {
continue
}
var content struct {
Membership string `json:"membership"`
}
if err := json.Unmarshal(ev.Content(), &content); err != nil {
return nil, err
}
if content.Membership != "join" {
continue
}
serverName, err := domainFromID(*ev.StateKey())
if err != nil {
return nil, err
}
joinedHosts = append(joinedHosts, types.JoinedHost{
EventID: ev.EventID(), ServerName: serverName,
})
}
return joinedHosts, nil
}
// combineDeltas combines two deltas into a single delta.
func combineDeltas(adds1, removes1, adds2, removes2 []string) (adds, removes []string) {
addSet := map[string]bool{}
removeSet := map[string]bool{}
var ok bool
for _, value := range adds1 {
addSet[value] = true
}
for _, value := range removes1 {
removeSet[value] = true
}
for _, value := range adds2 {
if _, ok = removeSet[value]; ok {
removeSet[value] = false
} else {
addSet[value] = true
}
}
for _, value := range removes2 {
if _, ok = addSet[value]; ok {
addSet[value] = false
} else {
removeSet[value] = true
}
}
for value, include := range addSet {
if include {
adds = append(adds, value)
}
}
for value, include := range removeSet {
if include {
removes = append(removes, value)
}
}
return
}
// lookupStateEvents looks up the state events that are added by a new event.
func (s *OutputRoomEvent) lookupStateEvents(
addsStateEventIDs []string, event gomatrixserverlib.Event,
) ([]gomatrixserverlib.Event, error) {
// Fast path if there aren't any new state events.
if len(addsStateEventIDs) == 0 {
return nil, nil
}
// Fast path if the only state event added is the event itself.
if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() {
return []gomatrixserverlib.Event{event}, nil
}
missing := addsStateEventIDs
var result []gomatrixserverlib.Event
// Check if event itself is being added.
for _, eventID := range missing {
if eventID == event.EventID() {
result = append(result, event)
break
}
}
missing = missingEventsFrom(result, addsStateEventIDs)
if len(missing) == 0 {
return result, nil
}
// At this point the missing events are neither the event itself nor are
// they present in our local database. Our only option is to fetch them
// from the roomserver using the query API.
eventReq := api.QueryEventsByIDRequest{EventIDs: missing}
var eventResp api.QueryEventsByIDResponse
if err := s.query.QueryEventsByID(&eventReq, &eventResp); err != nil {
return nil, err
}
result = append(result, eventResp.Events...)
missing = missingEventsFrom(result, addsStateEventIDs)
if len(missing) != 0 {
return nil, fmt.Errorf(
"missing %d state events IDs at event %q", len(missing), event.EventID(),
)
}
return result, nil
}
func missingEventsFrom(events []gomatrixserverlib.Event, required []string) []string {
have := map[string]bool{}
for _, event := range events {
have[event.EventID()] = true
}
var missing []string
for _, eventID := range required {
if !have[eventID] {
missing = append(missing, eventID)
}
}
return missing
}
// domainFromID returns everything after the first ":" character to extract
// the domain part of a matrix ID.
// TODO: duplicated from gomatrixserverlib.
func domainFromID(id string) (gomatrixserverlib.ServerName, error) {
// IDs have the format: SIGIL LOCALPART ":" DOMAIN
// Split on the first ":" character since the domain can contain ":"
// characters.
parts := strings.SplitN(id, ":", 2)
if len(parts) != 2 {
// The ID must have a ":" character.
return "", fmt.Errorf("invalid ID: %q", id)
}
// Return everything after the first ":" character.
return gomatrixserverlib.ServerName(parts[1]), nil
}

View file

@ -0,0 +1,145 @@
package queue
import (
"fmt"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/gomatrixserverlib"
"sync"
"time"
)
// OutgoingQueues is a collection of queues for sending transactions to other
// matrix servers
type OutgoingQueues struct {
mutex sync.Mutex
queues map[gomatrixserverlib.ServerName]*outgoingQueue
origin gomatrixserverlib.ServerName
client *gomatrixserverlib.FederationClient
}
// NewOutgoingQueues makes a new OutgoingQueues
func NewOutgoingQueues(origin gomatrixserverlib.ServerName, client *gomatrixserverlib.FederationClient) *OutgoingQueues {
return &OutgoingQueues{
origin: origin,
client: client,
queues: map[gomatrixserverlib.ServerName]*outgoingQueue{},
}
}
// SendEvent sends an event to the destinations
func (oqs *OutgoingQueues) SendEvent(
ev *gomatrixserverlib.Event, origin gomatrixserverlib.ServerName,
destinations []gomatrixserverlib.ServerName,
) error {
if origin != oqs.origin {
return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q",
origin, oqs.origin,
)
}
// Remove our own server from the list of destinations.
destinations = filterDestinations(oqs.origin, destinations)
log.WithFields(log.Fields{
"destinations": destinations, "event": ev.EventID(),
}).Info("Sending event")
oqs.mutex.Lock()
defer oqs.mutex.Unlock()
for _, destination := range destinations {
if destination == oqs.origin {
continue
}
oq := oqs.queues[destination]
if oq == nil {
oq = &outgoingQueue{
origin: oqs.origin,
destination: destination,
client: oqs.client,
}
oqs.queues[destination] = oq
}
oq.sendEvent(ev)
}
return nil
}
func filterDestinations(origin gomatrixserverlib.ServerName, destinations []gomatrixserverlib.ServerName) []gomatrixserverlib.ServerName {
var result []gomatrixserverlib.ServerName
for _, destination := range destinations {
if destination == origin {
continue
}
result = append(result, destination)
}
return result
}
type outgoingQueue struct {
mutex sync.Mutex
client *gomatrixserverlib.FederationClient
origin gomatrixserverlib.ServerName
destination gomatrixserverlib.ServerName
running bool
sentCounter int
lastTransactionIDs []gomatrixserverlib.TransactionID
pendingEvents []*gomatrixserverlib.Event
}
func (oq *outgoingQueue) sendEvent(ev *gomatrixserverlib.Event) {
oq.mutex.Lock()
defer oq.mutex.Unlock()
oq.pendingEvents = append(oq.pendingEvents, ev)
if !oq.running {
go oq.backgroundSend()
}
}
func (oq *outgoingQueue) backgroundSend() {
for {
t := oq.next()
if t == nil {
// If the queue is empty then stop processing for this destination.
// TODO: Remove this destination from the queue map.
return
}
// TODO: handle retries.
// TODO: blacklist uncooperative servers.
_, err := oq.client.SendTransaction(*t)
if err != nil {
log.WithFields(log.Fields{
"destination": oq.destination,
log.ErrorKey: err,
}).Info("problem sending transaction")
}
}
}
func (oq *outgoingQueue) next() *gomatrixserverlib.Transaction {
oq.mutex.Lock()
defer oq.mutex.Unlock()
if len(oq.pendingEvents) == 0 {
oq.running = false
return nil
}
var t gomatrixserverlib.Transaction
now := gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.sentCounter))
t.Origin = oq.origin
t.Destination = oq.destination
t.OriginServerTS = now
t.PreviousIDs = oq.lastTransactionIDs
if t.PreviousIDs == nil {
t.PreviousIDs = []gomatrixserverlib.TransactionID{}
}
oq.lastTransactionIDs = []gomatrixserverlib.TransactionID{t.TransactionID}
for _, pdu := range oq.pendingEvents {
t.PDUs = append(t.PDUs, *pdu)
}
oq.pendingEvents = nil
oq.sentCounter += len(t.PDUs)
return &t
}

View file

@ -0,0 +1,109 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/gomatrixserverlib"
)
const joinedHostsSchema = `
-- The joined_hosts table stores a list of m.room.member event ids in the
-- current state for each room where the membership is "join".
CREATE TABLE IF NOT EXISTS joined_hosts (
-- The string ID of the room.
room_id TEXT NOT NULL,
-- The event ID of the m.room.member
event_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS joined_hosts_event_id_idx
ON joined_hosts (event_id);
CREATE INDEX IF NOT EXISTS joined_hosts_room_id_idx
ON joined_hosts (room_id)
`
const insertJoinedHostsSQL = "" +
"INSERT INTO joined_hosts (room_id, event_id, server_name)" +
" VALUES ($1, $2, $3)"
const deleteJoinedHostsSQL = "" +
"DELETE FROM joined_hosts WHERE event_id = ANY($1)"
const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM joined_hosts" +
" WHERE room_id = $1"
type joinedHostsStatements struct {
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
}
func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(joinedHostsSchema)
if err != nil {
return
}
if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
return
}
return
}
func (s *joinedHostsStatements) insertJoinedHosts(
txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName,
) error {
_, err := txn.Stmt(s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName)
return err
}
func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error {
_, err := txn.Stmt(s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs))
return err
}
func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) {
rows, err := txn.Stmt(s.selectJoinedHostsStmt).Query(roomID)
if err != nil {
return nil, err
}
defer rows.Close()
var result []types.JoinedHost
for rows.Next() {
var eventID, serverName string
if err = rows.Scan(&eventID, &serverName); err != nil {
return nil, err
}
result = append(result, types.JoinedHost{
EventID: eventID,
ServerName: gomatrixserverlib.ServerName(serverName),
})
}
return result, nil
}

View file

@ -0,0 +1,84 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"database/sql"
)
const roomSchema = `
CREATE TABLE IF NOT EXISTS rooms (
-- The string ID of the room
room_id TEXT NOT NULL CONSTRAINT room_id_unique UNIQUE,
-- The most recent event state by the room server.
-- We can use this to tell if our view of the room state has become
-- desynchronised.
last_event_id TEXT NOT NULL
);`
const insertRoomSQL = "" +
"INSERT INTO rooms (room_id, last_event_id)" +
" VALUES ($1, '')" +
" ON CONFLICT ON CONSTRAINT room_id_unique" +
" DO NOTHING"
const selectRoomForUpdateSQL = "" +
"SELECT last_event_id FROM rooms WHERE room_id = $1 FOR UPDATE"
const updateRoomSQL = "" +
"UPDATE rooms SET last_event_id = $2 WHERE room_id = $1"
type roomStatements struct {
insertRoomStmt *sql.Stmt
selectRoomForUpdateStmt *sql.Stmt
updateRoomStmt *sql.Stmt
}
func (s *roomStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(roomSchema)
if err != nil {
return
}
if s.insertRoomStmt, err = db.Prepare(insertRoomSQL); err != nil {
return
}
if s.selectRoomForUpdateStmt, err = db.Prepare(selectRoomForUpdateSQL); err != nil {
return
}
if s.updateRoomStmt, err = db.Prepare(updateRoomSQL); err != nil {
return
}
return
}
func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
_, err := txn.Stmt(s.insertRoomStmt).Exec(roomID)
return err
}
func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) {
var lastEventID string
err := txn.Stmt(s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID)
if err != nil {
return "", err
}
return lastEventID, nil
}
func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error {
_, err := txn.Stmt(s.updateRoomStmt).Exec(roomID, lastEventID)
return err
}

View file

@ -0,0 +1,110 @@
package storage
import (
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/types"
)
// Database stores information needed by the federation sender
type Database struct {
joinedHostsStatements
roomStatements
common.PartitionOffsetStatements
db *sql.DB
}
// NewDatabase opens a new database
func NewDatabase(dataSourceName string) (*Database, error) {
var result Database
var err error
if result.db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
return nil, err
}
return &result, nil
}
func (d *Database) prepare() error {
var err error
if err = d.joinedHostsStatements.prepare(d.db); err != nil {
return err
}
if err = d.roomStatements.prepare(d.db); err != nil {
return err
}
if err = d.PartitionOffsetStatements.Prepare(d.db); err != nil {
return err
}
return nil
}
// PartitionOffsets implements common.PartitionStorer
func (d *Database) PartitionOffsets(topic string) ([]common.PartitionOffset, error) {
return d.SelectPartitionOffsets(topic)
}
// SetPartitionOffset implements common.PartitionStorer
func (d *Database) SetPartitionOffset(topic string, partition int32, offset int64) error {
return d.UpsertPartitionOffset(topic, partition, offset)
}
// UpdateRoom updates the joined hosts for a room.
func (d *Database) UpdateRoom(
roomID, oldEventID, newEventID string,
addHosts []types.JoinedHost,
removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) {
err = runTransaction(d.db, func(txn *sql.Tx) error {
if err = d.insertRoom(txn, roomID); err != nil {
return err
}
lastSentEventID, err := d.selectRoomForUpdate(txn, roomID)
if err != nil {
return err
}
if lastSentEventID != oldEventID {
return types.EventIDMismatchError{lastSentEventID, oldEventID}
}
joinedHosts, err = d.selectJoinedHosts(txn, roomID)
if err != nil {
return err
}
for _, add := range addHosts {
err = d.insertJoinedHosts(txn, roomID, add.EventID, add.ServerName)
if err != nil {
return err
}
}
if err = d.deleteJoinedHosts(txn, removeHosts); err != nil {
return err
}
return d.updateRoom(txn, roomID, newEventID)
})
return
}
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}

View file

@ -0,0 +1,30 @@
package types
import (
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
// A JoinedHost is a server that is joined to a matrix room.
type JoinedHost struct {
// THe EventID of a m.room.member event that joins a server to a room.
EventID string
// The
ServerName gomatrixserverlib.ServerName
}
// A EventIDMismatchError indicates that we have got out of sync with the
// rooms erver.
type EventIDMismatchError struct {
// The event ID we have stored in our local database.
DatabaseID string
// The event ID received from the room server.
RoomServerID string
}
func (l EventIDMismatchError) Error() string {
return fmt.Sprintf(
"mismatched last sent event ID: had %q in database got %q from room server",
l.DatabaseID, l.RoomServerID,
)
}