From 620d11ea21837b35cef867cfc9e5d77d1c86c24c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 25 Jul 2017 14:43:34 +0100 Subject: [PATCH] Merge account data retrieval functions --- .../storage/accounts/account_data_table.go | 45 +++---------------- .../auth/storage/accounts/storage.go | 5 +-- 2 files changed, 7 insertions(+), 43 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go index 0bceb6778..19f3b3ba3 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go @@ -41,19 +41,15 @@ const insertAccountDataSQL = ` ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content ` -const selectGlobalAccountDataSQL = "" + - "SELECT type, content FROM account_data WHERE localpart = $1 AND room_id = ''" - -const selectRoomAccountDataSQL = "" + +const selectAccountDataSQL = "" + "SELECT type, content FROM account_data WHERE localpart = $1 AND room_id = $2" const deleteAccountDataSQL = "" + "DELETE FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" type accountDataStatements struct { - insertAccountDataStmt *sql.Stmt - selectGlobalAccountDataStmt *sql.Stmt - selectRoomAccountDataStmt *sql.Stmt + insertAccountDataStmt *sql.Stmt + selectAccountDataStmt *sql.Stmt } func (s *accountDataStatements) prepare(db *sql.DB) (err error) { @@ -64,10 +60,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { return } - if s.selectGlobalAccountDataStmt, err = db.Prepare(selectGlobalAccountDataSQL); err != nil { - return - } - if s.selectRoomAccountDataStmt, err = db.Prepare(selectRoomAccountDataSQL); err != nil { + if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { return } return @@ -78,36 +71,10 @@ func (s *accountDataStatements) insertAccountData(localpart string, roomID strin return } -func (s *accountDataStatements) selectGlobalAccountData(localpart string) ([]gomatrixserverlib.ClientEvent, error) { +func (s *accountDataStatements) selectAccountData(localpart string, roomID string) ([]gomatrixserverlib.ClientEvent, error) { events := []gomatrixserverlib.ClientEvent{} - rows, err := s.selectGlobalAccountDataStmt.Query(localpart) - if err != nil { - return events, err - } - - for rows.Next() { - var dataType string - var content []byte - - if err := rows.Scan(&dataType, &content); err != nil && err != sql.ErrNoRows { - return events, err - } - - ac := gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - events = append(events, ac) - } - - return events, nil -} - -func (s *accountDataStatements) selectRoomAccountData(localpart string, roomID string) ([]gomatrixserverlib.ClientEvent, error) { - events := []gomatrixserverlib.ClientEvent{} - - rows, err := s.selectRoomAccountDataStmt.Query(localpart, roomID) + rows, err := s.selectAccountDataStmt.Query(localpart, roomID) if err != nil { return events, err } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 75e4ccc58..53b5e53f8 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -206,10 +206,7 @@ func (d *Database) SaveAccountData(localpart string, roomID string, dataType str // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountData(localpart string, roomID string) ([]gomatrixserverlib.ClientEvent, error) { - if len(roomID) > 0 { - return d.accountDatas.selectRoomAccountData(localpart, roomID) - } - return d.accountDatas.selectGlobalAccountData(localpart) + return d.accountDatas.selectAccountData(localpart, roomID) } func hashPassword(plaintext string) (hash string, err error) {