Request DB only once per request

This commit is contained in:
Brendan Abolivier 2017-07-25 16:37:02 +01:00
parent 620d11ea21
commit 0bb1995f4f
No known key found for this signature in database
GPG key ID: 8EF1500759F70623
3 changed files with 33 additions and 23 deletions

View file

@ -42,7 +42,7 @@ const insertAccountDataSQL = `
` `
const selectAccountDataSQL = "" + const selectAccountDataSQL = "" +
"SELECT type, content FROM account_data WHERE localpart = $1 AND room_id = $2" "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
const deleteAccountDataSQL = "" + const deleteAccountDataSQL = "" +
"DELETE FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" "DELETE FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
@ -71,28 +71,38 @@ func (s *accountDataStatements) insertAccountData(localpart string, roomID strin
return return
} }
func (s *accountDataStatements) selectAccountData(localpart string, roomID string) ([]gomatrixserverlib.ClientEvent, error) { func (s *accountDataStatements) selectAccountData(localpart string) (
events := []gomatrixserverlib.ClientEvent{} global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
rows, err := s.selectAccountDataStmt.Query(localpart, roomID) err error,
) {
rows, err := s.selectAccountDataStmt.Query(localpart)
if err != nil { if err != nil {
return events, err return
} }
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
for rows.Next() { for rows.Next() {
var roomID string
var dataType string var dataType string
var content []byte var content []byte
if err := rows.Scan(&dataType, &content); err != nil && err != sql.ErrNoRows { if err = rows.Scan(&roomID, &dataType, &content); err != nil && err != sql.ErrNoRows {
return events, err return
} }
ac := gomatrixserverlib.ClientEvent{ ac := gomatrixserverlib.ClientEvent{
Type: dataType, Type: dataType,
Content: content, Content: content,
} }
events = append(events, ac)
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
} else {
global = append(global, ac)
}
} }
return events, nil return
} }

View file

@ -199,14 +199,14 @@ func (d *Database) SaveAccountData(localpart string, roomID string, dataType str
} }
// GetAccountData returns account data related to a given localpart // GetAccountData returns account data related to a given localpart
// If a non-empty string is provided as the room ID, returns all account data
// related to the room
// If an empty string is provided as the room ID, returns all account data that
// aren't related to any room
// If no account data could be found, returns an empty arrays // If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(localpart string, roomID string) ([]gomatrixserverlib.ClientEvent, error) { func (d *Database) GetAccountData(localpart string) (
return d.accountDatas.selectAccountData(localpart, roomID) global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(localpart)
} }
func hashPassword(plaintext string) (hash string, err error) { func hashPassword(plaintext string) (hash string, err error) {

View file

@ -114,24 +114,24 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.Stre
} }
func (rp *RequestPool) appendAccountData(data *types.Response, userID string) (*types.Response, error) { func (rp *RequestPool) appendAccountData(data *types.Response, userID string) (*types.Response, error) {
// TODO: We currently send all account data on every sync response, we should instead send data
// that has changed on incremental sync responses
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
events, err := rp.accountDB.GetAccountData(localpart, "") global, rooms, err := rp.accountDB.GetAccountData(localpart)
if err != nil { if err != nil {
return nil, err return nil, err
} }
data.AccountData.Events = events data.AccountData.Events = global
for r, j := range data.Rooms.Join { for r, j := range data.Rooms.Join {
events, err := rp.accountDB.GetAccountData(localpart, r) if len(rooms[r]) > 0 {
if err != nil { j.AccountData.Events = rooms[r]
return nil, err data.Rooms.Join[r] = j
} }
j.AccountData.Events = events
data.Rooms.Join[r] = j
} }
return data, nil return data, nil