diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 0c4d2c20e..b397c5460 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -52,6 +52,13 @@ func GetUserDevices( Devices: []gomatrixserverlib.RespUserDevice{}, } + if masterKey, ok := sigRes.MasterKeys[userID]; ok { + response.MasterKey = &masterKey + } + if selfSigningKey, ok := sigRes.SelfSigningKeys[userID]; ok { + response.SelfSigningKey = &selfSigningKey + } + for _, dev := range res.Devices { var key gomatrixserverlib.RespUserDeviceKeys err := json.Unmarshal(dev.DeviceKeys.KeyJSON, &key) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index c379fcf24..72bb6576f 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -252,7 +252,12 @@ type QuerySignaturesRequest struct { type QuerySignaturesResponse struct { // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap - Error *KeyError + // A map of target user ID -> cross-signing master key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing self-signing key + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // The request error, if any + Error *KeyError } type InputDeviceListUpdateRequest struct { diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 6fa5f1f87..b2703a49b 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -105,7 +105,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P // If the user hasn't given a new master key, then let's go and get their // existing keys from the database. if !hasMasterKey { - existingKeys, err := a.DB.CrossSigningKeysForUser(ctx, req.UserID) + existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) if err != nil { res.Error = &api.KeyError{ Err: "Retrieving cross-signing keys from database failed: " + err.Error(), @@ -405,17 +405,11 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase( continue } - for keyType, keyData := range keys { - b64 := keyData.Encode() - keyID := gomatrixserverlib.KeyID("ed25519:" + b64) - key := gomatrixserverlib.CrossSigningKey{ - UserID: userID, - Usage: []gomatrixserverlib.CrossSigningKeyPurpose{ - keyType, - }, - Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ - keyID: keyData, - }, + for keyType, key := range keys { + var keyID gomatrixserverlib.KeyID + for id := range key.Keys { + keyID = id + break } sigs, err := a.DB.CrossSigningSigsForTarget(ctx, userID, keyID) @@ -465,7 +459,26 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase( func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) { for targetUserID, forTargetUser := range req.TargetIDs { for _, targetKeyID := range forTargetUser { - keyMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetKeyID) + keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) + if err != nil { + if err == sql.ErrNoRows { + continue + } + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err), + } + } + + for targetPurpose, targetKey := range keyMap { + switch targetPurpose { + case gomatrixserverlib.CrossSigningKeyPurposeMaster: + res.MasterKeys[targetUserID] = targetKey + case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: + res.SelfSigningKeys[targetUserID] = targetKey + } + } + + sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetKeyID) if err != nil { if err == sql.ErrNoRows { continue @@ -476,7 +489,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign return } - for sourceUserID, forSourceUser := range keyMap { + for sourceUserID, forSourceUser := range sigMap { for sourceKeyID, sourceSig := range forSourceUser { if res.Signatures == nil { res.Signatures = map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 756dc32ad..b9db81ad6 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -78,7 +78,8 @@ type Database interface { // MarkDeviceListStale sets the stale bit for this user to isStale. MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error - CrossSigningKeysForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) + CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) + CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 767242950..64ce53ef1 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -159,7 +159,46 @@ func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isSta } // CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { +func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { + keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) + } + results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + for purpose, key := range keyMap { + keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) + result := gomatrixserverlib.CrossSigningKey{ + UserID: userID, + Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + keyID: key, + }, + } + sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, keyID) + if err != nil { + continue + } + for sigUserID, forSigUserID := range sigMap { + if userID != sigUserID { + continue + } + if result.Signatures == nil { + result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + if _, ok := result.Signatures[sigUserID]; !ok { + result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for sigKeyID, sigBytes := range forSigUserID { + result.Signatures[sigUserID][sigKeyID] = sigBytes + } + } + results[purpose] = result + } + return results, nil +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) } diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 38be34f65..0c567a962 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -49,6 +49,8 @@ func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryD } func (k *mockKeyAPI) InputDeviceListUpdate(ctx context.Context, req *keyapi.InputDeviceListUpdateRequest, res *keyapi.InputDeviceListUpdateResponse) { +} +func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) { } type mockRoomserverAPI struct {