diff --git a/go.mod b/go.mod index f1cb3c9be..e931eecdc 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/matrix-org/dendrite require ( + github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/Shopify/sarama v1.27.0 github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect github.com/gologme/log v1.2.0 @@ -32,6 +33,7 @@ require ( github.com/pressly/goose v2.7.0-rc5+incompatible github.com/prometheus/client_golang v1.7.1 github.com/sirupsen/logrus v1.6.0 + github.com/stretchr/testify v1.6.1 github.com/tidwall/gjson v1.6.1 github.com/tidwall/sjson v1.1.1 github.com/uber/jaeger-client-go v2.25.0+incompatible diff --git a/go.sum b/go.sum index ac7827d9f..5c4f27a5d 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0 h1:p3puK8Sl2xK+2Fnn github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 1d2825d56..88bfe4d4d 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -15,10 +15,14 @@ package sqlutil import ( + "context" "database/sql" "errors" "fmt" "runtime" + "strings" + + "github.com/matrix-org/util" ) // ErrUserExists is returned if a username already exists in the database. @@ -107,3 +111,43 @@ func SQLiteDriverName() string { } return "sqlite3" } + +func minOfInts(a, b int) int { + if a <= b { + return a + } + return b +} + +// QueryProvider defines the interface for querys used by RunLimitedVariablesQuery. +type QueryProvider interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +} + +// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement +// SQLlite can handle. See https://www.sqlite.org/limits.html for more information. +const SQLite3MaxVariables = 999 + +// RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries. +func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, rowHandler func(*sql.Rows) error, variables []interface{}, limit uint) error { + var start int + for start < len(variables) { + n := minOfInts(len(variables)-start, int(limit)) + query := strings.Replace(query, "($1)", QueryVariadic(n), 1) + rows, err := qp.QueryContext(ctx, query, variables[start:start+n]...) + if err != nil { + return err + } + err = rowHandler(rows) + if err := rows.Close(); err != nil { + util.GetLogger(ctx).WithError(err).Error(err.Error()) + return err + } + if err != nil { + util.GetLogger(ctx).WithError(err).Error(err.Error()) + return err + } + start = start + n + } + return nil +} diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go new file mode 100644 index 000000000..1191d2b7e --- /dev/null +++ b/internal/sqlutil/sqlutil_test.go @@ -0,0 +1,150 @@ +package sqlutil + +import ( + "context" + "database/sql" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + limit := uint(4) + + r := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3) + + mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){3}\\)").WillReturnRows(r) + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assert.NoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }, iKeyIDs, limit) + assert.NoError(t, err, "Call returned an error") + assert.Len(t, result, len(v), "Result should be 3 long") +} + +func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + limit := uint(4) + + r := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3). + AddRow(4) + + mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){4}\\)").WillReturnRows(r) + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3, 4} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assert.NoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }, iKeyIDs, limit) + assert.NoError(t, err, "Call returned an error") + assert.Len(t, result, len(v), "Result should be 3 long") +} + +func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + limit := uint(4) + + r1 := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3). + AddRow(4) + + r2 := mock.NewRows([]string{"id"}). + AddRow(5) + + mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){4}\\)").WillReturnRows(r1) + mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){1}\\)").WillReturnRows(r2) + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3, 4, 5} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assert.NoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }, iKeyIDs, limit) + assert.NoError(t, err, "Call returned an error") + assert.Equal(t, v, result, "Result is not as expected") + assert.Len(t, result, len(v), "Result should be 3 long") +} + +func TestShouldREturnErrorIfRowsScanReturnsError(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + limit := uint(4) + + r := mock.NewRows([]string{"id"}). + AddRow("hej"). + AddRow(2). + AddRow(3) + + mock.ExpectQuery("SELECT id WHERE id IN \\((\\$[0-9]{1,4},?\\s?){3}\\)").WillReturnRows(r) + q := "SELECT id WHERE id IN ($1)" + v := []int{-1, -2, 3} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]uint, 0) + err = RunLimitedVariablesQuery(ctx, q, db, func(rows *sql.Rows) error { + for rows.Next() { + var id uint + err = rows.Scan(&id) + if err != nil { + return err + } + result = append(result, id) + } + return nil + }, iKeyIDs, limit) + assert.Error(t, err, "Call did not return an error") +} diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go index f756ef5e3..efea9af7e 100644 --- a/serverkeyapi/storage/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -18,9 +18,8 @@ package sqlite3 import ( "context" "database/sql" - "strings" + "fmt" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -88,48 +87,49 @@ func (s *serverKeyStatements) bulkSelectServerKeys( ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { - var nameAndKeyIDs []string + nameAndKeyIDs := make([]string, 0, len(requests)) for request := range requests { nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) } - - query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1) - + results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests)) iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) for i, v := range nameAndKeyIDs { iKeyIDs[i] = v } - rows, err := s.db.QueryContext(ctx, query, iKeyIDs...) + err := sqlutil.RunLimitedVariablesQuery(ctx, bulkSelectServerKeysSQL, s.db, + func(rows *sql.Rows) error { + for rows.Next() { + var serverName string + var keyID string + var key string + var validUntilTS int64 + var expiredTS int64 + if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { + return fmt.Errorf("bulkSelectServerKeys: %v", err) + } + r := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: gomatrixserverlib.ServerName(serverName), + KeyID: gomatrixserverlib.KeyID(keyID), + } + vk := gomatrixserverlib.VerifyKey{} + err := vk.Key.Decode(key) + if err != nil { + return fmt.Errorf("bulkSelectServerKeys: %v", err) + } + results[r] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: vk, + ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), + ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), + } + } + return nil + }, + iKeyIDs, sqlutil.SQLite3MaxVariables) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed") - results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} - for rows.Next() { - var serverName string - var keyID string - var key string - var validUntilTS int64 - var expiredTS int64 - if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { - return nil, err - } - r := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: gomatrixserverlib.ServerName(serverName), - KeyID: gomatrixserverlib.KeyID(keyID), - } - vk := gomatrixserverlib.VerifyKey{} - err = vk.Key.Decode(key) - if err != nil { - return nil, err - } - results[r] = gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: vk, - ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), - ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), - } - } return results, nil }