mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Merge branch 'master' into kegan/fedsender-as-fedclient-subset
This commit is contained in:
commit
2331326392
|
|
@ -67,7 +67,7 @@ const (
|
||||||
|
|
||||||
type eventsStatements struct {
|
type eventsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
||||||
countEventsByApplicationServiceIDStmt *sql.Stmt
|
countEventsByApplicationServiceIDStmt *sql.Stmt
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ const selectTxnIDSQL = `
|
||||||
|
|
||||||
type txnStatements struct {
|
type txnStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
selectTxnIDStmt *sql.Stmt
|
selectTxnIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
|
@ -35,6 +36,10 @@ type sendEventResponse struct {
|
||||||
EventID string `json:"event_id"`
|
EventID string `json:"event_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
|
||||||
|
)
|
||||||
|
|
||||||
// SendEvent implements:
|
// SendEvent implements:
|
||||||
// /rooms/{roomID}/send/{eventType}
|
// /rooms/{roomID}/send/{eventType}
|
||||||
// /rooms/{roomID}/send/{eventType}/{txnID}
|
// /rooms/{roomID}/send/{eventType}/{txnID}
|
||||||
|
|
@ -63,6 +68,13 @@ func SendEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create a mutex for the specific user in the specific room
|
||||||
|
// this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order
|
||||||
|
userID := device.UserID
|
||||||
|
mutex, _ := userRoomSendMutexes.LoadOrStore(roomID+userID, &sync.Mutex{})
|
||||||
|
mutex.(*sync.Mutex).Lock()
|
||||||
|
defer mutex.(*sync.Mutex).Unlock()
|
||||||
|
|
||||||
e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI)
|
e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ const selectKnownUsersSQL = "" +
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -66,10 +66,16 @@ type ServerStatistics struct {
|
||||||
serverName gomatrixserverlib.ServerName //
|
serverName gomatrixserverlib.ServerName //
|
||||||
blacklisted atomic.Bool // is the node blacklisted
|
blacklisted atomic.Bool // is the node blacklisted
|
||||||
backoffStarted atomic.Bool // is the backoff started
|
backoffStarted atomic.Bool // is the backoff started
|
||||||
|
backoffUntil atomic.Value // time.Time until this backoff interval ends
|
||||||
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
|
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
|
||||||
successCounter atomic.Uint32 // how many times have we succeeded?
|
successCounter atomic.Uint32 // how many times have we succeeded?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// duration returns how long the next backoff interval should be.
|
||||||
|
func (s *ServerStatistics) duration(count uint32) time.Duration {
|
||||||
|
return time.Second * time.Duration(math.Exp2(float64(count)))
|
||||||
|
}
|
||||||
|
|
||||||
// Success updates the server statistics with a new successful
|
// Success updates the server statistics with a new successful
|
||||||
// attempt, which increases the sent counter and resets the idle and
|
// attempt, which increases the sent counter and resets the idle and
|
||||||
// failure counters. If a host was blacklisted at this point then
|
// failure counters. If a host was blacklisted at this point then
|
||||||
|
|
@ -88,11 +94,36 @@ func (s *ServerStatistics) Success() {
|
||||||
|
|
||||||
// Failure marks a failure and starts backing off if needed.
|
// Failure marks a failure and starts backing off if needed.
|
||||||
// The next call to BackoffIfRequired will do the right thing
|
// The next call to BackoffIfRequired will do the right thing
|
||||||
// after this.
|
// after this. It will return the time that the current failure
|
||||||
func (s *ServerStatistics) Failure() {
|
// will result in backoff waiting until, and a bool signalling
|
||||||
|
// whether we have blacklisted and therefore to give up.
|
||||||
|
func (s *ServerStatistics) Failure() (time.Time, bool) {
|
||||||
|
// If we aren't already backing off, this call will start
|
||||||
|
// a new backoff period. Reset the counter to 0 so that
|
||||||
|
// we backoff only for short periods of time to start with.
|
||||||
if s.backoffStarted.CAS(false, true) {
|
if s.backoffStarted.CAS(false, true) {
|
||||||
s.backoffCount.Store(0)
|
s.backoffCount.Store(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we have blacklisted this node.
|
||||||
|
if s.blacklisted.Load() {
|
||||||
|
return time.Now(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we're already backing off and we haven't yet surpassed
|
||||||
|
// the deadline then return that. Repeated calls to Failure
|
||||||
|
// within a single backoff interval will have no side effects.
|
||||||
|
if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) {
|
||||||
|
return until, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// We're either backing off and have passed the deadline, or
|
||||||
|
// we aren't backing off, so work out what the next interval
|
||||||
|
// will be.
|
||||||
|
count := s.backoffCount.Load()
|
||||||
|
until := time.Now().Add(s.duration(count))
|
||||||
|
s.backoffUntil.Store(until)
|
||||||
|
return until, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackoffIfRequired will block for as long as the current
|
// BackoffIfRequired will block for as long as the current
|
||||||
|
|
@ -103,21 +134,8 @@ func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt <
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the interrupt channel is already closed. If it is
|
|
||||||
// then this call should have no side effects but still return
|
|
||||||
// the current duration.
|
|
||||||
select {
|
|
||||||
case <-interrupt:
|
|
||||||
count := s.backoffCount.Load()
|
|
||||||
return time.Second * time.Duration(math.Exp2(float64(count))), false
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// Work out how many times we've backed off so far.
|
|
||||||
count := s.backoffCount.Inc()
|
|
||||||
duration := time.Second * time.Duration(math.Exp2(float64(count)))
|
|
||||||
|
|
||||||
// Work out if we should be blacklisting at this point.
|
// Work out if we should be blacklisting at this point.
|
||||||
|
count := s.backoffCount.Inc()
|
||||||
if count >= s.statistics.FailuresUntilBlacklist {
|
if count >= s.statistics.FailuresUntilBlacklist {
|
||||||
// We've exceeded the maximum amount of times we're willing
|
// We've exceeded the maximum amount of times we're willing
|
||||||
// to back off, which is probably in the region of hours by
|
// to back off, which is probably in the region of hours by
|
||||||
|
|
@ -129,9 +147,14 @@ func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt <
|
||||||
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
|
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return duration, true
|
return 0, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Work out when we should wait until.
|
||||||
|
duration := s.duration(count)
|
||||||
|
until := time.Now().Add(duration)
|
||||||
|
s.backoffUntil.Store(until)
|
||||||
|
|
||||||
// Notify the destination queue that we're backing off now.
|
// Notify the destination queue that we're backing off now.
|
||||||
backingOff.Store(true)
|
backingOff.Store(true)
|
||||||
defer backingOff.Store(false)
|
defer backingOff.Store(false)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import (
|
||||||
|
|
||||||
func TestBackoff(t *testing.T) {
|
func TestBackoff(t *testing.T) {
|
||||||
stats := Statistics{
|
stats := Statistics{
|
||||||
FailuresUntilBlacklist: 5,
|
FailuresUntilBlacklist: 7,
|
||||||
}
|
}
|
||||||
server := ServerStatistics{
|
server := ServerStatistics{
|
||||||
statistics: &stats,
|
statistics: &stats,
|
||||||
|
|
@ -44,10 +44,20 @@ func TestBackoff(t *testing.T) {
|
||||||
// Get the duration.
|
// Get the duration.
|
||||||
duration, blacklist := server.BackoffIfRequired(backingOff, interrupt)
|
duration, blacklist := server.BackoffIfRequired(backingOff, interrupt)
|
||||||
|
|
||||||
|
// Register another failure for good measure. This should have no
|
||||||
|
// side effects since a backoff is already in progress. If it does
|
||||||
|
// then we'll fail.
|
||||||
|
until, blacklisted := server.Failure()
|
||||||
|
if time.Until(until) > duration {
|
||||||
|
t.Fatal("Failure produced unexpected side effect when it shouldn't have")
|
||||||
|
}
|
||||||
|
|
||||||
// Check if we should be blacklisted by now.
|
// Check if we should be blacklisted by now.
|
||||||
if i > stats.FailuresUntilBlacklist {
|
if i >= stats.FailuresUntilBlacklist {
|
||||||
if !blacklist {
|
if !blacklist {
|
||||||
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
|
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
|
||||||
|
} else if blacklist != blacklisted {
|
||||||
|
t.Fatalf("BackoffIfRequired and Failure returned different blacklist values")
|
||||||
} else {
|
} else {
|
||||||
t.Logf("Backoff %d is blacklisted as expected", i)
|
t.Logf("Backoff %d is blacklisted as expected", i)
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,6 @@ const deleteBlacklistSQL = "" +
|
||||||
|
|
||||||
type blacklistStatements struct {
|
type blacklistStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertBlacklistStmt *sql.Stmt
|
insertBlacklistStmt *sql.Stmt
|
||||||
selectBlacklistStmt *sql.Stmt
|
selectBlacklistStmt *sql.Stmt
|
||||||
deleteBlacklistStmt *sql.Stmt
|
deleteBlacklistStmt *sql.Stmt
|
||||||
|
|
@ -50,8 +49,7 @@ type blacklistStatements struct {
|
||||||
|
|
||||||
func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
||||||
s = &blacklistStatements{
|
s = &blacklistStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err = db.Exec(blacklistSchema)
|
_, err = db.Exec(blacklistSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -75,11 +73,9 @@ func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
||||||
func (s *blacklistStatements) InsertBlacklist(
|
func (s *blacklistStatements) InsertBlacklist(
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
|
_, err := stmt.ExecContext(ctx, serverName)
|
||||||
_, err := stmt.ExecContext(ctx, serverName)
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
||||||
|
|
@ -105,9 +101,7 @@ func (s *blacklistStatements) SelectBlacklist(
|
||||||
func (s *blacklistStatements) DeleteBlacklist(
|
func (s *blacklistStatements) DeleteBlacklist(
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
|
_, err := stmt.ExecContext(ctx, serverName)
|
||||||
_, err := stmt.ExecContext(ctx, serverName)
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ const deleteBlacklistSQL = "" +
|
||||||
|
|
||||||
type blacklistStatements struct {
|
type blacklistStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertBlacklistStmt *sql.Stmt
|
insertBlacklistStmt *sql.Stmt
|
||||||
selectBlacklistStmt *sql.Stmt
|
selectBlacklistStmt *sql.Stmt
|
||||||
deleteBlacklistStmt *sql.Stmt
|
deleteBlacklistStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ const selectJoinedHostsForRoomsSQL = "" +
|
||||||
|
|
||||||
type joinedHostsStatements struct {
|
type joinedHostsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertJoinedHostsStmt *sql.Stmt
|
insertJoinedHostsStmt *sql.Stmt
|
||||||
deleteJoinedHostsStmt *sql.Stmt
|
deleteJoinedHostsStmt *sql.Stmt
|
||||||
selectJoinedHostsStmt *sql.Stmt
|
selectJoinedHostsStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ const selectQueueServerNamesSQL = "" +
|
||||||
|
|
||||||
type queueEDUsStatements struct {
|
type queueEDUsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertQueueEDUStmt *sql.Stmt
|
insertQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUStmt *sql.Stmt
|
selectQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ const selectJSONSQL = "" +
|
||||||
|
|
||||||
type queueJSONStatements struct {
|
type queueJSONStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertJSONStmt *sql.Stmt
|
insertJSONStmt *sql.Stmt
|
||||||
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ const selectQueuePDUsServerNamesSQL = "" +
|
||||||
|
|
||||||
type queuePDUsStatements struct {
|
type queuePDUsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertQueuePDUStmt *sql.Stmt
|
insertQueuePDUStmt *sql.Stmt
|
||||||
selectQueueNextTransactionIDStmt *sql.Stmt
|
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||||
selectQueuePDUsByTransactionStmt *sql.Stmt
|
selectQueuePDUsByTransactionStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ const updateRoomSQL = "" +
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertRoomStmt *sql.Stmt
|
insertRoomStmt *sql.Stmt
|
||||||
selectRoomForUpdateStmt *sql.Stmt
|
selectRoomForUpdateStmt *sql.Stmt
|
||||||
updateRoomStmt *sql.Stmt
|
updateRoomStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"go.uber.org/atomic"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrUserExists is returned if a username already exists in the database.
|
// ErrUserExists is returned if a username already exists in the database.
|
||||||
|
|
@ -52,7 +50,7 @@ func EndTransaction(txn Transaction, succeeded *bool) error {
|
||||||
func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||||
txn, err := db.Begin()
|
txn, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err)
|
||||||
}
|
}
|
||||||
succeeded := false
|
succeeded := false
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|
@ -106,69 +104,6 @@ func SQLiteDriverName() string {
|
||||||
return "sqlite3"
|
return "sqlite3"
|
||||||
}
|
}
|
||||||
|
|
||||||
// TransactionWriter allows queuing database writes so that you don't
|
type TransactionWriter interface {
|
||||||
// contend on database locks in, e.g. SQLite. Only one task will run
|
Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error
|
||||||
// at a time on a given TransactionWriter.
|
|
||||||
type TransactionWriter struct {
|
|
||||||
running atomic.Bool
|
|
||||||
todo chan transactionWriterTask
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTransactionWriter() *TransactionWriter {
|
|
||||||
return &TransactionWriter{
|
|
||||||
todo: make(chan transactionWriterTask),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// transactionWriterTask represents a specific task.
|
|
||||||
type transactionWriterTask struct {
|
|
||||||
db *sql.DB
|
|
||||||
txn *sql.Tx
|
|
||||||
f func(txn *sql.Tx) error
|
|
||||||
wait chan error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do queues a task to be run by a TransactionWriter. The function
|
|
||||||
// provided will be ran within a transaction as supplied by the
|
|
||||||
// txn parameter if one is supplied, and if not, will take out a
|
|
||||||
// new transaction from the database supplied in the database
|
|
||||||
// parameter. Either way, this will block until the task is done.
|
|
||||||
func (w *TransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
|
||||||
if w.todo == nil {
|
|
||||||
return errors.New("not initialised")
|
|
||||||
}
|
|
||||||
if !w.running.Load() {
|
|
||||||
go w.run()
|
|
||||||
}
|
|
||||||
task := transactionWriterTask{
|
|
||||||
db: db,
|
|
||||||
txn: txn,
|
|
||||||
f: f,
|
|
||||||
wait: make(chan error, 1),
|
|
||||||
}
|
|
||||||
w.todo <- task
|
|
||||||
return <-task.wait
|
|
||||||
}
|
|
||||||
|
|
||||||
// run processes the tasks for a given transaction writer. Only one
|
|
||||||
// of these goroutines will run at a time. A transaction will be
|
|
||||||
// opened using the database object from the task and then this will
|
|
||||||
// be passed as a parameter to the task function.
|
|
||||||
func (w *TransactionWriter) run() {
|
|
||||||
if !w.running.CAS(false, true) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer w.running.Store(false)
|
|
||||||
for task := range w.todo {
|
|
||||||
if task.txn != nil {
|
|
||||||
task.wait <- task.f(task.txn)
|
|
||||||
} else if task.db != nil {
|
|
||||||
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
|
|
||||||
return task.f(txn)
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
panic("expected database or transaction but got neither")
|
|
||||||
}
|
|
||||||
close(task.wait)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
22
internal/sqlutil/writer_dummy.go
Normal file
22
internal/sqlutil/writer_dummy.go
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
package sqlutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DummyTransactionWriter struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDummyTransactionWriter() TransactionWriter {
|
||||||
|
return &DummyTransactionWriter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DummyTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
||||||
|
if txn == nil {
|
||||||
|
return WithTransaction(db, func(txn *sql.Tx) error {
|
||||||
|
return f(txn)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
return f(txn)
|
||||||
|
}
|
||||||
|
}
|
||||||
75
internal/sqlutil/writer_exclusive.go
Normal file
75
internal/sqlutil/writer_exclusive.go
Normal file
|
|
@ -0,0 +1,75 @@
|
||||||
|
package sqlutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"go.uber.org/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExclusiveTransactionWriter allows queuing database writes so that you don't
|
||||||
|
// contend on database locks in, e.g. SQLite. Only one task will run
|
||||||
|
// at a time on a given ExclusiveTransactionWriter.
|
||||||
|
type ExclusiveTransactionWriter struct {
|
||||||
|
running atomic.Bool
|
||||||
|
todo chan transactionWriterTask
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTransactionWriter() TransactionWriter {
|
||||||
|
return &ExclusiveTransactionWriter{
|
||||||
|
todo: make(chan transactionWriterTask),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// transactionWriterTask represents a specific task.
|
||||||
|
type transactionWriterTask struct {
|
||||||
|
db *sql.DB
|
||||||
|
txn *sql.Tx
|
||||||
|
f func(txn *sql.Tx) error
|
||||||
|
wait chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do queues a task to be run by a TransactionWriter. The function
|
||||||
|
// provided will be ran within a transaction as supplied by the
|
||||||
|
// txn parameter if one is supplied, and if not, will take out a
|
||||||
|
// new transaction from the database supplied in the database
|
||||||
|
// parameter. Either way, this will block until the task is done.
|
||||||
|
func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
||||||
|
if w.todo == nil {
|
||||||
|
return errors.New("not initialised")
|
||||||
|
}
|
||||||
|
if !w.running.Load() {
|
||||||
|
go w.run()
|
||||||
|
}
|
||||||
|
task := transactionWriterTask{
|
||||||
|
db: db,
|
||||||
|
txn: txn,
|
||||||
|
f: f,
|
||||||
|
wait: make(chan error, 1),
|
||||||
|
}
|
||||||
|
w.todo <- task
|
||||||
|
return <-task.wait
|
||||||
|
}
|
||||||
|
|
||||||
|
// run processes the tasks for a given transaction writer. Only one
|
||||||
|
// of these goroutines will run at a time. A transaction will be
|
||||||
|
// opened using the database object from the task and then this will
|
||||||
|
// be passed as a parameter to the task function.
|
||||||
|
func (w *ExclusiveTransactionWriter) run() {
|
||||||
|
if !w.running.CAS(false, true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer w.running.Store(false)
|
||||||
|
for task := range w.todo {
|
||||||
|
if task.txn != nil {
|
||||||
|
task.wait <- task.f(task.txn)
|
||||||
|
} else if task.db != nil {
|
||||||
|
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
|
||||||
|
return task.f(txn)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
panic("expected database or transaction but got neither")
|
||||||
|
}
|
||||||
|
close(task.wait)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -63,7 +63,7 @@ const deleteAllDeviceKeysSQL = "" +
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ const selectKeyChangesSQL = "" +
|
||||||
|
|
||||||
type keyChangesStatements struct {
|
type keyChangesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
upsertKeyChangeStmt *sql.Stmt
|
upsertKeyChangeStmt *sql.Stmt
|
||||||
selectKeyChangesStmt *sql.Stmt
|
selectKeyChangesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ const selectKeyByAlgorithmSQL = "" +
|
||||||
|
|
||||||
type oneTimeKeysStatements struct {
|
type oneTimeKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
upsertKeysStmt *sql.Stmt
|
upsertKeysStmt *sql.Stmt
|
||||||
selectKeysStmt *sql.Stmt
|
selectKeysStmt *sql.Stmt
|
||||||
selectKeysCountStmt *sql.Stmt
|
selectKeysCountStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user
|
||||||
|
|
||||||
type mediaStatements struct {
|
type mediaStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertMediaStmt *sql.Stmt
|
insertMediaStmt *sql.Stmt
|
||||||
selectMediaStmt *sql.Stmt
|
selectMediaStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
|
@ -56,7 +57,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
||||||
) (err error) {
|
) (err error) {
|
||||||
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
|
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
|
||||||
}
|
}
|
||||||
succeeded := false
|
succeeded := false
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|
@ -78,7 +79,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = u.doUpdateLatestEvents(); err != nil {
|
if err = u.doUpdateLatestEvents(); err != nil {
|
||||||
return err
|
return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
succeeded = true
|
succeeded = true
|
||||||
|
|
@ -92,7 +93,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
||||||
type latestEventsUpdater struct {
|
type latestEventsUpdater struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
api *RoomserverInternalAPI
|
api *RoomserverInternalAPI
|
||||||
updater types.RoomRecentEventsUpdater
|
updater *shared.LatestEventsUpdater
|
||||||
roomNID types.RoomNID
|
roomNID types.RoomNID
|
||||||
stateAtEvent types.StateAtEvent
|
stateAtEvent types.StateAtEvent
|
||||||
event gomatrixserverlib.Event
|
event gomatrixserverlib.Event
|
||||||
|
|
@ -136,7 +137,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
// don't need to do anything, as we've handled it already.
|
// don't need to do anything, as we've handled it already.
|
||||||
hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID)
|
hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("u.updater.HasEventBeenSent: %w", err)
|
||||||
} else if hasBeenSent {
|
} else if hasBeenSent {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -144,7 +145,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
// Update the roomserver_previous_events table with references. This
|
// Update the roomserver_previous_events table with references. This
|
||||||
// is effectively tracking the structure of the DAG.
|
// is effectively tracking the structure of the DAG.
|
||||||
if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil {
|
if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil {
|
||||||
return err
|
return fmt.Errorf("u.updater.StorePreviousEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the event reference for our new event. This will be used when
|
// Get the event reference for our new event. This will be used when
|
||||||
|
|
@ -155,7 +156,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
// in the room. If it is then it isn't a latest event.
|
// in the room. If it is then it isn't a latest event.
|
||||||
alreadyReferenced, err := u.updater.IsReferenced(eventReference)
|
alreadyReferenced, err := u.updater.IsReferenced(eventReference)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("u.updater.IsReferenced: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Work out what the latest events are.
|
// Work out what the latest events are.
|
||||||
|
|
@ -172,19 +173,19 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
// Now that we know what the latest events are, it's time to get the
|
// Now that we know what the latest events are, it's time to get the
|
||||||
// latest state.
|
// latest state.
|
||||||
if err = u.latestState(); err != nil {
|
if err = u.latestState(); err != nil {
|
||||||
return err
|
return fmt.Errorf("u.latestState: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we need to generate any output events then here's where we do it.
|
// If we need to generate any output events then here's where we do it.
|
||||||
// TODO: Move this!
|
// TODO: Move this!
|
||||||
updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added)
|
updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("u.api.updateMemberships: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
update, err := u.makeOutputNewRoomEvent()
|
update, err := u.makeOutputNewRoomEvent()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
|
||||||
}
|
}
|
||||||
updates = append(updates, *update)
|
updates = append(updates, *update)
|
||||||
|
|
||||||
|
|
@ -197,14 +198,18 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
|
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
|
||||||
// necessary bookkeeping we'll keep the event sending synchronous for now.
|
// necessary bookkeeping we'll keep the event sending synchronous for now.
|
||||||
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
|
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
|
||||||
return err
|
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
|
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
|
||||||
return err
|
return fmt.Errorf("u.updater.SetLatestEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return u.updater.MarkEventAsSent(u.stateAtEvent.EventNID)
|
if err = u.updater.MarkEventAsSent(u.stateAtEvent.EventNID); err != nil {
|
||||||
|
return fmt.Errorf("u.updater.MarkEventAsSent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *latestEventsUpdater) latestState() error {
|
func (u *latestEventsUpdater) latestState() error {
|
||||||
|
|
@ -224,7 +229,7 @@ func (u *latestEventsUpdater) latestState() error {
|
||||||
u.ctx, u.roomNID, latestStateAtEvents,
|
u.ctx, u.roomNID, latestStateAtEvents,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we are overwriting the state then we should make sure that we
|
// If we are overwriting the state then we should make sure that we
|
||||||
|
|
@ -243,7 +248,7 @@ func (u *latestEventsUpdater) latestState() error {
|
||||||
u.ctx, u.oldStateNID, u.newStateNID,
|
u.ctx, u.oldStateNID, u.newStateNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also work out the state before the event removes and the event
|
// Also work out the state before the event removes and the event
|
||||||
|
|
@ -251,7 +256,11 @@ func (u *latestEventsUpdater) latestState() error {
|
||||||
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots(
|
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots(
|
||||||
u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
|
u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
|
||||||
)
|
)
|
||||||
return err
|
if err != nil {
|
||||||
|
return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateLatest(
|
func calculateLatest(
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
@ -29,7 +30,7 @@ import (
|
||||||
// consumers about the invites added or retired by the change in current state.
|
// consumers about the invites added or retired by the change in current state.
|
||||||
func (r *RoomserverInternalAPI) updateMemberships(
|
func (r *RoomserverInternalAPI) updateMemberships(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
updater types.RoomRecentEventsUpdater,
|
updater *shared.LatestEventsUpdater,
|
||||||
removed, added []types.StateEntry,
|
removed, added []types.StateEntry,
|
||||||
) ([]api.OutputEvent, error) {
|
) ([]api.OutputEvent, error) {
|
||||||
changes := membershipChanges(removed, added)
|
changes := membershipChanges(removed, added)
|
||||||
|
|
@ -77,7 +78,7 @@ func (r *RoomserverInternalAPI) updateMemberships(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) updateMembership(
|
func (r *RoomserverInternalAPI) updateMembership(
|
||||||
updater types.RoomRecentEventsUpdater,
|
updater *shared.LatestEventsUpdater,
|
||||||
targetUserNID types.EventStateKeyNID,
|
targetUserNID types.EventStateKeyNID,
|
||||||
remove, add *gomatrixserverlib.Event,
|
remove, add *gomatrixserverlib.Event,
|
||||||
updates []api.OutputEvent,
|
updates []api.OutputEvent,
|
||||||
|
|
@ -141,7 +142,7 @@ func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bo
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateToInviteMembership(
|
func updateToInviteMembership(
|
||||||
mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
|
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
|
||||||
roomVersion gomatrixserverlib.RoomVersion,
|
roomVersion gomatrixserverlib.RoomVersion,
|
||||||
) ([]api.OutputEvent, error) {
|
) ([]api.OutputEvent, error) {
|
||||||
// We may have already sent the invite to the user, either because we are
|
// We may have already sent the invite to the user, either because we are
|
||||||
|
|
@ -171,7 +172,7 @@ func updateToInviteMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateToJoinMembership(
|
func updateToJoinMembership(
|
||||||
mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
|
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
|
||||||
) ([]api.OutputEvent, error) {
|
) ([]api.OutputEvent, error) {
|
||||||
// If the user is already marked as being joined, we call SetToJoin to update
|
// If the user is already marked as being joined, we call SetToJoin to update
|
||||||
// the event ID then we can return immediately. Retired is ignored as there
|
// the event ID then we can return immediately. Retired is ignored as there
|
||||||
|
|
@ -207,7 +208,7 @@ func updateToJoinMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateToLeaveMembership(
|
func updateToLeaveMembership(
|
||||||
mu types.MembershipUpdater, add *gomatrixserverlib.Event,
|
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event,
|
||||||
newMembership string, updates []api.OutputEvent,
|
newMembership string, updates []api.OutputEvent,
|
||||||
) ([]api.OutputEvent, error) {
|
) ([]api.OutputEvent, error) {
|
||||||
// If the user is already neither joined, nor invited to the room then we
|
// If the user is already neither joined, nor invited to the room then we
|
||||||
|
|
|
||||||
|
|
@ -558,7 +558,11 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
|
||||||
// 2) There weren't any prev_events for this event so the state is
|
// 2) There weren't any prev_events for this event so the state is
|
||||||
// empty.
|
// empty.
|
||||||
metrics.algorithm = "empty_state"
|
metrics.algorithm = "empty_state"
|
||||||
return metrics.stop(v.db.AddState(ctx, roomNID, nil, nil))
|
stateNID, err := v.db.AddState(ctx, roomNID, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("v.db.AddState: %w", err)
|
||||||
|
}
|
||||||
|
return metrics.stop(stateNID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(prevStates) == 1 {
|
if len(prevStates) == 1 {
|
||||||
|
|
@ -578,22 +582,30 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
metrics.algorithm = "_load_state_blocks"
|
metrics.algorithm = "_load_state_blocks"
|
||||||
return metrics.stop(0, err)
|
return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err))
|
||||||
}
|
}
|
||||||
stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs
|
stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs
|
||||||
if len(stateBlockNIDs) < maxStateBlockNIDs {
|
if len(stateBlockNIDs) < maxStateBlockNIDs {
|
||||||
// 4) The number of state data blocks is small enough that we can just
|
// 4) The number of state data blocks is small enough that we can just
|
||||||
// add the state event as a block of size one to the end of the blocks.
|
// add the state event as a block of size one to the end of the blocks.
|
||||||
metrics.algorithm = "single_delta"
|
metrics.algorithm = "single_delta"
|
||||||
return metrics.stop(v.db.AddState(
|
stateNID, err := v.db.AddState(
|
||||||
ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
|
ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
|
||||||
))
|
)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("v.db.AddState: %w", err)
|
||||||
|
}
|
||||||
|
return metrics.stop(stateNID, err)
|
||||||
}
|
}
|
||||||
// If there are too many deltas then we need to calculate the full state
|
// If there are too many deltas then we need to calculate the full state
|
||||||
// So fall through to calculateAndStoreStateAfterManyEvents
|
// So fall through to calculateAndStoreStateAfterManyEvents
|
||||||
}
|
}
|
||||||
|
|
||||||
return v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics)
|
stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err)
|
||||||
|
}
|
||||||
|
return stateNID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
|
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
@ -86,7 +87,7 @@ type Database interface {
|
||||||
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
||||||
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
||||||
// If this returns an error then no further action is required.
|
// If this returns an error then no further action is required.
|
||||||
GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error)
|
GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (*shared.LatestEventsUpdater, error)
|
||||||
// Look up event ID by transaction's info.
|
// Look up event ID by transaction's info.
|
||||||
// This is used to determine if the room event is processed/processing already.
|
// This is used to determine if the room event is processed/processing already.
|
||||||
// Returns an empty string if no such event exists.
|
// Returns an empty string if no such event exists.
|
||||||
|
|
@ -123,7 +124,7 @@ type Database interface {
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
RemoveRoomAlias(ctx context.Context, alias string) error
|
RemoveRoomAlias(ctx context.Context, alias string) error
|
||||||
// Build a membership updater for the target user in a room.
|
// Build a membership updater for the target user in a room.
|
||||||
MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error)
|
MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (*shared.MembershipUpdater, error)
|
||||||
// Lookup the membership of a given user in a given room.
|
// Lookup the membership of a given user in a given room.
|
||||||
// Returns the numeric ID of the latest membership event sent from this user
|
// Returns the numeric ID of the latest membership event sent from this user
|
||||||
// in this room, along a boolean set to true if the user is still in this room,
|
// in this room, along a boolean set to true if the user is still in this room,
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
|
Writer: sqlutil.NewDummyTransactionWriter(),
|
||||||
EventTypesTable: eventTypes,
|
EventTypesTable: eventTypes,
|
||||||
EventStateKeysTable: eventStateKeys,
|
EventStateKeysTable: eventStateKeys,
|
||||||
EventJSONTable: eventJSON,
|
EventJSONTable: eventJSON,
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,13 @@ package shared
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type roomRecentEventsUpdater struct {
|
type LatestEventsUpdater struct {
|
||||||
transaction
|
transaction
|
||||||
d *Database
|
d *Database
|
||||||
roomNID types.RoomNID
|
roomNID types.RoomNID
|
||||||
|
|
@ -17,11 +18,7 @@ type roomRecentEventsUpdater struct {
|
||||||
currentStateSnapshotNID types.StateSnapshotNID
|
currentStateSnapshotNID types.StateSnapshotNID
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRoomRecentEventsUpdater(d *Database, ctx context.Context, roomNID types.RoomNID, useTxns bool) (types.RoomRecentEventsUpdater, error) {
|
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomNID types.RoomNID) (*LatestEventsUpdater, error) {
|
||||||
txn, err := d.DB.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
||||||
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
|
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -41,48 +38,46 @@ func NewRoomRecentEventsUpdater(d *Database, ctx context.Context, roomNID types.
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !useTxns {
|
return &LatestEventsUpdater{
|
||||||
txn.Commit() // nolint: errcheck
|
|
||||||
txn = nil
|
|
||||||
}
|
|
||||||
return &roomRecentEventsUpdater{
|
|
||||||
transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoomVersion implements types.RoomRecentEventsUpdater
|
// RoomVersion implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
|
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
|
||||||
version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID)
|
version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// LatestEvents implements types.RoomRecentEventsUpdater
|
// LatestEvents implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
|
func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
|
||||||
return u.latestEvents
|
return u.latestEvents
|
||||||
}
|
}
|
||||||
|
|
||||||
// LastEventIDSent implements types.RoomRecentEventsUpdater
|
// LastEventIDSent implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) LastEventIDSent() string {
|
func (u *LatestEventsUpdater) LastEventIDSent() string {
|
||||||
return u.lastEventIDSent
|
return u.lastEventIDSent
|
||||||
}
|
}
|
||||||
|
|
||||||
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
|
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
||||||
return u.currentStateSnapshotNID
|
return u.currentStateSnapshotNID
|
||||||
}
|
}
|
||||||
|
|
||||||
// StorePreviousEvents implements types.RoomRecentEventsUpdater
|
// StorePreviousEvents implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||||
for _, ref := range previousEventReferences {
|
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||||
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
for _, ref := range previousEventReferences {
|
||||||
return err
|
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||||
|
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return nil
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsReferenced implements types.RoomRecentEventsUpdater
|
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
||||||
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
|
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|
@ -90,11 +85,11 @@ func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return false, err
|
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) SetLatestEvents(
|
func (u *LatestEventsUpdater) SetLatestEvents(
|
||||||
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
|
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
|
||||||
currentStateSnapshotNID types.StateSnapshotNID,
|
currentStateSnapshotNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
|
|
@ -102,19 +97,26 @@ func (u *roomRecentEventsUpdater) SetLatestEvents(
|
||||||
for i := range latest {
|
for i := range latest {
|
||||||
eventNIDs[i] = latest[i].EventNID
|
eventNIDs[i] = latest[i].EventNID
|
||||||
}
|
}
|
||||||
return u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
|
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||||
|
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
|
||||||
|
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
||||||
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
|
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||||
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, u.txn, eventNID)
|
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||||
|
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) {
|
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
||||||
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal)
|
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal)
|
||||||
}
|
}
|
||||||
|
|
@ -3,13 +3,14 @@ package shared
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type membershipUpdater struct {
|
type MembershipUpdater struct {
|
||||||
transaction
|
transaction
|
||||||
d *Database
|
d *Database
|
||||||
roomNID types.RoomNID
|
roomNID types.RoomNID
|
||||||
|
|
@ -18,21 +19,9 @@ type membershipUpdater struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMembershipUpdater(
|
func NewMembershipUpdater(
|
||||||
ctx context.Context, d *Database, roomID, targetUserID string,
|
ctx context.Context, d *Database, txn *sql.Tx, roomID, targetUserID string,
|
||||||
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
|
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
useTxns bool,
|
) (*MembershipUpdater, error) {
|
||||||
) (types.MembershipUpdater, error) {
|
|
||||||
txn, err := d.DB.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
succeeded := false
|
|
||||||
defer func() {
|
|
||||||
if !succeeded {
|
|
||||||
txn.Rollback() // nolint: errcheck
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion)
|
roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -43,17 +32,7 @@ func NewMembershipUpdater(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal)
|
return d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
succeeded = true
|
|
||||||
if !useTxns {
|
|
||||||
txn.Commit() // nolint: errcheck
|
|
||||||
updater.transaction.txn = nil
|
|
||||||
}
|
|
||||||
return updater, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) membershipUpdaterTxn(
|
func (d *Database) membershipUpdaterTxn(
|
||||||
|
|
@ -62,10 +41,15 @@ func (d *Database) membershipUpdaterTxn(
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
targetUserNID types.EventStateKeyNID,
|
targetUserNID types.EventStateKeyNID,
|
||||||
targetLocal bool,
|
targetLocal bool,
|
||||||
) (*membershipUpdater, error) {
|
) (*MembershipUpdater, error) {
|
||||||
|
err := d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||||
if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
|
if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
|
||||||
return nil, err
|
return fmt.Errorf("d.MembershipTable.InsertMembership: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("u.d.Writer.Do: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
membership, err := d.MembershipTable.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
|
membership, err := d.MembershipTable.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
|
||||||
|
|
@ -73,55 +57,55 @@ func (d *Database) membershipUpdaterTxn(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &membershipUpdater{
|
return &MembershipUpdater{
|
||||||
transaction{ctx, txn}, d, roomNID, targetUserNID, membership,
|
transaction{ctx, txn}, d, roomNID, targetUserNID, membership,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsInvite implements types.MembershipUpdater
|
// IsInvite implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) IsInvite() bool {
|
func (u *MembershipUpdater) IsInvite() bool {
|
||||||
return u.membership == tables.MembershipStateInvite
|
return u.membership == tables.MembershipStateInvite
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsJoin implements types.MembershipUpdater
|
// IsJoin implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) IsJoin() bool {
|
func (u *MembershipUpdater) IsJoin() bool {
|
||||||
return u.membership == tables.MembershipStateJoin
|
return u.membership == tables.MembershipStateJoin
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsLeave implements types.MembershipUpdater
|
// IsLeave implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) IsLeave() bool {
|
func (u *MembershipUpdater) IsLeave() bool {
|
||||||
return u.membership == tables.MembershipStateLeaveOrBan
|
return u.membership == tables.MembershipStateLeaveOrBan
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetToInvite implements types.MembershipUpdater
|
// SetToInvite implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
|
func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
|
||||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
|
||||||
}
|
}
|
||||||
inserted, err := u.d.InvitesTable.InsertInviteEvent(
|
inserted, err := u.d.InvitesTable.InsertInviteEvent(
|
||||||
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
|
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
|
||||||
}
|
}
|
||||||
if u.membership != tables.MembershipStateInvite {
|
if u.membership != tables.MembershipStateInvite {
|
||||||
if err = u.d.MembershipTable.UpdateMembership(
|
if err = u.d.MembershipTable.UpdateMembership(
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return false, err
|
return false, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return inserted, nil
|
return inserted, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetToJoin implements types.MembershipUpdater
|
// SetToJoin implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
|
func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
|
||||||
var inviteEventIDs []string
|
var inviteEventIDs []string
|
||||||
|
|
||||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this is a join event update, there is no invite to update
|
// If this is a join event update, there is no invite to update
|
||||||
|
|
@ -130,14 +114,14 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up the NID of the new join event
|
// Look up the NID of the new join event
|
||||||
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("u.d.EventNIDs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.membership != tables.MembershipStateJoin || isUpdate {
|
if u.membership != tables.MembershipStateJoin || isUpdate {
|
||||||
|
|
@ -145,7 +129,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||||
tables.MembershipStateJoin, nIDs[eventID],
|
tables.MembershipStateJoin, nIDs[eventID],
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -153,22 +137,22 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetToLeave implements types.MembershipUpdater
|
// SetToLeave implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
|
func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
|
||||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
|
||||||
}
|
}
|
||||||
inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired(
|
inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired(
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up the NID of the new leave event
|
// Look up the NID of the new leave event
|
||||||
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("u.d.EventNIDs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.membership != tables.MembershipStateLeaveOrBan {
|
if u.membership != tables.MembershipStateLeaveOrBan {
|
||||||
|
|
@ -176,7 +160,7 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||||
tables.MembershipStateLeaveOrBan, nIDs[eventID],
|
tables.MembershipStateLeaveOrBan, nIDs[eventID],
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return inviteEventIDs, nil
|
return inviteEventIDs, nil
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ const redactionsArePermanent = false
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
|
Writer sqlutil.TransactionWriter
|
||||||
EventsTable tables.Events
|
EventsTable tables.Events
|
||||||
EventJSONTable tables.EventJSON
|
EventJSONTable tables.EventJSON
|
||||||
EventTypesTable tables.EventTypes
|
EventTypesTable tables.EventTypes
|
||||||
|
|
@ -83,20 +84,23 @@ func (d *Database) AddState(
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
state []types.StateEntry,
|
state []types.StateEntry,
|
||||||
) (stateNID types.StateSnapshotNID, err error) {
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
if len(state) > 0 {
|
if len(state) > 0 {
|
||||||
var stateBlockNID types.StateBlockNID
|
var stateBlockNID types.StateBlockNID
|
||||||
stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state)
|
stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("d.StateBlockTable.BulkInsertStateData: %w", err)
|
||||||
}
|
}
|
||||||
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
|
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
|
||||||
}
|
}
|
||||||
stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs)
|
stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs)
|
||||||
return err
|
if err != nil {
|
||||||
|
return fmt.Errorf("d.StateSnapshotTable.InsertState: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, fmt.Errorf("d.Writer.Do: %w", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -110,7 +114,9 @@ func (d *Database) EventNIDs(
|
||||||
func (d *Database) SetState(
|
func (d *Database) SetState(
|
||||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID)
|
return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
|
||||||
|
return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StateAtEventIDs(
|
func (d *Database) StateAtEventIDs(
|
||||||
|
|
@ -173,22 +179,19 @@ func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (ro
|
||||||
func (d *Database) LatestEventIDs(
|
func (d *Database) LatestEventIDs(
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
|
) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
|
||||||
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
var eventNIDs []types.EventNID
|
||||||
var eventNIDs []types.EventNID
|
eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, nil, roomNID)
|
||||||
eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, txn, roomNID)
|
if err != nil {
|
||||||
if err != nil {
|
return
|
||||||
return err
|
}
|
||||||
}
|
references, err = d.EventsTable.BulkSelectEventReference(ctx, nil, eventNIDs)
|
||||||
references, err = d.EventsTable.BulkSelectEventReference(ctx, txn, eventNIDs)
|
if err != nil {
|
||||||
if err != nil {
|
return
|
||||||
return err
|
}
|
||||||
}
|
depth, err = d.EventsTable.SelectMaxEventDepth(ctx, nil, eventNIDs)
|
||||||
depth, err = d.EventsTable.SelectMaxEventDepth(ctx, txn, eventNIDs)
|
if err != nil {
|
||||||
if err != nil {
|
return
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -221,7 +224,9 @@ func (d *Database) GetRoomVersionForRoomNID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
|
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
|
||||||
return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID)
|
return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
|
||||||
|
return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
|
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
|
||||||
|
|
@ -239,15 +244,21 @@ func (d *Database) GetCreatorIDForAlias(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
|
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
|
||||||
return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias)
|
return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
|
||||||
|
return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetMembership(
|
func (d *Database) GetMembership(
|
||||||
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
|
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
|
||||||
) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
|
) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
|
||||||
requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID)
|
var requestSenderUserNID types.EventStateKeyNID
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
|
||||||
|
requestSenderUserNID, err = d.assignStateKeyNID(ctx, nil, requestSenderUserID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return 0, false, fmt.Errorf("d.assignStateKeyNID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
senderMembershipEventNID, senderMembership, err :=
|
senderMembershipEventNID, senderMembership, err :=
|
||||||
|
|
@ -332,16 +343,35 @@ func (d *Database) GetTransactionEventID(
|
||||||
func (d *Database) MembershipUpdater(
|
func (d *Database) MembershipUpdater(
|
||||||
ctx context.Context, roomID, targetUserID string,
|
ctx context.Context, roomID, targetUserID string,
|
||||||
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
|
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
) (types.MembershipUpdater, error) {
|
) (*MembershipUpdater, error) {
|
||||||
return NewMembershipUpdater(ctx, d, roomID, targetUserID, targetLocal, roomVersion, true)
|
txn, err := d.DB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var updater *MembershipUpdater
|
||||||
|
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||||
|
updater, err = NewMembershipUpdater(ctx, d, txn, roomID, targetUserID, targetLocal, roomVersion)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return updater, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetLatestEventsForUpdate(
|
func (d *Database) GetLatestEventsForUpdate(
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
) (types.RoomRecentEventsUpdater, error) {
|
) (*LatestEventsUpdater, error) {
|
||||||
return NewRoomRecentEventsUpdater(d, ctx, roomNID, true)
|
txn, err := d.DB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var updater *LatestEventsUpdater
|
||||||
|
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||||
|
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomNID)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return updater, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:gocyclo
|
||||||
func (d *Database) StoreEvent(
|
func (d *Database) StoreEvent(
|
||||||
ctx context.Context, event gomatrixserverlib.Event,
|
ctx context.Context, event gomatrixserverlib.Event,
|
||||||
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
|
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
|
||||||
|
|
@ -357,7 +387,7 @@ func (d *Database) StoreEvent(
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
if txnAndSessionID != nil {
|
if txnAndSessionID != nil {
|
||||||
if err = d.TransactionsTable.InsertTransaction(
|
if err = d.TransactionsTable.InsertTransaction(
|
||||||
ctx, txn, txnAndSessionID.TransactionID,
|
ctx, txn, txnAndSessionID.TransactionID,
|
||||||
|
|
@ -425,7 +455,7 @@ func (d *Database) StoreEvent(
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, types.StateAtEvent{}, nil, "", err
|
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return roomNID, types.StateAtEvent{
|
return roomNID, types.StateAtEvent{
|
||||||
|
|
@ -441,7 +471,9 @@ func (d *Database) StoreEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error {
|
func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error {
|
||||||
return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish)
|
return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
|
||||||
|
return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
|
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
|
||||||
|
|
|
||||||
|
|
@ -49,15 +49,13 @@ const bulkSelectEventJSONSQL = `
|
||||||
|
|
||||||
type eventJSONStatements struct {
|
type eventJSONStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertEventJSONStmt *sql.Stmt
|
insertEventJSONStmt *sql.Stmt
|
||||||
bulkSelectEventJSONStmt *sql.Stmt
|
bulkSelectEventJSONStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
|
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
|
||||||
s := &eventJSONStatements{
|
s := &eventJSONStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(eventJSONSchema)
|
_, err := db.Exec(eventJSONSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -72,10 +70,8 @@ func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
|
||||||
func (s *eventJSONStatements) InsertEventJSON(
|
func (s *eventJSONStatements) InsertEventJSON(
|
||||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
|
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,6 @@ const bulkSelectEventStateKeyNIDSQL = `
|
||||||
|
|
||||||
type eventStateKeyStatements struct {
|
type eventStateKeyStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertEventStateKeyNIDStmt *sql.Stmt
|
insertEventStateKeyNIDStmt *sql.Stmt
|
||||||
selectEventStateKeyNIDStmt *sql.Stmt
|
selectEventStateKeyNIDStmt *sql.Stmt
|
||||||
bulkSelectEventStateKeyNIDStmt *sql.Stmt
|
bulkSelectEventStateKeyNIDStmt *sql.Stmt
|
||||||
|
|
@ -73,8 +72,7 @@ type eventStateKeyStatements struct {
|
||||||
|
|
||||||
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
|
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
|
||||||
s := &eventStateKeyStatements{
|
s := &eventStateKeyStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(eventStateKeysSchema)
|
_, err := db.Exec(eventStateKeysSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -91,19 +89,15 @@ func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
|
||||||
func (s *eventStateKeyStatements) InsertEventStateKeyNID(
|
func (s *eventStateKeyStatements) InsertEventStateKeyNID(
|
||||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||||
) (types.EventStateKeyNID, error) {
|
) (types.EventStateKeyNID, error) {
|
||||||
var eventStateKeyNID int64
|
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
||||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
res, err := insertStmt.ExecContext(ctx, eventStateKey)
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
if err != nil {
|
||||||
res, err := insertStmt.ExecContext(ctx, eventStateKey)
|
return 0, err
|
||||||
if err != nil {
|
}
|
||||||
return err
|
eventStateKeyNID, err := res.LastInsertId()
|
||||||
}
|
if err != nil {
|
||||||
eventStateKeyNID, err = res.LastInsertId()
|
return 0, err
|
||||||
if err != nil {
|
}
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
|
@ -78,7 +79,6 @@ const bulkSelectEventTypeNIDSQL = `
|
||||||
|
|
||||||
type eventTypeStatements struct {
|
type eventTypeStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertEventTypeNIDStmt *sql.Stmt
|
insertEventTypeNIDStmt *sql.Stmt
|
||||||
insertEventTypeNIDResultStmt *sql.Stmt
|
insertEventTypeNIDResultStmt *sql.Stmt
|
||||||
selectEventTypeNIDStmt *sql.Stmt
|
selectEventTypeNIDStmt *sql.Stmt
|
||||||
|
|
@ -87,8 +87,7 @@ type eventTypeStatements struct {
|
||||||
|
|
||||||
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
|
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
|
||||||
s := &eventTypeStatements{
|
s := &eventTypeStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(eventTypesSchema)
|
_, err := db.Exec(eventTypesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -104,18 +103,18 @@ func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventTypeStatements) InsertEventTypeNID(
|
func (s *eventTypeStatements) InsertEventTypeNID(
|
||||||
ctx context.Context, tx *sql.Tx, eventType string,
|
ctx context.Context, txn *sql.Tx, eventType string,
|
||||||
) (types.EventTypeNID, error) {
|
) (types.EventTypeNID, error) {
|
||||||
var eventTypeNID int64
|
var eventTypeNID int64
|
||||||
err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error {
|
insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt)
|
||||||
insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt)
|
resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt)
|
||||||
resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt)
|
_, err := insertStmt.ExecContext(ctx, eventType)
|
||||||
_, err := insertStmt.ExecContext(ctx, eventType)
|
if err != nil {
|
||||||
if err != nil {
|
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
|
||||||
return err
|
}
|
||||||
}
|
if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil {
|
||||||
return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
|
return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err)
|
||||||
})
|
}
|
||||||
return types.EventTypeNID(eventTypeNID), err
|
return types.EventTypeNID(eventTypeNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,6 @@ const selectRoomNIDForEventNIDSQL = "" +
|
||||||
|
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||||
|
|
@ -117,8 +116,7 @@ type eventStatements struct {
|
||||||
|
|
||||||
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
|
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
|
||||||
s := &eventStatements{
|
s := &eventStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(eventsSchema)
|
_, err := db.Exec(eventsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -155,22 +153,19 @@ func (s *eventStatements) InsertEvent(
|
||||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||||
// attempt to insert: the last_row_id is the event NID
|
// attempt to insert: the last_row_id is the event NID
|
||||||
var eventNID int64
|
var eventNID int64
|
||||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
result, err := insertStmt.ExecContext(
|
||||||
result, err := insertStmt.ExecContext(
|
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||||
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
)
|
||||||
)
|
if err != nil {
|
||||||
if err != nil {
|
return 0, 0, err
|
||||||
return err
|
}
|
||||||
}
|
modified, err := result.RowsAffected()
|
||||||
modified, err := result.RowsAffected()
|
if modified == 0 && err == nil {
|
||||||
if modified == 0 && err == nil {
|
return 0, 0, sql.ErrNoRows
|
||||||
return sql.ErrNoRows
|
}
|
||||||
}
|
eventNID, err = result.LastInsertId()
|
||||||
eventNID, err = result.LastInsertId()
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return types.EventNID(eventNID), 0, err
|
return types.EventNID(eventNID), 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -286,11 +281,8 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||||
func (s *eventStatements) UpdateEventState(
|
func (s *eventStatements) UpdateEventState(
|
||||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt)
|
return err
|
||||||
_, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) SelectEventSentToOutput(
|
func (s *eventStatements) SelectEventSentToOutput(
|
||||||
|
|
@ -302,11 +294,9 @@ func (s *eventStatements) SelectEventSentToOutput(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
|
func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
|
||||||
updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
|
_, err := updateStmt.ExecContext(ctx, int64(eventNID))
|
||||||
_, err := updateStmt.ExecContext(ctx, int64(eventNID))
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) SelectEventID(
|
func (s *eventStatements) SelectEventID(
|
||||||
|
|
@ -326,12 +316,16 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||||
//////////////
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
|
||||||
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
//////////////
|
||||||
|
|
||||||
|
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err)
|
||||||
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed")
|
||||||
results := make([]types.StateAtEventAndReference, len(eventNIDs))
|
results := make([]types.StateAtEventAndReference, len(eventNIDs))
|
||||||
i := 0
|
i := 0
|
||||||
|
|
@ -372,7 +366,7 @@ func (s *eventStatements) BulkSelectEventReference(
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||||
selectPrep, err := txn.Prepare(selectOrig)
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -471,10 +465,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
|
||||||
iEventIDs[i] = v
|
iEventIDs[i] = v
|
||||||
}
|
}
|
||||||
sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
||||||
err := txn.QueryRowContext(ctx, sqlStr, iEventIDs...).Scan(&result)
|
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err)
|
||||||
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,6 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni
|
||||||
|
|
||||||
type inviteStatements struct {
|
type inviteStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertInviteEventStmt *sql.Stmt
|
insertInviteEventStmt *sql.Stmt
|
||||||
selectInviteActiveForUserInRoomStmt *sql.Stmt
|
selectInviteActiveForUserInRoomStmt *sql.Stmt
|
||||||
updateInviteRetiredStmt *sql.Stmt
|
updateInviteRetiredStmt *sql.Stmt
|
||||||
|
|
@ -73,8 +72,7 @@ type inviteStatements struct {
|
||||||
|
|
||||||
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
|
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
|
||||||
s := &inviteStatements{
|
s := &inviteStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(inviteSchema)
|
_, err := db.Exec(inviteSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -96,20 +94,17 @@ func (s *inviteStatements) InsertInviteEvent(
|
||||||
inviteEventJSON []byte,
|
inviteEventJSON []byte,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
var count int64
|
var count int64
|
||||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
|
result, err := stmt.ExecContext(
|
||||||
result, err := stmt.ExecContext(
|
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||||
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
)
|
||||||
)
|
if err != nil {
|
||||||
if err != nil {
|
return false, err
|
||||||
return err
|
}
|
||||||
}
|
count, err = result.RowsAffected()
|
||||||
count, err = result.RowsAffected()
|
if err != nil {
|
||||||
if err != nil {
|
return false, err
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return count != 0, err
|
return count != 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -117,26 +112,23 @@ func (s *inviteStatements) UpdateInviteRetired(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (eventIDs []string, err error) {
|
) (eventIDs []string, err error) {
|
||||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
// gather all the event IDs we will retire
|
||||||
// gather all the event IDs we will retire
|
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
if err != nil {
|
||||||
if err != nil {
|
return
|
||||||
return err
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var inviteEventID string
|
||||||
|
if err = rows.Scan(&inviteEventID); err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
defer (func() { err = rows.Close() })()
|
eventIDs = append(eventIDs, inviteEventID)
|
||||||
for rows.Next() {
|
}
|
||||||
var inviteEventID string
|
// now retire the invites
|
||||||
if err = rows.Scan(&inviteEventID); err != nil {
|
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||||
return err
|
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
|
||||||
}
|
|
||||||
eventIDs = append(eventIDs, inviteEventID)
|
|
||||||
}
|
|
||||||
// now retire the invites
|
|
||||||
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
|
||||||
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,6 @@ const updateMembershipSQL = "" +
|
||||||
|
|
||||||
type membershipStatements struct {
|
type membershipStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertMembershipStmt *sql.Stmt
|
insertMembershipStmt *sql.Stmt
|
||||||
selectMembershipForUpdateStmt *sql.Stmt
|
selectMembershipForUpdateStmt *sql.Stmt
|
||||||
selectMembershipFromRoomAndTargetStmt *sql.Stmt
|
selectMembershipFromRoomAndTargetStmt *sql.Stmt
|
||||||
|
|
@ -90,8 +89,7 @@ type membershipStatements struct {
|
||||||
|
|
||||||
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
s := &membershipStatements{
|
s := &membershipStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(membershipSchema)
|
_, err := db.Exec(membershipSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -115,11 +113,9 @@ func (s *membershipStatements) InsertMembership(
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
localTarget bool,
|
localTarget bool,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
|
||||||
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectMembershipForUpdate(
|
func (s *membershipStatements) SelectMembershipForUpdate(
|
||||||
|
|
@ -201,11 +197,9 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||||
eventNID types.EventNID,
|
eventNID types.EventNID,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
|
_, err := stmt.ExecContext(
|
||||||
_, err := stmt.ExecContext(
|
ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
|
||||||
ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
|
)
|
||||||
)
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,15 +54,13 @@ const selectPreviousEventExistsSQL = `
|
||||||
|
|
||||||
type previousEventStatements struct {
|
type previousEventStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertPreviousEventStmt *sql.Stmt
|
insertPreviousEventStmt *sql.Stmt
|
||||||
selectPreviousEventExistsStmt *sql.Stmt
|
selectPreviousEventExistsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
|
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
|
||||||
s := &previousEventStatements{
|
s := &previousEventStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(previousEventSchema)
|
_, err := db.Exec(previousEventSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -82,13 +80,11 @@ func (s *previousEventStatements) InsertPreviousEvent(
|
||||||
previousEventReferenceSHA256 []byte,
|
previousEventReferenceSHA256 []byte,
|
||||||
eventNID types.EventNID,
|
eventNID types.EventNID,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
_, err := stmt.ExecContext(
|
||||||
_, err := stmt.ExecContext(
|
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
|
||||||
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
|
)
|
||||||
)
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the event reference exists
|
// Check if the event reference exists
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
@ -45,7 +44,6 @@ const selectPublishedSQL = "" +
|
||||||
|
|
||||||
type publishedStatements struct {
|
type publishedStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
upsertPublishedStmt *sql.Stmt
|
upsertPublishedStmt *sql.Stmt
|
||||||
selectAllPublishedStmt *sql.Stmt
|
selectAllPublishedStmt *sql.Stmt
|
||||||
selectPublishedStmt *sql.Stmt
|
selectPublishedStmt *sql.Stmt
|
||||||
|
|
@ -53,8 +51,7 @@ type publishedStatements struct {
|
||||||
|
|
||||||
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
|
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
|
||||||
s := &publishedStatements{
|
s := &publishedStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(publishedSchema)
|
_, err := db.Exec(publishedSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -69,12 +66,9 @@ func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
|
||||||
|
|
||||||
func (s *publishedStatements) UpsertRoomPublished(
|
func (s *publishedStatements) UpsertRoomPublished(
|
||||||
ctx context.Context, roomID string, published bool,
|
ctx context.Context, roomID string, published bool,
|
||||||
) (err error) {
|
) error {
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
_, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
|
||||||
stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
|
return err
|
||||||
_, err := stmt.ExecContext(ctx, roomID, published)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,6 @@ const markRedactionValidatedSQL = "" +
|
||||||
|
|
||||||
type redactionStatements struct {
|
type redactionStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertRedactionStmt *sql.Stmt
|
insertRedactionStmt *sql.Stmt
|
||||||
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
|
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
|
||||||
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
|
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
|
||||||
|
|
@ -62,8 +61,7 @@ type redactionStatements struct {
|
||||||
|
|
||||||
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
|
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
|
||||||
s := &redactionStatements{
|
s := &redactionStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(redactionsSchema)
|
_, err := db.Exec(redactionsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -81,11 +79,9 @@ func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
|
||||||
func (s *redactionStatements) InsertRedaction(
|
func (s *redactionStatements) InsertRedaction(
|
||||||
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
|
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
|
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
|
||||||
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
|
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
|
||||||
|
|
@ -121,9 +117,7 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
|
||||||
func (s *redactionStatements) MarkRedactionValidated(
|
func (s *redactionStatements) MarkRedactionValidated(
|
||||||
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
|
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
|
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
|
||||||
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
@ -57,7 +56,6 @@ const deleteRoomAliasSQL = `
|
||||||
|
|
||||||
type roomAliasesStatements struct {
|
type roomAliasesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertRoomAliasStmt *sql.Stmt
|
insertRoomAliasStmt *sql.Stmt
|
||||||
selectRoomIDFromAliasStmt *sql.Stmt
|
selectRoomIDFromAliasStmt *sql.Stmt
|
||||||
selectAliasesFromRoomIDStmt *sql.Stmt
|
selectAliasesFromRoomIDStmt *sql.Stmt
|
||||||
|
|
@ -67,8 +65,7 @@ type roomAliasesStatements struct {
|
||||||
|
|
||||||
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||||
s := &roomAliasesStatements{
|
s := &roomAliasesStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(roomAliasesSchema)
|
_, err := db.Exec(roomAliasesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -85,12 +82,9 @@ func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||||
|
|
||||||
func (s *roomAliasesStatements) InsertRoomAlias(
|
func (s *roomAliasesStatements) InsertRoomAlias(
|
||||||
ctx context.Context, alias string, roomID string, creatorUserID string,
|
ctx context.Context, alias string, roomID string, creatorUserID string,
|
||||||
) (err error) {
|
) error {
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
_, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt)
|
return err
|
||||||
_, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||||
|
|
@ -138,10 +132,7 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||||
|
|
||||||
func (s *roomAliasesStatements) DeleteRoomAlias(
|
func (s *roomAliasesStatements) DeleteRoomAlias(
|
||||||
ctx context.Context, alias string,
|
ctx context.Context, alias string,
|
||||||
) (err error) {
|
) error {
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
_, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias)
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt)
|
return err
|
||||||
_, err := stmt.ExecContext(ctx, alias)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,6 @@ const selectRoomVersionForRoomNIDSQL = "" +
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
selectRoomNIDStmt *sql.Stmt
|
selectRoomNIDStmt *sql.Stmt
|
||||||
selectLatestEventNIDsStmt *sql.Stmt
|
selectLatestEventNIDsStmt *sql.Stmt
|
||||||
|
|
@ -78,8 +77,7 @@ type roomStatements struct {
|
||||||
|
|
||||||
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
s := &roomStatements{
|
s := &roomStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(roomsSchema)
|
_, err := db.Exec(roomsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -100,20 +98,14 @@ func (s *roomStatements) InsertRoomNID(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
) (roomNID types.RoomNID, err error) {
|
) (roomNID types.RoomNID, err error) {
|
||||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
|
_, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
|
||||||
_, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("insertStmt.ExecContext: %w", err)
|
|
||||||
}
|
|
||||||
roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("s.SelectRoomNID: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.RoomNID(0), err
|
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
|
||||||
|
}
|
||||||
|
roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("s.SelectRoomNID: %w", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -170,17 +162,15 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
||||||
lastEventSentNID types.EventNID,
|
lastEventSentNID types.EventNID,
|
||||||
stateSnapshotNID types.StateSnapshotNID,
|
stateSnapshotNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
|
_, err := stmt.ExecContext(
|
||||||
_, err := stmt.ExecContext(
|
ctx,
|
||||||
ctx,
|
eventNIDsAsArray(eventNIDs),
|
||||||
eventNIDsAsArray(eventNIDs),
|
int64(lastEventSentNID),
|
||||||
int64(lastEventSentNID),
|
int64(stateSnapshotNID),
|
||||||
int64(stateSnapshotNID),
|
roomNID,
|
||||||
roomNID,
|
)
|
||||||
)
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomVersionForRoomID(
|
func (s *roomStatements) SelectRoomVersionForRoomID(
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,6 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
||||||
|
|
||||||
type stateBlockStatements struct {
|
type stateBlockStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertStateDataStmt *sql.Stmt
|
insertStateDataStmt *sql.Stmt
|
||||||
selectNextStateBlockNIDStmt *sql.Stmt
|
selectNextStateBlockNIDStmt *sql.Stmt
|
||||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||||
|
|
@ -83,8 +82,7 @@ type stateBlockStatements struct {
|
||||||
|
|
||||||
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||||
s := &stateBlockStatements{
|
s := &stateBlockStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(stateDataSchema)
|
_, err := db.Exec(stateDataSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -107,25 +105,22 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
var stateBlockNID types.StateBlockNID
|
var stateBlockNID types.StateBlockNID
|
||||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||||
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
for _, entry := range entries {
|
||||||
|
_, err = txn.Stmt(s.insertStateDataStmt).ExecContext(
|
||||||
|
ctx,
|
||||||
|
int64(stateBlockNID),
|
||||||
|
int64(entry.EventTypeNID),
|
||||||
|
int64(entry.EventStateKeyNID),
|
||||||
|
int64(entry.EventNID),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
for _, entry := range entries {
|
}
|
||||||
_, err := txn.Stmt(s.insertStateDataStmt).ExecContext(
|
|
||||||
ctx,
|
|
||||||
int64(stateBlockNID),
|
|
||||||
int64(entry.EventTypeNID),
|
|
||||||
int64(entry.EventStateKeyNID),
|
|
||||||
int64(entry.EventNID),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return stateBlockNID, err
|
return stateBlockNID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,15 +50,13 @@ const bulkSelectStateBlockNIDsSQL = "" +
|
||||||
|
|
||||||
type stateSnapshotStatements struct {
|
type stateSnapshotStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertStateStmt *sql.Stmt
|
insertStateStmt *sql.Stmt
|
||||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||||
s := &stateSnapshotStatements{
|
s := &stateSnapshotStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(stateSnapshotSchema)
|
_, err := db.Exec(stateSnapshotSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -78,19 +76,16 @@ func (s *stateSnapshotStatements) InsertState(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
insertStmt := txn.Stmt(s.insertStateStmt)
|
||||||
insertStmt := txn.Stmt(s.insertStateStmt)
|
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
||||||
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
if err != nil {
|
||||||
if err != nil {
|
return 0, err
|
||||||
return err
|
}
|
||||||
}
|
lastRowID, err := res.LastInsertId()
|
||||||
lastRowID, err := res.LastInsertId()
|
if err != nil {
|
||||||
if err != nil {
|
return 0, err
|
||||||
return err
|
}
|
||||||
}
|
stateNID = types.StateSnapshotNID(lastRowID)
|
||||||
stateNID = types.StateSnapshotNID(lastRowID)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ type Database struct {
|
||||||
invites tables.Invites
|
invites tables.Invites
|
||||||
membership tables.Membership
|
membership tables.Membership
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.TransactionWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open a sqlite database.
|
// Open a sqlite database.
|
||||||
|
|
@ -51,6 +52,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
d.writer = sqlutil.NewTransactionWriter()
|
||||||
//d.db.Exec("PRAGMA journal_mode=WAL;")
|
//d.db.Exec("PRAGMA journal_mode=WAL;")
|
||||||
//d.db.Exec("PRAGMA read_uncommitted = true;")
|
//d.db.Exec("PRAGMA read_uncommitted = true;")
|
||||||
|
|
||||||
|
|
@ -118,6 +120,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
|
Writer: sqlutil.NewTransactionWriter(),
|
||||||
EventsTable: d.events,
|
EventsTable: d.events,
|
||||||
EventTypesTable: d.eventTypes,
|
EventTypesTable: d.eventTypes,
|
||||||
EventStateKeysTable: d.eventStateKeys,
|
EventStateKeysTable: d.eventStateKeys,
|
||||||
|
|
@ -138,25 +141,25 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
|
|
||||||
func (d *Database) GetLatestEventsForUpdate(
|
func (d *Database) GetLatestEventsForUpdate(
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
) (types.RoomRecentEventsUpdater, error) {
|
) (*shared.LatestEventsUpdater, error) {
|
||||||
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
||||||
// multiple write transactions on sqlite. The code will perform additional
|
// multiple write transactions on sqlite. The code will perform additional
|
||||||
// write transactions independent of this one which will consistently cause
|
// write transactions independent of this one which will consistently cause
|
||||||
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
||||||
// same DB anyway, and we only execute updates sequentially, the only worries
|
// same DB anyway, and we only execute updates sequentially, the only worries
|
||||||
// are for rolling back when things go wrong. (atomicity)
|
// are for rolling back when things go wrong. (atomicity)
|
||||||
return shared.NewRoomRecentEventsUpdater(&d.Database, ctx, roomNID, false)
|
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) MembershipUpdater(
|
func (d *Database) MembershipUpdater(
|
||||||
ctx context.Context, roomID, targetUserID string,
|
ctx context.Context, roomID, targetUserID string,
|
||||||
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
|
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
) (updater types.MembershipUpdater, err error) {
|
) (*shared.MembershipUpdater, error) {
|
||||||
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
||||||
// multiple write transactions on sqlite. The code will perform additional
|
// multiple write transactions on sqlite. The code will perform additional
|
||||||
// write transactions independent of this one which will consistently cause
|
// write transactions independent of this one which will consistently cause
|
||||||
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
||||||
// same DB anyway, and we only execute updates sequentially, the only worries
|
// same DB anyway, and we only execute updates sequentially, the only worries
|
||||||
// are for rolling back when things go wrong. (atomicity)
|
// are for rolling back when things go wrong. (atomicity)
|
||||||
return shared.NewMembershipUpdater(ctx, &d.Database, roomID, targetUserID, targetLocal, roomVersion, false)
|
return shared.NewMembershipUpdater(ctx, &d.Database, nil, roomID, targetUserID, targetLocal, roomVersion)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -45,15 +45,13 @@ const selectTransactionEventIDSQL = `
|
||||||
|
|
||||||
type transactionStatements struct {
|
type transactionStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
|
||||||
insertTransactionStmt *sql.Stmt
|
insertTransactionStmt *sql.Stmt
|
||||||
selectTransactionEventIDStmt *sql.Stmt
|
selectTransactionEventIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
|
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
|
||||||
s := &transactionStatements{
|
s := &transactionStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(transactionsSchema)
|
_, err := db.Exec(transactionsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -72,14 +70,12 @@ func (s *transactionStatements) InsertTransaction(
|
||||||
sessionID int64,
|
sessionID int64,
|
||||||
userID string,
|
userID string,
|
||||||
eventID string,
|
eventID string,
|
||||||
) (err error) {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
|
_, err := stmt.ExecContext(
|
||||||
_, err := stmt.ExecContext(
|
ctx, transactionID, sessionID, userID, eventID,
|
||||||
ctx, transactionID, sessionID, userID, eventID,
|
)
|
||||||
)
|
return err
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *transactionStatements) SelectTransactionEventID(
|
func (s *transactionStatements) SelectTransactionEventID(
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -140,68 +139,6 @@ type StateEntryList struct {
|
||||||
StateEntries []StateEntry
|
StateEntries []StateEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
// A RoomRecentEventsUpdater is used to update the recent events in a room.
|
|
||||||
// (On postgresql this wraps a database transaction that holds a "FOR UPDATE"
|
|
||||||
// lock on the row in the rooms table holding the latest events for the room.)
|
|
||||||
type RoomRecentEventsUpdater interface {
|
|
||||||
// The room version of the room.
|
|
||||||
RoomVersion() gomatrixserverlib.RoomVersion
|
|
||||||
// The latest event IDs and state in the room.
|
|
||||||
LatestEvents() []StateAtEventAndReference
|
|
||||||
// The event ID of the latest event written to the output log in the room.
|
|
||||||
LastEventIDSent() string
|
|
||||||
// The current state of the room.
|
|
||||||
CurrentStateSnapshotNID() StateSnapshotNID
|
|
||||||
// Store the previous events referenced by an event.
|
|
||||||
// This adds the event NID to an entry in the database for each of the previous events.
|
|
||||||
// If there isn't an entry for one of previous events then an entry is created.
|
|
||||||
// If the entry already lists the event NID as a referrer then the entry unmodified.
|
|
||||||
// (i.e. the operation is idempotent)
|
|
||||||
StorePreviousEvents(eventNID EventNID, previousEventReferences []gomatrixserverlib.EventReference) error
|
|
||||||
// Check whether the eventReference is already referenced by another matrix event.
|
|
||||||
IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error)
|
|
||||||
// Set the list of latest events for the room.
|
|
||||||
// This replaces the current list stored in the database with the given list
|
|
||||||
SetLatestEvents(
|
|
||||||
roomNID RoomNID, latest []StateAtEventAndReference, lastEventNIDSent EventNID,
|
|
||||||
currentStateSnapshotNID StateSnapshotNID,
|
|
||||||
) error
|
|
||||||
// Check if the event has already be written to the output logs.
|
|
||||||
HasEventBeenSent(eventNID EventNID) (bool, error)
|
|
||||||
// Mark the event as having been sent to the output logs.
|
|
||||||
MarkEventAsSent(eventNID EventNID) error
|
|
||||||
// Build a membership updater for the target user in this room.
|
|
||||||
// It will share the same transaction as this updater.
|
|
||||||
MembershipUpdater(targetUserNID EventStateKeyNID, isTargetLocalUser bool) (MembershipUpdater, error)
|
|
||||||
// Implements Transaction so it can be committed or rolledback
|
|
||||||
sqlutil.Transaction
|
|
||||||
}
|
|
||||||
|
|
||||||
// A MembershipUpdater is used to update the membership of a user in a room.
|
|
||||||
// (On postgresql this wraps a database transaction that holds a "FOR UPDATE"
|
|
||||||
// lock on the row in the membership table for this user in the room)
|
|
||||||
// The caller should call one of SetToInvite, SetToJoin or SetToLeave once to
|
|
||||||
// make the update, or none of them if no update is required.
|
|
||||||
type MembershipUpdater interface {
|
|
||||||
// True if the target user is invited to the room before updating.
|
|
||||||
IsInvite() bool
|
|
||||||
// True if the target user is joined to the room before updating.
|
|
||||||
IsJoin() bool
|
|
||||||
// True if the target user is not invited or joined to the room before updating.
|
|
||||||
IsLeave() bool
|
|
||||||
// Set the state to invite.
|
|
||||||
// Returns whether this invite needs to be sent
|
|
||||||
SetToInvite(event gomatrixserverlib.Event) (needsSending bool, err error)
|
|
||||||
// Set the state to join or updates the event ID in the database.
|
|
||||||
// Returns a list of invite event IDs that this state change retired.
|
|
||||||
SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error)
|
|
||||||
// Set the state to leave.
|
|
||||||
// Returns a list of invite event IDs that this state change retired.
|
|
||||||
SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error)
|
|
||||||
// Implements Transaction so it can be committed or rolledback.
|
|
||||||
sqlutil.Transaction
|
|
||||||
}
|
|
||||||
|
|
||||||
// A MissingEventError is an error that happened because the roomserver was
|
// A MissingEventError is an error that happened because the roomserver was
|
||||||
// missing requested events from its database.
|
// missing requested events from its database.
|
||||||
type MissingEventError string
|
type MissingEventError string
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ const upsertServerKeysSQL = "" +
|
||||||
|
|
||||||
type serverKeyStatements struct {
|
type serverKeyStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
bulkSelectServerKeysStmt *sql.Stmt
|
bulkSelectServerKeysStmt *sql.Stmt
|
||||||
upsertServerKeysStmt *sql.Stmt
|
upsertServerKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ type Database struct {
|
||||||
BackwardExtremities tables.BackwardsExtremities
|
BackwardExtremities tables.BackwardsExtremities
|
||||||
SendToDevice tables.SendToDevice
|
SendToDevice tables.SendToDevice
|
||||||
Filter tables.Filter
|
Filter tables.Filter
|
||||||
SendToDeviceWriter *sqlutil.TransactionWriter
|
SendToDeviceWriter sqlutil.TransactionWriter
|
||||||
EDUCache *cache.EDUCache
|
EDUCache *cache.EDUCache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ const selectMaxAccountDataIDSQL = "" +
|
||||||
|
|
||||||
type accountDataStatements struct {
|
type accountDataStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
insertAccountDataStmt *sql.Stmt
|
insertAccountDataStmt *sql.Stmt
|
||||||
selectMaxAccountDataIDStmt *sql.Stmt
|
selectMaxAccountDataIDStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ const deleteBackwardExtremitySQL = "" +
|
||||||
|
|
||||||
type backwardExtremitiesStatements struct {
|
type backwardExtremitiesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertBackwardExtremityStmt *sql.Stmt
|
insertBackwardExtremityStmt *sql.Stmt
|
||||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||||
deleteBackwardExtremityStmt *sql.Stmt
|
deleteBackwardExtremityStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ const selectEventsWithEventIDsSQL = "" +
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ const insertFilterSQL = "" +
|
||||||
|
|
||||||
type filterStatements struct {
|
type filterStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
selectFilterStmt *sql.Stmt
|
selectFilterStmt *sql.Stmt
|
||||||
selectFilterIDByContentStmt *sql.Stmt
|
selectFilterIDByContentStmt *sql.Stmt
|
||||||
insertFilterStmt *sql.Stmt
|
insertFilterStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ const selectMaxInviteIDSQL = "" +
|
||||||
|
|
||||||
type inviteEventsStatements struct {
|
type inviteEventsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
insertInviteEventStmt *sql.Stmt
|
insertInviteEventStmt *sql.Stmt
|
||||||
selectInviteEventsInRangeStmt *sql.Stmt
|
selectInviteEventsInRangeStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ const selectStateInRangeSQL = "" +
|
||||||
|
|
||||||
type outputRoomEventsStatements struct {
|
type outputRoomEventsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventsStmt *sql.Stmt
|
selectEventsStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ const selectMaxPositionInTopologySQL = "" +
|
||||||
|
|
||||||
type outputRoomEventsTopologyStatements struct {
|
type outputRoomEventsTopologyStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertEventInTopologyStmt *sql.Stmt
|
insertEventInTopologyStmt *sql.Stmt
|
||||||
selectEventIDsInRangeASCStmt *sql.Stmt
|
selectEventIDsInRangeASCStmt *sql.Stmt
|
||||||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ const deleteSendToDeviceMessagesSQL = `
|
||||||
|
|
||||||
type sendToDeviceStatements struct {
|
type sendToDeviceStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertSendToDeviceMessageStmt *sql.Stmt
|
insertSendToDeviceMessageStmt *sql.Stmt
|
||||||
selectSendToDeviceMessagesStmt *sql.Stmt
|
selectSendToDeviceMessagesStmt *sql.Stmt
|
||||||
countSendToDeviceMessagesStmt *sql.Stmt
|
countSendToDeviceMessagesStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ const selectStreamIDStmt = "" +
|
||||||
|
|
||||||
type streamIDStatements struct {
|
type streamIDStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
increaseStreamIDStmt *sql.Stmt
|
increaseStreamIDStmt *sql.Stmt
|
||||||
selectStreamIDStmt *sql.Stmt
|
selectStreamIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ const selectAccountDataByTypeSQL = "" +
|
||||||
|
|
||||||
type accountDataStatements struct {
|
type accountDataStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertAccountDataStmt *sql.Stmt
|
insertAccountDataStmt *sql.Stmt
|
||||||
selectAccountDataStmt *sql.Stmt
|
selectAccountDataStmt *sql.Stmt
|
||||||
selectAccountDataByTypeStmt *sql.Stmt
|
selectAccountDataByTypeStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" +
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
selectPasswordHashStmt *sql.Stmt
|
selectPasswordHashStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" +
|
||||||
|
|
||||||
type profilesStatements struct {
|
type profilesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertProfileStmt *sql.Stmt
|
insertProfileStmt *sql.Stmt
|
||||||
selectProfileByLocalpartStmt *sql.Stmt
|
selectProfileByLocalpartStmt *sql.Stmt
|
||||||
setAvatarURLStmt *sql.Stmt
|
setAvatarURLStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ const deleteThreePIDSQL = "" +
|
||||||
|
|
||||||
type threepidStatements struct {
|
type threepidStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
selectLocalpartForThreePIDStmt *sql.Stmt
|
selectLocalpartForThreePIDStmt *sql.Stmt
|
||||||
selectThreePIDsForLocalpartStmt *sql.Stmt
|
selectThreePIDsForLocalpartStmt *sql.Stmt
|
||||||
insertThreePIDStmt *sql.Stmt
|
insertThreePIDStmt *sql.Stmt
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" +
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer sqlutil.TransactionWriter
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
selectDevicesCountStmt *sql.Stmt
|
selectDevicesCountStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue