diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 88bfe4d4d..90562ded3 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -129,22 +129,23 @@ type QueryProvider interface { const SQLite3MaxVariables = 999 // RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries. -func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, rowHandler func(*sql.Rows) error, variables []interface{}, limit uint) error { +func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, variables []interface{}, limit uint, rowHandler func(*sql.Rows) error) error { var start int for start < len(variables) { n := minOfInts(len(variables)-start, int(limit)) - query := strings.Replace(query, "($1)", QueryVariadic(n), 1) - rows, err := qp.QueryContext(ctx, query, variables[start:start+n]...) + nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1) + rows, err := qp.QueryContext(ctx, nextQuery, variables[start:start+n]...) if err != nil { + util.GetLogger(ctx).WithError(err).Error("QueryContext returned an error") return err } err = rowHandler(rows) - if err := rows.Close(); err != nil { - util.GetLogger(ctx).WithError(err).Error(err.Error()) + if closeErr := rows.Close(); closeErr != nil { + util.GetLogger(ctx).WithError(closeErr).Error("RunLimitedVariablesQuery: failed to close rows") return err } if err != nil { - util.GetLogger(ctx).WithError(err).Error(err.Error()) + util.GetLogger(ctx).WithError(err).Error("RunLimitedVariablesQuery: rowHandler returned error") return err } start = start + n diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go index 880f6a898..79469cddc 100644 --- a/internal/sqlutil/sqlutil_test.go +++ b/internal/sqlutil/sqlutil_test.go @@ -19,7 +19,8 @@ func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing. AddRow(2). AddRow(3) - mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){3}\\)").WillReturnRows(r) + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r) + // nolint:goconst q := "SELECT id WHERE id IN ($1)" v := []int{1, 2, 3} iKeyIDs := make([]interface{}, len(v)) @@ -29,7 +30,7 @@ func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing. ctx := context.Background() var result = make([]int, 0) - err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error { + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { for rows.Next() { var id int err = rows.Scan(&id) @@ -37,7 +38,7 @@ func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing. result = append(result, id) } return nil - }, iKeyIDs, limit) + }) assertNoError(t, err, "Call returned an error") if len(result) != len(v) { t.Fatalf("Result should be 3 long") @@ -55,7 +56,8 @@ func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) AddRow(3). AddRow(4) - mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){4}\\)").WillReturnRows(r) + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r) + // nolint:goconst q := "SELECT id WHERE id IN ($1)" v := []int{1, 2, 3, 4} iKeyIDs := make([]interface{}, len(v)) @@ -65,7 +67,7 @@ func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) ctx := context.Background() var result = make([]int, 0) - err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error { + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { for rows.Next() { var id int err = rows.Scan(&id) @@ -73,7 +75,7 @@ func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) result = append(result, id) } return nil - }, iKeyIDs, limit) + }) assertNoError(t, err, "Call returned an error") if len(result) != len(v) { t.Fatalf("Result should be 4 long") @@ -94,8 +96,9 @@ func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T r2 := mock.NewRows([]string{"id"}). AddRow(5) - mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){4}\\)").WillReturnRows(r1) - mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){1}\\)").WillReturnRows(r2) + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r1) + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1\)`).WillReturnRows(r2) + // nolint:goconst q := "SELECT id WHERE id IN ($1)" v := []int{1, 2, 3, 4, 5} iKeyIDs := make([]interface{}, len(v)) @@ -105,7 +108,7 @@ func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T ctx := context.Background() var result = make([]int, 0) - err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error { + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { for rows.Next() { var id int err = rows.Scan(&id) @@ -113,7 +116,7 @@ func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T result = append(result, id) } return nil - }, iKeyIDs, limit) + }) assertNoError(t, err, "Call returned an error") if len(result) != len(v) { t.Fatalf("Result should be 5 long") @@ -123,7 +126,7 @@ func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T } } -func TestShouldREturnErrorIfRowsScanReturnsError(t *testing.T) { +func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { db, mock, err := sqlmock.New() assertNoError(t, err, "Failed to make DB") limit := uint(4) @@ -134,7 +137,8 @@ func TestShouldREturnErrorIfRowsScanReturnsError(t *testing.T) { AddRow(2). AddRow(3) - mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){3}\\)").WillReturnRows(r) + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r) + // nolint:goconst q := "SELECT id WHERE id IN ($1)" v := []int{-1, -2, 3} iKeyIDs := make([]interface{}, len(v)) @@ -144,7 +148,7 @@ func TestShouldREturnErrorIfRowsScanReturnsError(t *testing.T) { ctx := context.Background() var result = make([]uint, 0) - err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error { + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { for rows.Next() { var id uint err = rows.Scan(&id) @@ -154,7 +158,7 @@ func TestShouldREturnErrorIfRowsScanReturnsError(t *testing.T) { result = append(result, id) } return nil - }, iKeyIDs, limit) + }) if err == nil { t.Fatalf("Call did not return an error") } diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go index efea9af7e..2484d6368 100644 --- a/serverkeyapi/storage/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -97,7 +97,8 @@ func (s *serverKeyStatements) bulkSelectServerKeys( iKeyIDs[i] = v } - err := sqlutil.RunLimitedVariablesQuery(ctx, bulkSelectServerKeysSQL, s.db, + err := sqlutil.RunLimitedVariablesQuery( + ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, func(rows *sql.Rows) error { for rows.Next() { var serverName string @@ -125,7 +126,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys( } return nil }, - iKeyIDs, sqlutil.SQLite3MaxVariables) + ) if err != nil { return nil, err