Use TransactionWriter in SQLite keyserver

This commit is contained in:
Neil Alexander 2020-08-05 10:06:01 +01:00
parent 22f028e141
commit 8d063627f5
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 37 additions and 22 deletions

View file

@ -20,6 +20,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -54,6 +55,7 @@ const selectMaxStreamForUserSQL = "" +
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
@ -62,7 +64,8 @@ type deviceKeysStatements struct {
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{ s := &deviceKeysStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(deviceKeysSchema) _, err := db.Exec(deviceKeysSchema)
if err != nil { if err != nil {
@ -141,14 +144,16 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
} }
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
now := time.Now().Unix() for _, key := range keys {
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( now := time.Now().Unix()
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
) ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
if err != nil { )
return err if err != nil {
return err
}
} }
} return nil
return nil })
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -51,13 +52,15 @@ const selectKeyChangesSQL = "" +
type keyChangesStatements struct { type keyChangesStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
upsertKeyChangeStmt *sql.Stmt upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt
} }
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{ s := &keyChangesStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(keyChangesSchema) _, err := db.Exec(keyChangesSchema)
if err != nil { if err != nil {
@ -73,8 +76,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
} }
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return err _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
return err
})
} }
func (s *keyChangesStatements) SelectKeyChanges( func (s *keyChangesStatements) SelectKeyChanges(

View file

@ -60,6 +60,7 @@ const selectKeyByAlgorithmSQL = "" +
type oneTimeKeysStatements struct { type oneTimeKeysStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
upsertKeysStmt *sql.Stmt upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt selectKeysCountStmt *sql.Stmt
@ -69,7 +70,8 @@ type oneTimeKeysStatements struct {
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
s := &oneTimeKeysStatements{ s := &oneTimeKeysStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(oneTimeKeysSchema) _, err := db.Exec(oneTimeKeysSchema)
if err != nil { if err != nil {
@ -150,7 +152,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
UserID: keys.UserID, UserID: keys.UserID,
KeyCount: make(map[string]int), KeyCount: make(map[string]int),
} }
return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { return counts, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
@ -183,14 +185,17 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
) (map[string]json.RawMessage, error) { ) (map[string]json.RawMessage, error) {
var keyID string var keyID string
var keyJSON string var keyJSON string
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
if err != nil { err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err == sql.ErrNoRows { if err != nil {
return nil, nil if err == sql.ErrNoRows {
return nil
}
return err
} }
return nil, err _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
} return err
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) })
return map[string]json.RawMessage{ return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON), algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err }, err