diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 648e3f1e9..d74229356 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -15,12 +15,12 @@ package routing import ( + "bytes" "context" "encoding/json" "fmt" "io" "io/ioutil" - "bytes" "mime" "net/http" "net/url" @@ -675,6 +675,36 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( return nil } +func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser) (int64, io.Reader, error) { + var reader io.Reader + var contentLength int64 + + if contentLengthHeader != "" { + parsedLength, parseErr := strconv.ParseInt(contentLengthHeader, 10, 64) + + if parseErr != nil { + r.Logger.WithError(parseErr).Warn("Failed to parse content length") + return 0, nil, errors.Wrap(parseErr, "invalid response from remote server") + } + + reader = *body + contentLength = parsedLength + } else { + // Content-Length header is missing, we need to read the whole body to get its length + bodyBytes, readAllErr := ioutil.ReadAll(*body) + if readAllErr != nil { + r.Logger.WithError(readAllErr).Warn("Could not read response body from remote server") + return 0, nil, errors.Wrap(readAllErr, "invalid response from remote server") + } + + contentLength = int64(len(bodyBytes)) + reader = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) + } + + return contentLength, reader, nil + +} + func (r *downloadRequest) fetchRemoteFile( ctx context.Context, client *gomatrixserverlib.Client, @@ -694,31 +724,11 @@ func (r *downloadRequest) fetchRemoteFile( } defer resp.Body.Close() // nolint: errcheck - var reader io.Reader - var contentLength int64 - - if resp.Header.Get("Content-Length") != "" { - parsedLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - - if err != nil { - r.Logger.WithError(err).Warn("Failed to parse content length") - return "", false, errors.Wrap(err, "invalid response from remote server") - } - - reader = resp.Body - contentLength = parsedLength - } else { - // Content-Length header is missing, we need to read the whole body to get its length - bodyBytes, err := ioutil.ReadAll(resp.Body) - if err != nil { - r.Logger.WithError(err).Warn("Could not read response body from remote server") - return "", false, fmt.Errorf("file could not be downloaded from remote server") - } - - contentLength = int64(len(bodyBytes)) - reader = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) + contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body) + if parseErr != nil { + return "", false, parseErr } - + if contentLength > int64(maxFileSizeBytes) { // TODO: Bubble up this as a 413 return "", false, fmt.Errorf("remote file is too large (%v > %v bytes)", contentLength, maxFileSizeBytes)