Fix concurrent map write in key server

This commit is contained in:
Neil Alexander 2022-10-19 12:03:12 +01:00
parent f3dae0e749
commit c1463db6c9
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
// nolint:gocyclo // nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
var respMu sync.Mutex
res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us // attempt to satisfy key queries from the local database first as we should get device updates pushed to us
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
// perform key queries for remote devices // perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
func (a *KeyInternalAPI) remoteKeysFromDatabase( func (a *KeyInternalAPI) remoteKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string { ) map[string]map[string][]string {
fetchRemote := make(map[string]map[string][]string) fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys { for domain, userToDeviceMap := range domainToDeviceKeys {
@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
// we can't safely return keys from the db when all devices are requested as we don't // we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added. // know if one has just been added.
if len(deviceIDs) > 0 { if len(deviceIDs) > 0 {
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs) err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
if err == nil { if err == nil {
continue continue
} }
@ -542,7 +543,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
// user so the fact that we're populating all devices here isn't a problem so long as we have devices. // user so the fact that we're populating all devices here isn't a problem so long as we have devices.
respMu.Lock() respMu.Lock()
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil) err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
respMu.Unlock() respMu.Unlock()
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
@ -573,7 +574,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
// inspecting the failures map though so they can know it's a cached response. // inspecting the failures map though so they can know it's a cached response.
for userID, dkeys := range devKeys { for userID, dkeys := range devKeys {
// drop the error as it's already a failure at this point // drop the error as it's already a failure at this point
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys) _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
} }
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
@ -585,7 +586,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
} }
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
) error { ) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote. // if we can't query the db or there are fewer keys than requested, fetch from remote.
@ -598,9 +599,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
if len(deviceIDs) == 0 && len(keys) == 0 { if len(deviceIDs) == 0 && len(keys) == 0 {
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
} }
respMu.Lock()
if res.DeviceKeys[userID] == nil { if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage) res.DeviceKeys[userID] = make(map[string]json.RawMessage)
} }
respMu.Unlock()
for _, key := range keys { for _, key := range keys {
if len(key.KeyJSON) == 0 { if len(key.KeyJSON) == 0 {
@ -610,7 +613,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"` DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName}) }{key.DisplayName})
respMu.Lock()
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
respMu.Unlock()
} }
return nil return nil
} }