From b5c8fc922c26f96695be8234ad15ec6ed092514f Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Thu, 10 Feb 2022 18:09:23 +0100 Subject: [PATCH] Add missing AccountType to return value --- userapi/internal/api.go | 1 + userapi/storage/accounts/postgres/accounts_table.go | 5 ++--- userapi/storage/accounts/sqlite3/accounts_table.go | 4 +--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 46e5ff200..f96d4804c 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -78,6 +78,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P Localpart: req.Localpart, ServerName: a.ServerName, UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + AccountType: req.AccountType, } return nil } diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 636451b5f..e6687b136 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -119,6 +119,7 @@ func (s *accountsStatements) insertAccount( UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, AppServiceID: appserviceID, + AccountType: accountType, }, nil } @@ -148,10 +149,9 @@ func (s *accountsStatements) selectAccountByLocalpart( ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account - var accType api.AccountType stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &accType) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -164,7 +164,6 @@ func (s *accountsStatements) selectAccountByLocalpart( acc.UserID = userutil.MakeUserID(localpart, s.serverName) acc.ServerName = s.serverName - acc.AccountType = accType return &acc, nil } diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 6f63a8023..05af80018 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -148,10 +148,9 @@ func (s *accountsStatements) selectAccountByLocalpart( ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account - var accType api.AccountType stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &accType) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -164,7 +163,6 @@ func (s *accountsStatements) selectAccountByLocalpart( acc.UserID = userutil.MakeUserID(localpart, s.serverName) acc.ServerName = s.serverName - acc.AccountType = accType return &acc, nil }