Merge branch 'master' into issue-1392

This commit is contained in:
Sam 2020-09-14 17:50:22 +02:00 committed by GitHub
commit 96ce9aeb5a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 254 additions and 33 deletions

1
go.mod
View file

@ -1,6 +1,7 @@
module github.com/matrix-org/dendrite module github.com/matrix-org/dendrite
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/Shopify/sarama v1.27.0 github.com/Shopify/sarama v1.27.0
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
github.com/gologme/log v1.2.0 github.com/gologme/log v1.2.0

2
go.sum
View file

@ -13,6 +13,8 @@ github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0 h1:p3puK8Sl2xK+2Fnn
github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI= github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y= github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=

View file

@ -20,6 +20,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"runtime" "runtime"
"strings"
"github.com/matrix-org/util"
) )
// ErrUserExists is returned if a username already exists in the database. // ErrUserExists is returned if a username already exists in the database.
@ -116,3 +119,44 @@ func SQLiteDriverName() string {
} }
return "sqlite3" return "sqlite3"
} }
func minOfInts(a, b int) int {
if a <= b {
return a
}
return b
}
// QueryProvider defines the interface for querys used by RunLimitedVariablesQuery.
type QueryProvider interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement
// SQLlite can handle. See https://www.sqlite.org/limits.html for more information.
const SQLite3MaxVariables = 999
// RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries.
func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, variables []interface{}, limit uint, rowHandler func(*sql.Rows) error) error {
var start int
for start < len(variables) {
n := minOfInts(len(variables)-start, int(limit))
nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1)
rows, err := qp.QueryContext(ctx, nextQuery, variables[start:start+n]...)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryContext returned an error")
return err
}
err = rowHandler(rows)
if closeErr := rows.Close(); closeErr != nil {
util.GetLogger(ctx).WithError(closeErr).Error("RunLimitedVariablesQuery: failed to close rows")
return err
}
if err != nil {
util.GetLogger(ctx).WithError(err).Error("RunLimitedVariablesQuery: rowHandler returned error")
return err
}
start = start + n
}
return nil
}

View file

@ -0,0 +1,173 @@
package sqlutil
import (
"context"
"database/sql"
"reflect"
"testing"
sqlmock "github.com/DATA-DOG/go-sqlmock"
)
func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
r := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assertNoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
})
assertNoError(t, err, "Call returned an error")
if len(result) != len(v) {
t.Fatalf("Result should be 3 long")
}
}
func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
r := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3).
AddRow(4)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3, 4}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assertNoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
})
assertNoError(t, err, "Call returned an error")
if len(result) != len(v) {
t.Fatalf("Result should be 4 long")
}
}
func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
r1 := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3).
AddRow(4)
r2 := mock.NewRows([]string{"id"}).
AddRow(5)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r1)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1\)`).WillReturnRows(r2)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3, 4, 5}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assertNoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
})
assertNoError(t, err, "Call returned an error")
if len(result) != len(v) {
t.Fatalf("Result should be 5 long")
}
if !reflect.DeepEqual(v, result) {
t.Fatalf("Result is not as expected: got %v want %v", v, result)
}
}
func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
// adding a string ID should result in rows.Scan returning an error
r := mock.NewRows([]string{"id"}).
AddRow("hej").
AddRow(2).
AddRow(3)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{-1, -2, 3}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]uint, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id uint
err = rows.Scan(&id)
if err != nil {
return err
}
result = append(result, id)
}
return nil
})
if err == nil {
t.Fatalf("Call did not return an error")
}
}
func assertNoError(t *testing.T, err error, msg string) {
t.Helper()
if err == nil {
return
}
t.Fatalf(msg)
}

View file

@ -18,9 +18,8 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings" "fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -88,41 +87,36 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context, ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
var nameAndKeyIDs []string nameAndKeyIDs := make([]string, 0, len(requests))
for request := range requests { for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
} }
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1)
iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
for i, v := range nameAndKeyIDs { for i, v := range nameAndKeyIDs {
iKeyIDs[i] = v iKeyIDs[i] = v
} }
rows, err := s.db.QueryContext(ctx, query, iKeyIDs...) err := sqlutil.RunLimitedVariablesQuery(
if err != nil { ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
return nil, err func(rows *sql.Rows) error {
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() { for rows.Next() {
var serverName string var serverName string
var keyID string var keyID string
var key string var key string
var validUntilTS int64 var validUntilTS int64
var expiredTS int64 var expiredTS int64
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return nil, err return fmt.Errorf("bulkSelectServerKeys: %v", err)
} }
r := gomatrixserverlib.PublicKeyLookupRequest{ r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName), ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID), KeyID: gomatrixserverlib.KeyID(keyID),
} }
vk := gomatrixserverlib.VerifyKey{} vk := gomatrixserverlib.VerifyKey{}
err = vk.Key.Decode(key) err := vk.Key.Decode(key)
if err != nil { if err != nil {
return nil, err return fmt.Errorf("bulkSelectServerKeys: %v", err)
} }
results[r] = gomatrixserverlib.PublicKeyLookupResult{ results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk, VerifyKey: vk,
@ -130,6 +124,13 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
} }
} }
return nil
},
)
if err != nil {
return nil, err
}
return results, nil return results, nil
} }