diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go index c50cd1fd9..ee57fbf38 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go @@ -38,12 +38,16 @@ CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart) const selectFilterSQL = "" + "SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2" +const selectFilterByContentSQL = "" + + "SELECT filter FROM account_filter WHERE localpart = $1 AND filter = $2" + const insertFilterSQL = "" + "INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id" type filterStatements struct { - selectFilterStmt *sql.Stmt - insertFilterStmt *sql.Stmt + selectFilterStmt *sql.Stmt + selectFilterByContentStmt *sql.Stmt + insertFilterStmt *sql.Stmt } func (s *filterStatements) prepare(db *sql.DB) (err error) { @@ -54,6 +58,9 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { return } + if s.selectFilterByContentStmt, err = db.Prepare(selectFilterByContentSQL); err != nil { + return + } if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { return } @@ -70,6 +77,15 @@ func (s *filterStatements) selectFilter( func (s *filterStatements) insertFilter( ctx context.Context, filter string, localpart string, ) (pos string, err error) { + var existingFilter string + + // Check if filter already exists in the database + err = s.selectFilterByContentStmt.QueryRowContext(ctx, + localpart, filter).Scan(&existingFilter) + if existingFilter != "" { + return existingFilter, err + } + err = s.insertFilterStmt.QueryRowContext(ctx, filter, localpart).Scan(&pos) return }