Don't use more than 999 variables in SQLite querys.

Solve this problem in a more general and reusable way.
Also fix #1369
Add some unit tests.

Signed-off-by: Henrik Sölver <henrik.solver@gmail.com>
This commit is contained in:
Henrik Sölver 2020-09-05 14:32:23 +02:00
parent 913020e4b7
commit 68ef563b3c
5 changed files with 231 additions and 33 deletions

2
go.mod
View file

@ -1,6 +1,7 @@
module github.com/matrix-org/dendrite module github.com/matrix-org/dendrite
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/Shopify/sarama v1.27.0 github.com/Shopify/sarama v1.27.0
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
github.com/gologme/log v1.2.0 github.com/gologme/log v1.2.0
@ -32,6 +33,7 @@ require (
github.com/pressly/goose v2.7.0-rc5+incompatible github.com/pressly/goose v2.7.0-rc5+incompatible
github.com/prometheus/client_golang v1.7.1 github.com/prometheus/client_golang v1.7.1
github.com/sirupsen/logrus v1.6.0 github.com/sirupsen/logrus v1.6.0
github.com/stretchr/testify v1.6.1
github.com/tidwall/gjson v1.6.1 github.com/tidwall/gjson v1.6.1
github.com/tidwall/sjson v1.1.1 github.com/tidwall/sjson v1.1.1
github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-client-go v2.25.0+incompatible

2
go.sum
View file

@ -13,6 +13,8 @@ github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0 h1:p3puK8Sl2xK+2Fnn
github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI= github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y= github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=

View file

@ -15,10 +15,14 @@
package sqlutil package sqlutil
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"runtime" "runtime"
"strings"
"github.com/matrix-org/util"
) )
// ErrUserExists is returned if a username already exists in the database. // ErrUserExists is returned if a username already exists in the database.
@ -107,3 +111,43 @@ func SQLiteDriverName() string {
} }
return "sqlite3" return "sqlite3"
} }
func minOfInts(a, b int) int {
if a <= b {
return a
}
return b
}
// QueryProvider defines the interface for querys used by RunLimitedVariablesQuery.
type QueryProvider interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement
// SQLlite can handle. See https://www.sqlite.org/limits.html for more information.
const SQLite3MaxVariables = 999
// RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries.
func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, rowHandler func(*sql.Rows) error, variables []interface{}, limit uint) error {
var start int
for start < len(variables) {
n := minOfInts(len(variables)-start, int(limit))
query := strings.Replace(query, "($1)", QueryVariadic(n), 1)
rows, err := qp.QueryContext(ctx, query, variables[start:start+n]...)
if err != nil {
return err
}
err = rowHandler(rows)
if err := rows.Close(); err != nil {
util.GetLogger(ctx).WithError(err).Error(err.Error())
return err
}
if err != nil {
util.GetLogger(ctx).WithError(err).Error(err.Error())
return err
}
start = start + n
}
return nil
}

View file

@ -0,0 +1,150 @@
package sqlutil
import (
"context"
"database/sql"
"testing"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)
func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NoError(t, err)
limit := uint(4)
r := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3)
mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){3}\\)").WillReturnRows(r)
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assert.NoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
}, iKeyIDs, limit)
assert.NoError(t, err, "Call returned an error")
assert.Len(t, result, len(v), "Result should be 3 long")
}
func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NoError(t, err)
limit := uint(4)
r := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3).
AddRow(4)
mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){4}\\)").WillReturnRows(r)
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3, 4}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assert.NoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
}, iKeyIDs, limit)
assert.NoError(t, err, "Call returned an error")
assert.Len(t, result, len(v), "Result should be 3 long")
}
func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NoError(t, err)
limit := uint(4)
r1 := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3).
AddRow(4)
r2 := mock.NewRows([]string{"id"}).
AddRow(5)
mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){4}\\)").WillReturnRows(r1)
mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){1}\\)").WillReturnRows(r2)
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3, 4, 5}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assert.NoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
}, iKeyIDs, limit)
assert.NoError(t, err, "Call returned an error")
assert.Equal(t, v, result, "Result is not as expected")
assert.Len(t, result, len(v), "Result should be 3 long")
}
func TestShouldREturnErrorIfRowsScanReturnsError(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NoError(t, err)
limit := uint(4)
r := mock.NewRows([]string{"id"}).
AddRow("hej").
AddRow(2).
AddRow(3)
mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){3}\\)").WillReturnRows(r)
q := "SELECT id WHERE id IN ($1)"
v := []int{-1, -2, 3}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]uint, 0)
err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error {
for rows.Next() {
var id uint
err = rows.Scan(&id)
if err != nil {
return err
}
result = append(result, id)
}
return nil
}, iKeyIDs, limit)
assert.Error(t, err, "Call did not return an error")
}

View file

@ -18,9 +18,8 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings" "fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -88,41 +87,35 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context, ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
var nameAndKeyIDs []string nameAndKeyIDs := make([]string, 0, len(requests))
for request := range requests { for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
} }
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1)
iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
for i, v := range nameAndKeyIDs { for i, v := range nameAndKeyIDs {
iKeyIDs[i] = v iKeyIDs[i] = v
} }
rows, err := s.db.QueryContext(ctx, query, iKeyIDs...) err := sqlutil.RunLimitedVariablesQuery(ctx, bulkSelectServerKeysSQL, s.db,
if err != nil { func(rows *sql.Rows) error {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() { for rows.Next() {
var serverName string var serverName string
var keyID string var keyID string
var key string var key string
var validUntilTS int64 var validUntilTS int64
var expiredTS int64 var expiredTS int64
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return nil, err return fmt.Errorf("bulkSelectServerKeys: %v", err)
} }
r := gomatrixserverlib.PublicKeyLookupRequest{ r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName), ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID), KeyID: gomatrixserverlib.KeyID(keyID),
} }
vk := gomatrixserverlib.VerifyKey{} vk := gomatrixserverlib.VerifyKey{}
err = vk.Key.Decode(key) err := vk.Key.Decode(key)
if err != nil { if err != nil {
return nil, err return fmt.Errorf("bulkSelectServerKeys: %v", err)
} }
results[r] = gomatrixserverlib.PublicKeyLookupResult{ results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk, VerifyKey: vk,
@ -130,6 +123,13 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
} }
} }
return nil
},
iKeyIDs, sqlutil.SQLite3MaxVariables)
if err != nil {
return nil, err
}
return results, nil return results, nil
} }