sqlite work

This commit is contained in:
Kegan Dougal 2020-02-12 11:55:31 +00:00
parent 0360f07110
commit fb9243b4ed
7 changed files with 299 additions and 234 deletions

View file

@ -30,11 +30,13 @@ type Transaction interface {
// EndTransaction ends a transaction. // EndTransaction ends a transaction.
// If the transaction succeeded then it is committed, otherwise it is rolledback. // If the transaction succeeded then it is committed, otherwise it is rolledback.
func EndTransaction(txn Transaction, succeeded *bool) { // You MUST check the error returned from this function to be sure that the transaction
// was applied correctly. For example, 'database is locked' errors in sqlite will happen here.
func EndTransaction(txn Transaction, succeeded *bool) error {
if *succeeded { if *succeeded {
txn.Commit() // nolint: errcheck return txn.Commit() // nolint: errcheck
} else { } else {
txn.Rollback() // nolint: errcheck return txn.Rollback() // nolint: errcheck
} }
} }
@ -47,7 +49,12 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
return return
} }
succeeded := false succeeded := false
defer EndTransaction(txn, &succeeded) defer func() {
err2 := EndTransaction(txn, &succeeded)
if err == nil && err2 != nil { // failed to commit/rollback
err = err2
}
}()
err = fn(txn) err = fn(txn)
if err != nil { if err != nil {

View file

@ -1,13 +1,21 @@
version: "3.4" version: "3.4"
services: services:
riot:
image: vectorim/riot-web
networks:
- internal
ports:
- "8500:80"
monolith: monolith:
container_name: dendrite_monolith container_name: dendrite_monolith
hostname: monolith hostname: monolith
entrypoint: ["bash", "./docker/services/monolith.sh"] entrypoint: ["bash", "./docker/services/monolith.sh", "--config", "/etc/dendrite/dendrite.yaml"]
build: ./ build: ./
volumes: volumes:
- ..:/build - ..:/build
- ./build/bin:/build/bin - ./build/bin:/build/bin
- ../cfg:/etc/dendrite
networks: networks:
- internal - internal
depends_on: depends_on:

View file

@ -18,7 +18,6 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings" "strings"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -79,21 +78,14 @@ type eventJSONPair struct {
func (s *eventJSONStatements) bulkSelectEventJSON( func (s *eventJSONStatements) bulkSelectEventJSON(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]eventJSONPair, error) { ) ([]eventJSONPair, error) {
///////////////
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs { for k, v := range eventNIDs {
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
if err != nil { if err != nil {
fmt.Println("bulkSelectEventJSON s.bulkSelectEventJSONStmt.QueryContext:", err)
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck defer rows.Close() // nolint: errcheck
@ -108,7 +100,6 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
result := &results[i] result := &results[i]
var eventNID int64 var eventNID int64
if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
fmt.Println("bulkSelectEventJSON rows.Scan:", err)
return nil, err return nil, err
} }
result.EventNID = types.EventNID(eventNID) result.EventNID = types.EventNID(eventNID)

View file

@ -21,7 +21,6 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -41,11 +40,6 @@ const insertEventStateKeyNIDSQL = `
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
` `
const insertEventStateKeyNIDResultSQL = `
SELECT event_state_key_nid FROM roomserver_event_state_keys
WHERE rowid = last_insert_rowid();
`
const selectEventStateKeyNIDSQL = ` const selectEventStateKeyNIDSQL = `
SELECT event_state_key_nid FROM roomserver_event_state_keys SELECT event_state_key_nid FROM roomserver_event_state_keys
WHERE event_state_key = $1 WHERE event_state_key = $1
@ -66,12 +60,11 @@ const bulkSelectEventStateKeySQL = `
` `
type eventStateKeyStatements struct { type eventStateKeyStatements struct {
db *sql.DB db *sql.DB
insertEventStateKeyNIDStmt *sql.Stmt insertEventStateKeyNIDStmt *sql.Stmt
insertEventStateKeyNIDResultStmt *sql.Stmt selectEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyStmt *sql.Stmt
bulkSelectEventStateKeyStmt *sql.Stmt
} }
func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
@ -82,7 +75,6 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
} }
return statementList{ return statementList{
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
{&s.insertEventStateKeyNIDResultStmt, insertEventStateKeyNIDResultSQL},
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
{&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
@ -94,13 +86,10 @@ func (s *eventStateKeyStatements) insertEventStateKeyNID(
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64 var eventStateKeyNID int64
var err error var err error
insertStmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt) var res sql.Result
selectStmt := common.TxStmt(txn, s.insertEventStateKeyNIDResultStmt) insertStmt := txn.Stmt(s.insertEventStateKeyNIDStmt)
if _, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil { if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil {
err = selectStmt.QueryRowContext(ctx).Scan(&eventStateKeyNID) eventStateKeyNID, err = res.LastInsertId()
if err != nil {
fmt.Println("insertEventStateKeyNID selectStmt.QueryRowContext:", err)
}
} else { } else {
fmt.Println("insertEventStateKeyNID insertStmt.ExecContext:", err) fmt.Println("insertEventStateKeyNID insertStmt.ExecContext:", err)
} }
@ -111,36 +100,22 @@ func (s *eventStateKeyStatements) selectEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string, ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64 var eventStateKeyNID int64
stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt) stmt := txn.Stmt(s.selectEventStateKeyNIDStmt)
fmt.Println("selectEventStateKeyNID for", eventStateKey)
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
if err != nil {
fmt.Println("selectEventStateKeyNID stmt.QueryRowContext:", err)
}
fmt.Println("selectEventStateKeyNID returns", eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err return types.EventStateKeyNID(eventStateKeyNID), err
} }
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKeys []string, ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
///////////////
iEventStateKeys := make([]interface{}, len(eventStateKeys)) iEventStateKeys := make([]interface{}, len(eventStateKeys))
for k, v := range eventStateKeys { for k, v := range eventStateKeys {
iEventStateKeys[k] = v iEventStateKeys[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1) selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
rows, err := common.TxStmt(txn, selectPrep).QueryContext( rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
ctx, iEventStateKeys...,
)
if err != nil { if err != nil {
fmt.Println("bulkSelectEventStateKeyNID s.bulkSelectEventStateKeyNIDStmt.QueryContext:", err)
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck defer rows.Close() // nolint: errcheck
@ -149,7 +124,6 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
var stateKey string var stateKey string
var stateKeyNID int64 var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
fmt.Println("bulkSelectEventStateKeyNID rows.Scan:", err)
return nil, err return nil, err
} }
result[stateKey] = types.EventStateKeyNID(stateKeyNID) result[stateKey] = types.EventStateKeyNID(stateKeyNID)
@ -160,25 +134,14 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
func (s *eventStateKeyStatements) bulkSelectEventStateKey( func (s *eventStateKeyStatements) bulkSelectEventStateKey(
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) { ) (map[types.EventStateKeyNID]string, error) {
///////////////
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
for k, v := range eventStateKeyNIDs { for k, v := range eventStateKeyNIDs {
iEventStateKeyNIDs[k] = v iEventStateKeyNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
/*nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i])
}*/
rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventStateKeyNIDs...)
if err != nil { if err != nil {
fmt.Println("bulkSelectEventStateKey s.bulkSelectEventStateKeyStmt.QueryContext:", err)
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck defer rows.Close() // nolint: errcheck
@ -187,7 +150,6 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
var stateKey string var stateKey string
var stateKeyNID int64 var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
fmt.Println("bulkSelectEventStateKey rows.Scan:", err)
return nil, err return nil, err
} }
result[types.EventStateKeyNID(stateKeyNID)] = stateKey result[types.EventStateKeyNID(stateKeyNID)] = stateKey

View file

@ -69,7 +69,7 @@ const bulkSelectStateAtEventByIDSQL = "" +
" WHERE event_id IN ($1)" " WHERE event_id IN ($1)"
const updateEventStateSQL = "" + const updateEventStateSQL = "" +
"UPDATE roomserver_events SET state_snapshot_nid = $2 WHERE event_nid = $1" "UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2"
const selectEventSentToOutputSQL = "" + const selectEventSentToOutputSQL = "" +
"SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1"
@ -153,21 +153,18 @@ func (s *eventStatements) insertEvent(
var eventNID int64 var eventNID int64
var stateNID int64 var stateNID int64
var err error var err error
var res sql.Result
insertStmt := common.TxStmt(txn, s.insertEventStmt) insertStmt := common.TxStmt(txn, s.insertEventStmt)
resultStmt := common.TxStmt(txn, s.insertEventResultStmt) resultStmt := common.TxStmt(txn, s.insertEventResultStmt)
if _, err = insertStmt.ExecContext( if res, err = insertStmt.ExecContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
); err == nil { ); err == nil {
a, b := res.LastInsertId()
fmt.Println("LastInsertId", a, b)
err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID) err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
if err != nil {
fmt.Println("insertEvent HAS FAILED!", err)
}
} else {
fmt.Println("insertEvent HAS GONE WRONG!", err)
} }
fmt.Println("Event NID:", eventNID) fmt.Println("INSERT event ID", eventID, "state snapshot NID:", stateNID, "event NID:", eventNID)
fmt.Println("State snapshot NID:", stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
} }
@ -192,7 +189,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
iEventIDs[k] = v iEventIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig) selectPrep, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -245,7 +242,7 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
iEventIDs[k] = v iEventIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig) selectPrep, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -287,10 +284,14 @@ func (s *eventStatements) updateEventState(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error { ) error {
updateStmt := common.TxStmt(txn, s.updateEventStateStmt) updateStmt := common.TxStmt(txn, s.updateEventStateStmt)
_, err := updateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) fmt.Println("=====================================")
fmt.Println(updateEventStateSQL, stateNID, eventNID)
res, err := updateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
if err != nil { if err != nil {
fmt.Println("updateEventState s.updateEventStateStmt.ExecContext:", err) fmt.Println("updateEventState s.updateEventStateStmt.ExecContext:", err)
} }
a, b := res.RowsAffected()
fmt.Println("Rows affected:", a, b)
return err return err
} }
@ -336,14 +337,9 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig) //////////////
if err != nil {
return nil, err
}
///////////////
selectStmt := common.TxStmt(txn, selectPrep) rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
if err != nil { if err != nil {
fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err) fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err)
return nil, err return nil, err
@ -389,7 +385,7 @@ func (s *eventStatements) bulkSelectEventReference(
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig) selectPrep, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -425,7 +421,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig) selectPrep, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -464,7 +460,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
iEventIDs[k] = v iEventIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig) selectPrep, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -22,7 +22,6 @@ import (
"strings" "strings"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -36,13 +35,7 @@ const stateSnapshotSchema = `
const insertStateSQL = ` const insertStateSQL = `
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
VALUES ($1, $2); VALUES ($1, $2);`
`
const insertStateResultSQL = `
SELECT state_snapshot_nid FROM roomserver_state_snapshots
WHERE rowid = last_insert_rowid();
`
// Bulk state data NID lookup. // Bulk state data NID lookup.
// Sorting by state_snapshot_nid means we can use binary search over the result // Sorting by state_snapshot_nid means we can use binary search over the result
@ -67,7 +60,6 @@ func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
return statementList{ return statementList{
{&s.insertStateStmt, insertStateSQL}, {&s.insertStateStmt, insertStateSQL},
{&s.insertStateResultStmt, insertStateResultSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
}.prepare(db) }.prepare(db)
} }
@ -79,15 +71,17 @@ func (s *stateSnapshotStatements) insertState(
for i := range stateBlockNIDs { for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i]) nids[i] = int64(stateBlockNIDs[i])
} }
insertStmt := common.TxStmt(txn, s.insertStateStmt) insertStmt := txn.Stmt(s.insertStateStmt)
resultStmt := common.TxStmt(txn, s.insertStateResultStmt) //resultStmt := txn.Stmt(s.insertStateResultStmt)
if _, err = insertStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err == nil { fmt.Println(insertStateSQL, roomNID, nids)
err = resultStmt.QueryRowContext(ctx).Scan(&stateNID) if res, err2 := insertStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err2 == nil {
if err != nil { lastRowID, err3 := res.LastInsertId()
fmt.Println("insertState s.insertStateResultStmt.QueryRowContext:", err) if err3 != nil {
err = err3
} }
stateNID = types.StateSnapshotNID(lastRowID)
} else { } else {
fmt.Println("insertState s.insertStateStmt.ExecContext:", err) fmt.Println("insertState s.insertStateStmt.ExecContext:", err2)
} }
return return
} }
@ -101,7 +95,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
nids[k] = v nids[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1) selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1)
selectPrep, err := s.db.Prepare(selectOrig) selectStmt, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,7 +106,6 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
nids[i] = int64(stateNIDs[i]) nids[i] = int64(stateNIDs[i])
} }
*/ */
selectStmt := common.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, nids...) rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil { if err != nil {
fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err) fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err)

View file

@ -21,6 +21,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"time"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -54,8 +55,8 @@ func Open(dataSourceName string) (*Database, error) {
return nil, err return nil, err
} }
//d.db.Exec("PRAGMA journal_mode=WAL;") //d.db.Exec("PRAGMA journal_mode=WAL;")
//d.db.Exec("PRAGMA parser_trace = true;") //d.db.Exec("PRAGMA read_uncommitted = true;")
d.db.SetMaxOpenConns(1) d.db.SetMaxOpenConns(2)
if err = d.statements.prepare(d.db); err != nil { if err = d.statements.prepare(d.db); err != nil {
return nil, err return nil, err
} }
@ -196,36 +197,56 @@ func (d *Database) assignStateKeyNID(
// StateEntriesForEventIDs implements input.EventDatabase // StateEntriesForEventIDs implements input.EventDatabase
func (d *Database) StateEntriesForEventIDs( func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) { ) (se []types.StateEntry, err error) {
return d.statements.bulkSelectStateEventByID(ctx, nil, eventIDs) err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs)
return err
})
return
} }
// EventTypeNIDs implements state.RoomStateDatabase // EventTypeNIDs implements state.RoomStateDatabase
func (d *Database) EventTypeNIDs( func (d *Database) EventTypeNIDs(
ctx context.Context, eventTypes []string, ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (etnids map[string]types.EventTypeNID, err error) {
return d.statements.bulkSelectEventTypeNID(ctx, nil, eventTypes) err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes)
return err
})
return
} }
// EventStateKeyNIDs implements state.RoomStateDatabase // EventStateKeyNIDs implements state.RoomStateDatabase
func (d *Database) EventStateKeyNIDs( func (d *Database) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string, ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (esknids map[string]types.EventStateKeyNID, err error) {
return d.statements.bulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
return err
})
return
} }
// EventStateKeys implements query.RoomserverQueryAPIDatabase // EventStateKeys implements query.RoomserverQueryAPIDatabase
func (d *Database) EventStateKeys( func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) { ) (out map[types.EventStateKeyNID]string, err error) {
return d.statements.bulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs) err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs)
return err
})
return
} }
// EventNIDs implements query.RoomserverQueryAPIDatabase // EventNIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) EventNIDs( func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) { ) (out map[string]types.EventNID, err error) {
return d.statements.bulkSelectEventNID(ctx, nil, eventIDs) err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs)
return err
})
return
} }
// Events implements input.EventDatabase // Events implements input.EventDatabase
@ -235,12 +256,15 @@ func (d *Database) Events(
var eventJSONs []eventJSONPair var eventJSONs []eventJSONPair
var err error var err error
results := make([]types.Event, len(eventNIDs)) results := make([]types.Event, len(eventNIDs))
fmt.Println("pre txn")
common.WithTransaction(d.db, func(txn *sql.Tx) error { common.WithTransaction(d.db, func(txn *sql.Tx) error {
fmt.Println("in txn", txn)
eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs)
if err != nil || len(eventJSONs) == 0 { if err != nil || len(eventJSONs) == 0 {
fmt.Println("d.statements.bulkSelectEventJSON:", err) fmt.Println("d.statements.bulkSelectEventJSON:", err)
return nil return nil
} }
fmt.Println("selected txn")
for i, eventJSON := range eventJSONs { for i, eventJSON := range eventJSONs {
result := &results[i] result := &results[i]
result.EventNID = eventJSON.EventNID result.EventNID = eventJSON.EventNID
@ -252,6 +276,7 @@ func (d *Database) Events(
} }
return nil return nil
}) })
fmt.Println("post txn")
if err != nil { if err != nil {
return []types.Event{}, err return []types.Event{}, err
} }
@ -265,7 +290,9 @@ func (d *Database) AddState(
stateBlockNIDs []types.StateBlockNID, stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry, state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {
common.WithTransaction(d.db, func(txn *sql.Tx) error { fmt.Println("AddState INSERT STATE START", stateBlockNIDs)
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
fmt.Println("insert state txn created")
if len(state) > 0 { if len(state) > 0 {
stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx, txn) stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx, txn)
if err != nil { if err != nil {
@ -277,8 +304,10 @@ func (d *Database) AddState(
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
} }
stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs) stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs)
return nil fmt.Println("AddState: completing txn", time.Now(), "err=", err)
return err
}) })
fmt.Println("AddState INSERT STATE END pkey=", stateNID, time.Now(), "err=", err)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -289,49 +318,77 @@ func (d *Database) AddState(
func (d *Database) SetState( func (d *Database) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error { ) error {
return d.statements.updateEventState(ctx, nil, eventNID, stateNID) fmt.Println("SetState event NID:", eventNID, "state NID:", stateNID)
e := common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.statements.updateEventState(ctx, txn, eventNID, stateNID)
})
fmt.Println("SetState finish", e)
return e
} }
// StateAtEventIDs implements input.EventDatabase // StateAtEventIDs implements input.EventDatabase
func (d *Database) StateAtEventIDs( func (d *Database) StateAtEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) { ) (se []types.StateAtEvent, err error) {
return d.statements.bulkSelectStateAtEventByID(ctx, nil, eventIDs) common.WithTransaction(d.db, func(txn *sql.Tx) error {
se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs)
return err
})
return
} }
// StateBlockNIDs implements state.RoomStateDatabase // StateBlockNIDs implements state.RoomStateDatabase
func (d *Database) StateBlockNIDs( func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) (sl []types.StateBlockNIDList, err error) {
return d.statements.bulkSelectStateBlockNIDs(ctx, nil, stateNIDs) fmt.Println("StateBlockNIDs SELECT STATE START", stateNIDs)
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
fmt.Println(" in txn")
sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
return err
})
fmt.Println("StateBlockNIDs SELECT STATE END", sl)
return
} }
// StateEntries implements state.RoomStateDatabase // StateEntries implements state.RoomStateDatabase
func (d *Database) StateEntries( func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID, ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) { ) (sel []types.StateEntryList, err error) {
return d.statements.bulkSelectStateBlockEntries(ctx, nil, stateBlockNIDs) err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
return err
})
return
} }
// SnapshotNIDFromEventID implements state.RoomStateDatabase // SnapshotNIDFromEventID implements state.RoomStateDatabase
func (d *Database) SnapshotNIDFromEventID( func (d *Database) SnapshotNIDFromEventID(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {
_, stateNID, err = d.statements.selectEvent(ctx, nil, eventID) err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
_, stateNID, err = d.statements.selectEvent(ctx, txn, eventID)
return err
})
return return
} }
// EventIDs implements input.RoomEventDatabase // EventIDs implements input.RoomEventDatabase
func (d *Database) EventIDs( func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) { ) (out map[types.EventNID]string, err error) {
return d.statements.bulkSelectEventID(ctx, nil, eventNIDs) err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs)
return err
})
return
} }
// GetLatestEventsForUpdate implements input.EventDatabase // GetLatestEventsForUpdate implements input.EventDatabase
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID,
) (types.RoomRecentEventsUpdater, error) { ) (types.RoomRecentEventsUpdater, error) {
fmt.Println("=============== GetLatestEventsForUpdate BEGIN TXN")
txn, err := d.db.Begin() txn, err := d.db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
@ -355,8 +412,15 @@ func (d *Database) GetLatestEventsForUpdate(
return nil, err return nil, err
} }
} }
fmt.Println("GetLatestEventsForUpdate returning updater")
// FIXME: we probably want to support long-lived txns in sqlite somehow, but we don't because we get
// 'database is locked' errors caused by multiple write txns (one being the long-lived txn created here)
// so for now let's not use a long-lived txn at all, and just commit it here and set the txn to nil so
// we fail fast if someone tries to use the underlying txn object.
txn.Commit()
return &roomRecentEventsUpdater{ return &roomRecentEventsUpdater{
transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, transaction{ctx, nil}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil }, nil
} }
@ -398,24 +462,33 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN
// StorePreviousEvents implements types.RoomRecentEventsUpdater // StorePreviousEvents implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences { err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { for _, ref := range previousEventReferences {
return err if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return err
}
} }
} return nil
return nil })
return err
} }
// IsReferenced implements types.RoomRecentEventsUpdater // IsReferenced implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) {
err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) fmt.Println("[[TXN]] IsReferenced")
if err == nil { err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
return true, nil err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256)
} if err == nil {
if err == sql.ErrNoRows { res = true
return false, nil err = nil
} }
return false, err if err == sql.ErrNoRows {
res = false
err = nil
}
return err
})
return
} }
// SetLatestEvents implements types.RoomRecentEventsUpdater // SetLatestEvents implements types.RoomRecentEventsUpdater
@ -423,38 +496,58 @@ func (u *roomRecentEventsUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID, currentStateSnapshotNID types.StateSnapshotNID,
) error { ) error {
eventNIDs := make([]types.EventNID, len(latest)) err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
for i := range latest { eventNIDs := make([]types.EventNID, len(latest))
eventNIDs[i] = latest[i].EventNID for i := range latest {
} eventNIDs[i] = latest[i].EventNID
// TODO: transaction was removed here - is this wise? }
return u.d.statements.updateLatestEventNIDs(u.ctx, nil, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
})
return err
} }
// HasEventBeenSent implements types.RoomRecentEventsUpdater // HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) {
// TODO: transaction was removed here - is this wise? // TODO: transaction was removed here - is this wise?
return u.d.statements.selectEventSentToOutput(u.ctx, nil, eventNID) fmt.Println("[[TXN]] HasEventBeenSent")
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID)
return err
})
return
} }
// MarkEventAsSent implements types.RoomRecentEventsUpdater // MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
// TODO: transaction was removed here - is this wise? // TODO: transaction was removed here - is this wise?
return u.d.statements.updateEventSentToOutput(u.ctx, nil, eventNID) fmt.Println("[[TXN]] updateEventSentToOutput")
err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID)
})
return err
} }
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) {
// TODO: transaction was removed here - is this wise? // TODO: transaction was removed here - is this wise?
return u.d.membershipUpdaterTxn(u.ctx, nil, u.roomNID, targetUserNID) fmt.Println("[[TXN]] membershipUpdaterTxn")
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID)
return err
})
return
} }
// RoomNID implements query.RoomserverQueryAPIDB // RoomNID implements query.RoomserverQueryAPIDB
func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) {
roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID) err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err == sql.ErrNoRows { roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
return 0, nil if err == sql.ErrNoRows {
} roomNID = 0
return roomNID, err err = nil
}
return err
})
return
} }
// LatestEventIDs implements query.RoomserverQueryAPIDatabase // LatestEventIDs implements query.RoomserverQueryAPIDatabase
@ -532,12 +625,14 @@ func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string, ctx context.Context, roomID, targetUserID string,
) (types.MembershipUpdater, error) { ) (types.MembershipUpdater, error) {
txn, err := d.db.Begin() txn, err := d.db.Begin()
fmt.Println("=== UPDATER TXN START ====")
if err != nil { if err != nil {
return nil, err return nil, err
} }
succeeded := false succeeded := false
defer func() { defer func() {
if !succeeded { if !succeeded {
fmt.Println("=== UPDATER TXN ROLLBACK ====")
txn.Rollback() // nolint: errcheck txn.Rollback() // nolint: errcheck
} }
}() }()
@ -606,92 +701,99 @@ func (u *membershipUpdater) IsLeave() bool {
} }
// SetToInvite implements types.MembershipUpdater // SetToInvite implements types.MembershipUpdater
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted bool, err error) {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
if err != nil { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, event.Sender())
return false, err if err != nil {
} return err
inserted, err := u.d.statements.insertInviteEvent(
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
)
if err != nil {
return false, err
}
if u.membership != membershipStateInvite {
if err = u.d.statements.updateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
); err != nil {
return false, err
} }
} inserted, err = u.d.statements.insertInviteEvent(
return inserted, nil u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
)
if err != nil {
return err
}
if u.membership != membershipStateInvite {
if err = u.d.statements.updateMembership(
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
); err != nil {
return err
}
}
return nil
})
return
} }
// SetToJoin implements types.MembershipUpdater // SetToJoin implements types.MembershipUpdater
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) {
var inviteEventIDs []string err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID)
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil {
return nil, err
}
// If this is a join event update, there is no invite to update
if !isUpdate {
inviteEventIDs, err = u.d.statements.updateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil { if err != nil {
return nil, err return err
} }
}
// Look up the NID of the new join event // If this is a join event update, there is no invite to update
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) if !isUpdate {
if err != nil { inviteEventIDs, err = u.d.statements.updateInviteRetired(
return nil, err u.ctx, txn, u.roomNID, u.targetUserNID,
} )
if err != nil {
if u.membership != membershipStateJoin || isUpdate { return err
if err = u.d.statements.updateMembership( }
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateJoin, nIDs[eventID],
); err != nil {
return nil, err
} }
}
return inviteEventIDs, nil // Look up the NID of the new join event
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
return err
}
if u.membership != membershipStateJoin || isUpdate {
if err = u.d.statements.updateMembership(
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateJoin, nIDs[eventID],
); err != nil {
return err
}
}
return nil
})
return
} }
// SetToLeave implements types.MembershipUpdater // SetToLeave implements types.MembershipUpdater
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
if err != nil { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID)
return nil, err if err != nil {
} return err
inviteEventIDs, err := u.d.statements.updateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return nil, err
}
// Look up the NID of the new leave event
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
return nil, err
}
if u.membership != membershipStateLeaveOrBan {
if err = u.d.statements.updateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateLeaveOrBan, nIDs[eventID],
); err != nil {
return nil, err
} }
} inviteEventIDs, err = u.d.statements.updateInviteRetired(
return inviteEventIDs, nil u.ctx, txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return err
}
// Look up the NID of the new leave event
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
return err
}
if u.membership != membershipStateLeaveOrBan {
if err = u.d.statements.updateMembership(
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateLeaveOrBan, nIDs[eventID],
); err != nil {
return err
}
}
return nil
})
return
} }
// GetMembership implements query.RoomserverQueryAPIDB // GetMembership implements query.RoomserverQueryAPIDB
@ -762,10 +864,16 @@ type transaction struct {
// Commit implements types.Transaction // Commit implements types.Transaction
func (t *transaction) Commit() error { func (t *transaction) Commit() error {
if t.txn == nil {
return nil
}
return t.txn.Commit() return t.txn.Commit()
} }
// Rollback implements types.Transaction // Rollback implements types.Transaction
func (t *transaction) Rollback() error { func (t *transaction) Rollback() error {
if t.txn == nil {
return nil
}
return t.txn.Rollback() return t.txn.Rollback()
} }