From 370cb74d2d330241186a8075c03dfea213373e56 Mon Sep 17 00:00:00 2001 From: Robert Swain Date: Mon, 22 May 2017 10:19:52 +0200 Subject: [PATCH] mediaapi/writers: Reuse same writer code for upload and download This now calculates a hash for downloads from remote servers as well as uploads to this server. --- .../dendrite/mediaapi/writers/download.go | 75 ++-------- .../dendrite/mediaapi/writers/fileutils.go | 129 ++++++++++++++++-- .../dendrite/mediaapi/writers/upload.go | 73 +++------- 3 files changed, 144 insertions(+), 133 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index 30d262a46..58c9b3adc 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -78,11 +78,6 @@ func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSON w.Write(resBytes) } -var errFileIsTooLarge = fmt.Errorf("file is too large") -var errRead = fmt.Errorf("failed to read response from remote server") -var errResponse = fmt.Errorf("failed to write file data to response body") -var errWrite = fmt.Errorf("failed to write file to disk") - var nTries = 5 // Download implements /download @@ -300,55 +295,6 @@ func (r *downloadRequest) createRemoteRequest() (*http.Response, *util.JSONRespo return resp, nil } -// copyToActiveAndPassive works like io.Copy except it copies from the reader to both of the writers -// If there is an error with the reader or the active writer, that is considered an error -// If there is an error with the passive writer, that is non-critical and copying continues -// maxFileSizeBytes limits the amount of data written to the passive writer -func copyToActiveAndPassive(r io.Reader, wActive io.Writer, wPassive io.Writer, maxFileSizeBytes types.ContentLength, mediaMetadata *types.MediaMetadata) (int64, int64, error) { - var bytesResponded, bytesWritten int64 = 0, 0 - var copyError error - // Note: the buffer size is the same as is used in io.Copy() - buffer := make([]byte, 32*1024) - for { - // read from remote request's response body - bytesRead, readErr := r.Read(buffer) - if bytesRead > 0 { - // write to client request's response body - bytesTemp, respErr := wActive.Write(buffer[:bytesRead]) - if bytesTemp != bytesRead || (respErr != nil && respErr != io.EOF) { - copyError = errResponse - break - } - bytesResponded += int64(bytesTemp) - if copyError == nil { - // Note: if we get here then copyError != errFileIsTooLarge && copyError != errWrite - // as if copyError == errResponse || copyError == errWrite then we would have broken - // out of the loop and there are no other cases - // if larger than maxFileSizeBytes then stop writing to disk and discard cached file - if bytesWritten+int64(len(buffer)) > int64(maxFileSizeBytes) { - copyError = errFileIsTooLarge - } else { - // write to disk - bytesTemp, writeErr := wPassive.Write(buffer[:bytesRead]) - if writeErr != nil && writeErr != io.EOF { - copyError = errWrite - } else { - bytesWritten += int64(bytesTemp) - } - } - } - } - if readErr != nil { - if readErr != io.EOF { - copyError = errRead - } - break - } - } - - return bytesResponded, bytesWritten, copyError -} - func (r *downloadRequest) closeConnection(w http.ResponseWriter) { r.Logger.WithFields(log.Fields{ "Origin": r.MediaMetadata.Origin, @@ -489,14 +435,6 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, absBasePa " object-src 'self';" w.Header().Set("Content-Security-Policy", contentSecurityPolicy) - // create the temporary file writer - tmpFileWriter, tmpFile, tmpDir, errorResponse := createTempFileWriter(absBasePath, r.Logger) - if errorResponse != nil { - r.jsonErrorResponse(w, *errorResponse) - return - } - defer tmpFile.Close() - // read the remote request's response body // simultaneously write it to the incoming request's response body and the temporary file r.Logger.WithFields(log.Fields{ @@ -504,19 +442,22 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, absBasePa "Origin": r.MediaMetadata.Origin, }).Infof("Proxying and caching remote file") + // The file data is hashed but is NOT used as the MediaID, unlike in Upload. The hash is useful as a + // method of deduplicating files to save storage, as well as a way to conduct + // integrity checks on the file data in the repository. // bytesResponded is the total number of bytes written to the response to the client request // bytesWritten is the total number of bytes written to disk - bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, maxFileSizeBytes, r.MediaMetadata) - tmpFileWriter.Flush() - if fetchError != nil { + hash, bytesResponded, bytesWritten, tmpDir, copyError := readAndHashAndWriteWithLimit(resp.Body, maxFileSizeBytes, absBasePath, w) + + if copyError != nil { logFields := log.Fields{ "MediaID": r.MediaMetadata.MediaID, "Origin": r.MediaMetadata.Origin, } - if fetchError == errFileIsTooLarge { + if copyError == errFileIsTooLarge { logFields["MaxFileSizeBytes"] = maxFileSizeBytes } - r.Logger.WithError(fetchError).WithFields(logFields).Warn("Error while fetching file") + r.Logger.WithError(copyError).WithFields(logFields).Warn("Error while transferring file") removeDir(tmpDir, r.Logger) // Note: if we have responded with any data in the body at all then we have already sent 200 OK and we can only abort at this point if bytesResponded < 1 { diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/fileutils.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/fileutils.go index d158a1029..a11176e97 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/fileutils.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/fileutils.go @@ -16,17 +16,20 @@ package writers import ( "bufio" + "crypto/sha256" + "encoding/base64" "fmt" + "hash" + "io" "io/ioutil" + "net/http" "os" "path" "path/filepath" "strings" log "github.com/Sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/util" ) func removeDir(dir types.Path, logger *log.Entry) { @@ -61,26 +64,126 @@ func createFileWriter(directory types.Path, filename types.Filename) (*bufio.Wri return bufio.NewWriter(file), file, nil } -func createTempFileWriter(absBasePath types.Path, logger *log.Entry) (*bufio.Writer, *os.File, types.Path, *util.JSONResponse) { +func createTempFileWriter(absBasePath types.Path) (*bufio.Writer, *os.File, types.Path, error) { tmpDir, err := createTempDir(absBasePath) if err != nil { - logger.WithError(err).WithField("dir", tmpDir).Warn("Failed to create temp dir") - return nil, nil, "", &util.JSONResponse{ - Code: 400, - JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")), - } + return nil, nil, "", fmt.Errorf("Failed to create temp dir: %q", err) } writer, tmpFile, err := createFileWriter(tmpDir, "content") if err != nil { - logger.WithError(err).Warn("Failed to create file writer") - return nil, nil, "", &util.JSONResponse{ - Code: 400, - JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")), - } + return nil, nil, "", fmt.Errorf("Failed to create file writer: %q", err) } return writer, tmpFile, tmpDir, nil } +var errFileIsTooLarge = fmt.Errorf("file is too large") +var errRead = fmt.Errorf("failed to read response from remote server") +var errResponse = fmt.Errorf("failed to write file data to response body") +var errHash = fmt.Errorf("failed to hash file data") +var errWrite = fmt.Errorf("failed to write file to disk") + +// writeToResponse takes bytesToWrite bytes from buffer and writes them to respWriter +// Returns bytes written and an error. In case of error, or if there is no respWriter, +// the number of bytes written will be 0. +func writeToResponse(respWriter http.ResponseWriter, buffer []byte, bytesToWrite int) (int64, error) { + if respWriter != nil { + bytesWritten, respErr := respWriter.Write(buffer[:bytesToWrite]) + if bytesWritten != bytesToWrite || (respErr != nil && respErr != io.EOF) { + return 0, errResponse + } + return int64(bytesWritten), nil + } + return 0, nil +} + +// writeToDiskAndHasher takes bytesToWrite bytes from buffer and writes them to tmpFileWriter and hasher. +// Returns bytes written and an error. In case of error, including if writing would exceed maxFileSizeBytes, +// the number of bytes written will be 0. +func writeToDiskAndHasher(tmpFileWriter *bufio.Writer, hasher hash.Hash, bytesWritten int64, maxFileSizeBytes types.ContentLength, buffer []byte, bytesToWrite int) (int64, error) { + // if larger than maxFileSizeBytes then stop writing to disk and discard cached file + if bytesWritten+int64(bytesToWrite) > int64(maxFileSizeBytes) { + return 0, errFileIsTooLarge + } + // write to hasher and to disk + bytesTemp, writeErr := tmpFileWriter.Write(buffer[:bytesToWrite]) + bytesHashed, hashErr := hasher.Write(buffer[:bytesToWrite]) + if writeErr != nil && writeErr != io.EOF || bytesTemp != bytesToWrite || bytesTemp != bytesHashed { + return 0, errWrite + } else if hashErr != nil && hashErr != io.EOF { + return 0, errHash + } + return int64(bytesTemp), nil +} + +// readAndHashAndWriteWithLimit works like io.Copy except it copies from the reqReader to the +// optionally-supplied respWriter and a temporary file named 'content' using a bufio.Writer. +// The data written to disk is hashed using the SHA-256 algorithm. +// If there is an error with the reqReader or the respWriter, that is considered an error. +// If there is an error with the hasher or tmpFileWriter, that is non-critical and copying +// to the respWriter continues. +// maxFileSizeBytes limits the amount of data written to disk and the hasher. +// If a respWriter is provided, all the data will be proxied from the reqReader to +// the respWriter, regardless of errors or limits on writing to disk. +// Returns all of the hash sum, bytes written to disk, and temporary directory path, or an error. +func readAndHashAndWriteWithLimit(reqReader io.Reader, maxFileSizeBytes types.ContentLength, absBasePath types.Path, respWriter http.ResponseWriter) (types.Base64Hash, types.ContentLength, types.ContentLength, types.Path, error) { + // create the temporary file writer + tmpFileWriter, tmpFile, tmpDir, err := createTempFileWriter(absBasePath) + if err != nil { + return "", -1, -1, "", err + } + defer tmpFile.Close() + + // The file data is hashed and the hash is returned. The hash is useful as a + // method of deduplicating files to save storage, as well as a way to conduct + // integrity checks on the file data in the repository. The hash gets used as + // the MediaID. + hasher := sha256.New() + + // bytesResponded is the total number of bytes written to the response to the client request + // bytesWritten is the total number of bytes written to disk + var bytesResponded, bytesWritten int64 = 0, 0 + var bytesTemp int64 + var copyError error + // Note: the buffer size is the same as is used in io.Copy() + buffer := make([]byte, 32*1024) + for { + // read from remote request's response body + bytesRead, readErr := reqReader.Read(buffer) + if bytesRead > 0 { + // Note: This code allows proxying files larger than maxFileSizeBytes! + // write to client request's response body + bytesTemp, copyError = writeToResponse(respWriter, buffer, bytesRead) + bytesResponded += bytesTemp + if copyError == nil { + // Note: if we get here then copyError != errFileIsTooLarge && copyError != errWrite + // as if copyError == errResponse || copyError == errWrite then we would have broken + // out of the loop and there are no other cases + bytesTemp, copyError = writeToDiskAndHasher(tmpFileWriter, hasher, bytesWritten, maxFileSizeBytes, buffer, (bytesRead)) + bytesWritten += bytesTemp + // If we do not have a respWriter then we are only writing to the hasher and tmpFileWriter. In that case, if we get an error, we need to break. + if respWriter == nil && copyError != nil { + break + } + } + } + if readErr != nil { + if readErr != io.EOF { + copyError = errRead + } + break + } + } + + if copyError != nil { + return "", -1, -1, "", copyError + } + + tmpFileWriter.Flush() + + hash := hasher.Sum(nil) + return types.Base64Hash(base64.URLEncoding.EncodeToString(hash[:])), types.ContentLength(bytesResponded), types.ContentLength(bytesWritten), tmpDir, nil +} + // getPathFromMediaMetadata validates and constructs the on-disk path to the media // based on its origin and mediaID // If a mediaID is too short, which could happen for other homeserver implementations, diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go index 132af87fd..dc886353f 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go @@ -15,14 +15,10 @@ package writers import ( - "crypto/sha256" "database/sql" - "encoding/base64" "fmt" - "io" "net/http" "net/url" - "os" "path" "strings" @@ -136,51 +132,6 @@ func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI) (*uploadRe return r, nil } -// writeFileWithLimitAndHash reads data from an io.Reader and writes it to a temporary -// file named 'content' in the returned temporary directory. It only reads up to a limit of -// cfg.MaxFileSizeBytes from the io.Reader. The data written is hashed and the hashsum is -// returned. If any errors occur, a util.JSONResponse error is returned. -func writeFileWithLimitAndHash(r io.Reader, cfg *config.MediaAPI, logger *log.Entry, contentLength types.ContentLength) ([]byte, types.Path, *util.JSONResponse) { - writer, file, tmpDir, errorResponse := createTempFileWriter(cfg.AbsBasePath, logger) - if errorResponse != nil { - return nil, "", errorResponse - } - defer file.Close() - - // The limited reader restricts how many bytes are read from the body to the specified maximum bytes - // Note: the golang HTTP server closes the request body - limitedBody := io.LimitReader(r, int64(cfg.MaxFileSizeBytes)) - // The file data is hashed and the hash is returned. The hash is useful as a - // method of deduplicating files to save storage, as well as a way to conduct - // integrity checks on the file data in the repository. The hash gets used as - // the MediaID. - hasher := sha256.New() - // A TeeReader is used to allow us to read from the limitedBody and simultaneously - // write to the hasher here and to the http.ResponseWriter via the io.Copy call below. - reader := io.TeeReader(limitedBody, hasher) - - bytesWritten, err := io.Copy(writer, reader) - if err != nil { - logger.WithError(err).Warn("Failed to copy") - removeDir(tmpDir, logger) - return nil, "", &util.JSONResponse{ - Code: 400, - JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")), - } - } - - writer.Flush() - - if bytesWritten != int64(contentLength) { - logger.WithFields(log.Fields{ - "bytesWritten": bytesWritten, - "contentLength": contentLength, - }).Warn("Fewer bytes written than expected") - } - - return hasher.Sum(nil), tmpDir, nil -} - // storeFileAndMetadata first moves a temporary file named content from tmpDir to its // final path (see getPathFromMediaMetadata for details.) Once the file is moved, the // metadata about the file is written into the media repository database. This order @@ -249,11 +200,27 @@ func Upload(req *http.Request, cfg *config.MediaAPI, db *storage.Database) util. // The file data is hashed and the hash is used as the MediaID. The hash is useful as a // method of deduplicating files to save storage, as well as a way to conduct // integrity checks on the file data in the repository. - hash, tmpDir, resErr := writeFileWithLimitAndHash(req.Body, cfg, logger, r.MediaMetadata.ContentLength) - if resErr != nil { - return *resErr + // bytesWritten is the total number of bytes written to disk + hash, _, bytesWritten, tmpDir, copyError := readAndHashAndWriteWithLimit(req.Body, cfg.MaxFileSizeBytes, cfg.AbsBasePath, nil) + + if copyError != nil { + logFields := log.Fields{ + "Origin": r.MediaMetadata.Origin, + "MediaID": r.MediaMetadata.MediaID, + } + if copyError == errFileIsTooLarge { + logFields["MaxFileSizeBytes"] = cfg.MaxFileSizeBytes + } + logger.WithError(copyError).WithFields(logFields).Warn("Error while transferring file") + removeDir(tmpDir, logger) + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")), + } } - r.MediaMetadata.MediaID = types.MediaID(base64.URLEncoding.EncodeToString(hash[:])) + + r.MediaMetadata.ContentLength = bytesWritten + r.MediaMetadata.MediaID = types.MediaID(hash) logger.WithFields(log.Fields{ "MediaID": r.MediaMetadata.MediaID,