send to device (STD) done, cooperate with key centre encryptoapi

This commit is contained in:
terrill 2018-07-18 10:53:27 +08:00
parent 7770664878
commit 57e6eb73dc
9 changed files with 433 additions and 17 deletions

View file

@ -60,7 +60,9 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
// todo : display name still has a problem when value is null
//"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
"SELECT device_id FROM device_devices WHERE localpart = $1"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -197,6 +199,8 @@ func (s *devicesStatements) selectDevicesByLocalpart(
for rows.Next() {
var dev authtypes.Device
err = rows.Scan(&dev.ID)
// todo: display name still has a problem when value is null
//err = rows.Scan(&dev.ID, &dev.DisplayName)
if err != nil {
return devices, err
}

View file

@ -167,20 +167,37 @@ func (s *keyStatements) selectInKeys(
ctx context.Context,
userID string,
arr []string,
) (holders []types.KeyHolder, err error) {
rows := sql.Rows{}
rowsP := &rows
) ([]types.KeyHolder, error) {
holders := []types.KeyHolder{}
stmt := s.selectAllKeyStmt
if len(arr) == 0 {
rowsP, err = stmt.QueryContext(ctx, userID, "device_key")
} else {
stmt = s.selectInKeysStmt
list := pq.Array(arr)
rowsP, err = stmt.QueryContext(ctx, userID, list)
// mapping for all device keys
rowsP, err := stmt.QueryContext(ctx, userID, "device_key")
if err != nil {
return nil, err
}
holders, err = injectKeyHolder(rowsP, holders)
if err != nil {
return nil, err
}
err = rowsP.Close()
return holders, err
}
stmt = s.selectInKeysStmt
list := pq.Array(arr)
rowsP, err := stmt.QueryContext(ctx, userID, list)
if err != nil {
return nil, err
}
holders, err = injectKeyHolder(rowsP, holders)
if err != nil {
return nil, err
}
err = rowsP.Close()
return holders, err
}
func injectKeyHolder(rows *sql.Rows, keyHolder []types.KeyHolder) (holders []types.KeyHolder, err error) {
for rows.Next() {
single := &types.KeyHolder{}
if err = rows.Scan(
@ -194,8 +211,8 @@ func (s *keyStatements) selectInKeys(
); err != nil {
return nil, err
}
holders = append(holders, *single)
keyHolder = append(keyHolder, *single)
}
err = rowsP.Close()
return holders, err
holders = keyHolder
return
}

View file

@ -27,10 +27,12 @@ import (
)
const pathPrefixR0 = "/_matrix/client/r0"
const pathPrefixUnstable = "/_matrix/client/unstable"
// Setup configures the given mux with sync-server listeners
func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database) {
func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database, notifier *sync.Notifier) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
unstablemux := apiMux.PathPrefix(pathPrefixUnstable).Subrouter()
r0mux.Handle("/sync", common.MakeAuthAPI("sync", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return srp.OnIncomingSyncRequest(req, device)
@ -50,4 +52,13 @@ func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServer
vars := mux.Vars(req)
return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], vars["stateKey"])
})).Methods(http.MethodGet, http.MethodOptions)
unstablemux.Handle("/sendToDevice/{eventType}/{txnId}",
common.MakeAuthAPI("look up changes", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
eventType := vars["eventType"]
txnID := vars["txnId"]
return SendToDevice(req, device.UserID, syncDB, deviceDB, eventType, txnID, notifier)
}),
).Methods(http.MethodPut, http.MethodOptions)
}

View file

@ -0,0 +1,89 @@
package routing
import (
"encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"net/http"
)
// SendToDevice this is a function for calling process of send-to-device messages those bypassed DAG
func SendToDevice(
req *http.Request,
sender string,
syncDB *storage.SyncServerDatabase,
deviceDB *devices.Database,
eventType, txnID string,
notifier *sync.Notifier,
) util.JSONResponse {
ctx := req.Context()
stdRq := types.StdRequest{}
httputil.UnmarshalJSONRequest(req, &stdRq)
for uid, deviceMap := range stdRq.Sender {
// federation consideration todo:
// if uid is remote domain a fed process should go
if false {
// federation process
return util.JSONResponse{}
}
// uid is local domain
for device, cont := range deviceMap {
jsonBuffer, err := json.Marshal(cont)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: struct{}{},
}
}
ev := types.StdHolder{
Sender: sender,
Event: jsonBuffer,
EventTyp: eventType,
}
var pos int64
// wildcard all devices
if device == "*" {
var deviceCollection []authtypes.Device
var localpart string
localpart, _, _ = gomatrixserverlib.SplitID('@', uid)
deviceCollection, err = deviceDB.GetDevicesByLocalpart(ctx, localpart)
for _, val := range deviceCollection {
pos, err = syncDB.InsertStdMessage(ctx, ev, txnID, uid, val.ID)
notifier.OnNewEvent(nil, uid, types.StreamPosition(pos))
}
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: struct{}{},
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
pos, err = syncDB.InsertStdMessage(ctx, ev, txnID, uid, device)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: struct{}{},
}
}
notifier.OnNewEvent(nil, uid, types.StreamPosition(pos))
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -0,0 +1,162 @@
package storage
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
)
// we treat send to device as abbrev as STD in the context below.
const sendToDeviceSchema = `
CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'),
txn_id TEXT NOT NULL,
sender TEXT NOT NULL,
event_type TEXT NOT NULL,
target_device_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
event_json TEXT NOT NULL,
del_read INTEGER DEFAULT 0,
max_read BIGINT DEFAULT currval('syncapi_stream_id') ,
CONSTRAINT syncapi_send_to_device_unique UNIQUE (txn_id, target_device_id, target_user_id)
);
`
const insertSTDSQL = "" +
"INSERT INTO syncapi_send_to_device (" +
" sender, event_type, target_user_id, target_device_id, txn_id, event_json" +
") VALUES ($1, $2, $3, $4, $5, $6) RETURNING id"
const deleteSTDSQL = "" +
"DELETE FROM syncapi_send_to_device WHERE target_user_id = $1 AND target_device_id = $2 AND max_read < $3 AND del_read = 1"
const selectSTDEventsInRangeSQL = "" +
"SELECT id, sender, event_type, event_json FROM syncapi_send_to_device" +
" WHERE target_user_id = $1 AND target_device_id = $2 AND id <= $3" +
" ORDER BY id LIMIT 100 "
const updateSTDEventSQL = "" +
"UPDATE syncapi_send_to_device SET del_read = 1 , max_read = $1 WHERE id = ANY($2)"
const selectMaxSTDIDSQL = "" +
"SELECT MAX(id) FROM syncapi_send_to_device"
type stdEventsStatements struct {
insertStdEventStmt *sql.Stmt
selectStdEventsInRangeStmt *sql.Stmt
deleteStdEventStmt *sql.Stmt
selectStdIDStmt *sql.Stmt
updateStdStmt *sql.Stmt
}
func (s *stdEventsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(sendToDeviceSchema)
if err != nil {
return
}
if s.insertStdEventStmt, err = db.Prepare(insertSTDSQL); err != nil {
return
}
if s.selectStdEventsInRangeStmt, err = db.Prepare(selectSTDEventsInRangeSQL); err != nil {
return
}
if s.deleteStdEventStmt, err = db.Prepare(deleteSTDSQL); err != nil {
return
}
if s.selectStdIDStmt, err = db.Prepare(selectMaxSTDIDSQL); err != nil {
return
}
if s.updateStdStmt, err = db.Prepare(updateSTDEventSQL); err != nil {
return
}
return
}
func (s *stdEventsStatements) insertStdEvent(
ctx context.Context, stdEvent types.StdHolder,
transactionID string, targetUID, targetDevice string,
) (streamPos int64, err error) {
err = s.insertStdEventStmt.QueryRowContext(
ctx,
stdEvent.Sender,
stdEvent.EventTyp,
targetUID,
targetDevice,
transactionID,
stdEvent.Event,
).Scan(&streamPos)
return
}
func (s *stdEventsStatements) deleteStdEvent(
ctx context.Context, userID, deviceID string,
idUpBound int64,
) error {
_, err := s.deleteStdEventStmt.ExecContext(ctx, userID, deviceID, idUpBound)
return err
}
func (s *stdEventsStatements) selectStdEventsInRange(
ctx context.Context, txn *sql.Tx,
targetUserID, targetDeviceID string,
endPos int64,
) ([]types.StdHolder, error) {
stdHolder := []types.StdHolder{}
stmt := common.TxStmt(txn, s.selectStdEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, targetDeviceID, endPos)
if err != nil {
return nil, err
}
for rows.Next() {
holder := types.StdHolder{}
var (
id int64
sender string
eventType string
eventJSON []byte
)
if err = rows.Scan(&id, &sender, &eventType, &eventJSON); err != nil {
closeErr := rows.Close()
if closeErr != nil {
return nil, closeErr
}
return nil, err
}
holder.StreamID = id
holder.Sender = sender
holder.Event = eventJSON
holder.EventTyp = eventType
stdHolder = append(stdHolder, holder)
}
err = rows.Close()
if err != nil {
return nil, err
}
// update events with read mark
update := []int64{}
for _, val := range stdHolder {
update = append(update, val.StreamID)
}
updateStmt := common.TxStmt(txn, s.updateStdStmt)
_, err = updateStmt.ExecContext(ctx, endPos, pq.Array(update))
if err != nil {
return nil, err
}
return stdHolder, nil
}
func (s *stdEventsStatements) selectMaxStdID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := common.TxStmt(txn, s.selectStdIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
}

View file

@ -23,8 +23,8 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/roomserver/api"
// Import the postgres database driver.
_ "github.com/lib/pq"
"encoding/json"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@ -54,6 +54,7 @@ type SyncServerDatabase struct {
events outputRoomEventsStatements
roomstate currentRoomStateStatements
invites inviteEventsStatements
stdMsg stdEventsStatements
}
// NewSyncServerDatabase creates a new sync server database
@ -78,6 +79,9 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) {
if err := d.invites.prepare(d.db); err != nil {
return nil, err
}
if err := d.stdMsg.prepare(d.db); err != nil {
return nil, err
}
return &d, nil
}
@ -212,6 +216,13 @@ func (d *SyncServerDatabase) syncStreamPositionTx(
if maxInviteID > maxID {
maxID = maxInviteID
}
maxStdID, err := d.stdMsg.selectMaxStdID(ctx, txn)
if err != nil {
return 0, err
}
if maxStdID > maxID {
maxID = maxStdID
}
return types.StreamPosition(maxID), nil
}
@ -678,3 +689,94 @@ func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string {
}
return ""
}
/*
send to device messaging implementation
del / maxID / select in range / insert
*/
// DelStdMessage delete message for a given maxID, those below would be deleted
func (d *SyncServerDatabase) DelStdMessage(
ctx context.Context, targetUID, targetDevice string, maxID int64,
) (err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
err := d.stdMsg.deleteStdEvent(ctx, targetUID, targetDevice, maxID)
return err
})
return
}
// InsertStdMessage insert std message
func (d *SyncServerDatabase) InsertStdMessage(
ctx context.Context, stdEvent types.StdHolder, transactionID, targetUID, targetDevice string,
) (pos int64, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
curPos, err := d.stdMsg.insertStdEvent(ctx, stdEvent, transactionID, targetUID, targetDevice)
pos = curPos
return err
})
return
}
// SelectMaxStdID select maximum id in std stream
func (d *SyncServerDatabase) SelectMaxStdID(
ctx context.Context,
) (maxID int64, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
max, err := d.stdMsg.selectMaxStdID(ctx, txn)
maxID = max
return err
})
return
}
// SelectRangedStd select a range of std messages
func (d *SyncServerDatabase) SelectRangedStd(
ctx context.Context,
targetUserID, targetDeviceID string,
endPos int64,
) (holder []types.StdHolder, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
list, err := d.stdMsg.selectStdEventsInRange(ctx, txn, targetUserID, targetDeviceID, endPos)
holder = list
return err
})
return
}
// StdEXT : send to device extension process
func StdEXT(
ctx context.Context,
syncDB *SyncServerDatabase,
respIn types.Response,
userID, deviceID string,
since int64,
) (respOut *types.Response) {
respOut = &respIn
// when extension works at the very beginning
err := syncDB.stdMsg.deleteStdEvent(ctx, userID, deviceID, since)
if err != nil {
return
}
// when err is nil, these before res should be tagged omitted,
// when next /sync is coming , and err is nil , all those omitted.
res, err := syncDB.SelectRangedStd(ctx, userID, deviceID, since)
if err != nil {
return
}
//toDevice := &types.ToDevice{}
mid := []types.StdEvent{}
//toDevice.StdEvent = mid
for _, val := range res {
ev := types.StdEvent{}
ev.Sender = val.Sender
ev.Type = val.EventTyp
err := json.Unmarshal(val.Event, &ev.Content)
if err != nil {
return
}
mid = append(mid, ev)
}
respOut.ToDevice.StdEvent = mid
return
}

View file

@ -106,6 +106,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
// can respond
syncData, err := rp.currentSyncForUser(*syncReq, currPos)
// std extension consideration
syncData = storage.StdEXT(syncReq.ctx, rp.db, *syncData, syncReq.device.UserID, syncReq.device.ID, int64(currPos))
if err != nil {
return httputil.LogThenError(req, err)
}

View file

@ -71,5 +71,5 @@ func SetupSyncAPIComponent(
logrus.WithError(err).Panicf("failed to start client data consumer")
}
routing.Setup(base.APIMux, requestPool, syncDB, deviceDB)
routing.Setup(base.APIMux, requestPool, syncDB, deviceDB, notifier)
}

View file

@ -50,6 +50,32 @@ type Response struct {
Invite map[string]InviteResponse `json:"invite"`
Leave map[string]LeaveResponse `json:"leave"`
} `json:"rooms"`
ToDevice ToDevice `json:"to_device"`
}
// StdHolder represents send to device response from db
type StdHolder struct {
StreamID int64
Sender string
EventTyp string
Event []byte
}
// StdRequest represents send to device request format
type StdRequest struct {
Sender map[string]map[string]interface{} `json:"messages"`
}
// ToDevice represents a middleware for response send to device
type ToDevice struct {
StdEvent []StdEvent `json:"events"`
}
// StdEvent represents send to device event format
type StdEvent struct {
Sender string `json:"sender"`
Type string `json:"type"`
Content interface{} `json:"content"`
}
// NewResponse creates an empty response with initialised maps.
@ -81,7 +107,8 @@ func (r *Response) IsEmpty() bool {
len(r.Rooms.Invite) == 0 &&
len(r.Rooms.Leave) == 0 &&
len(r.AccountData.Events) == 0 &&
len(r.Presence.Events) == 0
len(r.Presence.Events) == 0 &&
len(r.ToDevice.StdEvent) == 0
}
// JoinResponse represents a /sync response for a room which is under the 'join' key.