mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-16 11:23:11 -06:00
sqlite work
This commit is contained in:
parent
0360f07110
commit
fb9243b4ed
|
|
@ -30,11 +30,13 @@ type Transaction interface {
|
||||||
|
|
||||||
// EndTransaction ends a transaction.
|
// EndTransaction ends a transaction.
|
||||||
// If the transaction succeeded then it is committed, otherwise it is rolledback.
|
// If the transaction succeeded then it is committed, otherwise it is rolledback.
|
||||||
func EndTransaction(txn Transaction, succeeded *bool) {
|
// You MUST check the error returned from this function to be sure that the transaction
|
||||||
|
// was applied correctly. For example, 'database is locked' errors in sqlite will happen here.
|
||||||
|
func EndTransaction(txn Transaction, succeeded *bool) error {
|
||||||
if *succeeded {
|
if *succeeded {
|
||||||
txn.Commit() // nolint: errcheck
|
return txn.Commit() // nolint: errcheck
|
||||||
} else {
|
} else {
|
||||||
txn.Rollback() // nolint: errcheck
|
return txn.Rollback() // nolint: errcheck
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,7 +49,12 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
succeeded := false
|
succeeded := false
|
||||||
defer EndTransaction(txn, &succeeded)
|
defer func() {
|
||||||
|
err2 := EndTransaction(txn, &succeeded)
|
||||||
|
if err == nil && err2 != nil { // failed to commit/rollback
|
||||||
|
err = err2
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
err = fn(txn)
|
err = fn(txn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,21 @@
|
||||||
version: "3.4"
|
version: "3.4"
|
||||||
services:
|
services:
|
||||||
|
riot:
|
||||||
|
image: vectorim/riot-web
|
||||||
|
networks:
|
||||||
|
- internal
|
||||||
|
ports:
|
||||||
|
- "8500:80"
|
||||||
|
|
||||||
monolith:
|
monolith:
|
||||||
container_name: dendrite_monolith
|
container_name: dendrite_monolith
|
||||||
hostname: monolith
|
hostname: monolith
|
||||||
entrypoint: ["bash", "./docker/services/monolith.sh"]
|
entrypoint: ["bash", "./docker/services/monolith.sh", "--config", "/etc/dendrite/dendrite.yaml"]
|
||||||
build: ./
|
build: ./
|
||||||
volumes:
|
volumes:
|
||||||
- ..:/build
|
- ..:/build
|
||||||
- ./build/bin:/build/bin
|
- ./build/bin:/build/bin
|
||||||
|
- ../cfg:/etc/dendrite
|
||||||
networks:
|
networks:
|
||||||
- internal
|
- internal
|
||||||
depends_on:
|
depends_on:
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
|
@ -79,21 +78,14 @@ type eventJSONPair struct {
|
||||||
func (s *eventJSONStatements) bulkSelectEventJSON(
|
func (s *eventJSONStatements) bulkSelectEventJSON(
|
||||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
) ([]eventJSONPair, error) {
|
) ([]eventJSONPair, error) {
|
||||||
///////////////
|
|
||||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||||
for k, v := range eventNIDs {
|
for k, v := range eventNIDs {
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
///////////////
|
|
||||||
|
|
||||||
rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
|
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectEventJSON s.bulkSelectEventJSONStmt.QueryContext:", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close() // nolint: errcheck
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
@ -108,7 +100,6 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
|
||||||
result := &results[i]
|
result := &results[i]
|
||||||
var eventNID int64
|
var eventNID int64
|
||||||
if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
|
if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
|
||||||
fmt.Println("bulkSelectEventJSON rows.Scan:", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
result.EventNID = types.EventNID(eventNID)
|
result.EventNID = types.EventNID(eventNID)
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -41,11 +40,6 @@ const insertEventStateKeyNIDSQL = `
|
||||||
ON CONFLICT DO NOTHING;
|
ON CONFLICT DO NOTHING;
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertEventStateKeyNIDResultSQL = `
|
|
||||||
SELECT event_state_key_nid FROM roomserver_event_state_keys
|
|
||||||
WHERE rowid = last_insert_rowid();
|
|
||||||
`
|
|
||||||
|
|
||||||
const selectEventStateKeyNIDSQL = `
|
const selectEventStateKeyNIDSQL = `
|
||||||
SELECT event_state_key_nid FROM roomserver_event_state_keys
|
SELECT event_state_key_nid FROM roomserver_event_state_keys
|
||||||
WHERE event_state_key = $1
|
WHERE event_state_key = $1
|
||||||
|
|
@ -66,12 +60,11 @@ const bulkSelectEventStateKeySQL = `
|
||||||
`
|
`
|
||||||
|
|
||||||
type eventStateKeyStatements struct {
|
type eventStateKeyStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertEventStateKeyNIDStmt *sql.Stmt
|
insertEventStateKeyNIDStmt *sql.Stmt
|
||||||
insertEventStateKeyNIDResultStmt *sql.Stmt
|
selectEventStateKeyNIDStmt *sql.Stmt
|
||||||
selectEventStateKeyNIDStmt *sql.Stmt
|
bulkSelectEventStateKeyNIDStmt *sql.Stmt
|
||||||
bulkSelectEventStateKeyNIDStmt *sql.Stmt
|
bulkSelectEventStateKeyStmt *sql.Stmt
|
||||||
bulkSelectEventStateKeyStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
|
func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
@ -82,7 +75,6 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
|
||||||
}
|
}
|
||||||
return statementList{
|
return statementList{
|
||||||
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
|
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
|
||||||
{&s.insertEventStateKeyNIDResultStmt, insertEventStateKeyNIDResultSQL},
|
|
||||||
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
|
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
|
||||||
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
|
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
|
||||||
{&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
|
{&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
|
||||||
|
|
@ -94,13 +86,10 @@ func (s *eventStateKeyStatements) insertEventStateKeyNID(
|
||||||
) (types.EventStateKeyNID, error) {
|
) (types.EventStateKeyNID, error) {
|
||||||
var eventStateKeyNID int64
|
var eventStateKeyNID int64
|
||||||
var err error
|
var err error
|
||||||
insertStmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
var res sql.Result
|
||||||
selectStmt := common.TxStmt(txn, s.insertEventStateKeyNIDResultStmt)
|
insertStmt := txn.Stmt(s.insertEventStateKeyNIDStmt)
|
||||||
if _, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil {
|
if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil {
|
||||||
err = selectStmt.QueryRowContext(ctx).Scan(&eventStateKeyNID)
|
eventStateKeyNID, err = res.LastInsertId()
|
||||||
if err != nil {
|
|
||||||
fmt.Println("insertEventStateKeyNID selectStmt.QueryRowContext:", err)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("insertEventStateKeyNID insertStmt.ExecContext:", err)
|
fmt.Println("insertEventStateKeyNID insertStmt.ExecContext:", err)
|
||||||
}
|
}
|
||||||
|
|
@ -111,36 +100,22 @@ func (s *eventStateKeyStatements) selectEventStateKeyNID(
|
||||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||||
) (types.EventStateKeyNID, error) {
|
) (types.EventStateKeyNID, error) {
|
||||||
var eventStateKeyNID int64
|
var eventStateKeyNID int64
|
||||||
stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt)
|
stmt := txn.Stmt(s.selectEventStateKeyNIDStmt)
|
||||||
fmt.Println("selectEventStateKeyNID for", eventStateKey)
|
|
||||||
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
|
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
|
||||||
if err != nil {
|
|
||||||
fmt.Println("selectEventStateKeyNID stmt.QueryRowContext:", err)
|
|
||||||
}
|
|
||||||
fmt.Println("selectEventStateKeyNID returns", eventStateKeyNID)
|
|
||||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
||||||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||||
) (map[string]types.EventStateKeyNID, error) {
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
///////////////
|
|
||||||
iEventStateKeys := make([]interface{}, len(eventStateKeys))
|
iEventStateKeys := make([]interface{}, len(eventStateKeys))
|
||||||
for k, v := range eventStateKeys {
|
for k, v := range eventStateKeys {
|
||||||
iEventStateKeys[k] = v
|
iEventStateKeys[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1)
|
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
///////////////
|
|
||||||
|
|
||||||
rows, err := common.TxStmt(txn, selectPrep).QueryContext(
|
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||||
ctx, iEventStateKeys...,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectEventStateKeyNID s.bulkSelectEventStateKeyNIDStmt.QueryContext:", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close() // nolint: errcheck
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
@ -149,7 +124,6 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
||||||
var stateKey string
|
var stateKey string
|
||||||
var stateKeyNID int64
|
var stateKeyNID int64
|
||||||
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
||||||
fmt.Println("bulkSelectEventStateKeyNID rows.Scan:", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
|
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
|
||||||
|
|
@ -160,25 +134,14 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
||||||
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
||||||
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||||
) (map[types.EventStateKeyNID]string, error) {
|
) (map[types.EventStateKeyNID]string, error) {
|
||||||
///////////////
|
|
||||||
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
|
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
|
||||||
for k, v := range eventStateKeyNIDs {
|
for k, v := range eventStateKeyNIDs {
|
||||||
iEventStateKeyNIDs[k] = v
|
iEventStateKeyNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
///////////////
|
|
||||||
|
|
||||||
/*nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
|
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
|
||||||
for i := range eventStateKeyNIDs {
|
|
||||||
nIDs[i] = int64(eventStateKeyNIDs[i])
|
|
||||||
}*/
|
|
||||||
rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventStateKeyNIDs...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectEventStateKey s.bulkSelectEventStateKeyStmt.QueryContext:", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close() // nolint: errcheck
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
@ -187,7 +150,6 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
||||||
var stateKey string
|
var stateKey string
|
||||||
var stateKeyNID int64
|
var stateKeyNID int64
|
||||||
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
||||||
fmt.Println("bulkSelectEventStateKey rows.Scan:", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
|
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ const bulkSelectStateAtEventByIDSQL = "" +
|
||||||
" WHERE event_id IN ($1)"
|
" WHERE event_id IN ($1)"
|
||||||
|
|
||||||
const updateEventStateSQL = "" +
|
const updateEventStateSQL = "" +
|
||||||
"UPDATE roomserver_events SET state_snapshot_nid = $2 WHERE event_nid = $1"
|
"UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2"
|
||||||
|
|
||||||
const selectEventSentToOutputSQL = "" +
|
const selectEventSentToOutputSQL = "" +
|
||||||
"SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1"
|
"SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1"
|
||||||
|
|
@ -153,21 +153,18 @@ func (s *eventStatements) insertEvent(
|
||||||
var eventNID int64
|
var eventNID int64
|
||||||
var stateNID int64
|
var stateNID int64
|
||||||
var err error
|
var err error
|
||||||
|
var res sql.Result
|
||||||
insertStmt := common.TxStmt(txn, s.insertEventStmt)
|
insertStmt := common.TxStmt(txn, s.insertEventStmt)
|
||||||
resultStmt := common.TxStmt(txn, s.insertEventResultStmt)
|
resultStmt := common.TxStmt(txn, s.insertEventResultStmt)
|
||||||
if _, err = insertStmt.ExecContext(
|
if res, err = insertStmt.ExecContext(
|
||||||
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||||
); err == nil {
|
); err == nil {
|
||||||
|
a, b := res.LastInsertId()
|
||||||
|
fmt.Println("LastInsertId", a, b)
|
||||||
err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
|
err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
|
||||||
if err != nil {
|
|
||||||
fmt.Println("insertEvent HAS FAILED!", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fmt.Println("insertEvent HAS GONE WRONG!", err)
|
|
||||||
}
|
}
|
||||||
fmt.Println("Event NID:", eventNID)
|
fmt.Println("INSERT event ID", eventID, "state snapshot NID:", stateNID, "event NID:", eventNID)
|
||||||
fmt.Println("State snapshot NID:", stateNID)
|
|
||||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -192,7 +189,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
|
||||||
iEventIDs[k] = v
|
iEventIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
selectPrep, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -245,7 +242,7 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
|
||||||
iEventIDs[k] = v
|
iEventIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
selectPrep, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -287,10 +284,14 @@ func (s *eventStatements) updateEventState(
|
||||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
updateStmt := common.TxStmt(txn, s.updateEventStateStmt)
|
updateStmt := common.TxStmt(txn, s.updateEventStateStmt)
|
||||||
_, err := updateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID))
|
fmt.Println("=====================================")
|
||||||
|
fmt.Println(updateEventStateSQL, stateNID, eventNID)
|
||||||
|
res, err := updateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("updateEventState s.updateEventStateStmt.ExecContext:", err)
|
fmt.Println("updateEventState s.updateEventStateStmt.ExecContext:", err)
|
||||||
}
|
}
|
||||||
|
a, b := res.RowsAffected()
|
||||||
|
fmt.Println("Rows affected:", a, b)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -336,14 +337,9 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
//////////////
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
///////////////
|
|
||||||
|
|
||||||
selectStmt := common.TxStmt(txn, selectPrep)
|
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||||
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err)
|
fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -389,7 +385,7 @@ func (s *eventStatements) bulkSelectEventReference(
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
selectPrep, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -425,7 +421,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
selectPrep, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -464,7 +460,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
|
||||||
iEventIDs[k] = v
|
iEventIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
selectPrep, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/common"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -36,13 +35,7 @@ const stateSnapshotSchema = `
|
||||||
|
|
||||||
const insertStateSQL = `
|
const insertStateSQL = `
|
||||||
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
|
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
|
||||||
VALUES ($1, $2);
|
VALUES ($1, $2);`
|
||||||
`
|
|
||||||
|
|
||||||
const insertStateResultSQL = `
|
|
||||||
SELECT state_snapshot_nid FROM roomserver_state_snapshots
|
|
||||||
WHERE rowid = last_insert_rowid();
|
|
||||||
`
|
|
||||||
|
|
||||||
// Bulk state data NID lookup.
|
// Bulk state data NID lookup.
|
||||||
// Sorting by state_snapshot_nid means we can use binary search over the result
|
// Sorting by state_snapshot_nid means we can use binary search over the result
|
||||||
|
|
@ -67,7 +60,6 @@ func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
||||||
return statementList{
|
return statementList{
|
||||||
{&s.insertStateStmt, insertStateSQL},
|
{&s.insertStateStmt, insertStateSQL},
|
||||||
{&s.insertStateResultStmt, insertStateResultSQL},
|
|
||||||
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
|
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
@ -79,15 +71,17 @@ func (s *stateSnapshotStatements) insertState(
|
||||||
for i := range stateBlockNIDs {
|
for i := range stateBlockNIDs {
|
||||||
nids[i] = int64(stateBlockNIDs[i])
|
nids[i] = int64(stateBlockNIDs[i])
|
||||||
}
|
}
|
||||||
insertStmt := common.TxStmt(txn, s.insertStateStmt)
|
insertStmt := txn.Stmt(s.insertStateStmt)
|
||||||
resultStmt := common.TxStmt(txn, s.insertStateResultStmt)
|
//resultStmt := txn.Stmt(s.insertStateResultStmt)
|
||||||
if _, err = insertStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err == nil {
|
fmt.Println(insertStateSQL, roomNID, nids)
|
||||||
err = resultStmt.QueryRowContext(ctx).Scan(&stateNID)
|
if res, err2 := insertStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err2 == nil {
|
||||||
if err != nil {
|
lastRowID, err3 := res.LastInsertId()
|
||||||
fmt.Println("insertState s.insertStateResultStmt.QueryRowContext:", err)
|
if err3 != nil {
|
||||||
|
err = err3
|
||||||
}
|
}
|
||||||
|
stateNID = types.StateSnapshotNID(lastRowID)
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("insertState s.insertStateStmt.ExecContext:", err)
|
fmt.Println("insertState s.insertStateStmt.ExecContext:", err2)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -101,7 +95,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
|
||||||
nids[k] = v
|
nids[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1)
|
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
selectStmt, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -112,7 +106,6 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
|
||||||
nids[i] = int64(stateNIDs[i])
|
nids[i] = int64(stateNIDs[i])
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
selectStmt := common.TxStmt(txn, selectPrep)
|
|
||||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err)
|
fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err)
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
|
@ -54,8 +55,8 @@ func Open(dataSourceName string) (*Database, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
//d.db.Exec("PRAGMA journal_mode=WAL;")
|
//d.db.Exec("PRAGMA journal_mode=WAL;")
|
||||||
//d.db.Exec("PRAGMA parser_trace = true;")
|
//d.db.Exec("PRAGMA read_uncommitted = true;")
|
||||||
d.db.SetMaxOpenConns(1)
|
d.db.SetMaxOpenConns(2)
|
||||||
if err = d.statements.prepare(d.db); err != nil {
|
if err = d.statements.prepare(d.db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -196,36 +197,56 @@ func (d *Database) assignStateKeyNID(
|
||||||
// StateEntriesForEventIDs implements input.EventDatabase
|
// StateEntriesForEventIDs implements input.EventDatabase
|
||||||
func (d *Database) StateEntriesForEventIDs(
|
func (d *Database) StateEntriesForEventIDs(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, eventIDs []string,
|
||||||
) ([]types.StateEntry, error) {
|
) (se []types.StateEntry, err error) {
|
||||||
return d.statements.bulkSelectStateEventByID(ctx, nil, eventIDs)
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventTypeNIDs implements state.RoomStateDatabase
|
// EventTypeNIDs implements state.RoomStateDatabase
|
||||||
func (d *Database) EventTypeNIDs(
|
func (d *Database) EventTypeNIDs(
|
||||||
ctx context.Context, eventTypes []string,
|
ctx context.Context, eventTypes []string,
|
||||||
) (map[string]types.EventTypeNID, error) {
|
) (etnids map[string]types.EventTypeNID, err error) {
|
||||||
return d.statements.bulkSelectEventTypeNID(ctx, nil, eventTypes)
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventStateKeyNIDs implements state.RoomStateDatabase
|
// EventStateKeyNIDs implements state.RoomStateDatabase
|
||||||
func (d *Database) EventStateKeyNIDs(
|
func (d *Database) EventStateKeyNIDs(
|
||||||
ctx context.Context, eventStateKeys []string,
|
ctx context.Context, eventStateKeys []string,
|
||||||
) (map[string]types.EventStateKeyNID, error) {
|
) (esknids map[string]types.EventStateKeyNID, err error) {
|
||||||
return d.statements.bulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventStateKeys implements query.RoomserverQueryAPIDatabase
|
// EventStateKeys implements query.RoomserverQueryAPIDatabase
|
||||||
func (d *Database) EventStateKeys(
|
func (d *Database) EventStateKeys(
|
||||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||||
) (map[types.EventStateKeyNID]string, error) {
|
) (out map[types.EventStateKeyNID]string, err error) {
|
||||||
return d.statements.bulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventNIDs implements query.RoomserverQueryAPIDatabase
|
// EventNIDs implements query.RoomserverQueryAPIDatabase
|
||||||
func (d *Database) EventNIDs(
|
func (d *Database) EventNIDs(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, eventIDs []string,
|
||||||
) (map[string]types.EventNID, error) {
|
) (out map[string]types.EventNID, err error) {
|
||||||
return d.statements.bulkSelectEventNID(ctx, nil, eventIDs)
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Events implements input.EventDatabase
|
// Events implements input.EventDatabase
|
||||||
|
|
@ -235,12 +256,15 @@ func (d *Database) Events(
|
||||||
var eventJSONs []eventJSONPair
|
var eventJSONs []eventJSONPair
|
||||||
var err error
|
var err error
|
||||||
results := make([]types.Event, len(eventNIDs))
|
results := make([]types.Event, len(eventNIDs))
|
||||||
|
fmt.Println("pre txn")
|
||||||
common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
fmt.Println("in txn", txn)
|
||||||
eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs)
|
eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs)
|
||||||
if err != nil || len(eventJSONs) == 0 {
|
if err != nil || len(eventJSONs) == 0 {
|
||||||
fmt.Println("d.statements.bulkSelectEventJSON:", err)
|
fmt.Println("d.statements.bulkSelectEventJSON:", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
fmt.Println("selected txn")
|
||||||
for i, eventJSON := range eventJSONs {
|
for i, eventJSON := range eventJSONs {
|
||||||
result := &results[i]
|
result := &results[i]
|
||||||
result.EventNID = eventJSON.EventNID
|
result.EventNID = eventJSON.EventNID
|
||||||
|
|
@ -252,6 +276,7 @@ func (d *Database) Events(
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
fmt.Println("post txn")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []types.Event{}, err
|
return []types.Event{}, err
|
||||||
}
|
}
|
||||||
|
|
@ -265,7 +290,9 @@ func (d *Database) AddState(
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
state []types.StateEntry,
|
state []types.StateEntry,
|
||||||
) (stateNID types.StateSnapshotNID, err error) {
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
fmt.Println("AddState INSERT STATE START", stateBlockNIDs)
|
||||||
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
fmt.Println("insert state txn created")
|
||||||
if len(state) > 0 {
|
if len(state) > 0 {
|
||||||
stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx, txn)
|
stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx, txn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -277,8 +304,10 @@ func (d *Database) AddState(
|
||||||
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
|
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
|
||||||
}
|
}
|
||||||
stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs)
|
stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs)
|
||||||
return nil
|
fmt.Println("AddState: completing txn", time.Now(), "err=", err)
|
||||||
|
return err
|
||||||
})
|
})
|
||||||
|
fmt.Println("AddState INSERT STATE END pkey=", stateNID, time.Now(), "err=", err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
@ -289,49 +318,77 @@ func (d *Database) AddState(
|
||||||
func (d *Database) SetState(
|
func (d *Database) SetState(
|
||||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
return d.statements.updateEventState(ctx, nil, eventNID, stateNID)
|
fmt.Println("SetState event NID:", eventNID, "state NID:", stateNID)
|
||||||
|
e := common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.statements.updateEventState(ctx, txn, eventNID, stateNID)
|
||||||
|
})
|
||||||
|
fmt.Println("SetState finish", e)
|
||||||
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateAtEventIDs implements input.EventDatabase
|
// StateAtEventIDs implements input.EventDatabase
|
||||||
func (d *Database) StateAtEventIDs(
|
func (d *Database) StateAtEventIDs(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, eventIDs []string,
|
||||||
) ([]types.StateAtEvent, error) {
|
) (se []types.StateAtEvent, err error) {
|
||||||
return d.statements.bulkSelectStateAtEventByID(ctx, nil, eventIDs)
|
common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateBlockNIDs implements state.RoomStateDatabase
|
// StateBlockNIDs implements state.RoomStateDatabase
|
||||||
func (d *Database) StateBlockNIDs(
|
func (d *Database) StateBlockNIDs(
|
||||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||||
) ([]types.StateBlockNIDList, error) {
|
) (sl []types.StateBlockNIDList, err error) {
|
||||||
return d.statements.bulkSelectStateBlockNIDs(ctx, nil, stateNIDs)
|
fmt.Println("StateBlockNIDs SELECT STATE START", stateNIDs)
|
||||||
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
fmt.Println(" in txn")
|
||||||
|
sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
fmt.Println("StateBlockNIDs SELECT STATE END", sl)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateEntries implements state.RoomStateDatabase
|
// StateEntries implements state.RoomStateDatabase
|
||||||
func (d *Database) StateEntries(
|
func (d *Database) StateEntries(
|
||||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||||
) ([]types.StateEntryList, error) {
|
) (sel []types.StateEntryList, err error) {
|
||||||
return d.statements.bulkSelectStateBlockEntries(ctx, nil, stateBlockNIDs)
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SnapshotNIDFromEventID implements state.RoomStateDatabase
|
// SnapshotNIDFromEventID implements state.RoomStateDatabase
|
||||||
func (d *Database) SnapshotNIDFromEventID(
|
func (d *Database) SnapshotNIDFromEventID(
|
||||||
ctx context.Context, eventID string,
|
ctx context.Context, eventID string,
|
||||||
) (stateNID types.StateSnapshotNID, err error) {
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
_, stateNID, err = d.statements.selectEvent(ctx, nil, eventID)
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
_, stateNID, err = d.statements.selectEvent(ctx, txn, eventID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventIDs implements input.RoomEventDatabase
|
// EventIDs implements input.RoomEventDatabase
|
||||||
func (d *Database) EventIDs(
|
func (d *Database) EventIDs(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
) (map[types.EventNID]string, error) {
|
) (out map[types.EventNID]string, err error) {
|
||||||
return d.statements.bulkSelectEventID(ctx, nil, eventNIDs)
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestEventsForUpdate implements input.EventDatabase
|
// GetLatestEventsForUpdate implements input.EventDatabase
|
||||||
func (d *Database) GetLatestEventsForUpdate(
|
func (d *Database) GetLatestEventsForUpdate(
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
) (types.RoomRecentEventsUpdater, error) {
|
) (types.RoomRecentEventsUpdater, error) {
|
||||||
|
fmt.Println("=============== GetLatestEventsForUpdate BEGIN TXN")
|
||||||
txn, err := d.db.Begin()
|
txn, err := d.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -355,8 +412,15 @@ func (d *Database) GetLatestEventsForUpdate(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fmt.Println("GetLatestEventsForUpdate returning updater")
|
||||||
|
|
||||||
|
// FIXME: we probably want to support long-lived txns in sqlite somehow, but we don't because we get
|
||||||
|
// 'database is locked' errors caused by multiple write txns (one being the long-lived txn created here)
|
||||||
|
// so for now let's not use a long-lived txn at all, and just commit it here and set the txn to nil so
|
||||||
|
// we fail fast if someone tries to use the underlying txn object.
|
||||||
|
txn.Commit()
|
||||||
return &roomRecentEventsUpdater{
|
return &roomRecentEventsUpdater{
|
||||||
transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
transaction{ctx, nil}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -398,24 +462,33 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN
|
||||||
|
|
||||||
// StorePreviousEvents implements types.RoomRecentEventsUpdater
|
// StorePreviousEvents implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||||
for _, ref := range previousEventReferences {
|
err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
||||||
if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
for _, ref := range previousEventReferences {
|
||||||
return err
|
if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return nil
|
})
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsReferenced implements types.RoomRecentEventsUpdater
|
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) {
|
||||||
err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
|
fmt.Println("[[TXN]] IsReferenced")
|
||||||
if err == nil {
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
||||||
return true, nil
|
err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256)
|
||||||
}
|
if err == nil {
|
||||||
if err == sql.ErrNoRows {
|
res = true
|
||||||
return false, nil
|
err = nil
|
||||||
}
|
}
|
||||||
return false, err
|
if err == sql.ErrNoRows {
|
||||||
|
res = false
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
||||||
|
|
@ -423,38 +496,58 @@ func (u *roomRecentEventsUpdater) SetLatestEvents(
|
||||||
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
|
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
|
||||||
currentStateSnapshotNID types.StateSnapshotNID,
|
currentStateSnapshotNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
eventNIDs := make([]types.EventNID, len(latest))
|
err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
||||||
for i := range latest {
|
eventNIDs := make([]types.EventNID, len(latest))
|
||||||
eventNIDs[i] = latest[i].EventNID
|
for i := range latest {
|
||||||
}
|
eventNIDs[i] = latest[i].EventNID
|
||||||
// TODO: transaction was removed here - is this wise?
|
}
|
||||||
return u.d.statements.updateLatestEventNIDs(u.ctx, nil, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
|
return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
|
||||||
|
})
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) {
|
||||||
// TODO: transaction was removed here - is this wise?
|
// TODO: transaction was removed here - is this wise?
|
||||||
return u.d.statements.selectEventSentToOutput(u.ctx, nil, eventNID)
|
fmt.Println("[[TXN]] HasEventBeenSent")
|
||||||
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
||||||
|
res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||||
// TODO: transaction was removed here - is this wise?
|
// TODO: transaction was removed here - is this wise?
|
||||||
return u.d.statements.updateEventSentToOutput(u.ctx, nil, eventNID)
|
fmt.Println("[[TXN]] updateEventSentToOutput")
|
||||||
|
err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
||||||
|
return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID)
|
||||||
|
})
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) {
|
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) {
|
||||||
// TODO: transaction was removed here - is this wise?
|
// TODO: transaction was removed here - is this wise?
|
||||||
return u.d.membershipUpdaterTxn(u.ctx, nil, u.roomNID, targetUserNID)
|
fmt.Println("[[TXN]] membershipUpdaterTxn")
|
||||||
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
||||||
|
mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoomNID implements query.RoomserverQueryAPIDB
|
// RoomNID implements query.RoomserverQueryAPIDB
|
||||||
func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) {
|
func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) {
|
||||||
roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID)
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
if err == sql.ErrNoRows {
|
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
|
||||||
return 0, nil
|
if err == sql.ErrNoRows {
|
||||||
}
|
roomNID = 0
|
||||||
return roomNID, err
|
err = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
|
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
|
||||||
|
|
@ -532,12 +625,14 @@ func (d *Database) MembershipUpdater(
|
||||||
ctx context.Context, roomID, targetUserID string,
|
ctx context.Context, roomID, targetUserID string,
|
||||||
) (types.MembershipUpdater, error) {
|
) (types.MembershipUpdater, error) {
|
||||||
txn, err := d.db.Begin()
|
txn, err := d.db.Begin()
|
||||||
|
fmt.Println("=== UPDATER TXN START ====")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
succeeded := false
|
succeeded := false
|
||||||
defer func() {
|
defer func() {
|
||||||
if !succeeded {
|
if !succeeded {
|
||||||
|
fmt.Println("=== UPDATER TXN ROLLBACK ====")
|
||||||
txn.Rollback() // nolint: errcheck
|
txn.Rollback() // nolint: errcheck
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
@ -606,92 +701,99 @@ func (u *membershipUpdater) IsLeave() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetToInvite implements types.MembershipUpdater
|
// SetToInvite implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
|
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted bool, err error) {
|
||||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
||||||
if err != nil {
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, event.Sender())
|
||||||
return false, err
|
if err != nil {
|
||||||
}
|
return err
|
||||||
inserted, err := u.d.statements.insertInviteEvent(
|
|
||||||
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if u.membership != membershipStateInvite {
|
|
||||||
if err = u.d.statements.updateMembership(
|
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
|
|
||||||
); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
}
|
||||||
}
|
inserted, err = u.d.statements.insertInviteEvent(
|
||||||
return inserted, nil
|
u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if u.membership != membershipStateInvite {
|
||||||
|
if err = u.d.statements.updateMembership(
|
||||||
|
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetToJoin implements types.MembershipUpdater
|
// SetToJoin implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
|
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) {
|
||||||
var inviteEventIDs []string
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
||||||
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID)
|
||||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If this is a join event update, there is no invite to update
|
|
||||||
if !isUpdate {
|
|
||||||
inviteEventIDs, err = u.d.statements.updateInviteRetired(
|
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Look up the NID of the new join event
|
// If this is a join event update, there is no invite to update
|
||||||
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
if !isUpdate {
|
||||||
if err != nil {
|
inviteEventIDs, err = u.d.statements.updateInviteRetired(
|
||||||
return nil, err
|
u.ctx, txn, u.roomNID, u.targetUserNID,
|
||||||
}
|
)
|
||||||
|
if err != nil {
|
||||||
if u.membership != membershipStateJoin || isUpdate {
|
return err
|
||||||
if err = u.d.statements.updateMembership(
|
}
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
|
||||||
membershipStateJoin, nIDs[eventID],
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return inviteEventIDs, nil
|
// Look up the NID of the new join event
|
||||||
|
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.membership != membershipStateJoin || isUpdate {
|
||||||
|
if err = u.d.statements.updateMembership(
|
||||||
|
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||||
|
membershipStateJoin, nIDs[eventID],
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetToLeave implements types.MembershipUpdater
|
// SetToLeave implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
|
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) {
|
||||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
||||||
if err != nil {
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID)
|
||||||
return nil, err
|
if err != nil {
|
||||||
}
|
return err
|
||||||
inviteEventIDs, err := u.d.statements.updateInviteRetired(
|
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look up the NID of the new leave event
|
|
||||||
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.membership != membershipStateLeaveOrBan {
|
|
||||||
if err = u.d.statements.updateMembership(
|
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
|
||||||
membershipStateLeaveOrBan, nIDs[eventID],
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
}
|
inviteEventIDs, err = u.d.statements.updateInviteRetired(
|
||||||
return inviteEventIDs, nil
|
u.ctx, txn, u.roomNID, u.targetUserNID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the NID of the new leave event
|
||||||
|
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.membership != membershipStateLeaveOrBan {
|
||||||
|
if err = u.d.statements.updateMembership(
|
||||||
|
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||||
|
membershipStateLeaveOrBan, nIDs[eventID],
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMembership implements query.RoomserverQueryAPIDB
|
// GetMembership implements query.RoomserverQueryAPIDB
|
||||||
|
|
@ -762,10 +864,16 @@ type transaction struct {
|
||||||
|
|
||||||
// Commit implements types.Transaction
|
// Commit implements types.Transaction
|
||||||
func (t *transaction) Commit() error {
|
func (t *transaction) Commit() error {
|
||||||
|
if t.txn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return t.txn.Commit()
|
return t.txn.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rollback implements types.Transaction
|
// Rollback implements types.Transaction
|
||||||
func (t *transaction) Rollback() error {
|
func (t *transaction) Rollback() error {
|
||||||
|
if t.txn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return t.txn.Rollback()
|
return t.txn.Rollback()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue