diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index ada9098f8..c812b9d65 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -389,47 +389,54 @@ func (r *downloadRequest) respondFromLocalFile( return nil, fmt.Errorf("io.Copy: %w", err) } } else { - // Update the header to be multipart/mixed; boundary=$randomBoundary - boundary := uuid.NewString() - w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary) - - w.Header().Del("Content-Length") // let Go handle the content length - mw := multipart.NewWriter(w) - defer func() { - if err = mw.Close(); err != nil { - r.Logger.WithError(err).Error("Failed to close multipart writer") - } - }() - - if err = mw.SetBoundary(boundary); err != nil { - return nil, fmt.Errorf("failed to set multipart boundary: %w", err) - } - - // JSON object part - jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{ - "Content-Type": {"application/json"}, - }) + var written int64 + written, err = multipartResponse(w, r, string(responseMetadata.ContentType), responseFile) if err != nil { - return nil, fmt.Errorf("failed to create json writer: %w", err) - } - if _, err = jsonWriter.Write([]byte("{}")); err != nil { - return nil, fmt.Errorf("failed to write to json writer: %w", err) - } - - // media part - mediaWriter, err := mw.CreatePart(textproto.MIMEHeader{ - "Content-Type": {string(responseMetadata.ContentType)}, - }) - if err != nil { - return nil, fmt.Errorf("failed to create media writer: %w", err) - } - if _, err = io.Copy(mediaWriter, responseFile); err != nil { - return nil, fmt.Errorf("failed to write to media writer: %w", err) + return nil, err } + responseMetadata.FileSizeBytes = types.FileSizeBytes(written) } return responseMetadata, nil } +func multipartResponse(w http.ResponseWriter, r *downloadRequest, contentType string, responseFile io.Reader) (int64, error) { + // Update the header to be multipart/mixed; boundary=$randomBoundary + boundary := uuid.NewString() + w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary) + + w.Header().Del("Content-Length") // let Go handle the content length + mw := multipart.NewWriter(w) + defer func() { + if err := mw.Close(); err != nil { + r.Logger.WithError(err).Error("Failed to close multipart writer") + } + }() + + if err := mw.SetBoundary(boundary); err != nil { + return 0, fmt.Errorf("failed to set multipart boundary: %w", err) + } + + // JSON object part + jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"application/json"}, + }) + if err != nil { + return 0, fmt.Errorf("failed to create json writer: %w", err) + } + if _, err = jsonWriter.Write([]byte("{}")); err != nil { + return 0, fmt.Errorf("failed to write to json writer: %w", err) + } + + // media part + mediaWriter, err := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {contentType}, + }) + if err != nil { + return 0, fmt.Errorf("failed to create media writer: %w", err) + } + return io.Copy(mediaWriter, responseFile) +} + func (r *downloadRequest) addDownloadFilenameToHeaders( w http.ResponseWriter, responseMetadata *types.MediaMetadata, @@ -851,47 +858,7 @@ func (r *downloadRequest) fetchRemoteFile( var reader io.Reader var parseErr error if isAuthed { - r.Logger.Debug("Downloaded file using authenticated endpoint") - var params map[string]string - _, params, err = mime.ParseMediaType(resp.Header.Get("Content-Type")) - if err != nil { - return "", false, err - } - if params["boundary"] == "" { - return "", false, fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) - } - mr := multipart.NewReader(resp.Body, params["boundary"]) - - // Get the first, JSON, part - p, multipartErr := mr.NextPart() - if multipartErr != nil { - return "", false, multipartErr - } - - if p.Header.Get("Content-Type") != "application/json" { - return "", false, fmt.Errorf("first part of the response must be application/json") - } - // Try to parse media meta information - meta := mediaMeta{} - if err = json.NewDecoder(p).Decode(&meta); err != nil { - return "", false, err - } - defer p.Close() // nolint: errcheck - - // Get the actual media content - p, multipartErr = mr.NextPart() - if multipartErr != nil { - return "", false, multipartErr - } - - redirect := p.Header.Get("Location") - if redirect != "" { - return "", false, fmt.Errorf("Location header is not yet supported") - } else { - contentLength, reader, parseErr = r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes) - // For multipart requests, we need to get the Content-Type of the second part, which is the actual media - r.MediaMetadata.ContentType = types.ContentType(p.Header.Get("Content-Type")) - } + parseErr, contentLength, reader = parseMultipartResponse(r, resp, maxFileSizeBytes) } else { // The reader returned here will be limited either by the Content-Length // and/or the configured maximum media size. @@ -961,6 +928,50 @@ func (r *downloadRequest) fetchRemoteFile( return types.Path(finalPath), duplicate, nil } +func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSizeBytes config.FileSizeBytes) (error, int64, io.Reader) { + _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return err, 0, nil + } + if params["boundary"] == "" { + return fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin), 0, nil + } + mr := multipart.NewReader(resp.Body, params["boundary"]) + + // Get the first, JSON, part + p, err := mr.NextPart() + if err != nil { + return err, 0, nil + } + defer p.Close() // nolint: errcheck + + if p.Header.Get("Content-Type") != "application/json" { + return fmt.Errorf("first part of the response must be application/json"), 0, nil + } + // Try to parse media meta information + meta := mediaMeta{} + if err = json.NewDecoder(p).Decode(&meta); err != nil { + return err, 0, nil + } + defer p.Close() // nolint: errcheck + + // Get the actual media content + p, err = mr.NextPart() + if err != nil { + return err, 0, nil + } + + redirect := p.Header.Get("Location") + if redirect != "" { + return fmt.Errorf("Location header is not yet supported"), 0, nil + } + + contentLength, reader, err := r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes) + // For multipart requests, we need to get the Content-Type of the second part, which is the actual media + r.MediaMetadata.ContentType = types.ContentType(p.Header.Get("Content-Type")) + return err, contentLength, reader +} + // contentDispositionFor returns the Content-Disposition for a given // content type. func contentDispositionFor(contentType types.ContentType) string { diff --git a/mediaapi/routing/download_test.go b/mediaapi/routing/download_test.go index 21f6bfc2c..11368919a 100644 --- a/mediaapi/routing/download_test.go +++ b/mediaapi/routing/download_test.go @@ -1,8 +1,13 @@ package routing import ( + "bytes" + "io" + "net/http" + "net/http/httptest" "testing" + "github.com/matrix-org/dendrite/mediaapi/types" "github.com/stretchr/testify/assert" ) @@ -11,3 +16,28 @@ func Test_dispositionFor(t *testing.T) { assert.Equal(t, "attachment", contentDispositionFor("image/svg"), "image/svg") assert.Equal(t, "inline", contentDispositionFor("image/jpeg"), "image/jpg") } + +func Test_Multipart(t *testing.T) { + r := &downloadRequest{ + MediaMetadata: &types.MediaMetadata{}, + } + data := bytes.Buffer{} + responseBody := "This media is plain text. Maybe somebody used it as a paste bin." + data.WriteString(responseBody) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := multipartResponse(w, r, "text/plain", &data) + assert.NoError(t, err) + })) + defer srv.Close() + + resp, err := srv.Client().Get(srv.URL) + assert.NoError(t, err) + defer resp.Body.Close() + // contentLength is always 0, since there's no Content-Length header on the multipart part. + err, _, reader := parseMultipartResponse(r, resp, 1000) + assert.NoError(t, err) + gotResponse, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, responseBody, string(gotResponse)) +} diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 20417d1cc..d567d9abd 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -138,7 +138,7 @@ var thumbnailSize = promauto.NewHistogramVec( Namespace: "dendrite", Subsystem: "mediaapi", Name: "thumbnail_size_bytes", - Help: "Total number of media_api requests for thumbnails", + Help: "Total size of media_api requests for thumbnails", Buckets: []float64{50, 100, 200, 500, 900, 1500, 3000, 6000}, }, []string{"code", "type"}, @@ -149,7 +149,7 @@ var downloadCounter = promauto.NewCounterVec( Namespace: "dendrite", Subsystem: "mediaapi", Name: "download", - Help: "Total number of media_api requests for full downloads", + Help: "Total size of media_api requests for full downloads", }, []string{"code", "type"}, ) @@ -159,8 +159,8 @@ var downloadSize = promauto.NewHistogramVec( Namespace: "dendrite", Subsystem: "mediaapi", Name: "download_size_bytes", - Help: "Total number of media_api requests for full downloads", - Buckets: []float64{200, 500, 900, 1500, 3000, 6000, 10_000, 50_000, 100_000}, + Help: "Total size of media_api requests for full downloads", + Buckets: []float64{1500, 3000, 6000, 10_000, 50_000, 100_000}, }, []string{"code", "type"}, ) @@ -181,7 +181,10 @@ func makeDownloadAPI( var requestType string if cfg.Matrix.Metrics.Enabled { split := strings.Split(name, "_") + // The first part of the split is either "download" or "thumbnail" name = split[0] + // The remainder of the split is something like "authed_download" or "unauthed_thumbnail", etc. + // This is used to curry the metrics with the given types. requestType = strings.Join(split[1:], "_") counterVec = thumbnailCounter