diff --git a/clientapi/auth/storage/accounts/sqlite3/filter_table.go b/clientapi/auth/storage/accounts/sqlite3/filter_table.go index 691ead775..7f1a0c249 100644 --- a/clientapi/auth/storage/accounts/sqlite3/filter_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/filter_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "github.com/matrix-org/gomatrixserverlib" ) @@ -47,14 +48,10 @@ const selectFilterIDByContentSQL = "" + const insertFilterSQL = "" + "INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)" -const selectLastInsertedFilterIDSQL = "" + - "SELECT id FROM account_filter WHERE rowid = last_insert_rowid()" - type filterStatements struct { - selectFilterStmt *sql.Stmt - selectLastInsertedFilterIDStmt *sql.Stmt - selectFilterIDByContentStmt *sql.Stmt - insertFilterStmt *sql.Stmt + selectFilterStmt *sql.Stmt + selectFilterIDByContentStmt *sql.Stmt + insertFilterStmt *sql.Stmt } func (s *filterStatements) prepare(db *sql.DB) (err error) { @@ -65,9 +62,6 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { return } - if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil { - return - } if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { return } @@ -128,12 +122,14 @@ func (s *filterStatements) insertFilter( } // Otherwise insert the filter and return the new ID - if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil { + res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + if err != nil { return "", err } - row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx) - if err := row.Scan(&filterID); err != nil { + rowid, err := res.LastInsertId() + if err != nil { return "", err } + filterID = fmt.Sprintf("%d", rowid) return }