From 1a24e380344b862bb67d5a6e2cce2389b060f5f1 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 18 Jun 2020 15:39:32 +0100 Subject: [PATCH] Tweak database fetching --- userapi/storage/accounts/postgres/account_data_table.go | 6 ++++-- userapi/storage/accounts/sqlite3/account_data_table.go | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go index 7eb5841be..5fabadc67 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -99,7 +99,7 @@ func (s *accountDataStatements) selectAccountData( for rows.Next() { var roomID string var dataType string - var content json.RawMessage + var content []byte if err = rows.Scan(&roomID, &dataType, &content); err != nil { return @@ -121,12 +121,14 @@ func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { + var bytes []byte stmt := s.selectAccountDataByTypeStmt - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&data); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } return } + data = json.RawMessage(bytes) return } diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index 3227ee248..78f03a3fa 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -117,12 +117,14 @@ func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { + var bytes []byte stmt := s.selectAccountDataByTypeStmt - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&data); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } return } + data = json.RawMessage(bytes) return }