diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go index 19f3b3ba3..af0c5ea9e 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go @@ -42,7 +42,7 @@ const insertAccountDataSQL = ` ` const selectAccountDataSQL = "" + - "SELECT type, content FROM account_data WHERE localpart = $1 AND room_id = $2" + "SELECT room_id, type, content FROM account_data WHERE localpart = $1" const deleteAccountDataSQL = "" + "DELETE FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" @@ -71,28 +71,38 @@ func (s *accountDataStatements) insertAccountData(localpart string, roomID strin return } -func (s *accountDataStatements) selectAccountData(localpart string, roomID string) ([]gomatrixserverlib.ClientEvent, error) { - events := []gomatrixserverlib.ClientEvent{} - - rows, err := s.selectAccountDataStmt.Query(localpart, roomID) +func (s *accountDataStatements) selectAccountData(localpart string) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + rows, err := s.selectAccountDataStmt.Query(localpart) if err != nil { - return events, err + return } + rooms = make(map[string][]gomatrixserverlib.ClientEvent) + for rows.Next() { + var roomID string var dataType string var content []byte - if err := rows.Scan(&dataType, &content); err != nil && err != sql.ErrNoRows { - return events, err + if err = rows.Scan(&roomID, &dataType, &content); err != nil && err != sql.ErrNoRows { + return } ac := gomatrixserverlib.ClientEvent{ Type: dataType, Content: content, } - events = append(events, ac) + + if len(roomID) > 0 { + rooms[roomID] = append(rooms[roomID], ac) + } else { + global = append(global, ac) + } } - return events, nil + return } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 53b5e53f8..b70c685b9 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -199,14 +199,14 @@ func (d *Database) SaveAccountData(localpart string, roomID string, dataType str } // GetAccountData returns account data related to a given localpart -// If a non-empty string is provided as the room ID, returns all account data -// related to the room -// If an empty string is provided as the room ID, returns all account data that -// aren't related to any room // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(localpart string, roomID string) ([]gomatrixserverlib.ClientEvent, error) { - return d.accountDatas.selectAccountData(localpart, roomID) +func (d *Database) GetAccountData(localpart string) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + return d.accountDatas.selectAccountData(localpart) } func hashPassword(plaintext string) (hash string, err error) { diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go index f2dd4e586..953e5f4f6 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go @@ -114,24 +114,24 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.Stre } func (rp *RequestPool) appendAccountData(data *types.Response, userID string) (*types.Response, error) { + // TODO: We currently send all account data on every sync response, we should instead send data + // that has changed on incremental sync responses localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return nil, err } - events, err := rp.accountDB.GetAccountData(localpart, "") + global, rooms, err := rp.accountDB.GetAccountData(localpart) if err != nil { return nil, err } - data.AccountData.Events = events + data.AccountData.Events = global for r, j := range data.Rooms.Join { - events, err := rp.accountDB.GetAccountData(localpart, r) - if err != nil { - return nil, err + if len(rooms[r]) > 0 { + j.AccountData.Events = rooms[r] + data.Rooms.Join[r] = j } - j.AccountData.Events = events - data.Rooms.Join[r] = j } return data, nil