Don't exhaust memory for large files, don't limit more than necessary

This commit is contained in:
Neil Alexander 2021-02-17 12:28:09 +00:00
parent 15e1f4afc9
commit b12f4c2726
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 29 additions and 28 deletions

View file

@ -109,7 +109,7 @@ func RemoveDir(dir types.Path, logger *log.Entry) {
// WriteTempFile writes to a new temporary file.
// The file is deleted if there was an error while writing.
func WriteTempFile(
ctx context.Context, reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path,
ctx context.Context, reqReader io.Reader, absBasePath config.Path,
) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) {
size = -1
logger := util.GetLogger(ctx)
@ -124,18 +124,11 @@ func WriteTempFile(
}
}()
// If the max_file_size_bytes configuration option is set to a positive
// number then limit the upload to that size. Otherwise, just read the
// whole file.
limitedReader := reqReader
if maxFileSizeBytes > 0 {
limitedReader = io.LimitReader(reqReader, int64(maxFileSizeBytes))
}
// Hash the file data. The hash will be returned. The hash is useful as a
// method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository.
hasher := sha256.New()
teeReader := io.TeeReader(limitedReader, hasher)
teeReader := io.TeeReader(reqReader, hasher)
bytesWritten, err := io.Copy(tmpFileWriter, teeReader)
if err != nil && err != io.EOF {
RemoveDir(tmpDir, logger)

View file

@ -15,7 +15,6 @@
package routing
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -675,34 +674,41 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
return nil
}
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser) (int64, io.Reader, error) {
var reader io.Reader
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
reader := *body
var contentLength int64
if contentLengthHeader != "" {
// A Content-Length header is provided. Let's try to parse it.
parsedLength, parseErr := strconv.ParseInt(contentLengthHeader, 10, 64)
if parseErr != nil {
r.Logger.WithError(parseErr).Warn("Failed to parse content length")
return 0, nil, errors.Wrap(parseErr, "invalid response from remote server")
}
reader = *body
contentLength = parsedLength
} else {
// Content-Length header is missing, we need to read the whole body to get its length
bodyBytes, readAllErr := ioutil.ReadAll(*body)
if readAllErr != nil {
r.Logger.WithError(readAllErr).Warn("Could not read response body from remote server")
return 0, nil, errors.Wrap(readAllErr, "invalid response from remote server")
if parsedLength > int64(maxFileSizeBytes) {
return 0, nil, fmt.Errorf(
"remote file size (%d bytes) exceeds locally configured max media size (%d bytes)",
parsedLength, maxFileSizeBytes,
)
}
contentLength = int64(len(bodyBytes))
reader = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
// We successfully parsed the Content-Length, so we'll return a limited
// reader that restricts us to reading only up to this size.
reader = ioutil.NopCloser(io.LimitReader(*body, parsedLength))
contentLength = parsedLength
} else {
// Content-Length header is missing. If we have a maximum file size
// configured then we'll just make sure that the reader is limited to
// that size. We'll return a zero content length, but that's OK, since
// ultimately it will get rewritten later when the temp file is written
// to disk.
if maxFileSizeBytes > 0 {
reader = ioutil.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
}
contentLength = 0
}
return contentLength, reader, nil
}
func (r *downloadRequest) fetchRemoteFile(
@ -724,7 +730,9 @@ func (r *downloadRequest) fetchRemoteFile(
}
defer resp.Body.Close() // nolint: errcheck
contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body)
// The reader returned here will be limited either by the Content-Length
// and/or the configured maximum media size.
contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes)
if parseErr != nil {
return "", false, parseErr
}
@ -760,7 +768,7 @@ func (r *downloadRequest) fetchRemoteFile(
// method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository.
// Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK.
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reader, maxFileSizeBytes, absBasePath)
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reader, absBasePath)
if err != nil {
r.Logger.WithError(err).WithFields(log.Fields{
"MaxFileSizeBytes": maxFileSizeBytes,

View file

@ -147,7 +147,7 @@ func (r *uploadRequest) doUpload(
// r.storeFileAndMetadata(ctx, tmpDir, ...)
// before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of
// nested function to guarantee either storage or cleanup.
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, *cfg.MaxFileSizeBytes, cfg.AbsBasePath)
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, cfg.AbsBasePath)
if err != nil {
r.Logger.WithError(err).WithFields(log.Fields{
"MaxFileSizeBytes": *cfg.MaxFileSizeBytes,