selectAccountDataByType return ClientEvent pointer instead of slice of ClientEvent (#798)

This pull request is an attempt to fix #773.

Signed-off-by: Kouame Behouba Manassé behouba@gmail.com
This commit is contained in:
Behouba Manassé 2019-09-30 19:25:04 +03:00 committed by Andrew Morgan
parent 7b454bdd27
commit 49fd47c863
4 changed files with 18 additions and 29 deletions

View file

@ -120,28 +120,17 @@ func (s *accountDataStatements) selectAccountData(
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data []gomatrixserverlib.ClientEvent, err error) { ) (data *gomatrixserverlib.ClientEvent, err error) {
data = []gomatrixserverlib.ClientEvent{}
stmt := s.selectAccountDataByTypeStmt stmt := s.selectAccountDataByTypeStmt
rows, err := stmt.QueryContext(ctx, localpart, roomID, dataType) var content []byte
if err != nil {
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
return return
} }
for rows.Next() { data = &gomatrixserverlib.ClientEvent{
var content []byte Type: dataType,
Content: content,
if err = rows.Scan(&content); err != nil {
return
}
ac := gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
data = append(data, ac)
} }
return return

View file

@ -263,11 +263,11 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
// localpart, room ID and type. // localpart, room ID and type.
// If no account data could be found, returns an empty array // If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType( func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data []gomatrixserverlib.ClientEvent, err error) { ) (data *gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType( return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType, ctx, localpart, roomID, dataType,
) )

View file

@ -59,7 +59,7 @@ func GetTags(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
if len(data) == 0 { if data == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: struct{}{}, JSON: struct{}{},
@ -68,7 +68,7 @@ func GetTags(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: data[0].Content, JSON: data.Content,
} }
} }
@ -103,8 +103,8 @@ func PutTag(
} }
var tagContent gomatrix.TagContent var tagContent gomatrix.TagContent
if len(data) > 0 { if data != nil {
if err = json.Unmarshal(data[0].Content, &tagContent); err != nil { if err = json.Unmarshal(data.Content, &tagContent); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
} else { } else {
@ -155,7 +155,7 @@ func DeleteTag(
} }
// If there are no tags in the database, exit // If there are no tags in the database, exit
if len(data) == 0 { if data == nil {
// Spec only defines 200 responses for this endpoint so we don't return anything else. // Spec only defines 200 responses for this endpoint so we don't return anything else.
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -164,7 +164,7 @@ func DeleteTag(
} }
var tagContent gomatrix.TagContent var tagContent gomatrix.TagContent
err = json.Unmarshal(data[0].Content, &tagContent) err = json.Unmarshal(data.Content, &tagContent)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -204,7 +204,7 @@ func obtainSavedTags(
userID string, userID string,
roomID string, roomID string,
accountDB *accounts.Database, accountDB *accounts.Database,
) (string, []gomatrixserverlib.ClientEvent, error) { ) (string, *gomatrixserverlib.ClientEvent, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return "", nil, err return "", nil, err

View file

@ -196,13 +196,13 @@ func (rp *RequestPool) appendAccountData(
events := []gomatrixserverlib.ClientEvent{} events := []gomatrixserverlib.ClientEvent{}
// Request the missing data from the database // Request the missing data from the database
for _, dataType := range dataTypes { for _, dataType := range dataTypes {
evs, err := rp.accountDB.GetAccountDataByType( event, err := rp.accountDB.GetAccountDataByType(
req.ctx, localpart, roomID, dataType, req.ctx, localpart, roomID, dataType,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
events = append(events, evs...) events = append(events, *event)
} }
// Append the data to the response // Append the data to the response