mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 08:13:09 -06:00
Don't exhaust memory for large files, don't limit more than necessary
This commit is contained in:
parent
15e1f4afc9
commit
b12f4c2726
|
|
@ -109,7 +109,7 @@ func RemoveDir(dir types.Path, logger *log.Entry) {
|
|||
// WriteTempFile writes to a new temporary file.
|
||||
// The file is deleted if there was an error while writing.
|
||||
func WriteTempFile(
|
||||
ctx context.Context, reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path,
|
||||
ctx context.Context, reqReader io.Reader, absBasePath config.Path,
|
||||
) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) {
|
||||
size = -1
|
||||
logger := util.GetLogger(ctx)
|
||||
|
|
@ -124,18 +124,11 @@ func WriteTempFile(
|
|||
}
|
||||
}()
|
||||
|
||||
// If the max_file_size_bytes configuration option is set to a positive
|
||||
// number then limit the upload to that size. Otherwise, just read the
|
||||
// whole file.
|
||||
limitedReader := reqReader
|
||||
if maxFileSizeBytes > 0 {
|
||||
limitedReader = io.LimitReader(reqReader, int64(maxFileSizeBytes))
|
||||
}
|
||||
// Hash the file data. The hash will be returned. The hash is useful as a
|
||||
// method of deduplicating files to save storage, as well as a way to conduct
|
||||
// integrity checks on the file data in the repository.
|
||||
hasher := sha256.New()
|
||||
teeReader := io.TeeReader(limitedReader, hasher)
|
||||
teeReader := io.TeeReader(reqReader, hasher)
|
||||
bytesWritten, err := io.Copy(tmpFileWriter, teeReader)
|
||||
if err != nil && err != io.EOF {
|
||||
RemoveDir(tmpDir, logger)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
|
@ -675,34 +674,41 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser) (int64, io.Reader, error) {
|
||||
var reader io.Reader
|
||||
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
|
||||
reader := *body
|
||||
var contentLength int64
|
||||
|
||||
if contentLengthHeader != "" {
|
||||
// A Content-Length header is provided. Let's try to parse it.
|
||||
parsedLength, parseErr := strconv.ParseInt(contentLengthHeader, 10, 64)
|
||||
|
||||
if parseErr != nil {
|
||||
r.Logger.WithError(parseErr).Warn("Failed to parse content length")
|
||||
return 0, nil, errors.Wrap(parseErr, "invalid response from remote server")
|
||||
}
|
||||
|
||||
reader = *body
|
||||
contentLength = parsedLength
|
||||
} else {
|
||||
// Content-Length header is missing, we need to read the whole body to get its length
|
||||
bodyBytes, readAllErr := ioutil.ReadAll(*body)
|
||||
if readAllErr != nil {
|
||||
r.Logger.WithError(readAllErr).Warn("Could not read response body from remote server")
|
||||
return 0, nil, errors.Wrap(readAllErr, "invalid response from remote server")
|
||||
if parsedLength > int64(maxFileSizeBytes) {
|
||||
return 0, nil, fmt.Errorf(
|
||||
"remote file size (%d bytes) exceeds locally configured max media size (%d bytes)",
|
||||
parsedLength, maxFileSizeBytes,
|
||||
)
|
||||
}
|
||||
|
||||
contentLength = int64(len(bodyBytes))
|
||||
reader = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
// We successfully parsed the Content-Length, so we'll return a limited
|
||||
// reader that restricts us to reading only up to this size.
|
||||
reader = ioutil.NopCloser(io.LimitReader(*body, parsedLength))
|
||||
contentLength = parsedLength
|
||||
} else {
|
||||
// Content-Length header is missing. If we have a maximum file size
|
||||
// configured then we'll just make sure that the reader is limited to
|
||||
// that size. We'll return a zero content length, but that's OK, since
|
||||
// ultimately it will get rewritten later when the temp file is written
|
||||
// to disk.
|
||||
if maxFileSizeBytes > 0 {
|
||||
reader = ioutil.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
|
||||
}
|
||||
contentLength = 0
|
||||
}
|
||||
|
||||
return contentLength, reader, nil
|
||||
|
||||
}
|
||||
|
||||
func (r *downloadRequest) fetchRemoteFile(
|
||||
|
|
@ -724,7 +730,9 @@ func (r *downloadRequest) fetchRemoteFile(
|
|||
}
|
||||
defer resp.Body.Close() // nolint: errcheck
|
||||
|
||||
contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body)
|
||||
// The reader returned here will be limited either by the Content-Length
|
||||
// and/or the configured maximum media size.
|
||||
contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes)
|
||||
if parseErr != nil {
|
||||
return "", false, parseErr
|
||||
}
|
||||
|
|
@ -760,7 +768,7 @@ func (r *downloadRequest) fetchRemoteFile(
|
|||
// method of deduplicating files to save storage, as well as a way to conduct
|
||||
// integrity checks on the file data in the repository.
|
||||
// Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK.
|
||||
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reader, maxFileSizeBytes, absBasePath)
|
||||
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reader, absBasePath)
|
||||
if err != nil {
|
||||
r.Logger.WithError(err).WithFields(log.Fields{
|
||||
"MaxFileSizeBytes": maxFileSizeBytes,
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ func (r *uploadRequest) doUpload(
|
|||
// r.storeFileAndMetadata(ctx, tmpDir, ...)
|
||||
// before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of
|
||||
// nested function to guarantee either storage or cleanup.
|
||||
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, *cfg.MaxFileSizeBytes, cfg.AbsBasePath)
|
||||
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, cfg.AbsBasePath)
|
||||
if err != nil {
|
||||
r.Logger.WithError(err).WithFields(log.Fields{
|
||||
"MaxFileSizeBytes": *cfg.MaxFileSizeBytes,
|
||||
|
|
|
|||
Loading…
Reference in a new issue