Various tweaks, update tag behaviour

This commit is contained in:
Neil Alexander 2020-06-18 17:42:51 +01:00
parent 35d5d00668
commit 59d7edaf6a
8 changed files with 83 additions and 114 deletions

View file

@ -104,11 +104,6 @@ func SaveAccountData(
} }
} }
if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError()
}
dataReq := api.InputAccountDataRequest{ dataReq := api.InputAccountDataRequest{
UserID: userID, UserID: userID,
DataType: dataType, DataType: dataType,
@ -118,10 +113,13 @@ func SaveAccountData(
dataRes := api.InputAccountDataResponse{} dataRes := api.InputAccountDataResponse{}
if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil { if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
return util.JSONResponse{ return util.ErrorResponse(err)
Code: http.StatusNotFound,
JSON: jsonerror.Forbidden("data not found"),
} }
// TODO: user API should do this since it's account data
if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -24,23 +24,14 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// newTag creates and returns a new gomatrix.TagContent
func newTag() gomatrix.TagContent {
return gomatrix.TagContent{
Tags: make(map[string]gomatrix.TagProperties),
}
}
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
func GetTags( func GetTags(
req *http.Request, req *http.Request,
accountDB accounts.Database, userAPI api.UserInternalAPI,
device *api.Device, device *api.Device,
userID string, userID string,
roomID string, roomID string,
@ -54,22 +45,15 @@ func GetTags(
} }
} }
_, data, err := obtainSavedTags(req, userID, roomID, accountDB) tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if data == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: struct{}{}, JSON: tagContent,
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: data,
} }
} }
@ -78,7 +62,7 @@ func GetTags(
// the tag to the "map" and saving the new "map" to the DB // the tag to the "map" and saving the new "map" to the DB
func PutTag( func PutTag(
req *http.Request, req *http.Request,
accountDB accounts.Database, userAPI api.UserInternalAPI,
device *api.Device, device *api.Device,
userID string, userID string,
roomID string, roomID string,
@ -98,34 +82,25 @@ func PutTag(
return *reqErr return *reqErr
} }
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
var tagContent gomatrix.TagContent if tagContent.Tags == nil {
if data != nil { tagContent.Tags = make(map[string]gomatrix.TagProperties)
if err = json.Unmarshal(data, &tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
}
} else {
tagContent = newTag()
} }
tagContent.Tags[tag] = properties tagContent.Tags[tag] = properties
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// Send data to syncProducer in order to inform clients of changes if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
// Run in a goroutine in order to prevent blocking the tag request response
go func() {
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
} }
}()
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -138,7 +113,7 @@ func PutTag(
// the "map" and then saving the new "map" in the DB // the "map" and then saving the new "map" in the DB
func DeleteTag( func DeleteTag(
req *http.Request, req *http.Request,
accountDB accounts.Database, userAPI api.UserInternalAPI,
device *api.Device, device *api.Device,
userID string, userID string,
roomID string, roomID string,
@ -153,28 +128,12 @@ func DeleteTag(
} }
} }
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// If there are no tags in the database, exit
if data == nil {
// Spec only defines 200 responses for this endpoint so we don't return anything else.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
var tagContent gomatrix.TagContent
err = json.Unmarshal(data, &tagContent)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
}
// Check whether the tag to be deleted exists // Check whether the tag to be deleted exists
if _, ok := tagContent.Tags[tag]; ok { if _, ok := tagContent.Tags[tag]; ok {
delete(tagContent.Tags, tag) delete(tagContent.Tags, tag)
@ -185,18 +144,16 @@ func DeleteTag(
JSON: struct{}{}, JSON: struct{}{},
} }
} }
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// Send data to syncProducer in order to inform clients of changes // TODO: user API should do this since it's account data
// Run in a goroutine in order to prevent blocking the tag request response
go func() {
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
} }
}()
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -210,32 +167,46 @@ func obtainSavedTags(
req *http.Request, req *http.Request,
userID string, userID string,
roomID string, roomID string,
accountDB accounts.Database, userAPI api.UserInternalAPI,
) (string, json.RawMessage, error) { ) (tags gomatrix.TagContent, err error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) dataReq := api.QueryAccountDataRequest{
if err != nil { UserID: userID,
return "", nil, err RoomID: roomID,
DataType: "m.tag",
} }
dataRes := api.QueryAccountDataResponse{}
data, err := accountDB.GetAccountDataByType( err = userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes)
req.Context(), localpart, roomID, "m.tag", if err != nil {
) return
}
return localpart, data, err data, ok := dataRes.RoomAccountData[roomID]["m.tag"]
if !ok {
return
}
if err = json.Unmarshal(data, &tags); err != nil {
return
}
return tags, nil
} }
// saveTagData saves the provided tag data into the database // saveTagData saves the provided tag data into the database
func saveTagData( func saveTagData(
req *http.Request, req *http.Request,
localpart string, userID string,
roomID string, roomID string,
accountDB accounts.Database, userAPI api.UserInternalAPI,
Tag gomatrix.TagContent, Tag gomatrix.TagContent,
) error { ) error {
newTagData, err := json.Marshal(Tag) newTagData, err := json.Marshal(Tag)
if err != nil { if err != nil {
return err return err
} }
dataReq := api.InputAccountDataRequest{
return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", json.RawMessage(newTagData)) UserID: userID,
RoomID: roomID,
DataType: "m.tag",
AccountData: json.RawMessage(newTagData),
}
dataRes := api.InputAccountDataResponse{}
return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes)
} }

View file

@ -604,7 +604,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer) return GetTags(req, userAPI, device, vars["userId"], vars["roomId"], syncProducer)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -614,7 +614,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
@ -624,7 +624,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}), }),
).Methods(http.MethodDelete, http.MethodOptions) ).Methods(http.MethodDelete, http.MethodOptions)

View file

@ -205,14 +205,14 @@ func (rp *RequestPool) appendAccountData(
if req.since == nil { if req.since == nil {
// If this is the initial sync, we don't need to check if a data has // If this is the initial sync, we don't need to check if a data has
// already been sent. Instead, we send the whole batch. // already been sent. Instead, we send the whole batch.
var res userapi.QueryAccountDataResponse dataReq := &userapi.QueryAccountDataRequest{
err := rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{
UserID: userID, UserID: userID,
}, &res) }
if err != nil { dataRes := &userapi.QueryAccountDataResponse{}
if err := rp.userAPI.QueryAccountData(req.ctx, dataReq, dataRes); err != nil {
return nil, err return nil, err
} }
for datatype, databody := range res.GlobalAccountData { for datatype, databody := range dataRes.GlobalAccountData {
data.AccountData.Events = append( data.AccountData.Events = append(
data.AccountData.Events, data.AccountData.Events,
gomatrixserverlib.ClientEvent{ gomatrixserverlib.ClientEvent{
@ -222,7 +222,7 @@ func (rp *RequestPool) appendAccountData(
) )
} }
for r, j := range data.Rooms.Join { for r, j := range data.Rooms.Join {
for datatype, databody := range res.RoomAccountData[r] { for datatype, databody := range dataRes.RoomAccountData[r] {
j.AccountData.Events = append( j.AccountData.Events = append(
j.AccountData.Events, j.AccountData.Events,
gomatrixserverlib.ClientEvent{ gomatrixserverlib.ClientEvent{

View file

@ -36,7 +36,7 @@ type UserInternalAPI interface {
type InputAccountDataRequest struct { type InputAccountDataRequest struct {
UserID string // required: the user to set account data for UserID string // required: the user to set account data for
RoomID string // optional: the room to associate the account data with RoomID string // optional: the room to associate the account data with
DataType string // required: the data type of the data DataType string // optional: the data type of the data
AccountData json.RawMessage // required: the message content AccountData json.RawMessage // required: the message content
} }

View file

@ -137,8 +137,6 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
} }
func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error {
res.GlobalAccountData = make(map[string]json.RawMessage)
res.RoomAccountData = make(map[string]map[string]json.RawMessage)
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return err return err
@ -152,6 +150,8 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
if err != nil { if err != nil {
return err return err
} }
res.RoomAccountData = make(map[string]map[string]json.RawMessage)
res.GlobalAccountData = make(map[string]json.RawMessage)
if data != nil { if data != nil {
if req.RoomID != "" { if req.RoomID != "" {
if _, ok := res.RoomAccountData[req.RoomID]; !ok { if _, ok := res.RoomAccountData[req.RoomID]; !ok {

View file

@ -83,18 +83,18 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
global map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage,
err error, error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil { if err != nil {
return return nil, nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
global = map[string]json.RawMessage{} global := map[string]json.RawMessage{}
rooms = map[string]map[string]json.RawMessage{} rooms := map[string]map[string]json.RawMessage{}
for rows.Next() { for rows.Next() {
var roomID string var roomID string
@ -102,7 +102,7 @@ func (s *accountDataStatements) selectAccountData(
var content []byte var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil { if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return return nil, nil, err
} }
if roomID != "" { if roomID != "" {

View file

@ -80,17 +80,17 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
global map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage,
err error, error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil { if err != nil {
return return nil, nil, err
} }
global = map[string]json.RawMessage{} global := map[string]json.RawMessage{}
rooms = map[string]map[string]json.RawMessage{} rooms := map[string]map[string]json.RawMessage{}
for rows.Next() { for rows.Next() {
var roomID string var roomID string
@ -98,7 +98,7 @@ func (s *accountDataStatements) selectAccountData(
var content []byte var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil { if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return return nil, nil, err
} }
if roomID != "" { if roomID != "" {
@ -111,7 +111,7 @@ func (s *accountDataStatements) selectAccountData(
} }
} }
return return global, rooms, nil
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) selectAccountDataByType(