key_count extension for /sync

This commit is contained in:
terrill 2018-07-27 13:02:54 +08:00
parent fcf90dc524
commit 8b4b3c6fc4
10 changed files with 129 additions and 17 deletions

View file

@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/common/transactions"
"github.com/matrix-org/dendrite/typingserver"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/clientapi"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
@ -59,7 +58,7 @@ func main() {
alias, input, query := roomserver.SetupRoomServerComponent(base)
typingInputAPI := typingserver.SetupTypingServerComponent(base)
encryptoapi.SetupEcryptoapi(base, deviceDB)
encryptDB := encryptoapi.SetupEcryptoapi(base, deviceDB)
clientapi.SetupClientAPIComponent(
base, deviceDB, accountDB,
@ -70,8 +69,8 @@ func main() {
federationsender.SetupFederationSenderComponent(base, federation, query)
mediaapi.SetupMediaAPIComponent(base, deviceDB)
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB)
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query)
appservice.SetupAppServiceAPIComponent(base, accountDB, deviceDB, federation, alias, query, transactions.New())
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, encryptDB)
//appservice.SetupAppServiceAPIComponent(base, accountDB, deviceDB, federation, alias, query, transactions.New())
httpHandler := common.WrapHandlerInCORS(base.APIMux)

View file

@ -29,7 +29,7 @@ func main() {
_, _, query := base.CreateHTTPRoomserverAPIs()
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query)
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, nil)
base.SetupAndServeHTTP(string(base.Cfg.Listen.SyncAPI))
}

View file

@ -31,7 +31,7 @@ import (
func SetupEcryptoapi(
base *basecomponent.BaseDendrite,
deviceDB *devices.Database,
) {
) *storage.Database {
encryptionDB, err := storage.NewDatabase(string(base.Cfg.Database.EncryptAPI))
if err != nil {
logrus.WithError(err).Panicf("failed to connect to encryption db")
@ -42,4 +42,5 @@ func SetupEcryptoapi(
deviceDB,
)
routing.InitNotifier(base)
return encryptionDB
}

View file

@ -59,14 +59,19 @@ const selectAllkeysSQL = `
SELECT user_id, device_id, key_id, key_type, key_info, algorithm, signature FROM encrypt_keys
WHERE user_id = $1 AND key_type = $2
`
const selectCountOneTimeKey = `
SELECT algorithm, COUNT(algorithm) FROM encrypt_keys WHERE user_id = $1 AND device_id = $2 AND key_type = 'one_time_key'
GROUP BY algorithm
`
type keyStatements struct {
insertKeyStmt *sql.Stmt
selectKeyStmt *sql.Stmt
selectInKeysStmt *sql.Stmt
selectAllKeyStmt *sql.Stmt
selectSingleKeyStmt *sql.Stmt
deleteSingleKeyStmt *sql.Stmt
insertKeyStmt *sql.Stmt
selectKeyStmt *sql.Stmt
selectInKeysStmt *sql.Stmt
selectAllKeyStmt *sql.Stmt
selectSingleKeyStmt *sql.Stmt
deleteSingleKeyStmt *sql.Stmt
selectCountOneTimeKeyStmt *sql.Stmt
}
func (s *keyStatements) prepare(db *sql.DB) (err error) {
@ -92,6 +97,9 @@ func (s *keyStatements) prepare(db *sql.DB) (err error) {
if s.selectSingleKeyStmt, err = db.Prepare(deleteSinglekeySQL); err != nil {
return
}
if s.selectCountOneTimeKeyStmt, err = db.Prepare(selectCountOneTimeKey); err != nil {
return
}
return
}
@ -216,3 +224,28 @@ func injectKeyHolder(rows *sql.Rows, keyHolder []types.KeyHolder) (holders []typ
holders = keyHolder
return
}
// select by user and device
func (s *keyStatements) selectOneTimeKeyCount(
ctx context.Context,
userID, deviceID string,
) (map[string]int, error) {
holders := make(map[string]int)
rows, err := s.selectCountOneTimeKeyStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
for rows.Next() {
var singleKey string
var singleVal int
if err = rows.Scan(
&singleKey,
&singleVal,
); err != nil {
return nil, err
}
holders[singleKey] = singleVal
}
err = rows.Close()
return holders, err
}

View file

@ -120,3 +120,12 @@ func (d *Database) SelectOneTimeKeySingle(
holder, err = d.keyStatements.selectSingleKey(ctx, userID, deviceID, algorithm)
return
}
// SyncOneTimeCount for sync device_one_time_keys_count extension
func (d *Database) SyncOneTimeCount(
ctx context.Context,
userID, deviceID string,
) (holder map[string]int, err error) {
holder, err = d.keyStatements.selectOneTimeKeyCount(ctx, userID, deviceID)
return
}

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common"
encryptoapi "github.com/matrix-org/dendrite/encryptoapi/storage"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/util"
@ -31,7 +32,7 @@ const pathPrefixR0 = "/_matrix/client/r0"
const pathPrefixUnstable = "/_matrix/client/unstable"
// Setup configures the given mux with sync-server listeners
func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database, notifier *sync.Notifier) {
func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database, notifier *sync.Notifier, encryptDB *encryptoapi.Database) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
unstablemux := apiMux.PathPrefix(pathPrefixUnstable).Subrouter()
@ -39,7 +40,7 @@ func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServer
// TODO: Add AS support for all handlers below.
r0mux.Handle("/sync", common.MakeAuthAPI("sync", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return srp.OnIncomingSyncRequest(req, device)
return srp.OnIncomingSyncRequest(req, device, encryptDB)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state", common.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {

View file

@ -0,0 +1,64 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sync
import (
"context"
encryptoapi "github.com/matrix-org/dendrite/encryptoapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"sync"
)
type keyCounter struct {
sync.RWMutex
m map[string]map[string]int
}
var counter = keyCounter{
m: make(map[string]map[string]int),
}
// CounterRead returns uid to countMap
func CounterRead(uid string) map[string]int {
counter.RLock()
defer counter.RUnlock()
return counter.m[uid]
}
// CounterWrite write count map to share for all response
func CounterWrite(uid string, m map[string]int) {
counter.Lock()
defer counter.Unlock()
counter.m[uid] = m
}
// KeyCountEXT key count extension process
func KeyCountEXT(
ctx context.Context,
encryptionDB *encryptoapi.Database,
respIn types.Response,
userID, deviceID string,
) (respOut *types.Response) {
respOut = &respIn
// when extension works at the very beginning
resp, err := encryptionDB.SyncOneTimeCount(ctx, userID, deviceID)
CounterWrite(userID, resp)
if err != nil {
return
}
respOut.SignNum = resp
return
}

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
encryptoapi "github.com/matrix-org/dendrite/encryptoapi/storage"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@ -44,7 +45,7 @@ func NewRequestPool(db *storage.SyncServerDatabase, n *Notifier, adb *accounts.D
// OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be
// called in a dedicated goroutine for this request. This function will block the goroutine
// until a response is ready, or it times out.
func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtypes.Device) util.JSONResponse {
func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtypes.Device, encryptDB *encryptoapi.Database) util.JSONResponse {
// Extract values from request
logger := util.GetLogger(req.Context())
userID := device.UserID
@ -109,6 +110,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
// std extension consideration
syncData = storage.StdEXT(syncReq.ctx, rp.db, *syncData, syncReq.device.UserID, syncReq.device.ID, int64(currPos))
syncData = KeyCountEXT(syncReq.ctx, encryptDB, *syncData, syncReq.device.UserID, syncReq.device.ID)
if err != nil {
return httputil.LogThenError(req, err)

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
encryptoapi "github.com/matrix-org/dendrite/encryptoapi/storage"
"github.com/matrix-org/dendrite/syncapi/consumers"
"github.com/matrix-org/dendrite/syncapi/routing"
"github.com/matrix-org/dendrite/syncapi/storage"
@ -38,6 +39,7 @@ func SetupSyncAPIComponent(
deviceDB *devices.Database,
accountsDB *accounts.Database,
queryAPI api.RoomserverQueryAPI,
encryptDB *encryptoapi.Database,
) {
syncDB, err := storage.NewSyncServerDatabase(string(base.Cfg.Database.SyncAPI))
if err != nil {
@ -71,5 +73,5 @@ func SetupSyncAPIComponent(
logrus.WithError(err).Panicf("failed to start client data consumer")
}
routing.Setup(base.APIMux, requestPool, syncDB, deviceDB, notifier)
routing.Setup(base.APIMux, requestPool, syncDB, deviceDB, notifier, encryptDB)
}

View file

@ -50,7 +50,8 @@ type Response struct {
Invite map[string]InviteResponse `json:"invite"`
Leave map[string]LeaveResponse `json:"leave"`
} `json:"rooms"`
ToDevice ToDevice `json:"to_device"`
ToDevice ToDevice `json:"to_device"`
SignNum map[string]int `json:"device_one_time_keys_count"`
}
// StdHolder represents send to device response from db