Breaking: add stream_id to keyserver_device_keys table

This commit is contained in:
Kegan Dougal 2020-08-03 16:19:49 +01:00
parent ffcb6d2ea1
commit cad8da67a9
9 changed files with 166 additions and 84 deletions

View file

@ -43,6 +43,13 @@ func (k *KeyError) Error() string {
return k.Err return k.Err
} }
// DeviceMessage represents the message produced into Kafka by the key server.
type DeviceMessage struct {
DeviceKeys
// A monotonically increasing number which represents device changes for this user.
StreamID int
}
// DeviceKeys represents a set of device keys for a single device // DeviceKeys represents a set of device keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type DeviceKeys struct { type DeviceKeys struct {
@ -50,10 +57,20 @@ type DeviceKeys struct {
UserID string UserID string
// The device ID of this device // The device ID of this device
DeviceID string DeviceID string
// The device display name
DisplayName string
// The raw device key JSON // The raw device key JSON
KeyJSON []byte KeyJSON []byte
} }
// WithStreamID returns a copy of this device message with the given stream ID
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
return DeviceMessage{
DeviceKeys: *k,
StreamID: streamID,
}
}
// OneTimeKeys represents a set of one-time keys for a single device // OneTimeKeys represents a set of one-time keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type OneTimeKeys struct { type OneTimeKeys struct {

View file

@ -61,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
res.KeyErrors = make(map[string]map[string]*api.KeyError) res.KeyErrors = make(map[string]map[string]*api.KeyError)
a.uploadDeviceKeys(ctx, req, res) a.uploadLocalDeviceKeys(ctx, req, res)
a.uploadOneTimeKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res)
} }
@ -286,18 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys(
} }
} }
func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
var keysToStore []api.DeviceKeys var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key // assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys { for _, key := range req.DeviceKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
if err != nil {
continue // ignore invalid users
}
if serverName != a.ThisServer {
continue // ignore remote users
}
if len(key.KeyJSON) == 0 { if len(key.KeyJSON) == 0 {
keysToStore = append(keysToStore, key) keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking continue // deleted keys don't need sanity checking
} }
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID { if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
keysToStore = append(keysToStore, key) keysToStore = append(keysToStore, key.WithStreamID(0))
continue continue
} }
@ -310,11 +317,13 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
} }
// get existing device keys so we can check for changes // get existing device keys so we can check for changes
existingKeys := make([]api.DeviceKeys, len(keysToStore)) existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore { for i := range keysToStore {
existingKeys[i] = api.DeviceKeys{ existingKeys[i] = api.DeviceMessage{
DeviceKeys: api.DeviceKeys{
UserID: keysToStore[i].UserID, UserID: keysToStore[i].UserID,
DeviceID: keysToStore[i].DeviceID, DeviceID: keysToStore[i].DeviceID,
},
} }
} }
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
@ -324,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
return return
} }
// store the device keys and emit changes // store the device keys and emit changes
if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil { err := a.DB.StoreDeviceKeys(ctx, keysToStore)
if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
} }
return return
} }
err := a.emitDeviceKeyChanges(existingKeys, keysToStore) err = a.emitDeviceKeyChanges(existingKeys, keysToStore)
if err != nil { if err != nil {
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
} }
@ -375,9 +385,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
} }
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error { func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error {
// find keys in new that are not in existing // find keys in new that are not in existing
var keysAdded []api.DeviceKeys var keysAdded []api.DeviceMessage
for _, newKey := range new { for _, newKey := range new {
exists := false exists := false
for _, existingKey := range existing { for _, existingKey := range existing {

View file

@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 {
} }
// ProduceKeyChanges creates new change events for each key // ProduceKeyChanges creates new change events for each key
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
for _, key := range keys { for _, key := range keys {
var m sarama.ProducerMessage var m sarama.ProducerMessage

View file

@ -32,17 +32,18 @@ type Database interface {
// OneTimeKeysCount returns a count of all OTKs for this device. // OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). // for this (user, device).
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
// Returns an error if there was a problem storing the keys. // Returns an error if there was a problem storing the keys.
StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.

View file

@ -20,7 +20,6 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL, ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL, key_json TEXT NOT NULL,
-- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
-- This means we do not store an unbounded append-only log of device keys, which is not actually
-- required in the spec because in the event of a missed update the server fetches the entire
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
stream_id BIGINT NOT NULL,
-- Clobber based on tuple of user/device. -- Clobber based on tuple of user/device.
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
); );
` `
const upsertDeviceKeysSQL = "" + const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
" VALUES ($1, $2, $3, $4)" + " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
" DO UPDATE SET key_json = $4" " DO UPDATE SET key_json = $4, stream_id = $5"
const selectDeviceKeysSQL = "" + const selectDeviceKeysSQL = "" +
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
} }
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -73,38 +81,46 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err return nil, err
} }
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
// this will be '' when there is no device // this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
} }
return nil return nil
} }
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int, err error) {
now := time.Now().Unix() err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&streamID)
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { return
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
) )
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
})
} }
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -114,15 +130,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
for _, d := range deviceIDs { for _, d := range deviceIDs {
deviceIDMap[d] = true deviceIDMap[d] = true
} }
var result []api.DeviceKeys var result []api.DeviceMessage
for rows.Next() { for rows.Next() {
var dk api.DeviceKeys var dk api.DeviceMessage
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
return nil, err return nil, err
} }
dk.KeyJSON = []byte(keyJSON) dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
// include the key if we want all keys (no device) or it was asked // include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk) result = append(result, dk)

View file

@ -43,15 +43,36 @@ func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
} }
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
} }
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys) // work out the latest stream IDs for each user
userIDToStreamID := make(map[string]int)
for _, k := range keys {
userIDToStreamID[k.UserID] = 0
}
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
for userID := range userIDToStreamID {
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
if err != nil {
return err
}
userIDToStreamID[userID] = streamID
}
// set the stream IDs for each key
for i := range keys {
k := keys[i]
userIDToStreamID[k.UserID]++ // start stream from 1
k.StreamID = userIDToStreamID[k.UserID]
keys[i] = k
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
} }
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
} }

View file

@ -20,7 +20,6 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL, ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL, key_json TEXT NOT NULL,
stream_id BIGINT NOT NULL,
-- Clobber based on tuple of user/device. -- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id) UNIQUE (user_id, device_id)
); );
` `
const upsertDeviceKeysSQL = "" + const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
" VALUES ($1, $2, $3, $4)" + " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (user_id, device_id)" + " ON CONFLICT (user_id, device_id)" +
" DO UPDATE SET key_json = $4" " DO UPDATE SET key_json = $4, stream_id = $5"
const selectDeviceKeysSQL = "" + const selectDeviceKeysSQL = "" +
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
} }
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err return nil, err
} }
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool) deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs { for _, d := range deviceIDs {
deviceIDMap[d] = true deviceIDMap[d] = true
@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceKeys var result []api.DeviceMessage
for rows.Next() { for rows.Next() {
var dk api.DeviceKeys var dk api.DeviceMessage
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
return nil, err return nil, err
} }
dk.KeyJSON = []byte(keyJSON) dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
// include the key if we want all keys (no device) or it was asked // include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk) result = append(result, dk)
@ -103,30 +112,35 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
return result, rows.Err() return result, rows.Err()
} }
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
// this will be '' when there is no device // this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
} }
return nil return nil
} }
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int, err error) {
now := time.Now().Unix() err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&streamID)
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { return
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
) )
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
})
} }

View file

@ -32,9 +32,10 @@ type OneTimeKeys interface {
} }
type DeviceKeys interface { type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int, err error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
} }
type KeyChanges interface { type KeyChanges interface {

View file

@ -98,7 +98,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
defer func() { defer func() {
s.updateOffset(msg) s.updateOffset(msg)
}() }()
var output api.DeviceKeys var output api.DeviceMessage
if err := json.Unmarshal(msg.Value, &output); err != nil { if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream // If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server") log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server")