Don't exhaust memory for large files, don't limit more than necessary

This commit is contained in:
Neil Alexander 2021-02-17 12:28:09 +00:00
parent 15e1f4afc9
commit b12f4c2726
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 29 additions and 28 deletions

View file

@ -109,7 +109,7 @@ func RemoveDir(dir types.Path, logger *log.Entry) {
// WriteTempFile writes to a new temporary file. // WriteTempFile writes to a new temporary file.
// The file is deleted if there was an error while writing. // The file is deleted if there was an error while writing.
func WriteTempFile( func WriteTempFile(
ctx context.Context, reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path, ctx context.Context, reqReader io.Reader, absBasePath config.Path,
) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) { ) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) {
size = -1 size = -1
logger := util.GetLogger(ctx) logger := util.GetLogger(ctx)
@ -124,18 +124,11 @@ func WriteTempFile(
} }
}() }()
// If the max_file_size_bytes configuration option is set to a positive
// number then limit the upload to that size. Otherwise, just read the
// whole file.
limitedReader := reqReader
if maxFileSizeBytes > 0 {
limitedReader = io.LimitReader(reqReader, int64(maxFileSizeBytes))
}
// Hash the file data. The hash will be returned. The hash is useful as a // Hash the file data. The hash will be returned. The hash is useful as a
// method of deduplicating files to save storage, as well as a way to conduct // method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository. // integrity checks on the file data in the repository.
hasher := sha256.New() hasher := sha256.New()
teeReader := io.TeeReader(limitedReader, hasher) teeReader := io.TeeReader(reqReader, hasher)
bytesWritten, err := io.Copy(tmpFileWriter, teeReader) bytesWritten, err := io.Copy(tmpFileWriter, teeReader)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
RemoveDir(tmpDir, logger) RemoveDir(tmpDir, logger)

View file

@ -15,7 +15,6 @@
package routing package routing
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -675,34 +674,41 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
return nil return nil
} }
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser) (int64, io.Reader, error) { func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
var reader io.Reader reader := *body
var contentLength int64 var contentLength int64
if contentLengthHeader != "" { if contentLengthHeader != "" {
// A Content-Length header is provided. Let's try to parse it.
parsedLength, parseErr := strconv.ParseInt(contentLengthHeader, 10, 64) parsedLength, parseErr := strconv.ParseInt(contentLengthHeader, 10, 64)
if parseErr != nil { if parseErr != nil {
r.Logger.WithError(parseErr).Warn("Failed to parse content length") r.Logger.WithError(parseErr).Warn("Failed to parse content length")
return 0, nil, errors.Wrap(parseErr, "invalid response from remote server") return 0, nil, errors.Wrap(parseErr, "invalid response from remote server")
} }
if parsedLength > int64(maxFileSizeBytes) {
reader = *body return 0, nil, fmt.Errorf(
contentLength = parsedLength "remote file size (%d bytes) exceeds locally configured max media size (%d bytes)",
} else { parsedLength, maxFileSizeBytes,
// Content-Length header is missing, we need to read the whole body to get its length )
bodyBytes, readAllErr := ioutil.ReadAll(*body)
if readAllErr != nil {
r.Logger.WithError(readAllErr).Warn("Could not read response body from remote server")
return 0, nil, errors.Wrap(readAllErr, "invalid response from remote server")
} }
contentLength = int64(len(bodyBytes)) // We successfully parsed the Content-Length, so we'll return a limited
reader = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) // reader that restricts us to reading only up to this size.
reader = ioutil.NopCloser(io.LimitReader(*body, parsedLength))
contentLength = parsedLength
} else {
// Content-Length header is missing. If we have a maximum file size
// configured then we'll just make sure that the reader is limited to
// that size. We'll return a zero content length, but that's OK, since
// ultimately it will get rewritten later when the temp file is written
// to disk.
if maxFileSizeBytes > 0 {
reader = ioutil.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
}
contentLength = 0
} }
return contentLength, reader, nil return contentLength, reader, nil
} }
func (r *downloadRequest) fetchRemoteFile( func (r *downloadRequest) fetchRemoteFile(
@ -724,7 +730,9 @@ func (r *downloadRequest) fetchRemoteFile(
} }
defer resp.Body.Close() // nolint: errcheck defer resp.Body.Close() // nolint: errcheck
contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body) // The reader returned here will be limited either by the Content-Length
// and/or the configured maximum media size.
contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes)
if parseErr != nil { if parseErr != nil {
return "", false, parseErr return "", false, parseErr
} }
@ -760,7 +768,7 @@ func (r *downloadRequest) fetchRemoteFile(
// method of deduplicating files to save storage, as well as a way to conduct // method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository. // integrity checks on the file data in the repository.
// Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK. // Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK.
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reader, maxFileSizeBytes, absBasePath) hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reader, absBasePath)
if err != nil { if err != nil {
r.Logger.WithError(err).WithFields(log.Fields{ r.Logger.WithError(err).WithFields(log.Fields{
"MaxFileSizeBytes": maxFileSizeBytes, "MaxFileSizeBytes": maxFileSizeBytes,

View file

@ -147,7 +147,7 @@ func (r *uploadRequest) doUpload(
// r.storeFileAndMetadata(ctx, tmpDir, ...) // r.storeFileAndMetadata(ctx, tmpDir, ...)
// before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of // before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of
// nested function to guarantee either storage or cleanup. // nested function to guarantee either storage or cleanup.
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, *cfg.MaxFileSizeBytes, cfg.AbsBasePath) hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, cfg.AbsBasePath)
if err != nil { if err != nil {
r.Logger.WithError(err).WithFields(log.Fields{ r.Logger.WithError(err).WithFields(log.Fields{
"MaxFileSizeBytes": *cfg.MaxFileSizeBytes, "MaxFileSizeBytes": *cfg.MaxFileSizeBytes,