mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 13:53:09 -06:00
Refactor TransactionWriter in user API
This commit is contained in:
parent
66d0134e2a
commit
2a9bb642ef
|
|
@ -57,9 +57,9 @@ type accountDataStatements struct {
|
||||||
selectAccountDataByTypeStmt *sql.Stmt
|
selectAccountDataByTypeStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(accountDataSchema)
|
_, err = db.Exec(accountDataSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -67,9 +67,9 @@ type accountsStatements struct {
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(accountsSchema)
|
_, err = db.Exec(accountsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -61,9 +61,9 @@ type profilesStatements struct {
|
||||||
selectProfilesBySearchStmt *sql.Stmt
|
selectProfilesBySearchStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(profilesSchema)
|
_, err = db.Exec(profilesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -64,16 +64,16 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
|
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.accounts.prepare(db, serverName); err != nil {
|
if err = d.accounts.prepare(db, d.writer, serverName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.profiles.prepare(db); err != nil {
|
if err = d.profiles.prepare(db, d.writer); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.accountDatas.prepare(db); err != nil {
|
if err = d.accountDatas.prepare(db, d.writer); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.threepids.prepare(db); err != nil {
|
if err = d.threepids.prepare(db, d.writer); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return d, nil
|
return d, nil
|
||||||
|
|
|
||||||
|
|
@ -61,9 +61,9 @@ type threepidStatements struct {
|
||||||
deleteThreePIDStmt *sql.Stmt
|
deleteThreePIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(threepidSchema)
|
_, err = db.Exec(threepidSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -91,9 +91,9 @@ type devicesStatements struct {
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(devicesSchema)
|
_, err = db.Exec(devicesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ var deviceIDByteLength = 6
|
||||||
// Database represents a device database.
|
// Database represents a device database.
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.TransactionWriter
|
||||||
devices devicesStatements
|
devices devicesStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
writer := sqlutil.NewTransactionWriter()
|
||||||
d := devicesStatements{}
|
d := devicesStatements{}
|
||||||
if err = d.prepare(db, serverName); err != nil {
|
if err = d.prepare(db, writer, serverName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Database{db, d}, nil
|
return &Database{db, writer, d}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||||
|
|
@ -88,7 +90,7 @@ func (d *Database) CreateDevice(
|
||||||
displayName *string,
|
displayName *string,
|
||||||
) (dev *api.Device, returnErr error) {
|
) (dev *api.Device, returnErr error) {
|
||||||
if deviceID != nil {
|
if deviceID != nil {
|
||||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
// Revoke existing tokens for this device
|
// Revoke existing tokens for this device
|
||||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||||
|
|
@ -108,7 +110,7 @@ func (d *Database) CreateDevice(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
|
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue