mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-26 08:11:55 -06:00
Process inbound device list updates from federation (#1240)
* Add InputDeviceListUpdate * Unbreak unit tests * Process inbound device list updates from federation - Persist the keys in the keyserver and produce key changes - Does not currently fetch keys from the remote server if the prev IDs are missing * Linting
This commit is contained in:
parent
15dc1f4d03
commit
642f9cb964
|
@ -82,7 +82,7 @@ func Setup(
|
||||||
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||||
return Send(
|
return Send(
|
||||||
httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]),
|
httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]),
|
||||||
cfg, rsAPI, eduAPI, keys, federation,
|
cfg, rsAPI, eduAPI, keyAPI, keys, federation,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
)).Methods(http.MethodPut, http.MethodOptions)
|
)).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
@ -37,6 +38,7 @@ func Send(
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
rsAPI api.RoomserverInternalAPI,
|
rsAPI api.RoomserverInternalAPI,
|
||||||
eduAPI eduserverAPI.EDUServerInputAPI,
|
eduAPI eduserverAPI.EDUServerInputAPI,
|
||||||
|
keyAPI keyapi.KeyInternalAPI,
|
||||||
keys gomatrixserverlib.JSONVerifier,
|
keys gomatrixserverlib.JSONVerifier,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
@ -48,6 +50,7 @@ func Send(
|
||||||
federation: federation,
|
federation: federation,
|
||||||
haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent),
|
haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent),
|
||||||
newEvents: make(map[string]bool),
|
newEvents: make(map[string]bool),
|
||||||
|
keyAPI: keyAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
var txnEvents struct {
|
var txnEvents struct {
|
||||||
|
@ -100,6 +103,7 @@ type txnReq struct {
|
||||||
context context.Context
|
context context.Context
|
||||||
rsAPI api.RoomserverInternalAPI
|
rsAPI api.RoomserverInternalAPI
|
||||||
eduAPI eduserverAPI.EDUServerInputAPI
|
eduAPI eduserverAPI.EDUServerInputAPI
|
||||||
|
keyAPI keyapi.KeyInternalAPI
|
||||||
keys gomatrixserverlib.JSONVerifier
|
keys gomatrixserverlib.JSONVerifier
|
||||||
federation txnFederationClient
|
federation txnFederationClient
|
||||||
// local cache of events for auth checks, etc - this may include events
|
// local cache of events for auth checks, etc - this may include events
|
||||||
|
@ -308,12 +312,29 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case gomatrixserverlib.MDeviceListUpdate:
|
||||||
|
t.processDeviceListUpdate(e)
|
||||||
default:
|
default:
|
||||||
util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu")
|
util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *txnReq) processDeviceListUpdate(e gomatrixserverlib.EDU) {
|
||||||
|
var payload gomatrixserverlib.DeviceListUpdateEvent
|
||||||
|
if err := json.Unmarshal(e.Content, &payload); err != nil {
|
||||||
|
util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal device list update event")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var inputRes keyapi.InputDeviceListUpdateResponse
|
||||||
|
t.keyAPI.InputDeviceListUpdate(context.Background(), &keyapi.InputDeviceListUpdateRequest{
|
||||||
|
Event: payload,
|
||||||
|
}, &inputRes)
|
||||||
|
if inputRes.Error != nil {
|
||||||
|
util.GetLogger(t.context).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) error {
|
func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) error {
|
||||||
prevEventIDs := e.PrevEventIDs()
|
prevEventIDs := e.PrevEventIDs()
|
||||||
|
|
||||||
|
|
|
@ -21,11 +21,14 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type KeyInternalAPI interface {
|
type KeyInternalAPI interface {
|
||||||
// SetUserAPI assigns a user API to query when extracting device names.
|
// SetUserAPI assigns a user API to query when extracting device names.
|
||||||
SetUserAPI(i userapi.UserInternalAPI)
|
SetUserAPI(i userapi.UserInternalAPI)
|
||||||
|
// InputDeviceListUpdate from a federated server EDU
|
||||||
|
InputDeviceListUpdate(ctx context.Context, req *InputDeviceListUpdateRequest, res *InputDeviceListUpdateResponse)
|
||||||
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
|
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
|
||||||
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
||||||
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
||||||
|
@ -200,3 +203,11 @@ type QueryDeviceMessagesResponse struct {
|
||||||
Devices []DeviceMessage
|
Devices []DeviceMessage
|
||||||
Error *KeyError
|
Error *KeyError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InputDeviceListUpdateRequest struct {
|
||||||
|
Event gomatrixserverlib.DeviceListUpdateEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
type InputDeviceListUpdateResponse struct {
|
||||||
|
Error *KeyError
|
||||||
|
}
|
||||||
|
|
|
@ -38,12 +38,76 @@ type KeyInternalAPI struct {
|
||||||
FedClient *gomatrixserverlib.FederationClient
|
FedClient *gomatrixserverlib.FederationClient
|
||||||
UserAPI userapi.UserInternalAPI
|
UserAPI userapi.UserInternalAPI
|
||||||
Producer *producers.KeyChange
|
Producer *producers.KeyChange
|
||||||
|
// A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
|
||||||
|
// request to the remote server and race.
|
||||||
|
// TODO: Put in an LRU cache to bound growth
|
||||||
|
UserIDToMutex map[string]*sync.Mutex
|
||||||
|
Mutex *sync.Mutex // protects UserIDToMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
|
func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
|
||||||
a.UserAPI = i
|
a.UserAPI = i
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *KeyInternalAPI) mutex(userID string) *sync.Mutex {
|
||||||
|
a.Mutex.Lock()
|
||||||
|
defer a.Mutex.Unlock()
|
||||||
|
if a.UserIDToMutex[userID] == nil {
|
||||||
|
a.UserIDToMutex[userID] = &sync.Mutex{}
|
||||||
|
}
|
||||||
|
return a.UserIDToMutex[userID]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *KeyInternalAPI) InputDeviceListUpdate(
|
||||||
|
ctx context.Context, req *api.InputDeviceListUpdateRequest, res *api.InputDeviceListUpdateResponse,
|
||||||
|
) {
|
||||||
|
mu := a.mutex(req.Event.UserID)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
// check if we have the prev IDs
|
||||||
|
exists, err := a.DB.PrevIDsExists(ctx, req.Event.UserID, req.Event.PrevID)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("failed to check if prev ids exist: %s", err),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we haven't missed anything update the database and notify users
|
||||||
|
if exists {
|
||||||
|
keys := []api.DeviceMessage{
|
||||||
|
{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: req.Event.DeviceID,
|
||||||
|
DisplayName: req.Event.DeviceDisplayName,
|
||||||
|
KeyJSON: req.Event.Keys,
|
||||||
|
UserID: req.Event.UserID,
|
||||||
|
},
|
||||||
|
StreamID: req.Event.StreamID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = a.DB.StoreRemoteDeviceKeys(ctx, keys)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("failed to store remote device keys: %s", err),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// ALWAYS emit key changes when we've been poked over federation just in case
|
||||||
|
// this poke is important for something.
|
||||||
|
err = a.Producer.ProduceKeyChanges(keys)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("failed to emit remote device key changes: %s", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we're missing an ID go and fetch it from the remote HS
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
||||||
if req.Partition < 0 {
|
if req.Partition < 0 {
|
||||||
req.Partition = a.Producer.DefaultPartition()
|
req.Partition = a.Producer.DefaultPartition()
|
||||||
|
@ -351,7 +415,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// store the device keys and emit changes
|
// store the device keys and emit changes
|
||||||
err := a.DB.StoreDeviceKeys(ctx, keysToStore)
|
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
||||||
|
|
|
@ -27,12 +27,13 @@ import (
|
||||||
|
|
||||||
// HTTP paths for the internal HTTP APIs
|
// HTTP paths for the internal HTTP APIs
|
||||||
const (
|
const (
|
||||||
PerformUploadKeysPath = "/keyserver/performUploadKeys"
|
InputDeviceListUpdatePath = "/keyserver/inputDeviceListUpdate"
|
||||||
PerformClaimKeysPath = "/keyserver/performClaimKeys"
|
PerformUploadKeysPath = "/keyserver/performUploadKeys"
|
||||||
QueryKeysPath = "/keyserver/queryKeys"
|
PerformClaimKeysPath = "/keyserver/performClaimKeys"
|
||||||
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
|
QueryKeysPath = "/keyserver/queryKeys"
|
||||||
QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys"
|
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
|
||||||
QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages"
|
QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys"
|
||||||
|
QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
|
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
@ -58,6 +59,20 @@ type httpKeyInternalAPI struct {
|
||||||
func (h *httpKeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
|
func (h *httpKeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
|
||||||
// no-op: doesn't need it
|
// no-op: doesn't need it
|
||||||
}
|
}
|
||||||
|
func (h *httpKeyInternalAPI) InputDeviceListUpdate(
|
||||||
|
ctx context.Context, req *api.InputDeviceListUpdateRequest, res *api.InputDeviceListUpdateResponse,
|
||||||
|
) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "InputDeviceListUpdate")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + InputDeviceListUpdatePath
|
||||||
|
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpKeyInternalAPI) PerformClaimKeys(
|
func (h *httpKeyInternalAPI) PerformClaimKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
@ -25,6 +25,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
|
func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
|
||||||
|
internalAPIMux.Handle(InputDeviceListUpdatePath,
|
||||||
|
httputil.MakeInternalAPI("inputDeviceListUpdate", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.InputDeviceListUpdateRequest{}
|
||||||
|
response := api.InputDeviceListUpdateResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
s.InputDeviceListUpdate(req.Context(), &request, &response)
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
internalAPIMux.Handle(PerformClaimKeysPath,
|
internalAPIMux.Handle(PerformClaimKeysPath,
|
||||||
httputil.MakeInternalAPI("performClaimKeys", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("performClaimKeys", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.PerformClaimKeysRequest{}
|
request := api.PerformClaimKeysRequest{}
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
package keyserver
|
package keyserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
@ -51,9 +53,11 @@ func NewInternalAPI(
|
||||||
DB: db,
|
DB: db,
|
||||||
}
|
}
|
||||||
return &internal.KeyInternalAPI{
|
return &internal.KeyInternalAPI{
|
||||||
DB: db,
|
DB: db,
|
||||||
ThisServer: cfg.Matrix.ServerName,
|
ThisServer: cfg.Matrix.ServerName,
|
||||||
FedClient: fedClient,
|
FedClient: fedClient,
|
||||||
Producer: keyChangeProducer,
|
Producer: keyChangeProducer,
|
||||||
|
Mutex: &sync.Mutex{},
|
||||||
|
UserIDToMutex: make(map[string]*sync.Mutex),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,11 +35,18 @@ type Database interface {
|
||||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
// StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
// for this (user, device).
|
// for this (user, device).
|
||||||
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
|
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
|
||||||
// Returns an error if there was a problem storing the keys.
|
// Returns an error if there was a problem storing the keys.
|
||||||
StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
|
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
|
// for this (user, device). Does not modify the stream ID for keys.
|
||||||
|
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
|
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||||
|
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||||
|
|
||||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
@ -56,12 +57,16 @@ const selectBatchDeviceKeysSQL = "" +
|
||||||
const selectMaxStreamForUserSQL = "" +
|
const selectMaxStreamForUserSQL = "" +
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
const countStreamIDsForUserSQL = "" +
|
||||||
|
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
selectMaxStreamForUserStmt *sql.Stmt
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
|
countStreamIDsForUserStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
@ -84,6 +89,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,6 +123,19 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||||
|
// nullable if there are no results
|
||||||
|
var count sql.NullInt32
|
||||||
|
err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if count.Valid {
|
||||||
|
return int(count.Int32), nil
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
|
|
|
@ -47,7 +47,25 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage)
|
||||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
|
||||||
|
sids := make([]int64, len(prevIDs))
|
||||||
|
for i := range prevIDs {
|
||||||
|
sids[i] = int64(prevIDs[i])
|
||||||
|
}
|
||||||
|
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return count == len(prevIDs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
|
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||||
|
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
// work out the latest stream IDs for each user
|
// work out the latest stream IDs for each user
|
||||||
userIDToStreamID := make(map[string]int)
|
userIDToStreamID := make(map[string]int)
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
|
|
|
@ -17,6 +17,7 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
@ -53,6 +54,9 @@ const selectBatchDeviceKeysSQL = "" +
|
||||||
const selectMaxStreamForUserSQL = "" +
|
const selectMaxStreamForUserSQL = "" +
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
const countStreamIDsForUserSQL = "" +
|
||||||
|
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer *sqlutil.TransactionWriter
|
||||||
|
@ -143,6 +147,25 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||||
|
iStreamIDs := make([]interface{}, len(streamIDs)+1)
|
||||||
|
iStreamIDs[0] = userID
|
||||||
|
for i := range streamIDs {
|
||||||
|
iStreamIDs[i+1] = streamIDs[i]
|
||||||
|
}
|
||||||
|
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
||||||
|
// nullable if there are no results
|
||||||
|
var count sql.NullInt32
|
||||||
|
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if count.Valid {
|
||||||
|
return int(count.Int32), nil
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
|
|
@ -126,15 +126,15 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
// StreamID: 2 as this is a 2nd device key
|
// StreamID: 2 as this is a 2nd device key
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||||
if msgs[0].StreamID != 1 {
|
if msgs[0].StreamID != 1 {
|
||||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||||
}
|
}
|
||||||
if msgs[1].StreamID != 1 {
|
if msgs[1].StreamID != 1 {
|
||||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||||
}
|
}
|
||||||
if msgs[2].StreamID != 2 {
|
if msgs[2].StreamID != 2 {
|
||||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// updating a device sets the next stream ID for that user
|
// updating a device sets the next stream ID for that user
|
||||||
|
@ -148,9 +148,9 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
// StreamID: 3
|
// StreamID: 3
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||||
if msgs[0].StreamID != 3 {
|
if msgs[0].StreamID != 3 {
|
||||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Querying for device keys returns the latest stream IDs
|
// Querying for device keys returns the latest stream IDs
|
||||||
|
|
|
@ -35,6 +35,7 @@ type DeviceKeys interface {
|
||||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||||
|
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,9 @@ func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneT
|
||||||
}
|
}
|
||||||
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) {
|
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) {
|
||||||
|
|
||||||
|
}
|
||||||
|
func (k *mockKeyAPI) InputDeviceListUpdate(ctx context.Context, req *keyapi.InputDeviceListUpdateRequest, res *keyapi.InputDeviceListUpdateResponse) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockCurrentStateAPI struct {
|
type mockCurrentStateAPI struct {
|
||||||
|
|
Loading…
Reference in a new issue