diff --git a/serverkeyapi/internal/api.go b/serverkeyapi/internal/api.go index ec3479a8c..2b42d2081 100644 --- a/serverkeyapi/internal/api.go +++ b/serverkeyapi/internal/api.go @@ -70,7 +70,6 @@ func (s *ServerKeyAPI) StoreKeys( return s.OurKeyRing.KeyDatabase.StoreKeys(ctx, results) } -// nolint:gocyclo func (s *ServerKeyAPI) FetchKeys( _ context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, @@ -78,64 +77,24 @@ func (s *ServerKeyAPI) FetchKeys( // Run in a background context - we don't want to stop this work just // because the caller gives up waiting. ctx := context.Background() + now := gomatrixserverlib.AsTimestamp(time.Now()) results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} origRequests := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{} for k, v := range requests { origRequests[k] = v } - now := gomatrixserverlib.AsTimestamp(time.Now()) // First, check if any of these key checks are for our own keys. If // they are then we will satisfy them directly. - for req := range requests { - if req.ServerName == s.Cfg.Matrix.ServerName { - // We found a key request that is supposed to be for our own - // keys. Remove it from the request list so we don't hit the - // database or the fetchers for it. - delete(requests, req) - - // Look up our own keys. - request := &api.QueryLocalKeysRequest{} - response := &api.QueryLocalKeysResponse{} - if err := s.QueryLocalKeys(ctx, request, response); err != nil { - return nil, err - } - - // Depending on whether the key is expired or not, we'll need - // to write slightly different - if verifyKeys, ok := response.ServerKeys.VerifyKeys[req.KeyID]; ok { - // The key is current. - results[req] = gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: verifyKeys, - ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, - ValidUntilTS: response.ServerKeys.ValidUntilTS, - } - } else if verifyKeys, ok := response.ServerKeys.OldVerifyKeys[req.KeyID]; ok { - // The key is expired. - results[req] = gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: verifyKeys.VerifyKey, - ExpiredTS: verifyKeys.ExpiredTS, - ValidUntilTS: gomatrixserverlib.PublicKeyNotValid, - } - } - } + if err := s.handleLocalKeys(ctx, requests, results); err != nil { + return nil, err } // Then consult our local database and see if we have the requested // keys. These might come from a cache, depending on the database // implementation used. - if dbResults, err := s.OurKeyRing.KeyDatabase.FetchKeys(ctx, requests); err == nil { - // We successfully got some keys. Add them to the results. - for req, res := range dbResults { - results[req] = res - - // If the key is valid right now then we can also remove it - // from the request list as we don't need to fetch it again - // in that case. - if res.WasValidAt(now, true) { - delete(requests, req) - } - } + if err := s.handleDatabaseKeys(ctx, now, requests, results); err != nil { + return nil, err } // For any key requests that we still have outstanding, next try to @@ -146,68 +105,13 @@ func (s *ServerKeyAPI) FetchKeys( if len(requests) == 0 { break } - fetcherCtx, fetcherCancel := context.WithTimeout(ctx, time.Second*30) - defer fetcherCancel() - logrus.WithFields(logrus.Fields{ - "fetcher_name": fetcher.FetcherName(), - }).Infof("Fetching %d key(s)", len(requests)) - if fetcherResults, err := fetcher.FetchKeys(fetcherCtx, requests); err == nil { - // Build a map of the results that we want to commit to the - // database. We do this in a separate map because otherwise we - // might end up trying to rewrite database entries. - storeResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} - - // Now let's look at the results that we got from this fetcher. - for req, res := range fetcherResults { - if prev, ok := results[req]; ok { - // We've already got a previous entry for this request - // so let's see if the newly retrieved one contains a more - // up-to-date validity period. - if res.ValidUntilTS > prev.ValidUntilTS { - // This key is newer than the one we had so let's store - // it in the database. - if req.ServerName != s.Cfg.Matrix.ServerName { - storeResults[req] = res - } - } - } else { - // We didn't already have a previous entry for this request - // so store it in the database anyway for now. - if req.ServerName != s.Cfg.Matrix.ServerName { - storeResults[req] = res - } - } - - // Update the results map with this new result. If nothing - // else, we can try verifying against this key. - results[req] = res - - // If the key is valid right now then we can remove it from the - // request list as we won't need to re-fetch it. - if res.WasValidAt(now, true) { - delete(requests, req) - } - } - - // Store the keys from our store map. - if err = s.OurKeyRing.KeyDatabase.StoreKeys(ctx, storeResults); err != nil { - logrus.WithError(err).WithFields(logrus.Fields{ - "fetcher_name": fetcher.FetcherName(), - "database_name": s.OurKeyRing.KeyDatabase.FetcherName(), - }).Errorf("Failed to store keys in the database") - return nil, fmt.Errorf("server key API failed to store retrieved keys: %w", err) - } - - if len(storeResults) > 0 { - logrus.WithFields(logrus.Fields{ - "fetcher_name": fetcher.FetcherName(), - }).Infof("Updated %d of %d key(s) in database", len(storeResults), len(results)) - } - } else { + // Ask the fetcher to look up our keys. + if err := s.handleFetcherKeys(ctx, now, fetcher, requests, results); err != nil { logrus.WithError(err).WithFields(logrus.Fields{ "fetcher_name": fetcher.FetcherName(), }).Errorf("Failed to retrieve %d key(s)", len(requests)) + continue } } @@ -233,3 +137,153 @@ func (s *ServerKeyAPI) FetchKeys( func (s *ServerKeyAPI) FetcherName() string { return fmt.Sprintf("ServerKeyAPI (wrapping %q)", s.OurKeyRing.KeyDatabase.FetcherName()) } + +// handleLocalKeys handles cases where the key request contains +// a request for our own server keys. +func (s *ServerKeyAPI) handleLocalKeys( + ctx context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + for req := range requests { + if req.ServerName == s.Cfg.Matrix.ServerName { + // We found a key request that is supposed to be for our own + // keys. Remove it from the request list so we don't hit the + // database or the fetchers for it. + delete(requests, req) + + // Look up our own keys. + request := &api.QueryLocalKeysRequest{} + response := &api.QueryLocalKeysResponse{} + if err := s.QueryLocalKeys(ctx, request, response); err != nil { + return err + } + + // Depending on whether the key is expired or not, we'll need + // to write slightly different + if verifyKeys, ok := response.ServerKeys.VerifyKeys[req.KeyID]; ok { + // The key is current. + results[req] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: verifyKeys, + ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, + ValidUntilTS: response.ServerKeys.ValidUntilTS, + } + } else if verifyKeys, ok := response.ServerKeys.OldVerifyKeys[req.KeyID]; ok { + // The key is expired. + results[req] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: verifyKeys.VerifyKey, + ExpiredTS: verifyKeys.ExpiredTS, + ValidUntilTS: gomatrixserverlib.PublicKeyNotValid, + } + } + } + } + + return nil +} + +// handleDatabaseKeys handles cases where the key requests can be +// satisfied from our local database/cache. +func (s *ServerKeyAPI) handleDatabaseKeys( + ctx context.Context, + now gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + // Ask the database/cache for the keys. + dbResults, err := s.OurKeyRing.KeyDatabase.FetchKeys(ctx, requests) + if err != nil { + return err + } + + // We successfully got some keys. Add them to the results. + for req, res := range dbResults { + results[req] = res + + // If the key is valid right now then we can also remove it + // from the request list as we don't need to fetch it again + // in that case. + if res.WasValidAt(now, true) { + delete(requests, req) + } + } + return nil +} + +// handleDatabaseKeys handles cases where a fetcher can satisfy +// the remaining requests. +func (s *ServerKeyAPI) handleFetcherKeys( + ctx context.Context, + now gomatrixserverlib.Timestamp, + fetcher gomatrixserverlib.KeyFetcher, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + logrus.WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + }).Infof("Fetching %d key(s)", len(requests)) + + // Create a context that limits our requests to 30 seconds. + fetcherCtx, fetcherCancel := context.WithTimeout(ctx, time.Second*30) + defer fetcherCancel() + + // Try to fetch the keys. + fetcherResults, err := fetcher.FetchKeys(fetcherCtx, requests) + if err != nil { + return err + } + + // Build a map of the results that we want to commit to the + // database. We do this in a separate map because otherwise we + // might end up trying to rewrite database entries. + storeResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} + + // Now let's look at the results that we got from this fetcher. + for req, res := range fetcherResults { + if prev, ok := results[req]; ok { + // We've already got a previous entry for this request + // so let's see if the newly retrieved one contains a more + // up-to-date validity period. + if res.ValidUntilTS > prev.ValidUntilTS { + // This key is newer than the one we had so let's store + // it in the database. + if req.ServerName != s.Cfg.Matrix.ServerName { + storeResults[req] = res + } + } + } else { + // We didn't already have a previous entry for this request + // so store it in the database anyway for now. + if req.ServerName != s.Cfg.Matrix.ServerName { + storeResults[req] = res + } + } + + // Update the results map with this new result. If nothing + // else, we can try verifying against this key. + results[req] = res + + // If the key is valid right now then we can remove it from the + // request list as we won't need to re-fetch it. + if res.WasValidAt(now, true) { + delete(requests, req) + } + } + + // Store the keys from our store map. + if err = s.OurKeyRing.KeyDatabase.StoreKeys(ctx, storeResults); err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + "database_name": s.OurKeyRing.KeyDatabase.FetcherName(), + }).Errorf("Failed to store keys in the database") + return fmt.Errorf("server key API failed to store retrieved keys: %w", err) + } + + if len(storeResults) > 0 { + logrus.WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + }).Infof("Updated %d of %d key(s) in database", len(storeResults), len(results)) + } + + return nil +}