This commit is contained in:
Neil Alexander 2020-06-18 16:38:28 +01:00
parent 1a24e38034
commit be235028b9
3 changed files with 40 additions and 26 deletions

View file

@ -16,6 +16,7 @@ package routing
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -46,15 +47,26 @@ func GetAccountData(
dataRes := api.QueryAccountDataResponse{} dataRes := api.QueryAccountDataResponse{}
if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil { if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
return util.ErrorResponse(fmt.Errorf("userAPI.QueryAccountData: %w", err))
}
var data json.RawMessage
var ok bool
if roomID != "" {
data, ok = dataRes.RoomAccountData[roomID][dataType]
} else {
data, ok = dataRes.GlobalAccountData[dataType]
}
if ok {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusOK,
JSON: jsonerror.Forbidden("data not found"), JSON: data,
} }
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusNotFound,
JSON: dataRes.RoomAccountData, JSON: jsonerror.Forbidden("data not found"),
} }
} }

View file

@ -221,7 +221,6 @@ func (rp *RequestPool) appendAccountData(
}, },
) )
} }
for r, j := range data.Rooms.Join { for r, j := range data.Rooms.Join {
for datatype, databody := range res.RoomAccountData[r] { for datatype, databody := range res.RoomAccountData[r] {
j.AccountData.Events = append( j.AccountData.Events = append(
@ -233,7 +232,6 @@ func (rp *RequestPool) appendAccountData(
) )
} }
} }
return data, nil return data, nil
} }
@ -264,35 +262,37 @@ func (rp *RequestPool) appendAccountData(
for roomID, dataTypes := range dataTypes { for roomID, dataTypes := range dataTypes {
// Request the missing data from the database // Request the missing data from the database
for _, dataType := range dataTypes { for _, dataType := range dataTypes {
var res userapi.QueryAccountDataResponse dataReq := userapi.QueryAccountDataRequest{
err = rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{
UserID: userID, UserID: userID,
RoomID: roomID, RoomID: roomID,
DataType: dataType, DataType: dataType,
}, &res) }
dataRes := userapi.QueryAccountDataResponse{}
err = rp.userAPI.QueryAccountData(req.ctx, &dataReq, &dataRes)
if err != nil { if err != nil {
return nil, err continue
} }
for t, d := range res.GlobalAccountData { if roomID == "" {
data.AccountData.Events = append( if globalData, ok := dataRes.GlobalAccountData[dataType]; ok {
data.AccountData.Events, data.AccountData.Events = append(
gomatrixserverlib.ClientEvent{ data.AccountData.Events,
Type: t, gomatrixserverlib.ClientEvent{
Content: gomatrixserverlib.RawJSON(d), Type: dataType,
}, Content: gomatrixserverlib.RawJSON(globalData),
) },
} )
for r, byRoom := range res.RoomAccountData { }
for t, d := range byRoom { } else {
joinData := data.Rooms.Join[r] if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok {
joinData := data.Rooms.Join[roomID]
joinData.AccountData.Events = append( joinData.AccountData.Events = append(
joinData.AccountData.Events, joinData.AccountData.Events,
gomatrixserverlib.ClientEvent{ gomatrixserverlib.ClientEvent{
Type: t, Type: dataType,
Content: gomatrixserverlib.RawJSON(d), Content: gomatrixserverlib.RawJSON(roomData),
}, },
) )
data.Rooms.Join[r] = joinData data.Rooms.Join[roomID] = joinData
} }
} }
} }

View file

@ -137,6 +137,8 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
} }
func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error {
res.GlobalAccountData = make(map[string]json.RawMessage)
res.RoomAccountData = make(map[string]map[string]json.RawMessage)
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return err return err
@ -153,7 +155,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
if data != nil { if data != nil {
if req.RoomID != "" { if req.RoomID != "" {
if _, ok := res.RoomAccountData[req.RoomID]; !ok { if _, ok := res.RoomAccountData[req.RoomID]; !ok {
res.RoomAccountData[req.RoomID] = map[string]json.RawMessage{} res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage)
} }
res.RoomAccountData[req.RoomID][req.DataType] = data res.RoomAccountData[req.RoomID][req.DataType] = data
} else { } else {