mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-16 11:23:11 -06:00
Use LastInsertId because empirically it works over the SELECT form (though I don't know why that is)
This commit is contained in:
parent
5ee98db1a5
commit
bc37106e14
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
@ -47,14 +48,10 @@ const selectFilterIDByContentSQL = "" +
|
||||||
const insertFilterSQL = "" +
|
const insertFilterSQL = "" +
|
||||||
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
|
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
|
||||||
|
|
||||||
const selectLastInsertedFilterIDSQL = "" +
|
|
||||||
"SELECT id FROM account_filter WHERE rowid = last_insert_rowid()"
|
|
||||||
|
|
||||||
type filterStatements struct {
|
type filterStatements struct {
|
||||||
selectFilterStmt *sql.Stmt
|
selectFilterStmt *sql.Stmt
|
||||||
selectLastInsertedFilterIDStmt *sql.Stmt
|
selectFilterIDByContentStmt *sql.Stmt
|
||||||
selectFilterIDByContentStmt *sql.Stmt
|
insertFilterStmt *sql.Stmt
|
||||||
insertFilterStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
@ -65,9 +62,6 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
||||||
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
|
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -128,12 +122,14 @@ func (s *filterStatements) insertFilter(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise insert the filter and return the new ID
|
// Otherwise insert the filter and return the new ID
|
||||||
if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil {
|
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
|
||||||
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx)
|
rowid, err := res.LastInsertId()
|
||||||
if err := row.Scan(&filterID); err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
filterID = fmt.Sprintf("%d", rowid)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue