Fix retrieving cross-signing signatures in /user/devices/{userId} (#2368)

* Fix retrieving cross-signing signatures in `/user/devices/{userId}`

We need to know the target device IDs in order to get the signatures and we weren't populating those.

* Fix up signature retrieval

* Fix SQLite

* Always include the target's own signatures as well as the requesting user
This commit is contained in:
Neil Alexander 2022-04-22 14:58:24 +01:00 committed by GitHub
parent c07f347f00
commit 6d78c4d67d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 52 additions and 26 deletions

View file

@ -43,6 +43,9 @@ func GetUserDevices(
}, },
} }
sigRes := &keyapi.QuerySignaturesResponse{} sigRes := &keyapi.QuerySignaturesResponse{}
for _, dev := range res.Devices {
sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID))
}
keyAPI.QuerySignatures(req.Context(), sigReq, sigRes) keyAPI.QuerySignatures(req.Context(), sigReq, sigRes)
response := gomatrixserverlib.RespUserDevices{ response := gomatrixserverlib.RespUserDevices{

View file

@ -455,10 +455,10 @@ func (a *KeyInternalAPI) processOtherSignatures(
func (a *KeyInternalAPI) crossSigningKeysFromDatabase( func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
) { ) {
for userID := range req.UserToDevices { for targetUserID := range req.UserToDevices {
keys, err := a.DB.CrossSigningKeysForUser(ctx, userID) keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", userID) logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
continue continue
} }
@ -469,9 +469,9 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
break break
} }
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, userID, keyID) sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", userID, keyID) logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
continue continue
} }
@ -491,7 +491,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
case req.UserID != "" && originUserID == req.UserID: case req.UserID != "" && originUserID == req.UserID:
// Include signatures that we created // Include signatures that we created
appendSignature(originUserID, originKeyID, signature) appendSignature(originUserID, originKeyID, signature)
case originUserID == userID: case originUserID == targetUserID:
// Include signatures that were created by the person whose key // Include signatures that were created by the person whose key
// we are processing // we are processing
appendSignature(originUserID, originKeyID, signature) appendSignature(originUserID, originKeyID, signature)
@ -501,13 +501,13 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
switch keyType { switch keyType {
case gomatrixserverlib.CrossSigningKeyPurposeMaster: case gomatrixserverlib.CrossSigningKeyPurposeMaster:
res.MasterKeys[userID] = key res.MasterKeys[targetUserID] = key
case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
res.SelfSigningKeys[userID] = key res.SelfSigningKeys[targetUserID] = key
case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
res.UserSigningKeys[userID] = key res.UserSigningKeys[targetUserID] = key
} }
} }
} }
@ -546,7 +546,8 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
} }
for _, targetKeyID := range forTargetUser { for _, targetKeyID := range forTargetUser {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetKeyID) // Get own signatures only.
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),

View file

@ -313,9 +313,31 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
// Finally, append signatures that we know about // Finally, append signatures that we know about
// TODO: This is horrible because we need to round-trip the signature from // TODO: This is horrible because we need to round-trip the signature from
// JSON, add the signatures and marshal it again, for some reason? // JSON, add the signatures and marshal it again, for some reason?
for userID, forUserID := range res.DeviceKeys {
for keyID, key := range forUserID { for targetUserID, masterKey := range res.MasterKeys {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, userID, gomatrixserverlib.KeyID(keyID)) for targetKeyID := range masterKey.Keys {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
if err != nil {
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
}
if len(sigMap) == 0 {
continue
}
for sourceUserID, forSourceUser := range sigMap {
for sourceKeyID, sourceSig := range forSourceUser {
if _, ok := masterKey.Signatures[sourceUserID]; !ok {
masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
}
}
}
}
for targetUserID, forUserID := range res.DeviceKeys {
for targetKeyID, key := range forUserID {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
if err != nil { if err != nil {
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue continue
@ -339,7 +361,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
} }
if js, err := json.Marshal(deviceKey); err == nil { if js, err := json.Marshal(deviceKey); err == nil {
res.DeviceKeys[userID][keyID] = js res.DeviceKeys[targetUserID][targetKeyID] = js
} }
} }
} }

View file

@ -81,7 +81,7 @@ type Database interface {
CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error

View file

@ -39,7 +39,7 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
const selectCrossSigningSigsForTargetSQL = "" + const selectCrossSigningSigsForTargetSQL = "" +
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
" WHERE target_user_id = $1 AND target_key_id = $2" " WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3"
const upsertCrossSigningSigsForTargetSQL = "" + const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
@ -72,9 +72,9 @@ func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, erro
} }
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID, ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) (r types.CrossSigningSigMap, err error) { ) (r types.CrossSigningSigMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID) rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -190,7 +190,7 @@ func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (
keyID: key, keyID: key,
}, },
} }
sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, keyID) sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
if err != nil { if err != nil {
continue continue
} }
@ -219,8 +219,8 @@ func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID strin
} }
// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. // CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
func (d *Database) CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, targetUserID, targetKeyID) return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
} }
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. // StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.

View file

@ -39,7 +39,7 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
const selectCrossSigningSigsForTargetSQL = "" + const selectCrossSigningSigsForTargetSQL = "" +
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
" WHERE target_user_id = $1 AND target_key_id = $2" " WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3"
const upsertCrossSigningSigsForTargetSQL = "" + const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
@ -71,13 +71,13 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error)
} }
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID, ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) (r types.CrossSigningSigMap, err error) { ) (r types.CrossSigningSigMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID) rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed")
r = types.CrossSigningSigMap{} r = types.CrossSigningSigMap{}
for rows.Next() { for rows.Next() {
var userID string var userID string

View file

@ -64,7 +64,7 @@ type CrossSigningKeys interface {
} }
type CrossSigningSigs interface { type CrossSigningSigs interface {
SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
} }