This commit is contained in:
Neil Alexander 2020-06-18 16:38:28 +01:00
parent 1a24e38034
commit be235028b9
3 changed files with 40 additions and 26 deletions

View file

@ -16,6 +16,7 @@ package routing
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
@ -46,15 +47,26 @@ func GetAccountData(
dataRes := api.QueryAccountDataResponse{}
if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
return util.ErrorResponse(fmt.Errorf("userAPI.QueryAccountData: %w", err))
}
var data json.RawMessage
var ok bool
if roomID != "" {
data, ok = dataRes.RoomAccountData[roomID][dataType]
} else {
data, ok = dataRes.GlobalAccountData[dataType]
}
if ok {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.Forbidden("data not found"),
Code: http.StatusOK,
JSON: data,
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: dataRes.RoomAccountData,
Code: http.StatusNotFound,
JSON: jsonerror.Forbidden("data not found"),
}
}

View file

@ -221,7 +221,6 @@ func (rp *RequestPool) appendAccountData(
},
)
}
for r, j := range data.Rooms.Join {
for datatype, databody := range res.RoomAccountData[r] {
j.AccountData.Events = append(
@ -233,7 +232,6 @@ func (rp *RequestPool) appendAccountData(
)
}
}
return data, nil
}
@ -264,35 +262,37 @@ func (rp *RequestPool) appendAccountData(
for roomID, dataTypes := range dataTypes {
// Request the missing data from the database
for _, dataType := range dataTypes {
var res userapi.QueryAccountDataResponse
err = rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{
dataReq := userapi.QueryAccountDataRequest{
UserID: userID,
RoomID: roomID,
DataType: dataType,
}, &res)
}
dataRes := userapi.QueryAccountDataResponse{}
err = rp.userAPI.QueryAccountData(req.ctx, &dataReq, &dataRes)
if err != nil {
return nil, err
continue
}
for t, d := range res.GlobalAccountData {
data.AccountData.Events = append(
data.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: t,
Content: gomatrixserverlib.RawJSON(d),
},
)
}
for r, byRoom := range res.RoomAccountData {
for t, d := range byRoom {
joinData := data.Rooms.Join[r]
if roomID == "" {
if globalData, ok := dataRes.GlobalAccountData[dataType]; ok {
data.AccountData.Events = append(
data.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: dataType,
Content: gomatrixserverlib.RawJSON(globalData),
},
)
}
} else {
if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok {
joinData := data.Rooms.Join[roomID]
joinData.AccountData.Events = append(
joinData.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: t,
Content: gomatrixserverlib.RawJSON(d),
Type: dataType,
Content: gomatrixserverlib.RawJSON(roomData),
},
)
data.Rooms.Join[r] = joinData
data.Rooms.Join[roomID] = joinData
}
}
}

View file

@ -137,6 +137,8 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
}
func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error {
res.GlobalAccountData = make(map[string]json.RawMessage)
res.RoomAccountData = make(map[string]map[string]json.RawMessage)
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return err
@ -153,7 +155,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
if data != nil {
if req.RoomID != "" {
if _, ok := res.RoomAccountData[req.RoomID]; !ok {
res.RoomAccountData[req.RoomID] = map[string]json.RawMessage{}
res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage)
}
res.RoomAccountData[req.RoomID][req.DataType] = data
} else {