From 89d64c6dca99474111883502ffad83c041fd15aa Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:15:56 +0200 Subject: [PATCH] Add support for getting authenticated federation media requests --- mediaapi/mediaapi.go | 3 +- mediaapi/routing/download.go | 70 +++++++++++++++++++++++++++++------- mediaapi/routing/routing.go | 13 ++++--- setup/monolith.go | 2 +- 4 files changed, 69 insertions(+), 19 deletions(-) diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index adeb89ab2..8b843e907 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -33,6 +33,7 @@ func AddPublicRoutes( cfg *config.Dendrite, userAPI userapi.MediaUserAPI, client *fclient.Client, + fedClient fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, ) { mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database) @@ -41,6 +42,6 @@ func AddPublicRoutes( } routing.Setup( - routers, cfg, mediaDB, userAPI, client, keyRing, + routers, cfg, mediaDB, userAPI, client, fedClient, keyRing, ) } diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 75140cffa..93040bcba 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -65,6 +65,8 @@ type downloadRequest struct { Logger *log.Entry DownloadFilename string forFederation bool // whether we need to return a multipart/mixed response + fedClient fclient.FederationClient + origin spec.ServerName } // Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84 @@ -115,6 +117,7 @@ func Download( cfg *config.MediaAPI, db storage.Database, client *fclient.Client, + fedClient fclient.FederationClient, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, isThumbnailRequest bool, @@ -137,6 +140,8 @@ func Download( }), DownloadFilename: customFilename, forFederation: forFederation, + origin: cfg.Matrix.ServerName, + fedClient: fedClient, } if dReq.IsThumbnailRequest { @@ -773,8 +778,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( return nil } -func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) { - reader := *body +func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, reader io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) { var contentLength int64 if contentLengthHeader != "" { @@ -793,7 +797,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, // We successfully parsed the Content-Length, so we'll return a limited // reader that restricts us to reading only up to this size. - reader = io.NopCloser(io.LimitReader(*body, parsedLength)) + reader = io.NopCloser(io.LimitReader(reader, parsedLength)) contentLength = parsedLength } else { // Content-Length header is missing. If we have a maximum file size @@ -802,7 +806,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, // ultimately it will get rewritten later when the temp file is written // to disk. if maxFileSizeBytes > 0 { - reader = io.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes))) + reader = io.NopCloser(io.LimitReader(reader, int64(maxFileSizeBytes))) } contentLength = 0 } @@ -818,19 +822,61 @@ func (r *downloadRequest) fetchRemoteFile( ) (types.Path, bool, error) { r.Logger.Debug("Fetching remote file") - // create request for remote file - resp, err := client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) + // Attempt to download via authenticated media endpoint + isMultiPart := true + resp, err := r.fedClient.DownloadMedia(ctx, r.origin, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) { - if resp != nil && resp.StatusCode == http.StatusNotFound { - return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) + isMultiPart = false + // try again on the unauthed endpoint + // create request for remote file + resp, err = client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) + if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) { + if resp != nil && resp.StatusCode == http.StatusNotFound { + return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) + } + return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) } - return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) } defer resp.Body.Close() // nolint: errcheck - // The reader returned here will be limited either by the Content-Length - // and/or the configured maximum media size. - contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes) + var contentLength int64 + var reader io.Reader + var parseErr error + if isMultiPart { + r.Logger.Info("Downloaded file using authenticated endpoint") + _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + panic(err) + } + if params["boundary"] == "" { + return "", false, fmt.Errorf("no boundary header found on %s", r.MediaMetadata.Origin) + } + mr := multipart.NewReader(resp.Body, params["boundary"]) + + first := true + for { + p, err := mr.NextPart() + if err == io.EOF { + break + } + if err != nil { + return "", false, err + } + + if !first { + readCloser := io.NopCloser(p) + contentLength, reader, parseErr = r.GetContentLengthAndReader(p.Header.Get("Content-Length"), readCloser, maxFileSizeBytes) + break + } + + first = false + } + } else { + // The reader returned here will be limited either by the Content-Length + // and/or the configured maximum media size. + contentLength, reader, parseErr = r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes) + } + if parseErr != nil { return "", false, parseErr } diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index b5998d686..82d397e00 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -52,6 +52,7 @@ func Setup( db storage.Database, userAPI userapi.MediaUserAPI, client *fclient.Client, + federationClient fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, ) { rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting) @@ -95,12 +96,12 @@ func Setup( MXCToResult: map[string]*types.RemoteRequestResult{}, } - downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false) + downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false) v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thumbnail/{serverName}/{mediaId}", - makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false), + makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), ).Methods(http.MethodGet, http.MethodOptions) // v1 client endpoints requiring auth @@ -110,15 +111,15 @@ func Setup( v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/thumbnail/{serverName}/{mediaId}", - httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()), + httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()), ).Methods(http.MethodGet, http.MethodOptions) // same, but for federation v1fedMux.Handle("/download/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing, - makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true), + makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true), )).Methods(http.MethodGet, http.MethodOptions) v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing, - makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true), + makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true), )).Methods(http.MethodGet, http.MethodOptions) } @@ -170,6 +171,7 @@ func makeDownloadAPI( rateLimits *httputil.RateLimits, db storage.Database, client *fclient.Client, + fedClient fclient.FederationClient, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, forFederation bool, @@ -235,6 +237,7 @@ func makeDownloadAPI( cfg, db, client, + fedClient, activeRemoteRequests, activeThumbnailGeneration, strings.HasPrefix(name, "thumbnail"), diff --git a/setup/monolith.go b/setup/monolith.go index 11cc59e27..72750354b 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -78,7 +78,7 @@ func (m *Monolith) AddAllPublicRoutes( federationapi.AddPublicRoutes( processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics, ) - mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.KeyRing) + mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.FedClient, m.KeyRing) syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics) if m.RelayAPI != nil {