Fix concurrent map write in key server

This commit is contained in:
Neil Alexander 2022-10-19 12:03:12 +01:00
parent f3dae0e749
commit c1463db6c9
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
// nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
var respMu sync.Mutex
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
// perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string {
fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys {
@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
// we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added.
if len(deviceIDs) > 0 {
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
if err == nil {
continue
}
@ -542,7 +543,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
respMu.Lock()
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
respMu.Unlock()
if err != nil {
logrus.WithFields(logrus.Fields{
@ -573,7 +574,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
// inspecting the failures map though so they can know it's a cached response.
for userID, dkeys := range devKeys {
// drop the error as it's already a failure at this point
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys)
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
}
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
@ -585,7 +586,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
}
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
@ -598,9 +599,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
if len(deviceIDs) == 0 && len(keys) == 0 {
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
}
respMu.Lock()
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
respMu.Unlock()
for _, key := range keys {
if len(key.KeyJSON) == 0 {
@ -610,7 +613,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
respMu.Lock()
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
respMu.Unlock()
}
return nil
}