diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index c812b9d65..c3ac3cdc7 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -33,7 +33,6 @@ import ( "sync" "unicode" - "github.com/google/uuid" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/thumbnailer" @@ -400,22 +399,16 @@ func (r *downloadRequest) respondFromLocalFile( } func multipartResponse(w http.ResponseWriter, r *downloadRequest, contentType string, responseFile io.Reader) (int64, error) { - // Update the header to be multipart/mixed; boundary=$randomBoundary - boundary := uuid.NewString() - w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary) - - w.Header().Del("Content-Length") // let Go handle the content length mw := multipart.NewWriter(w) + // Update the header to be multipart/mixed; boundary=$randomBoundary + w.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary()) + w.Header().Del("Content-Length") // let Go handle the content length defer func() { if err := mw.Close(); err != nil { r.Logger.WithError(err).Error("Failed to close multipart writer") } }() - if err := mw.SetBoundary(boundary); err != nil { - return 0, fmt.Errorf("failed to set multipart boundary: %w", err) - } - // JSON object part jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{ "Content-Type": {"application/json"}, @@ -858,7 +851,7 @@ func (r *downloadRequest) fetchRemoteFile( var reader io.Reader var parseErr error if isAuthed { - parseErr, contentLength, reader = parseMultipartResponse(r, resp, maxFileSizeBytes) + contentLength, reader, parseErr = parseMultipartResponse(r, resp, maxFileSizeBytes) } else { // The reader returned here will be limited either by the Content-Length // and/or the configured maximum media size. @@ -928,48 +921,48 @@ func (r *downloadRequest) fetchRemoteFile( return types.Path(finalPath), duplicate, nil } -func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSizeBytes config.FileSizeBytes) (error, int64, io.Reader) { +func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) { _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) if err != nil { - return err, 0, nil + return 0, nil, err } if params["boundary"] == "" { - return fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin), 0, nil + return 0, nil, fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) } mr := multipart.NewReader(resp.Body, params["boundary"]) // Get the first, JSON, part p, err := mr.NextPart() if err != nil { - return err, 0, nil + return 0, nil, err } defer p.Close() // nolint: errcheck if p.Header.Get("Content-Type") != "application/json" { - return fmt.Errorf("first part of the response must be application/json"), 0, nil + return 0, nil, fmt.Errorf("first part of the response must be application/json") } // Try to parse media meta information meta := mediaMeta{} if err = json.NewDecoder(p).Decode(&meta); err != nil { - return err, 0, nil + return 0, nil, err } defer p.Close() // nolint: errcheck // Get the actual media content p, err = mr.NextPart() if err != nil { - return err, 0, nil + return 0, nil, err } redirect := p.Header.Get("Location") if redirect != "" { - return fmt.Errorf("Location header is not yet supported"), 0, nil + return 0, nil, fmt.Errorf("Location header is not yet supported") } contentLength, reader, err := r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes) // For multipart requests, we need to get the Content-Type of the second part, which is the actual media r.MediaMetadata.ContentType = types.ContentType(p.Header.Get("Content-Type")) - return err, contentLength, reader + return contentLength, reader, err } // contentDispositionFor returns the Content-Disposition for a given diff --git a/mediaapi/routing/download_test.go b/mediaapi/routing/download_test.go index 11368919a..9654b7474 100644 --- a/mediaapi/routing/download_test.go +++ b/mediaapi/routing/download_test.go @@ -35,7 +35,7 @@ func Test_Multipart(t *testing.T) { assert.NoError(t, err) defer resp.Body.Close() // contentLength is always 0, since there's no Content-Length header on the multipart part. - err, _, reader := parseMultipartResponse(r, resp, 1000) + _, reader, err := parseMultipartResponse(r, resp, 1000) assert.NoError(t, err) gotResponse, err := io.ReadAll(reader) assert.NoError(t, err)