Save and retrieve account data (#166)
* Save function for account data * Fix upsert + add empty routes and function * Save account data * Retrieval functions * Implement retrieval in /sync * Fix arrays not correctly initialised * Merge account data retrieval functions * Request DB only once per request * Initialise array * Fix comment
This commit is contained in:
parent
6d073dcf9f
commit
3e394e9e21
|
@ -0,0 +1,109 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package accounts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const accountDataSchema = `
|
||||||
|
-- Stores data about accounts data.
|
||||||
|
CREATE TABLE IF NOT EXISTS account_data (
|
||||||
|
-- The Matrix user ID localpart for this account
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
-- The room ID for this data (empty string if not specific to a room)
|
||||||
|
room_id TEXT,
|
||||||
|
-- The account data type
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
-- The account data content
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
|
||||||
|
PRIMARY KEY(localpart, room_id, type)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertAccountDataSQL = `
|
||||||
|
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||||
|
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
||||||
|
`
|
||||||
|
|
||||||
|
const selectAccountDataSQL = "" +
|
||||||
|
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||||
|
|
||||||
|
const deleteAccountDataSQL = "" +
|
||||||
|
"DELETE FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||||
|
|
||||||
|
type accountDataStatements struct {
|
||||||
|
insertAccountDataStmt *sql.Stmt
|
||||||
|
selectAccountDataStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(accountDataSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountDataStatements) insertAccountData(localpart string, roomID string, dataType string, content string) (err error) {
|
||||||
|
_, err = s.insertAccountDataStmt.Exec(localpart, roomID, dataType, content)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountDataStatements) selectAccountData(localpart string) (
|
||||||
|
global []gomatrixserverlib.ClientEvent,
|
||||||
|
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
rows, err := s.selectAccountDataStmt.Query(localpart)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
global = []gomatrixserverlib.ClientEvent{}
|
||||||
|
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
var dataType string
|
||||||
|
var content []byte
|
||||||
|
|
||||||
|
if err = rows.Scan(&roomID, &dataType, &content); err != nil && err != sql.ErrNoRows {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ac := gomatrixserverlib.ClientEvent{
|
||||||
|
Type: dataType,
|
||||||
|
Content: content,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(roomID) > 0 {
|
||||||
|
rooms[roomID] = append(rooms[roomID], ac)
|
||||||
|
} else {
|
||||||
|
global = append(global, ac)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
|
@ -32,6 +32,7 @@ type Database struct {
|
||||||
accounts accountsStatements
|
accounts accountsStatements
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
memberships membershipStatements
|
memberships membershipStatements
|
||||||
|
accountDatas accountDataStatements
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +59,11 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
||||||
if err = m.prepare(db); err != nil {
|
if err = m.prepare(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Database{db, partitions, a, p, m, serverName}, nil
|
ac := accountDataStatements{}
|
||||||
|
if err = ac.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Database{db, partitions, a, p, m, ac, serverName}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
|
@ -199,6 +204,26 @@ func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveAccountData saves new account data for a given user and a given room.
|
||||||
|
// If the account data is not specific to a room, the room ID should be an empty string
|
||||||
|
// If an account data already exists for a given set (user, room, data type), it will
|
||||||
|
// update the corresponding row with the new content
|
||||||
|
// Returns a SQL error if there was an issue with the insertion/update
|
||||||
|
func (d *Database) SaveAccountData(localpart string, roomID string, dataType string, content string) error {
|
||||||
|
return d.accountDatas.insertAccountData(localpart, roomID, dataType, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountData returns account data related to a given localpart
|
||||||
|
// If no account data could be found, returns an empty arrays
|
||||||
|
// Returns an error if there was an issue with the retrieval
|
||||||
|
func (d *Database) GetAccountData(localpart string) (
|
||||||
|
global []gomatrixserverlib.ClientEvent,
|
||||||
|
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
return d.accountDatas.selectAccountData(localpart)
|
||||||
|
}
|
||||||
|
|
||||||
func hashPassword(plaintext string) (hash string, err error) {
|
func hashPassword(plaintext string) (hash string, err error) {
|
||||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
|
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
|
||||||
return string(hashBytes), err
|
return string(hashBytes), err
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package readers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
|
||||||
|
func SaveAccountData(
|
||||||
|
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
|
||||||
|
userID string, roomID string, dataType string,
|
||||||
|
) util.JSONResponse {
|
||||||
|
if req.Method != "PUT" {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 405,
|
||||||
|
JSON: jsonerror.NotFound("Bad method"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if userID != device.UserID {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 403,
|
||||||
|
JSON: jsonerror.Forbidden("userID does not match the current user"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer req.Body.Close()
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := accountDB.SaveAccountData(localpart, roomID, dataType, string(body)); err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: struct{}{},
|
||||||
|
}
|
||||||
|
}
|
|
@ -274,9 +274,16 @@ func Setup(
|
||||||
)
|
)
|
||||||
|
|
||||||
r0mux.Handle("/user/{userID}/account_data/{type}",
|
r0mux.Handle("/user/{userID}/account_data/{type}",
|
||||||
common.MakeAPI("user_account_data", func(req *http.Request) util.JSONResponse {
|
common.MakeAuthAPI("user_account_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
||||||
// TODO: Set and get the account_data
|
vars := mux.Vars(req)
|
||||||
return util.JSONResponse{Code: 200, JSON: struct{}{}}
|
return readers.SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"])
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
|
||||||
|
common.MakeAuthAPI("user_account_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
||||||
|
vars := mux.Vars(req)
|
||||||
|
return readers.SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"])
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/common/config"
|
"github.com/matrix-org/dendrite/common/config"
|
||||||
|
@ -58,6 +59,11 @@ func main() {
|
||||||
log.Panicf("startup: failed to create device database with data source %s : %s", cfg.Database.Device, err)
|
log.Panicf("startup: failed to create device database with data source %s : %s", cfg.Database.Device, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
adb, err := accounts.NewDatabase(string(cfg.Database.Account), cfg.Matrix.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
log.Panicf("startup: failed to create account database with data source %s : %s", cfg.Database.Account, err)
|
||||||
|
}
|
||||||
|
|
||||||
pos, err := db.SyncStreamPosition()
|
pos, err := db.SyncStreamPosition()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicf("startup: failed to get latest sync stream position : %s", err)
|
log.Panicf("startup: failed to get latest sync stream position : %s", err)
|
||||||
|
@ -76,6 +82,6 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("Starting sync server on ", cfg.Listen.SyncAPI)
|
log.Info("Starting sync server on ", cfg.Listen.SyncAPI)
|
||||||
routing.SetupSyncServerListeners(http.DefaultServeMux, http.DefaultClient, sync.NewRequestPool(db, n), deviceDB)
|
routing.SetupSyncServerListeners(http.DefaultServeMux, http.DefaultClient, sync.NewRequestPool(db, n, adb), deviceDB)
|
||||||
log.Fatal(http.ListenAndServe(string(cfg.Listen.SyncAPI), nil))
|
log.Fatal(http.ListenAndServe(string(cfg.Listen.SyncAPI), nil))
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,22 +20,25 @@ import (
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
|
||||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestPool manages HTTP long-poll connections for /sync
|
// RequestPool manages HTTP long-poll connections for /sync
|
||||||
type RequestPool struct {
|
type RequestPool struct {
|
||||||
db *storage.SyncServerDatabase
|
db *storage.SyncServerDatabase
|
||||||
|
accountDB *accounts.Database
|
||||||
notifier *Notifier
|
notifier *Notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRequestPool makes a new RequestPool
|
// NewRequestPool makes a new RequestPool
|
||||||
func NewRequestPool(db *storage.SyncServerDatabase, n *Notifier) *RequestPool {
|
func NewRequestPool(db *storage.SyncServerDatabase, n *Notifier, adb *accounts.Database) *RequestPool {
|
||||||
return &RequestPool{db, n}
|
return &RequestPool{db, adb, n}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be
|
// OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be
|
||||||
|
@ -74,6 +77,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
syncData, err := rp.currentSyncForUser(*syncReq, currentPos)
|
syncData, err := rp.currentSyncForUser(*syncReq, currentPos)
|
||||||
var res util.JSONResponse
|
var res util.JSONResponse
|
||||||
|
if err != nil {
|
||||||
|
res = httputil.LogThenError(req, err)
|
||||||
|
} else {
|
||||||
|
syncData, err = rp.appendAccountData(syncData, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res = httputil.LogThenError(req, err)
|
res = httputil.LogThenError(req, err)
|
||||||
} else {
|
} else {
|
||||||
|
@ -82,6 +89,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
||||||
JSON: syncData,
|
JSON: syncData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
done <- res
|
done <- res
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
@ -104,3 +112,27 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.Stre
|
||||||
}
|
}
|
||||||
return rp.db.IncrementalSync(req.userID, req.since, currentPos, req.limit)
|
return rp.db.IncrementalSync(req.userID, req.since, currentPos, req.limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rp *RequestPool) appendAccountData(data *types.Response, userID string) (*types.Response, error) {
|
||||||
|
// TODO: We currently send all account data on every sync response, we should instead send data
|
||||||
|
// that has changed on incremental sync responses
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
global, rooms, err := rp.accountDB.GetAccountData(localpart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data.AccountData.Events = global
|
||||||
|
|
||||||
|
for r, j := range data.Rooms.Join {
|
||||||
|
if len(rooms[r]) > 0 {
|
||||||
|
j.AccountData.Events = rooms[r]
|
||||||
|
data.Rooms.Join[r] = j
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue