diff --git a/mediaapi/fileutils/fileutils.go b/mediaapi/fileutils/fileutils.go index df19eee4a..7309cb882 100644 --- a/mediaapi/fileutils/fileutils.go +++ b/mediaapi/fileutils/fileutils.go @@ -109,7 +109,7 @@ func RemoveDir(dir types.Path, logger *log.Entry) { // WriteTempFile writes to a new temporary file. // The file is deleted if there was an error while writing. func WriteTempFile( - ctx context.Context, reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path, + ctx context.Context, reqReader io.Reader, absBasePath config.Path, ) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) { size = -1 logger := util.GetLogger(ctx) @@ -124,18 +124,11 @@ func WriteTempFile( } }() - // If the max_file_size_bytes configuration option is set to a positive - // number then limit the upload to that size. Otherwise, just read the - // whole file. - limitedReader := reqReader - if maxFileSizeBytes > 0 { - limitedReader = io.LimitReader(reqReader, int64(maxFileSizeBytes)) - } // Hash the file data. The hash will be returned. The hash is useful as a // method of deduplicating files to save storage, as well as a way to conduct // integrity checks on the file data in the repository. hasher := sha256.New() - teeReader := io.TeeReader(limitedReader, hasher) + teeReader := io.TeeReader(reqReader, hasher) bytesWritten, err := io.Copy(tmpFileWriter, teeReader) if err != nil && err != io.EOF { RemoveDir(tmpDir, logger) diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 19a04b3c7..828e0b71f 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -15,7 +15,6 @@ package routing import ( - "bytes" "context" "encoding/json" "fmt" @@ -675,34 +674,41 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( return nil } -func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser) (int64, io.Reader, error) { - var reader io.Reader +func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) { + reader := *body var contentLength int64 if contentLengthHeader != "" { + // A Content-Length header is provided. Let's try to parse it. parsedLength, parseErr := strconv.ParseInt(contentLengthHeader, 10, 64) - if parseErr != nil { r.Logger.WithError(parseErr).Warn("Failed to parse content length") return 0, nil, errors.Wrap(parseErr, "invalid response from remote server") } - - reader = *body - contentLength = parsedLength - } else { - // Content-Length header is missing, we need to read the whole body to get its length - bodyBytes, readAllErr := ioutil.ReadAll(*body) - if readAllErr != nil { - r.Logger.WithError(readAllErr).Warn("Could not read response body from remote server") - return 0, nil, errors.Wrap(readAllErr, "invalid response from remote server") + if parsedLength > int64(maxFileSizeBytes) { + return 0, nil, fmt.Errorf( + "remote file size (%d bytes) exceeds locally configured max media size (%d bytes)", + parsedLength, maxFileSizeBytes, + ) } - contentLength = int64(len(bodyBytes)) - reader = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) + // We successfully parsed the Content-Length, so we'll return a limited + // reader that restricts us to reading only up to this size. + reader = ioutil.NopCloser(io.LimitReader(*body, parsedLength)) + contentLength = parsedLength + } else { + // Content-Length header is missing. If we have a maximum file size + // configured then we'll just make sure that the reader is limited to + // that size. We'll return a zero content length, but that's OK, since + // ultimately it will get rewritten later when the temp file is written + // to disk. + if maxFileSizeBytes > 0 { + reader = ioutil.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes))) + } + contentLength = 0 } return contentLength, reader, nil - } func (r *downloadRequest) fetchRemoteFile( @@ -724,7 +730,9 @@ func (r *downloadRequest) fetchRemoteFile( } defer resp.Body.Close() // nolint: errcheck - contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body) + // The reader returned here will be limited either by the Content-Length + // and/or the configured maximum media size. + contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes) if parseErr != nil { return "", false, parseErr } @@ -760,7 +768,7 @@ func (r *downloadRequest) fetchRemoteFile( // method of deduplicating files to save storage, as well as a way to conduct // integrity checks on the file data in the repository. // Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK. - hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reader, maxFileSizeBytes, absBasePath) + hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reader, absBasePath) if err != nil { r.Logger.WithError(err).WithFields(log.Fields{ "MaxFileSizeBytes": maxFileSizeBytes, diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index 1dcf4e17b..2c5753745 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -147,7 +147,7 @@ func (r *uploadRequest) doUpload( // r.storeFileAndMetadata(ctx, tmpDir, ...) // before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of // nested function to guarantee either storage or cleanup. - hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, *cfg.MaxFileSizeBytes, cfg.AbsBasePath) + hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, cfg.AbsBasePath) if err != nil { r.Logger.WithError(err).WithFields(log.Fields{ "MaxFileSizeBytes": *cfg.MaxFileSizeBytes,