mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-01-18 09:54:27 -06:00
Update all usages of tx.Stmt to sqlutil.TxStmt (#1423)
* Replace all usages of txn.Stmt with sqlutil.TxStmt Signed-off-by: Sam Day <me@samcday.com> * Fix sign off link in PR template. Signed-off-by: Sam Day <me@samcday.com> Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
parent
60524f4b99
commit
a6700331ce
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
@ -3,4 +3,4 @@
|
|||
<!-- Please read CONTRIBUTING.md before submitting your pull request -->
|
||||
|
||||
* [ ] I have added any new tests that need to pass to `testfile` as specified in [docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md)
|
||||
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/CONTRIBUTING.md#sign-off)
|
||||
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/docs/CONTRIBUTING.md#sign-off)
|
||||
|
|
|
@ -88,6 +88,14 @@ func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
|
|||
return statement
|
||||
}
|
||||
|
||||
// TxStmtContext behaves similarly to TxStmt, with support for also passing context.
|
||||
func TxStmtContext(context context.Context, transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
|
||||
if transaction != nil {
|
||||
statement = transaction.StmtContext(context, statement)
|
||||
}
|
||||
return statement
|
||||
}
|
||||
|
||||
// Hack of the century
|
||||
func QueryVariadic(count int) string {
|
||||
return QueryVariadicOffset(count, 0)
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
@ -125,7 +126,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
|||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt32
|
||||
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
|
@ -151,7 +152,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
|
|||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -162,7 +163,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
|||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
@ -151,14 +152,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.
|
|||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -180,14 +181,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
|||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
|
|
|
@ -97,7 +97,7 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -156,7 +156,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
|||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt32
|
||||
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
|
@ -188,7 +188,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
|
|||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
@ -153,14 +154,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
|
|||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -182,14 +183,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
|||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
|
@ -86,7 +87,7 @@ func (s *stateSnapshotStatements) InsertState(
|
|||
for i := range stateBlockNIDs {
|
||||
nids[i] = int64(stateBlockNIDs[i])
|
||||
}
|
||||
err = txn.Stmt(s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
|
||||
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -105,12 +105,12 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
|||
return 0, nil
|
||||
}
|
||||
var stateBlockNID types.StateBlockNID
|
||||
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||
err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
_, err = txn.Stmt(s.insertStateDataStmt).ExecContext(
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
|
||||
ctx,
|
||||
int64(stateBlockNID),
|
||||
int64(entry.EventTypeNID),
|
||||
|
|
|
@ -76,7 +76,7 @@ func (s *stateSnapshotStatements) InsertState(
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
insertStmt := txn.Stmt(s.insertStateStmt)
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
|
|
@ -81,7 +81,7 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
|
|||
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
||||
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||
_, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
|||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
||||
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -160,13 +160,13 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
|||
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
||||
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
|
||||
_, err = sqlutil.TxStmt(txn, s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
|
||||
return
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids))
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -85,7 +86,7 @@ func (s *accountDataStatements) InsertAccountData(
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
|
||||
_, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -147,7 +148,7 @@ func (s *accountDataStatements) SelectMaxAccountDataID(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
) (id int64, err error) {
|
||||
var nullableID sql.NullInt64
|
||||
err = txn.Stmt(s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
|
||||
if nullableID.Valid {
|
||||
id = nullableID.Int64
|
||||
}
|
||||
|
|
|
@ -84,7 +84,7 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
|
|||
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
||||
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||
_, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -113,7 +113,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
|||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
||||
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
@ -75,7 +76,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
|||
func (s *accountDataStatements) insertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) (err error) {
|
||||
stmt := txn.Stmt(s.insertAccountDataStmt)
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
|
@ -99,7 +100,7 @@ func (s *accountsStatements) insertAccount(
|
|||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := txn.Stmt(s.insertAccountStmt)
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
var err error
|
||||
if appserviceID == "" {
|
||||
|
@ -162,7 +163,7 @@ func (s *accountsStatements) selectNewNumericLocalpart(
|
|||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = txn.Stmt(stmt)
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
return
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const profilesSchema = `
|
||||
|
@ -84,7 +85,7 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
|||
func (s *profilesStatements) insertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
@ -75,7 +77,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
|||
func (s *accountDataStatements) insertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
_, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
|
@ -104,9 +105,9 @@ func (s *accountsStatements) insertAccount(
|
|||
|
||||
var err error
|
||||
if appserviceID == "" {
|
||||
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||
} else {
|
||||
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -163,7 +164,7 @@ func (s *accountsStatements) selectNewNumericLocalpart(
|
|||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = txn.Stmt(stmt)
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
return
|
||||
|
|
|
@ -87,7 +87,7 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
|||
func (s *profilesStatements) insertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) error {
|
||||
_, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue