From 59d7edaf6af7818ed4b083b5d84d35e236c56550 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 18 Jun 2020 17:42:51 +0100 Subject: [PATCH] Various tweaks, update tag behaviour --- clientapi/routing/account_data.go | 16 +-- clientapi/routing/room_tagging.go | 127 +++++++----------- clientapi/routing/routing.go | 6 +- syncapi/sync/requestpool.go | 12 +- userapi/api/api.go | 2 +- userapi/internal/api.go | 4 +- .../accounts/postgres/account_data_table.go | 14 +- .../accounts/sqlite3/account_data_table.go | 16 +-- 8 files changed, 83 insertions(+), 114 deletions(-) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 5e31a5899..d5fafedb1 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -104,11 +104,6 @@ func SaveAccountData( } } - if err := syncProducer.SendData(userID, roomID, dataType); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") - return jsonerror.InternalServerError() - } - dataReq := api.InputAccountDataRequest{ UserID: userID, DataType: dataType, @@ -118,10 +113,13 @@ func SaveAccountData( dataRes := api.InputAccountDataResponse{} if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.Forbidden("data not found"), - } + return util.ErrorResponse(err) + } + + // TODO: user API should do this since it's account data + if err := syncProducer.SendData(userID, roomID, dataType); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") + return jsonerror.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index e6271ab16..c683cc949 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -24,23 +24,14 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -// newTag creates and returns a new gomatrix.TagContent -func newTag() gomatrix.TagContent { - return gomatrix.TagContent{ - Tags: make(map[string]gomatrix.TagProperties), - } -} - // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags func GetTags( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -54,22 +45,15 @@ func GetTags( } } - _, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - if data == nil { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } - } - return util.JSONResponse{ Code: http.StatusOK, - JSON: data, + JSON: tagContent, } } @@ -78,7 +62,7 @@ func GetTags( // the tag to the "map" and saving the new "map" to the DB func PutTag( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -98,34 +82,25 @@ func PutTag( return *reqErr } - localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - var tagContent gomatrix.TagContent - if data != nil { - if err = json.Unmarshal(data, &tagContent); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") - return jsonerror.InternalServerError() - } - } else { - tagContent = newTag() + if tagContent.Tags == nil { + tagContent.Tags = make(map[string]gomatrix.TagProperties) } tagContent.Tags[tag] = properties - if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + + if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } - // Send data to syncProducer in order to inform clients of changes - // Run in a goroutine in order to prevent blocking the tag request response - go func() { - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { - logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") - } - }() + if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } return util.JSONResponse{ Code: http.StatusOK, @@ -138,7 +113,7 @@ func PutTag( // the "map" and then saving the new "map" in the DB func DeleteTag( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -153,28 +128,12 @@ func DeleteTag( } } - localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - // If there are no tags in the database, exit - if data == nil { - // Spec only defines 200 responses for this endpoint so we don't return anything else. - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } - } - - var tagContent gomatrix.TagContent - err = json.Unmarshal(data, &tagContent) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") - return jsonerror.InternalServerError() - } - // Check whether the tag to be deleted exists if _, ok := tagContent.Tags[tag]; ok { delete(tagContent.Tags, tag) @@ -185,18 +144,16 @@ func DeleteTag( JSON: struct{}{}, } } - if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + + if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } - // Send data to syncProducer in order to inform clients of changes - // Run in a goroutine in order to prevent blocking the tag request response - go func() { - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { - logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") - } - }() + // TODO: user API should do this since it's account data + if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } return util.JSONResponse{ Code: http.StatusOK, @@ -210,32 +167,46 @@ func obtainSavedTags( req *http.Request, userID string, roomID string, - accountDB accounts.Database, -) (string, json.RawMessage, error) { - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return "", nil, err + userAPI api.UserInternalAPI, +) (tags gomatrix.TagContent, err error) { + dataReq := api.QueryAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: "m.tag", } - - data, err := accountDB.GetAccountDataByType( - req.Context(), localpart, roomID, "m.tag", - ) - - return localpart, data, err + dataRes := api.QueryAccountDataResponse{} + err = userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes) + if err != nil { + return + } + data, ok := dataRes.RoomAccountData[roomID]["m.tag"] + if !ok { + return + } + if err = json.Unmarshal(data, &tags); err != nil { + return + } + return tags, nil } // saveTagData saves the provided tag data into the database func saveTagData( req *http.Request, - localpart string, + userID string, roomID string, - accountDB accounts.Database, + userAPI api.UserInternalAPI, Tag gomatrix.TagContent, ) error { newTagData, err := json.Marshal(Tag) if err != nil { return err } - - return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", json.RawMessage(newTagData)) + dataReq := api.InputAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: "m.tag", + AccountData: json.RawMessage(newTagData), + } + dataRes := api.InputAccountDataResponse{} + return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index b478f10af..e91b07ac7 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -604,7 +604,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer) + return GetTags(req, userAPI, device, vars["userId"], vars["roomId"], syncProducer) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -614,7 +614,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -624,7 +624,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodDelete, http.MethodOptions) diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 09cf9b065..1a7ff6e53 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -205,14 +205,14 @@ func (rp *RequestPool) appendAccountData( if req.since == nil { // If this is the initial sync, we don't need to check if a data has // already been sent. Instead, we send the whole batch. - var res userapi.QueryAccountDataResponse - err := rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{ + dataReq := &userapi.QueryAccountDataRequest{ UserID: userID, - }, &res) - if err != nil { + } + dataRes := &userapi.QueryAccountDataResponse{} + if err := rp.userAPI.QueryAccountData(req.ctx, dataReq, dataRes); err != nil { return nil, err } - for datatype, databody := range res.GlobalAccountData { + for datatype, databody := range dataRes.GlobalAccountData { data.AccountData.Events = append( data.AccountData.Events, gomatrixserverlib.ClientEvent{ @@ -222,7 +222,7 @@ func (rp *RequestPool) appendAccountData( ) } for r, j := range data.Rooms.Join { - for datatype, databody := range res.RoomAccountData[r] { + for datatype, databody := range dataRes.RoomAccountData[r] { j.AccountData.Events = append( j.AccountData.Events, gomatrixserverlib.ClientEvent{ diff --git a/userapi/api/api.go b/userapi/api/api.go index cf0f05633..a80adf2d8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -36,7 +36,7 @@ type UserInternalAPI interface { type InputAccountDataRequest struct { UserID string // required: the user to set account data for RoomID string // optional: the room to associate the account data with - DataType string // required: the data type of the data + DataType string // optional: the data type of the data AccountData json.RawMessage // required: the message content } diff --git a/userapi/internal/api.go b/userapi/internal/api.go index b970a62c8..b081eca49 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -137,8 +137,6 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice } func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { - res.GlobalAccountData = make(map[string]json.RawMessage) - res.RoomAccountData = make(map[string]map[string]json.RawMessage) local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return err @@ -152,6 +150,8 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc if err != nil { return err } + res.RoomAccountData = make(map[string]map[string]json.RawMessage) + res.GlobalAccountData = make(map[string]json.RawMessage) if data != nil { if req.RoomID != "" { if _, ok := res.RoomAccountData[req.RoomID]; !ok { diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go index 5fabadc67..90c79e878 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -83,18 +83,18 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountData( ctx context.Context, localpart string, ) ( - global map[string]json.RawMessage, - rooms map[string]map[string]json.RawMessage, - err error, + /* global */ map[string]json.RawMessage, + /* rooms */ map[string]map[string]json.RawMessage, + error, ) { rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) if err != nil { - return + return nil, nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed") - global = map[string]json.RawMessage{} - rooms = map[string]map[string]json.RawMessage{} + global := map[string]json.RawMessage{} + rooms := map[string]map[string]json.RawMessage{} for rows.Next() { var roomID string @@ -102,7 +102,7 @@ func (s *accountDataStatements) selectAccountData( var content []byte if err = rows.Scan(&roomID, &dataType, &content); err != nil { - return + return nil, nil, err } if roomID != "" { diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index 78f03a3fa..d048dbd19 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -80,17 +80,17 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountData( ctx context.Context, localpart string, ) ( - global map[string]json.RawMessage, - rooms map[string]map[string]json.RawMessage, - err error, + /* global */ map[string]json.RawMessage, + /* rooms */ map[string]map[string]json.RawMessage, + error, ) { rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) if err != nil { - return + return nil, nil, err } - global = map[string]json.RawMessage{} - rooms = map[string]map[string]json.RawMessage{} + global := map[string]json.RawMessage{} + rooms := map[string]map[string]json.RawMessage{} for rows.Next() { var roomID string @@ -98,7 +98,7 @@ func (s *accountDataStatements) selectAccountData( var content []byte if err = rows.Scan(&roomID, &dataType, &content); err != nil { - return + return nil, nil, err } if roomID != "" { @@ -111,7 +111,7 @@ func (s *accountDataStatements) selectAccountData( } } - return + return global, rooms, nil } func (s *accountDataStatements) selectAccountDataByType(