diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index f90a876d1..dd2ed966a 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -15,6 +15,7 @@ package writers import ( + "bufio" "database/sql" "encoding/json" "fmt" @@ -277,6 +278,159 @@ func createRemoteRequest(mediaMetadata *types.MediaMetadata, logger *log.Entry) return resp, nil } +// FIXME: move to utils and use in upload as well +func createTempFileWriter(basePath types.Path, logger *log.Entry) (*bufio.Writer, *os.File, types.Path, *util.JSONResponse) { + tmpDir, err := createTempDir(basePath) + if err != nil { + logger.Infof("Failed to create temp dir %q\n", err) + return nil, nil, "", &util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), + } + } + tmpFile, writer, err := createFileWriter(tmpDir, "content") + if err != nil { + logger.Infof("Failed to create file writer %q\n", err) + return nil, nil, "", &util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), + } + } + return writer, tmpFile, tmpDir, nil +} + +// copyToActiveAndPassive works like io.Copy except it copies from the reader to both of the writers +// If there is an error with the reader or the active writer, that is considered an error +// If there is an error with the passive writer, that is non-critical and copying continues +// maxFileSize limits the amount of data written to the passive writer +func copyToActiveAndPassive(r io.Reader, wActive io.Writer, wPassive io.Writer, maxFileSize types.ContentLength, mediaMetadata *types.MediaMetadata, logger *log.Entry) (int64, int64, error) { + var bytesResponded, bytesWritten int64 = 0, 0 + var fetchError error + // Note: the buffer size is the same as is used in io.Copy() + buffer := make([]byte, 32*1024) + for { + // read from remote request's response body + bytesRead, readErr := r.Read(buffer) + if bytesRead > 0 { + // write to client request's response body + bytesTemp, respErr := wActive.Write(buffer[:bytesRead]) + if bytesTemp != bytesRead || (respErr != nil && respErr != io.EOF) { + logger.Errorf("bytesTemp %v != bytesRead %v : %v", bytesTemp, bytesRead, respErr) + fetchError = errResponse + break + } + bytesResponded += int64(bytesTemp) + if fetchError == nil || (fetchError != errFileIsTooLarge && fetchError != errWrite) { + // if larger than cfg.MaxFileSize then stop writing to disk and discard cached file + if bytesWritten+int64(len(buffer)) > int64(maxFileSize) { + fetchError = errFileIsTooLarge + } else { + // write to disk + bytesTemp, writeErr := wPassive.Write(buffer[:bytesRead]) + if writeErr != nil && writeErr != io.EOF { + fetchError = errWrite + } else { + bytesWritten += int64(bytesTemp) + } + } + } + } + if readErr != nil { + if readErr != io.EOF { + fetchError = errRead + } + break + } + } + + if fetchError != nil { + logFields := log.Fields{ + "MediaID": mediaMetadata.MediaID, + "Origin": mediaMetadata.Origin, + } + if fetchError == errFileIsTooLarge { + logFields["MaxFileSize"] = maxFileSize + } + logger.WithFields(logFields).Warnln(fetchError) + } + + return bytesResponded, bytesWritten, fetchError +} + +func closeConnection(w http.ResponseWriter, logger *log.Entry) { + logger.Println("Attempting to close the connection.") + hijacker, ok := w.(http.Hijacker) + if ok { + connection, _, hijackErr := hijacker.Hijack() + if hijackErr == nil { + logger.Println("Closing") + connection.Close() + } else { + logger.Printf("Error trying to hijack: %v", hijackErr) + } + } +} + +func completeRemoteRequest(activeRemoteRequests *types.ActiveRemoteRequests, mxcURL string) { + if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok { + activeRemoteRequestCondition.Broadcast() + } + delete(activeRemoteRequests.Set, mxcURL) + activeRemoteRequests.Unlock() +} + +func commitFileAndMetadata(tmpDir types.Path, basePath types.Path, mediaMetadata *types.MediaMetadata, activeRemoteRequests *types.ActiveRemoteRequests, db *storage.Database, mxcURL string, logger *log.Entry) bool { + updateActiveRemoteRequests := true + + logger.WithFields(log.Fields{ + "MediaID": mediaMetadata.MediaID, + "Origin": mediaMetadata.Origin, + "UploadName": mediaMetadata.UploadName, + "Content-Length": mediaMetadata.ContentLength, + "Content-Type": mediaMetadata.ContentType, + "Content-Disposition": mediaMetadata.ContentDisposition, + }).Infof("Storing file metadata to media repository database") + + // The database is the source of truth so we need to have moved the file first + err := moveFile( + types.Path(path.Join(string(tmpDir), "content")), + types.Path(getPathFromMediaMetadata(mediaMetadata, basePath)), + ) + if err != nil { + tmpDirErr := os.RemoveAll(string(tmpDir)) + if tmpDirErr != nil { + logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + } + return updateActiveRemoteRequests + } + + // Writing the metadata to the media repository database and removing the mxcURL from activeRemoteRequests needs to be atomic. + // If it were not atomic, a new request for the same file could come in in routine A and check the database before the INSERT. + // Routine B which was fetching could then have its INSERT complete and remove the mxcURL from the activeRemoteRequests. + // If routine A then checked the activeRemoteRequests it would think it needed to fetch the file when it's already in the database. + // The locking below mitigates this situation. + updateActiveRemoteRequests = false + activeRemoteRequests.Lock() + // FIXME: unlock after timeout of db request + // if written to disk, add to db + err = db.StoreMediaMetadata(mediaMetadata) + if err != nil { + finalDir := path.Dir(getPathFromMediaMetadata(mediaMetadata, basePath)) + finalDirErr := os.RemoveAll(finalDir) + if finalDirErr != nil { + logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr) + } + completeRemoteRequest(activeRemoteRequests, mxcURL) + return updateActiveRemoteRequests + } + logger.WithFields(log.Fields{ + "Origin": mediaMetadata.Origin, + "MediaID": mediaMetadata.MediaID, + }).Infof("Signalling other goroutines waiting for us to fetch the file.") + completeRemoteRequest(activeRemoteRequests, mxcURL) + return updateActiveRemoteRequests +} + func respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, mediaMetadata *types.MediaMetadata, cfg config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) { logger.WithFields(log.Fields{ "MediaID": mediaMetadata.MediaID, @@ -293,14 +447,12 @@ func respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, mediaMetada defer func() { if updateActiveRemoteRequests { activeRemoteRequests.Lock() - if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok { - activeRemoteRequestCondition.Broadcast() - } - delete(activeRemoteRequests.Set, mxcURL) - activeRemoteRequests.Unlock() + // Note that completeRemoteRequest unlocks activeRemoteRequests + completeRemoteRequest(activeRemoteRequests, mxcURL) } }() + // create request for remote file resp, errorResponse := createRemoteRequest(mediaMetadata, logger) if errorResponse != nil { jsonErrorResponse(w, *errorResponse, logger) @@ -335,22 +487,9 @@ func respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, mediaMetada w.Header().Set("Content-Security-Policy", contentSecurityPolicy) // create the temporary file writer - tmpDir, err := createTempDir(cfg.BasePath) - if err != nil { - logger.Infof("Failed to create temp dir %q\n", err) - jsonErrorResponse(w, util.JSONResponse{ - Code: 400, - JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), - }, logger) - return - } - tmpFile, writer, err := createFileWriter(tmpDir, "content") - if err != nil { - logger.Infof("Failed to create file writer %q\n", err) - jsonErrorResponse(w, util.JSONResponse{ - Code: 400, - JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), - }, logger) + tmpFileWriter, tmpFile, tmpDir, errorResponse := createTempFileWriter(cfg.BasePath, logger) + if errorResponse != nil { + jsonErrorResponse(w, *errorResponse, logger) return } defer tmpFile.Close() @@ -364,56 +503,9 @@ func respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, mediaMetada // bytesResponded is the total number of bytes written to the response to the client request // bytesWritten is the total number of bytes written to disk - var bytesResponded, bytesWritten int64 = 0, 0 - var fetchError error - // Note: the buffer size is the same as is used in io.Copy() - buffer := make([]byte, 32*1024) - for { - // read from remote request's response body - bytesRead, readErr := resp.Body.Read(buffer) - if bytesRead > 0 { - // write to client request's response body - bytesTemp, respErr := w.Write(buffer[:bytesRead]) - if bytesTemp != bytesRead || (respErr != nil && respErr != io.EOF) { - logger.Errorf("bytesTemp %v != bytesRead %v : %v", bytesTemp, bytesRead, respErr) - fetchError = errResponse - break - } - bytesResponded += int64(bytesTemp) - if fetchError == nil || (fetchError != errFileIsTooLarge && fetchError != errWrite) { - // if larger than cfg.MaxFileSize then stop writing to disk and discard cached file - if bytesWritten+int64(len(buffer)) > int64(cfg.MaxFileSize) { - fetchError = errFileIsTooLarge - } else { - // write to disk - bytesTemp, writeErr := writer.Write(buffer[:bytesRead]) - if writeErr != nil && writeErr != io.EOF { - fetchError = errWrite - } else { - bytesWritten += int64(bytesTemp) - } - } - } - } - if readErr != nil { - if readErr != io.EOF { - fetchError = errRead - } - break - } - } - - writer.Flush() - + bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, cfg.MaxFileSize, mediaMetadata, logger) + tmpFileWriter.Flush() if fetchError != nil { - logFields := log.Fields{ - "MediaID": mediaMetadata.MediaID, - "Origin": mediaMetadata.Origin, - } - if fetchError == errFileIsTooLarge { - logFields["MaxFileSize"] = cfg.MaxFileSize - } - logger.WithFields(logFields).Warnln(fetchError) tmpDirErr := os.RemoveAll(string(tmpDir)) if tmpDirErr != nil { logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) @@ -427,17 +519,7 @@ func respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, mediaMetada } else { // We attempt to bluntly close the connection because that is the // best thing we can do after we've sent a 200 OK - logger.Println("Attempting to close the connection.") - hijacker, ok := w.(http.Hijacker) - if ok { - connection, _, hijackErr := hijacker.Hijack() - if hijackErr == nil { - logger.Println("Closing") - connection.Close() - } else { - logger.Printf("Error trying to hijack: %v", hijackErr) - } - } + closeConnection(w, logger) } return } @@ -446,6 +528,7 @@ func respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, mediaMetada // Note: After this point we have responded to the client's request and are just dealing with local caching. // As we have responded with 200 OK, any errors are ineffectual to the client request and so we just log and return. + // FIXME: Does continuing to do work here that is ineffectual to the client have any bad side effects? Could we fire off the remainder in a separate goroutine to mitigate that? // It's possible the bytesWritten to the temporary file is different to the reported Content-Length from the remote // request's response. bytesWritten is therefore used as it is what would be sent to clients when reading from the local @@ -453,60 +536,7 @@ func respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, mediaMetada mediaMetadata.ContentLength = types.ContentLength(bytesWritten) mediaMetadata.UserID = types.MatrixUserID("@:" + string(mediaMetadata.Origin)) - logger.WithFields(log.Fields{ - "MediaID": mediaMetadata.MediaID, - "Origin": mediaMetadata.Origin, - "UploadName": mediaMetadata.UploadName, - "Content-Length": mediaMetadata.ContentLength, - "Content-Type": mediaMetadata.ContentType, - "Content-Disposition": mediaMetadata.ContentDisposition, - }).Infof("Storing file metadata to media repository database") - - // The database is the source of truth so we need to have moved the file first - err = moveFile( - types.Path(path.Join(string(tmpDir), "content")), - types.Path(getPathFromMediaMetadata(mediaMetadata, cfg.BasePath)), - ) - if err != nil { - tmpDirErr := os.RemoveAll(string(tmpDir)) - if tmpDirErr != nil { - logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) - } - return - } - - // Writing the metadata to the media repository database and removing the mxcURL from activeRemoteRequests needs to be atomic. - // If it were not atomic, a new request for the same file could come in in routine A and check the database before the INSERT. - // Routine B which was fetching could then have its INSERT complete and remove the mxcURL from the activeRemoteRequests. - // If routine A then checked the activeRemoteRequests it would think it needed to fetch the file when it's already in the database. - // The locking below mitigates this situation. - updateActiveRemoteRequests = false - activeRemoteRequests.Lock() - // FIXME: unlock after timeout of db request - // if written to disk, add to db - err = db.StoreMediaMetadata(mediaMetadata) - if err != nil { - finalDir := path.Dir(getPathFromMediaMetadata(mediaMetadata, cfg.BasePath)) - finalDirErr := os.RemoveAll(finalDir) - if finalDirErr != nil { - logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr) - } - if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok { - activeRemoteRequestCondition.Broadcast() - } - delete(activeRemoteRequests.Set, mxcURL) - activeRemoteRequests.Unlock() - return - } - logger.WithFields(log.Fields{ - "Origin": mediaMetadata.Origin, - "MediaID": mediaMetadata.MediaID, - }).Infof("Signalling other goroutines waiting for us to fetch the file.") - if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok { - activeRemoteRequestCondition.Broadcast() - } - delete(activeRemoteRequests.Set, mxcURL) - activeRemoteRequests.Unlock() + updateActiveRemoteRequests = commitFileAndMetadata(tmpDir, cfg.BasePath, mediaMetadata, activeRemoteRequests, db, mxcURL, logger) // TODO: generate thumbnails