Be more explicit when querying

This commit is contained in:
Till Faelligen 2022-04-07 10:33:17 +02:00
parent 27c76e3f89
commit 82a734ae30
2 changed files with 52 additions and 14 deletions

View file

@ -22,6 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -119,16 +120,18 @@ const countUserByAccountTypeSQL = `
SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1) SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1)
` `
// $1 = All non guest AccountType IDs
// $2 = Guest AccountType
const countRegisteredUserByTypeStmt = ` const countRegisteredUserByTypeStmt = `
SELECT user_type, COUNT(*) AS count FROM ( SELECT user_type, COUNT(*) AS count FROM (
SELECT SELECT
CASE CASE
WHEN account_type<>2 AND appservice_id IS NULL THEN 'native' WHEN account_type = ANY($1) AND appservice_id IS NULL THEN 'native'
WHEN account_type=2 AND appservice_id IS NULL THEN 'guest' WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest'
WHEN account_type<>2 AND appservice_id IS NOT NULL THEN 'bridged' WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type END AS user_type
FROM account_accounts FROM account_accounts
WHERE created_ts > $1 WHERE created_ts > $3
) AS t GROUP BY user_type ) AS t GROUP BY user_type
` `
@ -202,7 +205,12 @@ func (s *statsStatements) startTimers() {
func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) {
stmt := sqlutil.TxStmt(txn, s.countUserByAccountTypeStmt) stmt := sqlutil.TxStmt(txn, s.countUserByAccountTypeStmt)
err = stmt.QueryRowContext(ctx, err = stmt.QueryRowContext(ctx,
pq.Int64Array{1, 2, 3, 4}, pq.Int64Array{
int64(api.AccountTypeUser),
int64(api.AccountTypeGuest),
int64(api.AccountTypeAdmin),
int64(api.AccountTypeAppService),
},
).Scan(&result) ).Scan(&result)
return return
} }
@ -210,7 +218,11 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int
func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) {
stmt := sqlutil.TxStmt(txn, s.countUserByAccountTypeStmt) stmt := sqlutil.TxStmt(txn, s.countUserByAccountTypeStmt)
err = stmt.QueryRowContext(ctx, err = stmt.QueryRowContext(ctx,
pq.Int64Array{1, 2, 3}, pq.Int64Array{
int64(api.AccountTypeUser),
int64(api.AccountTypeGuest),
int64(api.AccountTypeAdmin),
},
).Scan(&result) ).Scan(&result)
return return
} }
@ -220,6 +232,12 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx)
registeredAfter := time.Now().AddDate(0, 0, -1) registeredAfter := time.Now().AddDate(0, 0, -1)
rows, err := stmt.QueryContext(ctx, rows, err := stmt.QueryContext(ctx,
pq.Int64Array{
int64(api.AccountTypeUser),
int64(api.AccountTypeAdmin),
int64(api.AccountTypeAppService),
},
api.AccountTypeGuest,
gomatrixserverlib.AsTimestamp(registeredAfter), gomatrixserverlib.AsTimestamp(registeredAfter),
) )
if err != nil { if err != nil {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -119,16 +120,18 @@ const countUserByAccountTypeSQL = `
SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1) SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1)
` `
// $1 = Guest AccountType
// $3 & $4 = All non guest AccountType IDs
const countRegisteredUserByTypeStmt = ` const countRegisteredUserByTypeStmt = `
SELECT user_type, COUNT(*) AS count FROM ( SELECT user_type, COUNT(*) AS count FROM (
SELECT SELECT
CASE CASE
WHEN account_type<>2 AND appservice_id IS NULL THEN 'native' WHEN account_type IN ($3) AND appservice_id IS NULL THEN 'native'
WHEN account_type=2 AND appservice_id IS NULL THEN 'guest' WHEN account_type = $1 AND appservice_id IS NULL THEN 'guest'
WHEN account_type<>2 AND appservice_id IS NOT NULL THEN 'bridged' WHEN account_type IN ($4) AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type END AS user_type
FROM account_accounts FROM account_accounts
WHERE created_ts > $1 WHERE created_ts > $2
) AS t GROUP BY user_type ) AS t GROUP BY user_type
` `
@ -228,12 +231,29 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res
} }
func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) (map[string]int64, error) { func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) (map[string]int64, error) {
stmt := sqlutil.TxStmt(txn, s.countRegisteredUserByTypeStmt) // $1 = Guest AccountType; $2 = timestamp
// $3 & $4 = All non guest AccountType IDs
nonGuests := []api.AccountType{api.AccountTypeUser, api.AccountTypeAdmin, api.AccountTypeAppService}
countSQL := strings.Replace(countRegisteredUserByTypeStmt, "($3)", sqlutil.QueryVariadicOffset(len(nonGuests), 2), 1)
countSQL = strings.Replace(countSQL, "($4)", sqlutil.QueryVariadicOffset(len(nonGuests), 2+len(nonGuests)), 1)
countRegisterdUserByType, err := txn.Prepare(countSQL)
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, countRegisterdUserByType)
registeredAfter := time.Now().AddDate(0, 0, -1) registeredAfter := time.Now().AddDate(0, 0, -1)
rows, err := stmt.QueryContext(ctx, params := make([]interface{}, len(nonGuests)*2+2)
gomatrixserverlib.AsTimestamp(registeredAfter), params[0] = api.AccountTypeGuest // $1
) params[1] = gomatrixserverlib.AsTimestamp(registeredAfter) // $2
// nonGuests is used twice
for i, v := range nonGuests {
params[i+2] = v // i: 2 3 4 => ($3, $4, $5)
params[i+2+len(nonGuests)] = v // i: 5 6 7 => ($6, $7, $8)
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }