From b8ae5f5f81e3cb9e03b0fb4e1ac1a10066d39dd0 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 18 Jun 2020 15:21:03 +0100 Subject: [PATCH] Refactor account data --- clientapi/routing/account_data.go | 56 +++++++++---------- clientapi/routing/room_tagging.go | 10 ++-- clientapi/routing/routing.go | 8 +-- syncapi/sync/requestpool.go | 56 +++++++++++++------ userapi/api/api.go | 27 ++++++--- userapi/internal/api.go | 29 ++++++++-- userapi/inthttp/client.go | 10 ++++ userapi/storage/accounts/interface.go | 7 ++- .../accounts/postgres/account_data_table.go | 42 ++++++-------- userapi/storage/accounts/postgres/storage.go | 13 +++-- .../accounts/sqlite3/account_data_table.go | 40 +++++-------- userapi/storage/accounts/sqlite3/storage.go | 13 +++-- 12 files changed, 172 insertions(+), 139 deletions(-) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 68e0dc5da..c4f53ad7a 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -22,15 +22,13 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} func GetAccountData( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, dataType string, ) util.JSONResponse { if userID != device.UserID { @@ -40,30 +38,29 @@ func GetAccountData( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + dataReq := api.QueryAccountDataRequest{ + UserID: userID, + DataType: dataType, + RoomID: roomID, } - - if data, err := accountDB.GetAccountDataByType( - req.Context(), localpart, roomID, dataType, - ); err == nil { + dataRes := api.QueryAccountDataResponse{} + if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") return util.JSONResponse{ - Code: http.StatusOK, - JSON: data.Content, + Code: http.StatusNotFound, + JSON: jsonerror.Forbidden("data not found"), } } return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.Forbidden("data not found"), + Code: http.StatusOK, + JSON: dataRes.RoomAccountData, } } // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} func SaveAccountData( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, ) util.JSONResponse { if userID != device.UserID { @@ -73,12 +70,6 @@ func SaveAccountData( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - defer req.Body.Close() // nolint: errcheck if req.Body == http.NoBody { @@ -101,16 +92,19 @@ func SaveAccountData( } } - if err := accountDB.SaveAccountData( - req.Context(), localpart, roomID, dataType, string(body), - ); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed") - return jsonerror.InternalServerError() + dataReq := api.InputAccountDataRequest{ + UserID: userID, + DataType: dataType, + RoomID: roomID, + AccountData: json.RawMessage(body), } - - if err := syncProducer.SendData(userID, roomID, dataType); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") - return jsonerror.InternalServerError() + dataRes := api.InputAccountDataResponse{} + if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.Forbidden("data not found"), + } } return util.JSONResponse{ diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index b1cfcca86..e6271ab16 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -69,7 +69,7 @@ func GetTags( return util.JSONResponse{ Code: http.StatusOK, - JSON: data.Content, + JSON: data, } } @@ -106,7 +106,7 @@ func PutTag( var tagContent gomatrix.TagContent if data != nil { - if err = json.Unmarshal(data.Content, &tagContent); err != nil { + if err = json.Unmarshal(data, &tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") return jsonerror.InternalServerError() } @@ -169,7 +169,7 @@ func DeleteTag( } var tagContent gomatrix.TagContent - err = json.Unmarshal(data.Content, &tagContent) + err = json.Unmarshal(data, &tagContent) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") return jsonerror.InternalServerError() @@ -211,7 +211,7 @@ func obtainSavedTags( userID string, roomID string, accountDB accounts.Database, -) (string, *gomatrixserverlib.ClientEvent, error) { +) (string, json.RawMessage, error) { localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return "", nil, err @@ -237,5 +237,5 @@ func saveTagData( return err } - return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData)) + return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", json.RawMessage(newTagData)) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 41c7fb18e..b478f10af 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -476,7 +476,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer) + return SaveAccountData(req, userAPI, device, vars["userID"], "", vars["type"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -486,7 +486,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) + return SaveAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -496,7 +496,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetAccountData(req, accountDB, device, vars["userID"], "", vars["type"]) + return GetAccountData(req, userAPI, device, vars["userID"], "", vars["type"]) }), ).Methods(http.MethodGet) @@ -506,7 +506,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"]) + return GetAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"]) }), ).Methods(http.MethodGet) diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 26b925eac..ebb83e17a 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -212,12 +212,25 @@ func (rp *RequestPool) appendAccountData( if err != nil { return nil, err } - data.AccountData.Events = res.GlobalAccountData + for datatype, databody := range res.GlobalAccountData { + data.AccountData.Events = append( + data.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: datatype, + Content: gomatrixserverlib.RawJSON(databody), + }, + ) + } for r, j := range data.Rooms.Join { - if len(res.RoomAccountData[r]) > 0 { - j.AccountData.Events = res.RoomAccountData[r] - data.Rooms.Join[r] = j + for datatype, databody := range res.RoomAccountData[r] { + j.AccountData.Events = append( + j.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: datatype, + Content: gomatrixserverlib.RawJSON(databody), + }, + ) } } @@ -249,7 +262,6 @@ func (rp *RequestPool) appendAccountData( // Iterate over the rooms for roomID, dataTypes := range dataTypes { - events := []gomatrixserverlib.ClientEvent{} // Request the missing data from the database for _, dataType := range dataTypes { var res userapi.QueryAccountDataResponse @@ -261,20 +273,28 @@ func (rp *RequestPool) appendAccountData( if err != nil { return nil, err } - if len(res.RoomAccountData[roomID]) > 0 { - events = append(events, res.RoomAccountData[roomID]...) - } else if len(res.GlobalAccountData) > 0 { - events = append(events, res.GlobalAccountData...) + for t, d := range res.GlobalAccountData { + data.AccountData.Events = append( + data.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: t, + Content: gomatrixserverlib.RawJSON(d), + }, + ) + } + for r, byRoom := range res.RoomAccountData { + for t, d := range byRoom { + joinData := data.Rooms.Join[r] + joinData.AccountData.Events = append( + joinData.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: t, + Content: gomatrixserverlib.RawJSON(d), + }, + ) + data.Rooms.Join[r] = joinData + } } - } - - // Append the data to the response - if len(roomID) > 0 { - jr := data.Rooms.Join[roomID] - jr.AccountData.Events = events - data.Rooms.Join[roomID] = jr - } else { - data.AccountData.Events = events } } diff --git a/userapi/api/api.go b/userapi/api/api.go index c953a5bac..cf0f05633 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -16,12 +16,14 @@ package api import ( "context" + "encoding/json" "github.com/matrix-org/gomatrixserverlib" ) // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { + InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -30,6 +32,18 @@ type UserInternalAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error } +// InputAccountDataRequest is the request for InputAccountData +type InputAccountDataRequest struct { + UserID string // required: the user to set account data for + RoomID string // optional: the room to associate the account data with + DataType string // required: the data type of the data + AccountData json.RawMessage // required: the message content +} + +// InputAccountDataResponse is the response for InputAccountData +type InputAccountDataResponse struct { +} + // QueryAccessTokenRequest is the request for QueryAccessToken type QueryAccessTokenRequest struct { AccessToken string @@ -46,18 +60,15 @@ type QueryAccessTokenResponse struct { // QueryAccountDataRequest is the request for QueryAccountData type QueryAccountDataRequest struct { - UserID string // required: the user to get account data for. - // TODO: This is a terribly confusing API shape :/ - DataType string // optional: if specified returns only a single event matching this data type. - // optional: Only used if DataType is set. If blank returns global account data matching the data type. - // If set, returns only room account data matching this data type. - RoomID string + UserID string // required: the user to get account data for. + RoomID string // optional: the room ID, or global account data if not specified. + DataType string // optional: the data type, or all types if not specified. } // QueryAccountDataResponse is the response for QueryAccountData type QueryAccountDataResponse struct { - GlobalAccountData []gomatrixserverlib.ClientEvent - RoomAccountData map[string][]gomatrixserverlib.ClientEvent + GlobalAccountData map[string]json.RawMessage // type -> data + RoomAccountData map[string]map[string]json.RawMessage // room -> type -> data } // QueryDevicesRequest is the request for QueryDevices diff --git a/userapi/internal/api.go b/userapi/internal/api.go index ae021f575..493cd5680 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -17,6 +17,7 @@ package internal import ( "context" "database/sql" + "encoding/json" "errors" "fmt" @@ -38,6 +39,20 @@ type UserInternalAPI struct { AppServices []config.ApplicationService } +func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + } + if req.DataType == "" { + return fmt.Errorf("data type must not be empty") + } + return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) +} + func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { if req.AccountType == api.AccountTypeGuest { acc, err := a.AccountDB.CreateGuestAccount(ctx) @@ -130,17 +145,19 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) } if req.DataType != "" { - var event *gomatrixserverlib.ClientEvent - event, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + var data json.RawMessage + data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) if err != nil { return err } - if event != nil { + if data != nil { if req.RoomID != "" { - res.RoomAccountData = make(map[string][]gomatrixserverlib.ClientEvent) - res.RoomAccountData[req.RoomID] = []gomatrixserverlib.ClientEvent{*event} + if _, ok := res.RoomAccountData[req.RoomID]; !ok { + res.RoomAccountData[req.RoomID] = map[string]json.RawMessage{} + } + res.RoomAccountData[req.RoomID][req.DataType] = data } else { - res.GlobalAccountData = append(res.GlobalAccountData, *event) + res.GlobalAccountData[req.DataType] = data } } return nil diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 0e9628c58..4ab0d690e 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -26,6 +26,8 @@ import ( // HTTP paths for the internal HTTP APIs const ( + InputAccountDataPath = "/userapi/inputAccountData" + PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" @@ -55,6 +57,14 @@ type httpUserInternalAPI struct { httpClient *http.Client } +func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData") + defer span.Finish() + + apiURL := h.apiURL + InputAccountDataPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpUserInternalAPI) PerformAccountCreation( ctx context.Context, request *api.PerformAccountCreationRequest, diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 13e3e2895..c6692879b 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -16,6 +16,7 @@ package accounts import ( "context" + "encoding/json" "errors" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -39,13 +40,13 @@ type Database interface { GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) - SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error - GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) + SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error + GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given // localpart, room ID and type. // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval - GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error) + GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) GetNewNumericLocalpart(ctx context.Context) (int64, error) SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go index 2f16c5c02..7eb5841be 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -17,9 +17,9 @@ package postgres import ( "context" "database/sql" + "encoding/json" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -73,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { stmt := txn.Stmt(s.insertAccountDataStmt) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) @@ -83,8 +83,8 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountData( ctx context.Context, localpart string, ) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, err error, ) { rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) @@ -93,50 +93,40 @@ func (s *accountDataStatements) selectAccountData( } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed") - global = []gomatrixserverlib.ClientEvent{} - rooms = make(map[string][]gomatrixserverlib.ClientEvent) + global = map[string]json.RawMessage{} + rooms = map[string]map[string]json.RawMessage{} for rows.Next() { var roomID string var dataType string - var content []byte + var content json.RawMessage if err = rows.Scan(&roomID, &dataType, &content); err != nil { return } - ac := gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - - if len(roomID) > 0 { - rooms[roomID] = append(rooms[roomID], ac) + if roomID != "" { + if _, ok := rooms[roomID]; !ok { + rooms[roomID] = map[string]json.RawMessage{} + } + rooms[roomID][dataType] = content } else { - global = append(global, ac) + global[dataType] = content } } + return global, rooms, rows.Err() } func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { stmt := s.selectAccountDataByTypeStmt - var content []byte - - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&data); err != nil { if err == sql.ErrNoRows { return nil, nil } - return } - - data = &gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - return } diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 2b88cb70a..e55099800 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "encoding/json" "errors" "strconv" @@ -169,7 +170,7 @@ func (d *Database) createAccount( return nil, err } - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -177,7 +178,7 @@ func (d *Database) createAccount( "sender": [], "underride": [] } - }`); err != nil { + }`)); err != nil { return nil, err } return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) @@ -295,7 +296,7 @@ func (d *Database) newMembership( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) @@ -306,8 +307,8 @@ func (d *Database) SaveAccountData( // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, err error, ) { return d.accountDatas.selectAccountData(ctx, localpart) @@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { return d.accountDatas.selectAccountDataByType( ctx, localpart, roomID, dataType, ) diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index b6bb63617..3227ee248 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -17,8 +17,7 @@ package sqlite3 import ( "context" "database/sql" - - "github.com/matrix-org/gomatrixserverlib" + "encoding/json" ) const accountDataSchema = ` @@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return @@ -81,8 +80,8 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountData( ctx context.Context, localpart string, ) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, err error, ) { rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) @@ -90,8 +89,8 @@ func (s *accountDataStatements) selectAccountData( return } - global = []gomatrixserverlib.ClientEvent{} - rooms = make(map[string][]gomatrixserverlib.ClientEvent) + global = map[string]json.RawMessage{} + rooms = map[string]map[string]json.RawMessage{} for rows.Next() { var roomID string @@ -102,15 +101,13 @@ func (s *accountDataStatements) selectAccountData( return } - ac := gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - - if len(roomID) > 0 { - rooms[roomID] = append(rooms[roomID], ac) + if roomID != "" { + if _, ok := rooms[roomID]; !ok { + rooms[roomID] = map[string]json.RawMessage{} + } + rooms[roomID][dataType] = content } else { - global = append(global, ac) + global[dataType] = content } } @@ -119,22 +116,13 @@ func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { stmt := s.selectAccountDataByTypeStmt - var content []byte - - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&data); err != nil { if err == sql.ErrNoRows { return nil, nil } - return } - - data = &gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - return } diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 4dd755a70..dbf6606c3 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "encoding/json" "errors" "strconv" "sync" @@ -180,7 +181,7 @@ func (d *Database) createAccount( return nil, err } - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -188,7 +189,7 @@ func (d *Database) createAccount( "sender": [], "underride": [] } - }`); err != nil { + }`)); err != nil { return nil, err } return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) @@ -306,7 +307,7 @@ func (d *Database) newMembership( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) @@ -317,8 +318,8 @@ func (d *Database) SaveAccountData( // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, err error, ) { return d.accountDatas.selectAccountData(ctx, localpart) @@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { return d.accountDatas.selectAccountDataByType( ctx, localpart, roomID, dataType, )