diff --git a/src/github.com/matrix-org/dendrite/clientapi/storage/output_room_events_table.go b/src/github.com/matrix-org/dendrite/clientapi/storage/output_room_events_table.go index dbbc7b228..26cf1cdb8 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/storage/output_room_events_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/storage/output_room_events_table.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsSchema = ` diff --git a/src/github.com/matrix-org/dendrite/clientapi/storage/syncserver.go b/src/github.com/matrix-org/dendrite/clientapi/storage/syncserver.go index e1cc82421..6a85f703a 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/storage/syncserver.go +++ b/src/github.com/matrix-org/dendrite/clientapi/storage/syncserver.go @@ -13,6 +13,7 @@ type SyncServerDatabase struct { db *sql.DB partitions common.PartitionOffsetStatements events outputRoomEventsStatements + roomstate currentRoomStateStatements } // NewSyncServerDatabase creates a new sync server database @@ -30,7 +31,11 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { if err = events.prepare(db); err != nil { return nil, err } - return &SyncServerDatabase{db, partitions, events}, nil + state := currentRoomStateStatements{} + if err := state.prepare(db); err != nil { + return nil, err + } + return &SyncServerDatabase{db, partitions, events, state}, nil } // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races @@ -69,3 +74,22 @@ func (d *SyncServerDatabase) PartitionOffsets(topic string) ([]common.PartitionO func (d *SyncServerDatabase) SetPartitionOffset(topic string, partition int32, offset int64) error { return d.partitions.UpsertPartitionOffset(topic, partition, offset) } + +func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { + txn, err := db.Begin() + if err != nil { + return + } + defer func() { + if r := recover(); r != nil { + txn.Rollback() + panic(r) + } else if err != nil { + txn.Rollback() + } else { + err = txn.Commit() + } + }() + err = fn(txn) + return +}