mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Process inbound device list updates from federation
- Persist the keys in the keyserver and produce key changes - Does not currently fetch keys from the remote server if the prev IDs are missing
This commit is contained in:
parent
ef79fe1987
commit
617862e40c
|
|
@ -38,15 +38,68 @@ type KeyInternalAPI struct {
|
|||
FedClient *gomatrixserverlib.FederationClient
|
||||
UserAPI userapi.UserInternalAPI
|
||||
Producer *producers.KeyChange
|
||||
// A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
|
||||
// request to the remote server and race.
|
||||
// TODO: Put in an LRU cache to bound growth
|
||||
UserIDToMutex map[string]*sync.Mutex
|
||||
Mutex *sync.Mutex // protects UserIDToMutex
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
|
||||
a.UserAPI = i
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) mutex(userID string) *sync.Mutex {
|
||||
a.Mutex.Lock()
|
||||
defer a.Mutex.Unlock()
|
||||
if a.UserIDToMutex[userID] == nil {
|
||||
a.UserIDToMutex[userID] = &sync.Mutex{}
|
||||
}
|
||||
return a.UserIDToMutex[userID]
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) InputDeviceListUpdate(
|
||||
ctx context.Context, req *api.InputDeviceListUpdateRequest, res *api.InputDeviceListUpdateResponse,
|
||||
) {
|
||||
mu := a.mutex(req.Event.UserID)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// check if we have the prev IDs
|
||||
exists, err := a.DB.PrevIDsExists(ctx, req.Event.UserID, req.Event.PrevID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to check if prev ids exist: %s", err),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// if we haven't missed anything update the database and notify users
|
||||
if exists {
|
||||
keys := []api.DeviceMessage{
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: req.Event.DeviceID,
|
||||
DisplayName: req.Event.DeviceDisplayName,
|
||||
KeyJSON: req.Event.Keys,
|
||||
UserID: req.Event.UserID,
|
||||
},
|
||||
StreamID: req.Event.StreamID,
|
||||
},
|
||||
}
|
||||
err = a.DB.StoreRemoteDeviceKeys(ctx, keys)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to store remote device keys: %s", err),
|
||||
}
|
||||
return
|
||||
}
|
||||
// ALWAYS emit key changes when we've been poked over federation just in case
|
||||
// this poke is important for something.
|
||||
a.Producer.ProduceKeyChanges(keys)
|
||||
return
|
||||
}
|
||||
|
||||
// if we're missing an ID go and fetch it from the remote HS
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -357,7 +410,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
|||
return
|
||||
}
|
||||
// store the device keys and emit changes
|
||||
err := a.DB.StoreDeviceKeys(ctx, keysToStore)
|
||||
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@
|
|||
package keyserver
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/internal/config"
|
||||
|
|
@ -55,5 +57,7 @@ func NewInternalAPI(
|
|||
ThisServer: cfg.Matrix.ServerName,
|
||||
FedClient: fedClient,
|
||||
Producer: keyChangeProducer,
|
||||
Mutex: &sync.Mutex{},
|
||||
UserIDToMutex: make(map[string]*sync.Mutex),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,11 +35,18 @@ type Database interface {
|
|||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device).
|
||||
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
|
||||
// Returns an error if there was a problem storing the keys.
|
||||
StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||
|
||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
|
|
@ -56,12 +57,16 @@ const selectBatchDeviceKeysSQL = "" +
|
|||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
countStreamIDsForUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
|
|
@ -84,6 +89,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
|
@ -115,6 +123,19 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
|
|||
return
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt32
|
||||
err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int32), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
|
|
|
|||
|
|
@ -47,7 +47,25 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage)
|
|||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||
}
|
||||
|
||||
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
|
||||
sids := make([]int64, len(prevIDs))
|
||||
for i := range prevIDs {
|
||||
sids[i] = int64(prevIDs[i])
|
||||
}
|
||||
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count == len(prevIDs), nil
|
||||
}
|
||||
|
||||
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
// work out the latest stream IDs for each user
|
||||
userIDToStreamID := make(map[string]int)
|
||||
for _, k := range keys {
|
||||
|
|
|
|||
|
|
@ -17,9 +17,11 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
|
@ -52,6 +54,9 @@ const selectBatchDeviceKeysSQL = "" +
|
|||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
|
|
@ -140,6 +145,25 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
|
|||
return
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
iStreamIDs := make([]interface{}, len(streamIDs)+1)
|
||||
iStreamIDs[0] = userID
|
||||
for i := range streamIDs {
|
||||
iStreamIDs[i+1] = streamIDs[i]
|
||||
}
|
||||
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt32
|
||||
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int32), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
|
|
|
|||
|
|
@ -114,15 +114,15 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
// StreamID: 2 as this is a 2nd device key
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||
}
|
||||
if msgs[1].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||
}
|
||||
if msgs[2].StreamID != 2 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||
}
|
||||
|
||||
// updating a device sets the next stream ID for that user
|
||||
|
|
@ -136,9 +136,9 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
// StreamID: 3
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 3 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||
}
|
||||
|
||||
// Querying for device keys returns the latest stream IDs
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ type DeviceKeys interface {
|
|||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue