Fix concurrent map write in key server
This commit is contained in:
parent
f3dae0e749
commit
c1463db6c9
|
@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
|
||||||
|
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
|
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
|
||||||
|
var respMu sync.Mutex
|
||||||
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
||||||
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||||
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||||
|
@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
}
|
}
|
||||||
|
|
||||||
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
||||||
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
|
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
|
||||||
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
|
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
|
||||||
// perform key queries for remote devices
|
// perform key queries for remote devices
|
||||||
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
|
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
|
||||||
|
@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
||||||
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
|
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
|
||||||
) map[string]map[string][]string {
|
) map[string]map[string][]string {
|
||||||
fetchRemote := make(map[string]map[string][]string)
|
fetchRemote := make(map[string]map[string][]string)
|
||||||
for domain, userToDeviceMap := range domainToDeviceKeys {
|
for domain, userToDeviceMap := range domainToDeviceKeys {
|
||||||
|
@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
||||||
// we can't safely return keys from the db when all devices are requested as we don't
|
// we can't safely return keys from the db when all devices are requested as we don't
|
||||||
// know if one has just been added.
|
// know if one has just been added.
|
||||||
if len(deviceIDs) > 0 {
|
if len(deviceIDs) > 0 {
|
||||||
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
|
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -542,7 +543,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
||||||
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
|
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
|
||||||
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
||||||
respMu.Lock()
|
respMu.Lock()
|
||||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
|
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
|
||||||
respMu.Unlock()
|
respMu.Unlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
|
@ -573,7 +574,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
||||||
// inspecting the failures map though so they can know it's a cached response.
|
// inspecting the failures map though so they can know it's a cached response.
|
||||||
for userID, dkeys := range devKeys {
|
for userID, dkeys := range devKeys {
|
||||||
// drop the error as it's already a failure at this point
|
// drop the error as it's already a failure at this point
|
||||||
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys)
|
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
|
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
|
||||||
|
@ -585,7 +586,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||||
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
|
||||||
) error {
|
) error {
|
||||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||||
|
@ -598,9 +599,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||||
if len(deviceIDs) == 0 && len(keys) == 0 {
|
if len(deviceIDs) == 0 && len(keys) == 0 {
|
||||||
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
||||||
}
|
}
|
||||||
|
respMu.Lock()
|
||||||
if res.DeviceKeys[userID] == nil {
|
if res.DeviceKeys[userID] == nil {
|
||||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||||
}
|
}
|
||||||
|
respMu.Unlock()
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if len(key.KeyJSON) == 0 {
|
if len(key.KeyJSON) == 0 {
|
||||||
|
@ -610,7 +613,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||||
DisplayName string `json:"device_display_name,omitempty"`
|
DisplayName string `json:"device_display_name,omitempty"`
|
||||||
}{key.DisplayName})
|
}{key.DisplayName})
|
||||||
|
respMu.Lock()
|
||||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||||
|
respMu.Unlock()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue