Try to optimize SelectOneTimeKeys (#1851)

* Try to optimize SelectOneTimeKeys

Signed-off-by: Till Faelligen <tfaelligen@gmail.com>

* Use pg.Array when using ANY...

Co-authored-by: Kegsay <kegan@matrix.org>
This commit is contained in:
S7evinK 2021-06-07 10:17:46 +02:00 committed by GitHub
parent 8b22c4270d
commit 89a6787fdb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
@ -47,7 +48,7 @@ const upsertKeysSQL = "" +
" DO UPDATE SET key_json = $6" " DO UPDATE SET key_json = $6"
const selectKeysSQL = "" + const selectKeysSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
const selectKeysCountSQL = "" + const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
@ -94,29 +95,22 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
} }
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
for _, ka := range keyIDsWithAlgorithms {
wantSet[ka] = true
}
result := make(map[string]json.RawMessage) result := make(map[string]json.RawMessage)
var (
algorithmWithID string
keyJSONStr string
)
for rows.Next() { for rows.Next() {
var keyID string if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
var algorithm string
var keyJSONStr string
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
return nil, err return nil, err
} }
keyIDWithAlgo := algorithm + ":" + keyID result[algorithmWithID] = json.RawMessage(keyJSONStr)
if wantSet[keyIDWithAlgo] {
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
}
} }
return result, rows.Err() return result, rows.Err()
} }