Use LastInsertId because empirically it works over the SELECT form (though I don't know why that is)

This commit is contained in:
Kegan Dougal 2020-02-18 16:24:06 +00:00
parent 5ee98db1a5
commit bc37106e14

View file

@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
@ -47,14 +48,10 @@ const selectFilterIDByContentSQL = "" +
const insertFilterSQL = "" +
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
const selectLastInsertedFilterIDSQL = "" +
"SELECT id FROM account_filter WHERE rowid = last_insert_rowid()"
type filterStatements struct {
selectFilterStmt *sql.Stmt
selectLastInsertedFilterIDStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
}
func (s *filterStatements) prepare(db *sql.DB) (err error) {
@ -65,9 +62,6 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return
}
if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil {
return
}
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return
}
@ -128,12 +122,14 @@ func (s *filterStatements) insertFilter(
}
// Otherwise insert the filter and return the new ID
if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil {
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
if err != nil {
return "", err
}
row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx)
if err := row.Scan(&filterID); err != nil {
rowid, err := res.LastInsertId()
if err != nil {
return "", err
}
filterID = fmt.Sprintf("%d", rowid)
return
}