diff --git a/userapi/storage/postgres/stats_table.go b/userapi/storage/postgres/stats_table.go index 5fec65fb2..a808b8283 100644 --- a/userapi/storage/postgres/stats_table.go +++ b/userapi/storage/postgres/stats_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -119,16 +120,18 @@ const countUserByAccountTypeSQL = ` SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1) ` +// $1 = All non guest AccountType IDs +// $2 = Guest AccountType const countRegisteredUserByTypeStmt = ` SELECT user_type, COUNT(*) AS count FROM ( SELECT CASE - WHEN account_type<>2 AND appservice_id IS NULL THEN 'native' - WHEN account_type=2 AND appservice_id IS NULL THEN 'guest' - WHEN account_type<>2 AND appservice_id IS NOT NULL THEN 'bridged' + WHEN account_type = ANY($1) AND appservice_id IS NULL THEN 'native' + WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest' + WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged' END AS user_type FROM account_accounts - WHERE created_ts > $1 + WHERE created_ts > $3 ) AS t GROUP BY user_type ` @@ -202,7 +205,12 @@ func (s *statsStatements) startTimers() { func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { stmt := sqlutil.TxStmt(txn, s.countUserByAccountTypeStmt) err = stmt.QueryRowContext(ctx, - pq.Int64Array{1, 2, 3, 4}, + pq.Int64Array{ + int64(api.AccountTypeUser), + int64(api.AccountTypeGuest), + int64(api.AccountTypeAdmin), + int64(api.AccountTypeAppService), + }, ).Scan(&result) return } @@ -210,7 +218,11 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { stmt := sqlutil.TxStmt(txn, s.countUserByAccountTypeStmt) err = stmt.QueryRowContext(ctx, - pq.Int64Array{1, 2, 3}, + pq.Int64Array{ + int64(api.AccountTypeUser), + int64(api.AccountTypeGuest), + int64(api.AccountTypeAdmin), + }, ).Scan(&result) return } @@ -220,6 +232,12 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) registeredAfter := time.Now().AddDate(0, 0, -1) rows, err := stmt.QueryContext(ctx, + pq.Int64Array{ + int64(api.AccountTypeUser), + int64(api.AccountTypeAdmin), + int64(api.AccountTypeAppService), + }, + api.AccountTypeGuest, gomatrixserverlib.AsTimestamp(registeredAfter), ) if err != nil { diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index 688f099b7..b6e5e3bd0 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -119,16 +120,18 @@ const countUserByAccountTypeSQL = ` SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1) ` +// $1 = Guest AccountType +// $3 & $4 = All non guest AccountType IDs const countRegisteredUserByTypeStmt = ` SELECT user_type, COUNT(*) AS count FROM ( SELECT CASE - WHEN account_type<>2 AND appservice_id IS NULL THEN 'native' - WHEN account_type=2 AND appservice_id IS NULL THEN 'guest' - WHEN account_type<>2 AND appservice_id IS NOT NULL THEN 'bridged' + WHEN account_type IN ($3) AND appservice_id IS NULL THEN 'native' + WHEN account_type = $1 AND appservice_id IS NULL THEN 'guest' + WHEN account_type IN ($4) AND appservice_id IS NOT NULL THEN 'bridged' END AS user_type FROM account_accounts - WHERE created_ts > $1 + WHERE created_ts > $2 ) AS t GROUP BY user_type ` @@ -228,12 +231,29 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res } func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) (map[string]int64, error) { - stmt := sqlutil.TxStmt(txn, s.countRegisteredUserByTypeStmt) + // $1 = Guest AccountType; $2 = timestamp + // $3 & $4 = All non guest AccountType IDs + nonGuests := []api.AccountType{api.AccountTypeUser, api.AccountTypeAdmin, api.AccountTypeAppService} + countSQL := strings.Replace(countRegisteredUserByTypeStmt, "($3)", sqlutil.QueryVariadicOffset(len(nonGuests), 2), 1) + countSQL = strings.Replace(countSQL, "($4)", sqlutil.QueryVariadicOffset(len(nonGuests), 2+len(nonGuests)), 1) + + countRegisterdUserByType, err := txn.Prepare(countSQL) + if err != nil { + return nil, err + } + stmt := sqlutil.TxStmt(txn, countRegisterdUserByType) registeredAfter := time.Now().AddDate(0, 0, -1) - rows, err := stmt.QueryContext(ctx, - gomatrixserverlib.AsTimestamp(registeredAfter), - ) + params := make([]interface{}, len(nonGuests)*2+2) + params[0] = api.AccountTypeGuest // $1 + params[1] = gomatrixserverlib.AsTimestamp(registeredAfter) // $2 + // nonGuests is used twice + for i, v := range nonGuests { + params[i+2] = v // i: 2 3 4 => ($3, $4, $5) + params[i+2+len(nonGuests)] = v // i: 5 6 7 => ($6, $7, $8) + } + + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err }