mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-23 14:53:10 -06:00
Don't use more than 999 variables in SQLite querys.
Solve this problem in a more general and reusable way. Also fix #1369 Add some unit tests. Signed-off-by: Henrik Sölver <henrik.solver@gmail.com>
This commit is contained in:
parent
913020e4b7
commit
68ef563b3c
2
go.mod
2
go.mod
|
|
@ -1,6 +1,7 @@
|
|||
module github.com/matrix-org/dendrite
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0
|
||||
github.com/Shopify/sarama v1.27.0
|
||||
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
|
||||
github.com/gologme/log v1.2.0
|
||||
|
|
@ -32,6 +33,7 @@ require (
|
|||
github.com/pressly/goose v2.7.0-rc5+incompatible
|
||||
github.com/prometheus/client_golang v1.7.1
|
||||
github.com/sirupsen/logrus v1.6.0
|
||||
github.com/stretchr/testify v1.6.1
|
||||
github.com/tidwall/gjson v1.6.1
|
||||
github.com/tidwall/sjson v1.1.1
|
||||
github.com/uber/jaeger-client-go v2.25.0+incompatible
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -13,6 +13,8 @@ github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0 h1:p3puK8Sl2xK+2Fnn
|
|||
github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI=
|
||||
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
|
||||
github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y=
|
||||
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
||||
github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=
|
||||
|
|
|
|||
|
|
@ -15,10 +15,14 @@
|
|||
package sqlutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// ErrUserExists is returned if a username already exists in the database.
|
||||
|
|
@ -107,3 +111,43 @@ func SQLiteDriverName() string {
|
|||
}
|
||||
return "sqlite3"
|
||||
}
|
||||
|
||||
func minOfInts(a, b int) int {
|
||||
if a <= b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// QueryProvider defines the interface for querys used by RunLimitedVariablesQuery.
|
||||
type QueryProvider interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement
|
||||
// SQLlite can handle. See https://www.sqlite.org/limits.html for more information.
|
||||
const SQLite3MaxVariables = 999
|
||||
|
||||
// RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries.
|
||||
func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, rowHandler func(*sql.Rows) error, variables []interface{}, limit uint) error {
|
||||
var start int
|
||||
for start < len(variables) {
|
||||
n := minOfInts(len(variables)-start, int(limit))
|
||||
query := strings.Replace(query, "($1)", QueryVariadic(n), 1)
|
||||
rows, err := qp.QueryContext(ctx, query, variables[start:start+n]...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = rowHandler(rows)
|
||||
if err := rows.Close(); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error(err.Error())
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error(err.Error())
|
||||
return err
|
||||
}
|
||||
start = start + n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
150
internal/sqlutil/sqlutil_test.go
Normal file
150
internal/sqlutil/sqlutil_test.go
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
package sqlutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NoError(t, err)
|
||||
limit := uint(4)
|
||||
|
||||
r := mock.NewRows([]string{"id"}).
|
||||
AddRow(1).
|
||||
AddRow(2).
|
||||
AddRow(3)
|
||||
|
||||
mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){3}\\)").WillReturnRows(r)
|
||||
q := "SELECT id WHERE id IN ($1)"
|
||||
v := []int{1, 2, 3}
|
||||
iKeyIDs := make([]interface{}, len(v))
|
||||
for i, d := range v {
|
||||
iKeyIDs[i] = d
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
var result = make([]int, 0)
|
||||
err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var id int
|
||||
err = rows.Scan(&id)
|
||||
assert.NoError(t, err, "rows.Scan returned an error")
|
||||
result = append(result, id)
|
||||
}
|
||||
return nil
|
||||
}, iKeyIDs, limit)
|
||||
assert.NoError(t, err, "Call returned an error")
|
||||
assert.Len(t, result, len(v), "Result should be 3 long")
|
||||
}
|
||||
|
||||
func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NoError(t, err)
|
||||
limit := uint(4)
|
||||
|
||||
r := mock.NewRows([]string{"id"}).
|
||||
AddRow(1).
|
||||
AddRow(2).
|
||||
AddRow(3).
|
||||
AddRow(4)
|
||||
|
||||
mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){4}\\)").WillReturnRows(r)
|
||||
q := "SELECT id WHERE id IN ($1)"
|
||||
v := []int{1, 2, 3, 4}
|
||||
iKeyIDs := make([]interface{}, len(v))
|
||||
for i, d := range v {
|
||||
iKeyIDs[i] = d
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
var result = make([]int, 0)
|
||||
err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var id int
|
||||
err = rows.Scan(&id)
|
||||
assert.NoError(t, err, "rows.Scan returned an error")
|
||||
result = append(result, id)
|
||||
}
|
||||
return nil
|
||||
}, iKeyIDs, limit)
|
||||
assert.NoError(t, err, "Call returned an error")
|
||||
assert.Len(t, result, len(v), "Result should be 3 long")
|
||||
}
|
||||
|
||||
func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NoError(t, err)
|
||||
limit := uint(4)
|
||||
|
||||
r1 := mock.NewRows([]string{"id"}).
|
||||
AddRow(1).
|
||||
AddRow(2).
|
||||
AddRow(3).
|
||||
AddRow(4)
|
||||
|
||||
r2 := mock.NewRows([]string{"id"}).
|
||||
AddRow(5)
|
||||
|
||||
mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){4}\\)").WillReturnRows(r1)
|
||||
mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){1}\\)").WillReturnRows(r2)
|
||||
q := "SELECT id WHERE id IN ($1)"
|
||||
v := []int{1, 2, 3, 4, 5}
|
||||
iKeyIDs := make([]interface{}, len(v))
|
||||
for i, d := range v {
|
||||
iKeyIDs[i] = d
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
var result = make([]int, 0)
|
||||
err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var id int
|
||||
err = rows.Scan(&id)
|
||||
assert.NoError(t, err, "rows.Scan returned an error")
|
||||
result = append(result, id)
|
||||
}
|
||||
return nil
|
||||
}, iKeyIDs, limit)
|
||||
assert.NoError(t, err, "Call returned an error")
|
||||
assert.Equal(t, v, result, "Result is not as expected")
|
||||
assert.Len(t, result, len(v), "Result should be 3 long")
|
||||
}
|
||||
|
||||
func TestShouldREturnErrorIfRowsScanReturnsError(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NoError(t, err)
|
||||
limit := uint(4)
|
||||
|
||||
r := mock.NewRows([]string{"id"}).
|
||||
AddRow("hej").
|
||||
AddRow(2).
|
||||
AddRow(3)
|
||||
|
||||
mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){3}\\)").WillReturnRows(r)
|
||||
q := "SELECT id WHERE id IN ($1)"
|
||||
v := []int{-1, -2, 3}
|
||||
iKeyIDs := make([]interface{}, len(v))
|
||||
for i, d := range v {
|
||||
iKeyIDs[i] = d
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
var result = make([]uint, 0)
|
||||
err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var id uint
|
||||
err = rows.Scan(&id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result = append(result, id)
|
||||
}
|
||||
return nil
|
||||
}, iKeyIDs, limit)
|
||||
assert.Error(t, err, "Call did not return an error")
|
||||
}
|
||||
|
|
@ -18,9 +18,8 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
|
@ -88,48 +87,49 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
|
|||
ctx context.Context,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||
var nameAndKeyIDs []string
|
||||
nameAndKeyIDs := make([]string, 0, len(requests))
|
||||
for request := range requests {
|
||||
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
||||
}
|
||||
|
||||
query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1)
|
||||
|
||||
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
|
||||
iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
|
||||
for i, v := range nameAndKeyIDs {
|
||||
iKeyIDs[i] = v
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, iKeyIDs...)
|
||||
err := sqlutil.RunLimitedVariablesQuery(ctx, bulkSelectServerKeysSQL, s.db,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
var keyID string
|
||||
var key string
|
||||
var validUntilTS int64
|
||||
var expiredTS int64
|
||||
if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
|
||||
return fmt.Errorf("bulkSelectServerKeys: %v", err)
|
||||
}
|
||||
r := gomatrixserverlib.PublicKeyLookupRequest{
|
||||
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||
KeyID: gomatrixserverlib.KeyID(keyID),
|
||||
}
|
||||
vk := gomatrixserverlib.VerifyKey{}
|
||||
err := vk.Key.Decode(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bulkSelectServerKeys: %v", err)
|
||||
}
|
||||
results[r] = gomatrixserverlib.PublicKeyLookupResult{
|
||||
VerifyKey: vk,
|
||||
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
|
||||
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
iKeyIDs, sqlutil.SQLite3MaxVariables)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
|
||||
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
var keyID string
|
||||
var key string
|
||||
var validUntilTS int64
|
||||
var expiredTS int64
|
||||
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := gomatrixserverlib.PublicKeyLookupRequest{
|
||||
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||
KeyID: gomatrixserverlib.KeyID(keyID),
|
||||
}
|
||||
vk := gomatrixserverlib.VerifyKey{}
|
||||
err = vk.Key.Decode(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results[r] = gomatrixserverlib.PublicKeyLookupResult{
|
||||
VerifyKey: vk,
|
||||
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
|
||||
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue