mediaapi/writers: Reuse same writer code for upload and download

This now calculates a hash for downloads from remote servers as well as
uploads to this server.
This commit is contained in:
Robert Swain 2017-05-22 10:19:52 +02:00
parent 9af66a1963
commit 370cb74d2d
3 changed files with 144 additions and 133 deletions

View file

@ -78,11 +78,6 @@ func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSON
w.Write(resBytes)
}
var errFileIsTooLarge = fmt.Errorf("file is too large")
var errRead = fmt.Errorf("failed to read response from remote server")
var errResponse = fmt.Errorf("failed to write file data to response body")
var errWrite = fmt.Errorf("failed to write file to disk")
var nTries = 5
// Download implements /download
@ -300,55 +295,6 @@ func (r *downloadRequest) createRemoteRequest() (*http.Response, *util.JSONRespo
return resp, nil
}
// copyToActiveAndPassive works like io.Copy except it copies from the reader to both of the writers
// If there is an error with the reader or the active writer, that is considered an error
// If there is an error with the passive writer, that is non-critical and copying continues
// maxFileSizeBytes limits the amount of data written to the passive writer
func copyToActiveAndPassive(r io.Reader, wActive io.Writer, wPassive io.Writer, maxFileSizeBytes types.ContentLength, mediaMetadata *types.MediaMetadata) (int64, int64, error) {
var bytesResponded, bytesWritten int64 = 0, 0
var copyError error
// Note: the buffer size is the same as is used in io.Copy()
buffer := make([]byte, 32*1024)
for {
// read from remote request's response body
bytesRead, readErr := r.Read(buffer)
if bytesRead > 0 {
// write to client request's response body
bytesTemp, respErr := wActive.Write(buffer[:bytesRead])
if bytesTemp != bytesRead || (respErr != nil && respErr != io.EOF) {
copyError = errResponse
break
}
bytesResponded += int64(bytesTemp)
if copyError == nil {
// Note: if we get here then copyError != errFileIsTooLarge && copyError != errWrite
// as if copyError == errResponse || copyError == errWrite then we would have broken
// out of the loop and there are no other cases
// if larger than maxFileSizeBytes then stop writing to disk and discard cached file
if bytesWritten+int64(len(buffer)) > int64(maxFileSizeBytes) {
copyError = errFileIsTooLarge
} else {
// write to disk
bytesTemp, writeErr := wPassive.Write(buffer[:bytesRead])
if writeErr != nil && writeErr != io.EOF {
copyError = errWrite
} else {
bytesWritten += int64(bytesTemp)
}
}
}
}
if readErr != nil {
if readErr != io.EOF {
copyError = errRead
}
break
}
}
return bytesResponded, bytesWritten, copyError
}
func (r *downloadRequest) closeConnection(w http.ResponseWriter) {
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
@ -489,14 +435,6 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, absBasePa
" object-src 'self';"
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
// create the temporary file writer
tmpFileWriter, tmpFile, tmpDir, errorResponse := createTempFileWriter(absBasePath, r.Logger)
if errorResponse != nil {
r.jsonErrorResponse(w, *errorResponse)
return
}
defer tmpFile.Close()
// read the remote request's response body
// simultaneously write it to the incoming request's response body and the temporary file
r.Logger.WithFields(log.Fields{
@ -504,19 +442,22 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, absBasePa
"Origin": r.MediaMetadata.Origin,
}).Infof("Proxying and caching remote file")
// The file data is hashed but is NOT used as the MediaID, unlike in Upload. The hash is useful as a
// method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository.
// bytesResponded is the total number of bytes written to the response to the client request
// bytesWritten is the total number of bytes written to disk
bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, maxFileSizeBytes, r.MediaMetadata)
tmpFileWriter.Flush()
if fetchError != nil {
hash, bytesResponded, bytesWritten, tmpDir, copyError := readAndHashAndWriteWithLimit(resp.Body, maxFileSizeBytes, absBasePath, w)
if copyError != nil {
logFields := log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}
if fetchError == errFileIsTooLarge {
if copyError == errFileIsTooLarge {
logFields["MaxFileSizeBytes"] = maxFileSizeBytes
}
r.Logger.WithError(fetchError).WithFields(logFields).Warn("Error while fetching file")
r.Logger.WithError(copyError).WithFields(logFields).Warn("Error while transferring file")
removeDir(tmpDir, r.Logger)
// Note: if we have responded with any data in the body at all then we have already sent 200 OK and we can only abort at this point
if bytesResponded < 1 {

View file

@ -16,17 +16,20 @@ package writers
import (
"bufio"
"crypto/sha256"
"encoding/base64"
"fmt"
"hash"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
"strings"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/util"
)
func removeDir(dir types.Path, logger *log.Entry) {
@ -61,26 +64,126 @@ func createFileWriter(directory types.Path, filename types.Filename) (*bufio.Wri
return bufio.NewWriter(file), file, nil
}
func createTempFileWriter(absBasePath types.Path, logger *log.Entry) (*bufio.Writer, *os.File, types.Path, *util.JSONResponse) {
func createTempFileWriter(absBasePath types.Path) (*bufio.Writer, *os.File, types.Path, error) {
tmpDir, err := createTempDir(absBasePath)
if err != nil {
logger.WithError(err).WithField("dir", tmpDir).Warn("Failed to create temp dir")
return nil, nil, "", &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")),
}
return nil, nil, "", fmt.Errorf("Failed to create temp dir: %q", err)
}
writer, tmpFile, err := createFileWriter(tmpDir, "content")
if err != nil {
logger.WithError(err).Warn("Failed to create file writer")
return nil, nil, "", &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")),
}
return nil, nil, "", fmt.Errorf("Failed to create file writer: %q", err)
}
return writer, tmpFile, tmpDir, nil
}
var errFileIsTooLarge = fmt.Errorf("file is too large")
var errRead = fmt.Errorf("failed to read response from remote server")
var errResponse = fmt.Errorf("failed to write file data to response body")
var errHash = fmt.Errorf("failed to hash file data")
var errWrite = fmt.Errorf("failed to write file to disk")
// writeToResponse takes bytesToWrite bytes from buffer and writes them to respWriter
// Returns bytes written and an error. In case of error, or if there is no respWriter,
// the number of bytes written will be 0.
func writeToResponse(respWriter http.ResponseWriter, buffer []byte, bytesToWrite int) (int64, error) {
if respWriter != nil {
bytesWritten, respErr := respWriter.Write(buffer[:bytesToWrite])
if bytesWritten != bytesToWrite || (respErr != nil && respErr != io.EOF) {
return 0, errResponse
}
return int64(bytesWritten), nil
}
return 0, nil
}
// writeToDiskAndHasher takes bytesToWrite bytes from buffer and writes them to tmpFileWriter and hasher.
// Returns bytes written and an error. In case of error, including if writing would exceed maxFileSizeBytes,
// the number of bytes written will be 0.
func writeToDiskAndHasher(tmpFileWriter *bufio.Writer, hasher hash.Hash, bytesWritten int64, maxFileSizeBytes types.ContentLength, buffer []byte, bytesToWrite int) (int64, error) {
// if larger than maxFileSizeBytes then stop writing to disk and discard cached file
if bytesWritten+int64(bytesToWrite) > int64(maxFileSizeBytes) {
return 0, errFileIsTooLarge
}
// write to hasher and to disk
bytesTemp, writeErr := tmpFileWriter.Write(buffer[:bytesToWrite])
bytesHashed, hashErr := hasher.Write(buffer[:bytesToWrite])
if writeErr != nil && writeErr != io.EOF || bytesTemp != bytesToWrite || bytesTemp != bytesHashed {
return 0, errWrite
} else if hashErr != nil && hashErr != io.EOF {
return 0, errHash
}
return int64(bytesTemp), nil
}
// readAndHashAndWriteWithLimit works like io.Copy except it copies from the reqReader to the
// optionally-supplied respWriter and a temporary file named 'content' using a bufio.Writer.
// The data written to disk is hashed using the SHA-256 algorithm.
// If there is an error with the reqReader or the respWriter, that is considered an error.
// If there is an error with the hasher or tmpFileWriter, that is non-critical and copying
// to the respWriter continues.
// maxFileSizeBytes limits the amount of data written to disk and the hasher.
// If a respWriter is provided, all the data will be proxied from the reqReader to
// the respWriter, regardless of errors or limits on writing to disk.
// Returns all of the hash sum, bytes written to disk, and temporary directory path, or an error.
func readAndHashAndWriteWithLimit(reqReader io.Reader, maxFileSizeBytes types.ContentLength, absBasePath types.Path, respWriter http.ResponseWriter) (types.Base64Hash, types.ContentLength, types.ContentLength, types.Path, error) {
// create the temporary file writer
tmpFileWriter, tmpFile, tmpDir, err := createTempFileWriter(absBasePath)
if err != nil {
return "", -1, -1, "", err
}
defer tmpFile.Close()
// The file data is hashed and the hash is returned. The hash is useful as a
// method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository. The hash gets used as
// the MediaID.
hasher := sha256.New()
// bytesResponded is the total number of bytes written to the response to the client request
// bytesWritten is the total number of bytes written to disk
var bytesResponded, bytesWritten int64 = 0, 0
var bytesTemp int64
var copyError error
// Note: the buffer size is the same as is used in io.Copy()
buffer := make([]byte, 32*1024)
for {
// read from remote request's response body
bytesRead, readErr := reqReader.Read(buffer)
if bytesRead > 0 {
// Note: This code allows proxying files larger than maxFileSizeBytes!
// write to client request's response body
bytesTemp, copyError = writeToResponse(respWriter, buffer, bytesRead)
bytesResponded += bytesTemp
if copyError == nil {
// Note: if we get here then copyError != errFileIsTooLarge && copyError != errWrite
// as if copyError == errResponse || copyError == errWrite then we would have broken
// out of the loop and there are no other cases
bytesTemp, copyError = writeToDiskAndHasher(tmpFileWriter, hasher, bytesWritten, maxFileSizeBytes, buffer, (bytesRead))
bytesWritten += bytesTemp
// If we do not have a respWriter then we are only writing to the hasher and tmpFileWriter. In that case, if we get an error, we need to break.
if respWriter == nil && copyError != nil {
break
}
}
}
if readErr != nil {
if readErr != io.EOF {
copyError = errRead
}
break
}
}
if copyError != nil {
return "", -1, -1, "", copyError
}
tmpFileWriter.Flush()
hash := hasher.Sum(nil)
return types.Base64Hash(base64.URLEncoding.EncodeToString(hash[:])), types.ContentLength(bytesResponded), types.ContentLength(bytesWritten), tmpDir, nil
}
// getPathFromMediaMetadata validates and constructs the on-disk path to the media
// based on its origin and mediaID
// If a mediaID is too short, which could happen for other homeserver implementations,

View file

@ -15,14 +15,10 @@
package writers
import (
"crypto/sha256"
"database/sql"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
@ -136,51 +132,6 @@ func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI) (*uploadRe
return r, nil
}
// writeFileWithLimitAndHash reads data from an io.Reader and writes it to a temporary
// file named 'content' in the returned temporary directory. It only reads up to a limit of
// cfg.MaxFileSizeBytes from the io.Reader. The data written is hashed and the hashsum is
// returned. If any errors occur, a util.JSONResponse error is returned.
func writeFileWithLimitAndHash(r io.Reader, cfg *config.MediaAPI, logger *log.Entry, contentLength types.ContentLength) ([]byte, types.Path, *util.JSONResponse) {
writer, file, tmpDir, errorResponse := createTempFileWriter(cfg.AbsBasePath, logger)
if errorResponse != nil {
return nil, "", errorResponse
}
defer file.Close()
// The limited reader restricts how many bytes are read from the body to the specified maximum bytes
// Note: the golang HTTP server closes the request body
limitedBody := io.LimitReader(r, int64(cfg.MaxFileSizeBytes))
// The file data is hashed and the hash is returned. The hash is useful as a
// method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository. The hash gets used as
// the MediaID.
hasher := sha256.New()
// A TeeReader is used to allow us to read from the limitedBody and simultaneously
// write to the hasher here and to the http.ResponseWriter via the io.Copy call below.
reader := io.TeeReader(limitedBody, hasher)
bytesWritten, err := io.Copy(writer, reader)
if err != nil {
logger.WithError(err).Warn("Failed to copy")
removeDir(tmpDir, logger)
return nil, "", &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")),
}
}
writer.Flush()
if bytesWritten != int64(contentLength) {
logger.WithFields(log.Fields{
"bytesWritten": bytesWritten,
"contentLength": contentLength,
}).Warn("Fewer bytes written than expected")
}
return hasher.Sum(nil), tmpDir, nil
}
// storeFileAndMetadata first moves a temporary file named content from tmpDir to its
// final path (see getPathFromMediaMetadata for details.) Once the file is moved, the
// metadata about the file is written into the media repository database. This order
@ -249,11 +200,27 @@ func Upload(req *http.Request, cfg *config.MediaAPI, db *storage.Database) util.
// The file data is hashed and the hash is used as the MediaID. The hash is useful as a
// method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository.
hash, tmpDir, resErr := writeFileWithLimitAndHash(req.Body, cfg, logger, r.MediaMetadata.ContentLength)
if resErr != nil {
return *resErr
// bytesWritten is the total number of bytes written to disk
hash, _, bytesWritten, tmpDir, copyError := readAndHashAndWriteWithLimit(req.Body, cfg.MaxFileSizeBytes, cfg.AbsBasePath, nil)
if copyError != nil {
logFields := log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}
r.MediaMetadata.MediaID = types.MediaID(base64.URLEncoding.EncodeToString(hash[:]))
if copyError == errFileIsTooLarge {
logFields["MaxFileSizeBytes"] = cfg.MaxFileSizeBytes
}
logger.WithError(copyError).WithFields(logFields).Warn("Error while transferring file")
removeDir(tmpDir, logger)
return util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")),
}
}
r.MediaMetadata.ContentLength = bytesWritten
r.MediaMetadata.MediaID = types.MediaID(hash)
logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,