From 9c29a31e7e07b961898a7509388bfe740ede5f72 Mon Sep 17 00:00:00 2001 From: Robert Swain Date: Thu, 11 May 2017 09:19:34 +0200 Subject: [PATCH] mediaapi/writers/download: Factor out respondFromRemoteFile --- .../dendrite/mediaapi/writers/download.go | 512 +++++++++--------- 1 file changed, 271 insertions(+), 241 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index 4a578af5f..f70608292 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -164,247 +164,7 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, } } - // FIXME: Only request once (would race if multiple requests for the same remote file) - // Use a hash set based on the origin and media ID (the request URL should be fine...) and synchronise adding / removing members - urls := getMatrixUrls(r.MediaMetadata.Origin) - - logger.Printf("Connecting to remote %q\n", urls[0]) - - remoteReqAddr := urls[0] + "/_matrix/media/v1/download/" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) - remoteReq, err := http.NewRequest("GET", remoteReqAddr, nil) - if err != nil { - jsonErrorResponse(w, util.JSONResponse{ - Code: 500, - JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), - }, logger) - return - } - - remoteReq.Header.Set("Host", string(r.MediaMetadata.Origin)) - - client := http.Client{} - resp, err := client.Do(remoteReq) - if err != nil { - jsonErrorResponse(w, util.JSONResponse{ - Code: 502, - JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), - }, logger) - return - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - logger.Printf("Server responded with %d\n", resp.StatusCode) - if resp.StatusCode == 404 { - jsonErrorResponse(w, util.JSONResponse{ - Code: 404, - JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), - }, logger) - return - } - jsonErrorResponse(w, util.JSONResponse{ - Code: 502, - JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), - }, logger) - return - } - - contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - if err != nil { - logger.Warn("Failed to parse content length") - } - r.MediaMetadata.ContentLength = types.ContentLength(contentLength) - - r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) - r.MediaMetadata.ContentDisposition = types.ContentDisposition(resp.Header.Get("Content-Disposition")) - // FIXME: parse from Content-Disposition header if possible, else fall back - //r.MediaMetadata.UploadName = types.Filename() - - logger.WithFields(log.Fields{ - "MediaID": r.MediaMetadata.MediaID, - "Origin": r.MediaMetadata.Origin, - }).Infof("Connected to remote") - - w.Header().Set("Content-Type", string(r.MediaMetadata.ContentType)) - w.Header().Set("Content-Length", strconv.FormatInt(int64(r.MediaMetadata.ContentLength), 10)) - contentSecurityPolicy := "default-src 'none';" + - " script-src 'none';" + - " plugin-types application/pdf;" + - " style-src 'unsafe-inline';" + - " object-src 'self';" - w.Header().Set("Content-Security-Policy", contentSecurityPolicy) - - tmpDir, err := createTempDir(cfg.BasePath) - if err != nil { - logger.Infof("Failed to create temp dir %q\n", err) - jsonErrorResponse(w, util.JSONResponse{ - Code: 400, - JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), - }, logger) - return - } - tmpFile, writer, err := createFileWriter(tmpDir, "content") - if err != nil { - logger.Infof("Failed to create file writer %q\n", err) - jsonErrorResponse(w, util.JSONResponse{ - Code: 400, - JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), - }, logger) - return - } - defer tmpFile.Close() - - logger.WithFields(log.Fields{ - "MediaID": r.MediaMetadata.MediaID, - "Origin": r.MediaMetadata.Origin, - }).Infof("Proxying and caching remote file") - - // bytesResponded is the total number of bytes written to the response to the client request - // bytesWritten is the total number of bytes written to disk - var bytesResponded, bytesWritten int64 = 0, 0 - var fetchError error - // Note: the buffer size is the same as is used in io.Copy() - buffer := make([]byte, 32*1024) - for { - // read from remote request's response body - bytesRead, readErr := resp.Body.Read(buffer) - if bytesRead > 0 { - // write to client request's response body - bytesTemp, respErr := w.Write(buffer[:bytesRead]) - if bytesTemp != bytesRead || (respErr != nil && respErr != io.EOF) { - logger.Errorf("bytesTemp %v != bytesRead %v : %v", bytesTemp, bytesRead, respErr) - fetchError = errResponse - break - } - bytesResponded += int64(bytesTemp) - if fetchError == nil || (fetchError != errFileIsTooLarge && fetchError != errWrite) { - // if larger than cfg.MaxFileSize then stop writing to disk and discard cached file - if bytesWritten+int64(len(buffer)) > int64(cfg.MaxFileSize) { - fetchError = errFileIsTooLarge - } else { - // write to disk - bytesTemp, writeErr := writer.Write(buffer[:bytesRead]) - if writeErr != nil && writeErr != io.EOF { - fetchError = errWrite - } else { - bytesWritten += int64(bytesTemp) - } - } - } - } - if readErr != nil { - if readErr != io.EOF { - fetchError = errRead - } - break - } - } - - writer.Flush() - - if fetchError != nil { - logFields := log.Fields{ - "MediaID": r.MediaMetadata.MediaID, - "Origin": r.MediaMetadata.Origin, - } - if fetchError == errFileIsTooLarge { - logFields["MaxFileSize"] = cfg.MaxFileSize - } - logger.WithFields(logFields).Warnln(fetchError) - tmpDirErr := os.RemoveAll(string(tmpDir)) - if tmpDirErr != nil { - logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) - } - // Note: if we have responded with any data in the body at all then we have already sent 200 OK and we can only abort at this point - if bytesResponded < 1 { - jsonErrorResponse(w, util.JSONResponse{ - Code: 502, - JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), - }, logger) - } else { - // We attempt to bluntly close the connection because that is the - // best thing we can do after we've sent a 200 OK - logger.Println("Attempting to close the connection.") - hijacker, ok := w.(http.Hijacker) - if ok { - connection, _, hijackErr := hijacker.Hijack() - if hijackErr == nil { - logger.Println("Closing") - connection.Close() - } else { - logger.Printf("Error trying to hijack: %v", hijackErr) - } - } - } - return - } - - // Note: After this point we have responded to the client's request and are just dealing with local caching. - // As we have responded with 200 OK, any errors are ineffectual to the client request and so we just log and return. - - r.MediaMetadata.ContentLength = types.ContentLength(bytesWritten) - r.MediaMetadata.UserID = types.MatrixUserID("@:" + string(r.MediaMetadata.Origin)) - - logger.WithFields(log.Fields{ - "MediaID": r.MediaMetadata.MediaID, - "Origin": r.MediaMetadata.Origin, - "UploadName": r.MediaMetadata.UploadName, - "Content-Length": r.MediaMetadata.ContentLength, - "Content-Type": r.MediaMetadata.ContentType, - "Content-Disposition": r.MediaMetadata.ContentDisposition, - }).Infof("Storing file metadata to media repository database") - - // The database is the source of truth so we need to have moved the file first - err = moveFile( - types.Path(path.Join(string(tmpDir), "content")), - types.Path(getPathFromMediaMetadata(r.MediaMetadata, cfg.BasePath)), - ) - if err != nil { - tmpDirErr := os.RemoveAll(string(tmpDir)) - if tmpDirErr != nil { - logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) - } - return - } - - // Writing the metadata to the media repository database and removing the mxcURL from activeRemoteRequests needs to be atomic. - // If it were not atomic, a new request for the same file could come in in routine A and check the database before the INSERT. - // Routine B which was fetching could then have its INSERT complete and remove the mxcURL from the activeRemoteRequests. - // If routine A then checked the activeRemoteRequests it would think it needed to fetch the file when it's already in the database. - // The locking below mitigates this situation. - activeRemoteRequests.Lock() - // FIXME: unlock after timeout of db request - // if written to disk, add to db - err = db.StoreMediaMetadata(r.MediaMetadata) - if err != nil { - finalDir := path.Dir(getPathFromMediaMetadata(r.MediaMetadata, cfg.BasePath)) - finalDirErr := os.RemoveAll(finalDir) - if finalDirErr != nil { - logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr) - } - delete(activeRemoteRequests.Set, mxcURL) - activeRemoteRequests.Unlock() - return - } - activeRemoteRequestCondition, _ := activeRemoteRequests.Set[mxcURL] - logger.WithFields(log.Fields{ - "Origin": r.MediaMetadata.Origin, - "MediaID": r.MediaMetadata.MediaID, - }).Infof("Signalling other goroutines waiting for us to fetch the file.") - activeRemoteRequestCondition.Broadcast() - delete(activeRemoteRequests.Set, mxcURL) - activeRemoteRequests.Unlock() - - // TODO: generate thumbnails - - logger.WithFields(log.Fields{ - "MediaID": r.MediaMetadata.MediaID, - "Origin": r.MediaMetadata.Origin, - "UploadName": r.MediaMetadata.UploadName, - "Content-Length": r.MediaMetadata.ContentLength, - "Content-Type": r.MediaMetadata.ContentType, - "Content-Disposition": r.MediaMetadata.ContentDisposition, - }).Infof("Remote file cached") + respondFromRemoteFile(w, logger, r.MediaMetadata, cfg, db, activeRemoteRequests) } else { // TODO: If we do not have a record and the origin is local, or if we have another error from the database, the file is not found jsonErrorResponse(w, util.JSONResponse{ @@ -472,6 +232,276 @@ func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, mediaMetadat } } +func respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, mediaMetadata *types.MediaMetadata, cfg config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) { + logger.WithFields(log.Fields{ + "MediaID": mediaMetadata.MediaID, + "Origin": mediaMetadata.Origin, + }).Infof("Fetching remote file") + + mxcURL := "mxc://" + string(mediaMetadata.Origin) + "/" + string(mediaMetadata.MediaID) + + // If we hit an error and we return early, we need to lock, broadcast on the condition, delete the condition and unlock. + // If we return normally we have slightly different locking around the storage of metadata to the database and deletion of the condition. + // As such, this deferred cleanup of the sync.Cond is conditional. + // This approach seems safer than potentially missing this cleanup in error cases. + updateActiveRemoteRequests := true + defer func() { + if updateActiveRemoteRequests { + activeRemoteRequests.Lock() + if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok { + activeRemoteRequestCondition.Broadcast() + } + delete(activeRemoteRequests.Set, mxcURL) + activeRemoteRequests.Unlock() + } + }() + + urls := getMatrixUrls(mediaMetadata.Origin) + + logger.Printf("Connecting to remote %q\n", urls[0]) + + remoteReqAddr := urls[0] + "/_matrix/media/v1/download/" + string(mediaMetadata.Origin) + "/" + string(mediaMetadata.MediaID) + remoteReq, err := http.NewRequest("GET", remoteReqAddr, nil) + if err != nil { + jsonErrorResponse(w, util.JSONResponse{ + Code: 500, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", mediaMetadata.MediaID, mediaMetadata.Origin)), + }, logger) + return + } + + remoteReq.Header.Set("Host", string(mediaMetadata.Origin)) + + client := http.Client{} + resp, err := client.Do(remoteReq) + if err != nil { + jsonErrorResponse(w, util.JSONResponse{ + Code: 502, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", mediaMetadata.MediaID, mediaMetadata.Origin)), + }, logger) + return + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + logger.Printf("Server responded with %d\n", resp.StatusCode) + if resp.StatusCode == 404 { + jsonErrorResponse(w, util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", mediaMetadata.MediaID)), + }, logger) + return + } + jsonErrorResponse(w, util.JSONResponse{ + Code: 502, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", mediaMetadata.MediaID, mediaMetadata.Origin)), + }, logger) + return + } + + contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + if err != nil { + logger.Warn("Failed to parse content length") + } + mediaMetadata.ContentLength = types.ContentLength(contentLength) + + mediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) + mediaMetadata.ContentDisposition = types.ContentDisposition(resp.Header.Get("Content-Disposition")) + // FIXME: parse from Content-Disposition header if possible, else fall back + //mediaMetadata.UploadName = types.Filename() + + logger.WithFields(log.Fields{ + "MediaID": mediaMetadata.MediaID, + "Origin": mediaMetadata.Origin, + }).Infof("Connected to remote") + + w.Header().Set("Content-Type", string(mediaMetadata.ContentType)) + w.Header().Set("Content-Length", strconv.FormatInt(int64(mediaMetadata.ContentLength), 10)) + contentSecurityPolicy := "default-src 'none';" + + " script-src 'none';" + + " plugin-types application/pdf;" + + " style-src 'unsafe-inline';" + + " object-src 'self';" + w.Header().Set("Content-Security-Policy", contentSecurityPolicy) + + tmpDir, err := createTempDir(cfg.BasePath) + if err != nil { + logger.Infof("Failed to create temp dir %q\n", err) + jsonErrorResponse(w, util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), + }, logger) + return + } + tmpFile, writer, err := createFileWriter(tmpDir, "content") + if err != nil { + logger.Infof("Failed to create file writer %q\n", err) + jsonErrorResponse(w, util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), + }, logger) + return + } + defer tmpFile.Close() + + logger.WithFields(log.Fields{ + "MediaID": mediaMetadata.MediaID, + "Origin": mediaMetadata.Origin, + }).Infof("Proxying and caching remote file") + + // bytesResponded is the total number of bytes written to the response to the client request + // bytesWritten is the total number of bytes written to disk + var bytesResponded, bytesWritten int64 = 0, 0 + var fetchError error + // Note: the buffer size is the same as is used in io.Copy() + buffer := make([]byte, 32*1024) + for { + // read from remote request's response body + bytesRead, readErr := resp.Body.Read(buffer) + if bytesRead > 0 { + // write to client request's response body + bytesTemp, respErr := w.Write(buffer[:bytesRead]) + if bytesTemp != bytesRead || (respErr != nil && respErr != io.EOF) { + logger.Errorf("bytesTemp %v != bytesRead %v : %v", bytesTemp, bytesRead, respErr) + fetchError = errResponse + break + } + bytesResponded += int64(bytesTemp) + if fetchError == nil || (fetchError != errFileIsTooLarge && fetchError != errWrite) { + // if larger than cfg.MaxFileSize then stop writing to disk and discard cached file + if bytesWritten+int64(len(buffer)) > int64(cfg.MaxFileSize) { + fetchError = errFileIsTooLarge + } else { + // write to disk + bytesTemp, writeErr := writer.Write(buffer[:bytesRead]) + if writeErr != nil && writeErr != io.EOF { + fetchError = errWrite + } else { + bytesWritten += int64(bytesTemp) + } + } + } + } + if readErr != nil { + if readErr != io.EOF { + fetchError = errRead + } + break + } + } + + writer.Flush() + + if fetchError != nil { + logFields := log.Fields{ + "MediaID": mediaMetadata.MediaID, + "Origin": mediaMetadata.Origin, + } + if fetchError == errFileIsTooLarge { + logFields["MaxFileSize"] = cfg.MaxFileSize + } + logger.WithFields(logFields).Warnln(fetchError) + tmpDirErr := os.RemoveAll(string(tmpDir)) + if tmpDirErr != nil { + logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + } + // Note: if we have responded with any data in the body at all then we have already sent 200 OK and we can only abort at this point + if bytesResponded < 1 { + jsonErrorResponse(w, util.JSONResponse{ + Code: 502, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", mediaMetadata.MediaID, mediaMetadata.Origin)), + }, logger) + } else { + // We attempt to bluntly close the connection because that is the + // best thing we can do after we've sent a 200 OK + logger.Println("Attempting to close the connection.") + hijacker, ok := w.(http.Hijacker) + if ok { + connection, _, hijackErr := hijacker.Hijack() + if hijackErr == nil { + logger.Println("Closing") + connection.Close() + } else { + logger.Printf("Error trying to hijack: %v", hijackErr) + } + } + } + return + } + + // Note: After this point we have responded to the client's request and are just dealing with local caching. + // As we have responded with 200 OK, any errors are ineffectual to the client request and so we just log and return. + + mediaMetadata.ContentLength = types.ContentLength(bytesWritten) + mediaMetadata.UserID = types.MatrixUserID("@:" + string(mediaMetadata.Origin)) + + logger.WithFields(log.Fields{ + "MediaID": mediaMetadata.MediaID, + "Origin": mediaMetadata.Origin, + "UploadName": mediaMetadata.UploadName, + "Content-Length": mediaMetadata.ContentLength, + "Content-Type": mediaMetadata.ContentType, + "Content-Disposition": mediaMetadata.ContentDisposition, + }).Infof("Storing file metadata to media repository database") + + // The database is the source of truth so we need to have moved the file first + err = moveFile( + types.Path(path.Join(string(tmpDir), "content")), + types.Path(getPathFromMediaMetadata(mediaMetadata, cfg.BasePath)), + ) + if err != nil { + tmpDirErr := os.RemoveAll(string(tmpDir)) + if tmpDirErr != nil { + logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + } + return + } + + // Writing the metadata to the media repository database and removing the mxcURL from activeRemoteRequests needs to be atomic. + // If it were not atomic, a new request for the same file could come in in routine A and check the database before the INSERT. + // Routine B which was fetching could then have its INSERT complete and remove the mxcURL from the activeRemoteRequests. + // If routine A then checked the activeRemoteRequests it would think it needed to fetch the file when it's already in the database. + // The locking below mitigates this situation. + updateActiveRemoteRequests = false + activeRemoteRequests.Lock() + // FIXME: unlock after timeout of db request + // if written to disk, add to db + err = db.StoreMediaMetadata(mediaMetadata) + if err != nil { + finalDir := path.Dir(getPathFromMediaMetadata(mediaMetadata, cfg.BasePath)) + finalDirErr := os.RemoveAll(finalDir) + if finalDirErr != nil { + logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr) + } + if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok { + activeRemoteRequestCondition.Broadcast() + } + delete(activeRemoteRequests.Set, mxcURL) + activeRemoteRequests.Unlock() + return + } + logger.WithFields(log.Fields{ + "Origin": mediaMetadata.Origin, + "MediaID": mediaMetadata.MediaID, + }).Infof("Signalling other goroutines waiting for us to fetch the file.") + if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok { + activeRemoteRequestCondition.Broadcast() + } + delete(activeRemoteRequests.Set, mxcURL) + activeRemoteRequests.Unlock() + + // TODO: generate thumbnails + + logger.WithFields(log.Fields{ + "MediaID": mediaMetadata.MediaID, + "Origin": mediaMetadata.Origin, + "UploadName": mediaMetadata.UploadName, + "Content-Length": mediaMetadata.ContentLength, + "Content-Type": mediaMetadata.ContentType, + "Content-Disposition": mediaMetadata.ContentDisposition, + }).Infof("Remote file cached") +} + // Given a matrix server name, attempt to discover URLs to contact the server // on. func getMatrixUrls(serverName types.ServerName) []string {