Use LastInsertId because empirically it works over the SELECT form (though I don't know why that is)

This commit is contained in:
Kegan Dougal 2020-02-18 16:24:06 +00:00
parent 5ee98db1a5
commit bc37106e14

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -47,14 +48,10 @@ const selectFilterIDByContentSQL = "" +
const insertFilterSQL = "" + const insertFilterSQL = "" +
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)" "INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
const selectLastInsertedFilterIDSQL = "" +
"SELECT id FROM account_filter WHERE rowid = last_insert_rowid()"
type filterStatements struct { type filterStatements struct {
selectFilterStmt *sql.Stmt selectFilterStmt *sql.Stmt
selectLastInsertedFilterIDStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt insertFilterStmt *sql.Stmt
insertFilterStmt *sql.Stmt
} }
func (s *filterStatements) prepare(db *sql.DB) (err error) { func (s *filterStatements) prepare(db *sql.DB) (err error) {
@ -65,9 +62,6 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return return
} }
if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil {
return
}
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return return
} }
@ -128,12 +122,14 @@ func (s *filterStatements) insertFilter(
} }
// Otherwise insert the filter and return the new ID // Otherwise insert the filter and return the new ID
if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil { res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
if err != nil {
return "", err return "", err
} }
row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx) rowid, err := res.LastInsertId()
if err := row.Scan(&filterID); err != nil { if err != nil {
return "", err return "", err
} }
filterID = fmt.Sprintf("%d", rowid)
return return
} }