diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index dfaf41bb1..d915246c7 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "time" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" @@ -108,6 +109,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { deviceIDMap[d] = true diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index b4d0e50fb..69fe7a6e4 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -17,9 +17,9 @@ package sqlite3 import ( "context" "database/sql" - "strings" "time" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" @@ -47,7 +47,7 @@ const selectDeviceKeysSQL = "" + "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id IN ($2)" + "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { db *sql.DB @@ -81,16 +81,11 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID for _, d := range deviceIDs { deviceIDMap[d] = true } - iDeviceIDs := make([]interface{}, len(deviceIDs)+1) - iDeviceIDs[0] = userID - for i := range deviceIDs { - iDeviceIDs[i+1] = deviceIDs[i] - } - querySQL := strings.Replace(selectBatchDeviceKeysSQL, "($2)", sqlutil.QueryVariadic(len(deviceIDs)), 1) - rows, err := s.db.QueryContext(ctx, querySQL, iDeviceIDs...) + rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") var result []api.DeviceKeys for rows.Next() { var dk api.DeviceKeys