Fix concurrent map write in key server
This commit is contained in:
parent
f3dae0e749
commit
c1463db6c9
|
@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
|
|||
|
||||
// nolint:gocyclo
|
||||
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
|
||||
var respMu sync.Mutex
|
||||
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
||||
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||
|
@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
|||
}
|
||||
|
||||
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
||||
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
|
||||
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
|
||||
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
|
||||
// perform key queries for remote devices
|
||||
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
|
||||
|
@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
|||
}
|
||||
|
||||
func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
|
||||
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
|
||||
) map[string]map[string][]string {
|
||||
fetchRemote := make(map[string]map[string][]string)
|
||||
for domain, userToDeviceMap := range domainToDeviceKeys {
|
||||
|
@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
|||
// we can't safely return keys from the db when all devices are requested as we don't
|
||||
// know if one has just been added.
|
||||
if len(deviceIDs) > 0 {
|
||||
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
|
||||
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
@ -542,7 +543,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
|||
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
|
||||
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
||||
respMu.Lock()
|
||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
|
||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
|
||||
respMu.Unlock()
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
|
@ -573,7 +574,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
|||
// inspecting the failures map though so they can know it's a cached response.
|
||||
for userID, dkeys := range devKeys {
|
||||
// drop the error as it's already a failure at this point
|
||||
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys)
|
||||
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
|
||||
}
|
||||
|
||||
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
|
||||
|
@ -585,7 +586,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
|||
}
|
||||
|
||||
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
||||
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
|
||||
) error {
|
||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||
|
@ -598,9 +599,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
|||
if len(deviceIDs) == 0 && len(keys) == 0 {
|
||||
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
||||
}
|
||||
respMu.Lock()
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
respMu.Unlock()
|
||||
|
||||
for _, key := range keys {
|
||||
if len(key.KeyJSON) == 0 {
|
||||
|
@ -610,7 +613,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
|||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{key.DisplayName})
|
||||
respMu.Lock()
|
||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||
respMu.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue