From 22f028e141297bcd8b1230573e55d5790e0d67a4 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 5 Aug 2020 10:00:35 +0100 Subject: [PATCH] SelectJoinedHostsForRooms should use QueryVariadic on SQLite (#1238) * SelectJoinedHostsForRooms should use QueryVariadic on SQLite * Fix strings.Replace * Fix statement --- .../storage/sqlite3/joined_hosts_table.go | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index 4ae980d7c..53736fa16 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "strings" "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/internal" @@ -63,13 +64,13 @@ const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" type joinedHostsStatements struct { - db *sql.DB - writer *sqlutil.TransactionWriter - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt - selectJoinedHostsForRoomsStmt *sql.Stmt + db *sql.DB + writer *sqlutil.TransactionWriter + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic } func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -93,9 +94,6 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { return } - if s.selectJoinedHostsForRoomsStmt, err = db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { - return - } return } @@ -168,7 +166,8 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms( iRoomIDs[i] = roomIDs[i] } - rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, iRoomIDs...) + sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) if err != nil { return nil, err }