From a8b7130745657710b7fd9328de10ea76e06d4b90 Mon Sep 17 00:00:00 2001 From: Robert Swain Date: Wed, 17 May 2017 16:39:01 +0200 Subject: [PATCH] mediaapi/writers/download: Clean up copyToActiveAndPassive --- .../dendrite/mediaapi/writers/download.go | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index 732824078..9a60401af 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -281,9 +281,9 @@ func (r *downloadRequest) createRemoteRequest() (*http.Response, *util.JSONRespo // If there is an error with the reader or the active writer, that is considered an error // If there is an error with the passive writer, that is non-critical and copying continues // maxFileSize limits the amount of data written to the passive writer -func copyToActiveAndPassive(r io.Reader, wActive io.Writer, wPassive io.Writer, maxFileSize types.ContentLength, mediaMetadata *types.MediaMetadata, logger *log.Entry) (int64, int64, error) { +func copyToActiveAndPassive(r io.Reader, wActive io.Writer, wPassive io.Writer, maxFileSize types.ContentLength, mediaMetadata *types.MediaMetadata) (int64, int64, error) { var bytesResponded, bytesWritten int64 = 0, 0 - var fetchError error + var copyError error // Note: the buffer size is the same as is used in io.Copy() buffer := make([]byte, 32*1024) for { @@ -293,20 +293,19 @@ func copyToActiveAndPassive(r io.Reader, wActive io.Writer, wPassive io.Writer, // write to client request's response body bytesTemp, respErr := wActive.Write(buffer[:bytesRead]) if bytesTemp != bytesRead || (respErr != nil && respErr != io.EOF) { - logger.Errorf("bytesTemp %v != bytesRead %v : %v", bytesTemp, bytesRead, respErr) - fetchError = errResponse + copyError = errResponse break } bytesResponded += int64(bytesTemp) - if fetchError == nil || (fetchError != errFileIsTooLarge && fetchError != errWrite) { - // if larger than cfg.MaxFileSize then stop writing to disk and discard cached file + if copyError == nil || (copyError != errFileIsTooLarge && copyError != errWrite) { + // if larger than maxFileSize then stop writing to disk and discard cached file if bytesWritten+int64(len(buffer)) > int64(maxFileSize) { - fetchError = errFileIsTooLarge + copyError = errFileIsTooLarge } else { // write to disk bytesTemp, writeErr := wPassive.Write(buffer[:bytesRead]) if writeErr != nil && writeErr != io.EOF { - fetchError = errWrite + copyError = errWrite } else { bytesWritten += int64(bytesTemp) } @@ -315,24 +314,13 @@ func copyToActiveAndPassive(r io.Reader, wActive io.Writer, wPassive io.Writer, } if readErr != nil { if readErr != io.EOF { - fetchError = errRead + copyError = errRead } break } } - if fetchError != nil { - logFields := log.Fields{ - "MediaID": mediaMetadata.MediaID, - "Origin": mediaMetadata.Origin, - } - if fetchError == errFileIsTooLarge { - logFields["MaxFileSize"] = maxFileSize - } - logger.WithFields(logFields).Warnln(fetchError) - } - - return bytesResponded, bytesWritten, fetchError + return bytesResponded, bytesWritten, copyError } func (r *downloadRequest) closeConnection(w http.ResponseWriter) { @@ -481,9 +469,17 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, basePath // bytesResponded is the total number of bytes written to the response to the client request // bytesWritten is the total number of bytes written to disk - bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, maxFileSize, r.MediaMetadata, r.Logger) + bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, maxFileSize, r.MediaMetadata) tmpFileWriter.Flush() if fetchError != nil { + logFields := log.Fields{ + "MediaID": r.MediaMetadata.MediaID, + "Origin": r.MediaMetadata.Origin, + } + if fetchError == errFileIsTooLarge { + logFields["MaxFileSize"] = maxFileSize + } + r.Logger.WithFields(logFields).Warnln(fetchError) tmpDirErr := os.RemoveAll(string(tmpDir)) if tmpDirErr != nil { r.Logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)