Refactor account data

This commit is contained in:
Neil Alexander 2020-06-18 15:21:03 +01:00
parent 3547a1768c
commit b8ae5f5f81
12 changed files with 172 additions and 139 deletions

View file

@ -22,15 +22,13 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
func GetAccountData( func GetAccountData(
req *http.Request, accountDB accounts.Database, device *api.Device, req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
userID string, roomID string, dataType string, userID string, roomID string, dataType string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -40,30 +38,29 @@ func GetAccountData(
} }
} }
localpart, _, err := gomatrixserverlib.SplitID('@', userID) dataReq := api.QueryAccountDataRequest{
if err != nil { UserID: userID,
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") DataType: dataType,
return jsonerror.InternalServerError() RoomID: roomID,
} }
dataRes := api.QueryAccountDataResponse{}
if data, err := accountDB.GetAccountDataByType( if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil {
req.Context(), localpart, roomID, dataType, util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
); err == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusNotFound,
JSON: data.Content, JSON: jsonerror.Forbidden("data not found"),
} }
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusOK,
JSON: jsonerror.Forbidden("data not found"), JSON: dataRes.RoomAccountData,
} }
} }
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
func SaveAccountData( func SaveAccountData(
req *http.Request, accountDB accounts.Database, device *api.Device, req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -73,12 +70,6 @@ func SaveAccountData(
} }
} }
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
defer req.Body.Close() // nolint: errcheck defer req.Body.Close() // nolint: errcheck
if req.Body == http.NoBody { if req.Body == http.NoBody {
@ -101,16 +92,19 @@ func SaveAccountData(
} }
} }
if err := accountDB.SaveAccountData( dataReq := api.InputAccountDataRequest{
req.Context(), localpart, roomID, dataType, string(body), UserID: userID,
); err != nil { DataType: dataType,
util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed") RoomID: roomID,
return jsonerror.InternalServerError() AccountData: json.RawMessage(body),
} }
dataRes := api.InputAccountDataResponse{}
if err := syncProducer.SendData(userID, roomID, dataType); err != nil { if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
return jsonerror.InternalServerError() return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.Forbidden("data not found"),
}
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -69,7 +69,7 @@ func GetTags(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: data.Content, JSON: data,
} }
} }
@ -106,7 +106,7 @@ func PutTag(
var tagContent gomatrix.TagContent var tagContent gomatrix.TagContent
if data != nil { if data != nil {
if err = json.Unmarshal(data.Content, &tagContent); err != nil { if err = json.Unmarshal(data, &tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -169,7 +169,7 @@ func DeleteTag(
} }
var tagContent gomatrix.TagContent var tagContent gomatrix.TagContent
err = json.Unmarshal(data.Content, &tagContent) err = json.Unmarshal(data, &tagContent)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -211,7 +211,7 @@ func obtainSavedTags(
userID string, userID string,
roomID string, roomID string,
accountDB accounts.Database, accountDB accounts.Database,
) (string, *gomatrixserverlib.ClientEvent, error) { ) (string, json.RawMessage, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@ -237,5 +237,5 @@ func saveTagData(
return err return err
} }
return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData)) return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", json.RawMessage(newTagData))
} }

View file

@ -476,7 +476,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer) return SaveAccountData(req, userAPI, device, vars["userID"], "", vars["type"], syncProducer)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
@ -486,7 +486,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) return SaveAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"], syncProducer)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
@ -496,7 +496,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetAccountData(req, accountDB, device, vars["userID"], "", vars["type"]) return GetAccountData(req, userAPI, device, vars["userID"], "", vars["type"])
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
@ -506,7 +506,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"]) return GetAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"])
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)

View file

@ -212,12 +212,25 @@ func (rp *RequestPool) appendAccountData(
if err != nil { if err != nil {
return nil, err return nil, err
} }
data.AccountData.Events = res.GlobalAccountData for datatype, databody := range res.GlobalAccountData {
data.AccountData.Events = append(
data.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: datatype,
Content: gomatrixserverlib.RawJSON(databody),
},
)
}
for r, j := range data.Rooms.Join { for r, j := range data.Rooms.Join {
if len(res.RoomAccountData[r]) > 0 { for datatype, databody := range res.RoomAccountData[r] {
j.AccountData.Events = res.RoomAccountData[r] j.AccountData.Events = append(
data.Rooms.Join[r] = j j.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: datatype,
Content: gomatrixserverlib.RawJSON(databody),
},
)
} }
} }
@ -249,7 +262,6 @@ func (rp *RequestPool) appendAccountData(
// Iterate over the rooms // Iterate over the rooms
for roomID, dataTypes := range dataTypes { for roomID, dataTypes := range dataTypes {
events := []gomatrixserverlib.ClientEvent{}
// Request the missing data from the database // Request the missing data from the database
for _, dataType := range dataTypes { for _, dataType := range dataTypes {
var res userapi.QueryAccountDataResponse var res userapi.QueryAccountDataResponse
@ -261,20 +273,28 @@ func (rp *RequestPool) appendAccountData(
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(res.RoomAccountData[roomID]) > 0 { for t, d := range res.GlobalAccountData {
events = append(events, res.RoomAccountData[roomID]...) data.AccountData.Events = append(
} else if len(res.GlobalAccountData) > 0 { data.AccountData.Events,
events = append(events, res.GlobalAccountData...) gomatrixserverlib.ClientEvent{
Type: t,
Content: gomatrixserverlib.RawJSON(d),
},
)
}
for r, byRoom := range res.RoomAccountData {
for t, d := range byRoom {
joinData := data.Rooms.Join[r]
joinData.AccountData.Events = append(
joinData.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: t,
Content: gomatrixserverlib.RawJSON(d),
},
)
data.Rooms.Join[r] = joinData
}
} }
}
// Append the data to the response
if len(roomID) > 0 {
jr := data.Rooms.Join[roomID]
jr.AccountData.Events = events
data.Rooms.Join[roomID] = jr
} else {
data.AccountData.Events = events
} }
} }

View file

@ -16,12 +16,14 @@ package api
import ( import (
"context" "context"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// UserInternalAPI is the internal API for information about users and devices. // UserInternalAPI is the internal API for information about users and devices.
type UserInternalAPI interface { type UserInternalAPI interface {
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
@ -30,6 +32,18 @@ type UserInternalAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
} }
// InputAccountDataRequest is the request for InputAccountData
type InputAccountDataRequest struct {
UserID string // required: the user to set account data for
RoomID string // optional: the room to associate the account data with
DataType string // required: the data type of the data
AccountData json.RawMessage // required: the message content
}
// InputAccountDataResponse is the response for InputAccountData
type InputAccountDataResponse struct {
}
// QueryAccessTokenRequest is the request for QueryAccessToken // QueryAccessTokenRequest is the request for QueryAccessToken
type QueryAccessTokenRequest struct { type QueryAccessTokenRequest struct {
AccessToken string AccessToken string
@ -46,18 +60,15 @@ type QueryAccessTokenResponse struct {
// QueryAccountDataRequest is the request for QueryAccountData // QueryAccountDataRequest is the request for QueryAccountData
type QueryAccountDataRequest struct { type QueryAccountDataRequest struct {
UserID string // required: the user to get account data for. UserID string // required: the user to get account data for.
// TODO: This is a terribly confusing API shape :/ RoomID string // optional: the room ID, or global account data if not specified.
DataType string // optional: if specified returns only a single event matching this data type. DataType string // optional: the data type, or all types if not specified.
// optional: Only used if DataType is set. If blank returns global account data matching the data type.
// If set, returns only room account data matching this data type.
RoomID string
} }
// QueryAccountDataResponse is the response for QueryAccountData // QueryAccountDataResponse is the response for QueryAccountData
type QueryAccountDataResponse struct { type QueryAccountDataResponse struct {
GlobalAccountData []gomatrixserverlib.ClientEvent GlobalAccountData map[string]json.RawMessage // type -> data
RoomAccountData map[string][]gomatrixserverlib.ClientEvent RoomAccountData map[string]map[string]json.RawMessage // room -> type -> data
} }
// QueryDevicesRequest is the request for QueryDevices // QueryDevicesRequest is the request for QueryDevices

View file

@ -17,6 +17,7 @@ package internal
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -38,6 +39,20 @@ type UserInternalAPI struct {
AppServices []config.ApplicationService AppServices []config.ApplicationService
} }
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return err
}
if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
}
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
if req.AccountType == api.AccountTypeGuest { if req.AccountType == api.AccountTypeGuest {
acc, err := a.AccountDB.CreateGuestAccount(ctx) acc, err := a.AccountDB.CreateGuestAccount(ctx)
@ -130,17 +145,19 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName)
} }
if req.DataType != "" { if req.DataType != "" {
var event *gomatrixserverlib.ClientEvent var data json.RawMessage
event, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil { if err != nil {
return err return err
} }
if event != nil { if data != nil {
if req.RoomID != "" { if req.RoomID != "" {
res.RoomAccountData = make(map[string][]gomatrixserverlib.ClientEvent) if _, ok := res.RoomAccountData[req.RoomID]; !ok {
res.RoomAccountData[req.RoomID] = []gomatrixserverlib.ClientEvent{*event} res.RoomAccountData[req.RoomID] = map[string]json.RawMessage{}
}
res.RoomAccountData[req.RoomID][req.DataType] = data
} else { } else {
res.GlobalAccountData = append(res.GlobalAccountData, *event) res.GlobalAccountData[req.DataType] = data
} }
} }
return nil return nil

View file

@ -26,6 +26,8 @@ import (
// HTTP paths for the internal HTTP APIs // HTTP paths for the internal HTTP APIs
const ( const (
InputAccountDataPath = "/userapi/inputAccountData"
PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation" PerformAccountCreationPath = "/userapi/performAccountCreation"
@ -55,6 +57,14 @@ type httpUserInternalAPI struct {
httpClient *http.Client httpClient *http.Client
} }
func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData")
defer span.Finish()
apiURL := h.apiURL + InputAccountDataPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformAccountCreation( func (h *httpUserInternalAPI) PerformAccountCreation(
ctx context.Context, ctx context.Context,
request *api.PerformAccountCreationRequest, request *api.PerformAccountCreationRequest,

View file

@ -16,6 +16,7 @@ package accounts
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -39,13 +40,13 @@ type Database interface {
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
// localpart, room ID and type. // localpart, room ID and type.
// If no account data could be found, returns nil // If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error) GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error) GetNewNumericLocalpart(ctx context.Context) (int64, error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)

View file

@ -17,9 +17,9 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -73,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
stmt := txn.Stmt(s.insertAccountDataStmt) stmt := txn.Stmt(s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
@ -83,8 +83,8 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
global []gomatrixserverlib.ClientEvent, global map[string]json.RawMessage,
rooms map[string][]gomatrixserverlib.ClientEvent, rooms map[string]map[string]json.RawMessage,
err error, err error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
@ -93,50 +93,40 @@ func (s *accountDataStatements) selectAccountData(
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
global = []gomatrixserverlib.ClientEvent{} global = map[string]json.RawMessage{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent) rooms = map[string]map[string]json.RawMessage{}
for rows.Next() { for rows.Next() {
var roomID string var roomID string
var dataType string var dataType string
var content []byte var content json.RawMessage
if err = rows.Scan(&roomID, &dataType, &content); err != nil { if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return return
} }
ac := gomatrixserverlib.ClientEvent{ if roomID != "" {
Type: dataType, if _, ok := rooms[roomID]; !ok {
Content: content, rooms[roomID] = map[string]json.RawMessage{}
} }
rooms[roomID][dataType] = content
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
} else { } else {
global = append(global, ac) global[dataType] = content
} }
} }
return global, rooms, rows.Err() return global, rooms, rows.Err()
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) { ) (data json.RawMessage, err error) {
stmt := s.selectAccountDataByTypeStmt stmt := s.selectAccountDataByTypeStmt
var content []byte if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&data); err != nil {
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return return
} }
data = &gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
return return
} }

View file

@ -17,6 +17,7 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"strconv" "strconv"
@ -169,7 +170,7 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -177,7 +178,7 @@ func (d *Database) createAccount(
"sender": [], "sender": [],
"underride": [] "underride": []
} }
}`); err != nil { }`)); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@ -295,7 +296,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content // update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update // Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@ -306,8 +307,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays // If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) ( func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent, global map[string]json.RawMessage,
rooms map[string][]gomatrixserverlib.ClientEvent, rooms map[string]map[string]json.RawMessage,
err error, err error,
) { ) {
return d.accountDatas.selectAccountData(ctx, localpart) return d.accountDatas.selectAccountData(ctx, localpart)
@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType( func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) { ) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType( return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType, ctx, localpart, roomID, dataType,
) )

View file

@ -17,8 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return return
@ -81,8 +80,8 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
global []gomatrixserverlib.ClientEvent, global map[string]json.RawMessage,
rooms map[string][]gomatrixserverlib.ClientEvent, rooms map[string]map[string]json.RawMessage,
err error, err error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
@ -90,8 +89,8 @@ func (s *accountDataStatements) selectAccountData(
return return
} }
global = []gomatrixserverlib.ClientEvent{} global = map[string]json.RawMessage{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent) rooms = map[string]map[string]json.RawMessage{}
for rows.Next() { for rows.Next() {
var roomID string var roomID string
@ -102,15 +101,13 @@ func (s *accountDataStatements) selectAccountData(
return return
} }
ac := gomatrixserverlib.ClientEvent{ if roomID != "" {
Type: dataType, if _, ok := rooms[roomID]; !ok {
Content: content, rooms[roomID] = map[string]json.RawMessage{}
} }
rooms[roomID][dataType] = content
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
} else { } else {
global = append(global, ac) global[dataType] = content
} }
} }
@ -119,22 +116,13 @@ func (s *accountDataStatements) selectAccountData(
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) { ) (data json.RawMessage, err error) {
stmt := s.selectAccountDataByTypeStmt stmt := s.selectAccountDataByTypeStmt
var content []byte if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&data); err != nil {
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return return
} }
data = &gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
return return
} }

View file

@ -17,6 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"strconv" "strconv"
"sync" "sync"
@ -180,7 +181,7 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -188,7 +189,7 @@ func (d *Database) createAccount(
"sender": [], "sender": [],
"underride": [] "underride": []
} }
}`); err != nil { }`)); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@ -306,7 +307,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content // update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update // Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@ -317,8 +318,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays // If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) ( func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent, global map[string]json.RawMessage,
rooms map[string][]gomatrixserverlib.ClientEvent, rooms map[string]map[string]json.RawMessage,
err error, err error,
) { ) {
return d.accountDatas.selectAccountData(ctx, localpart) return d.accountDatas.selectAccountData(ctx, localpart)
@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType( func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) { ) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType( return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType, ctx, localpart, roomID, dataType,
) )