diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index 981e0cc0d..8481cd05c 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -39,10 +39,11 @@ import ( // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-download type downloadRequest struct { MediaMetadata *types.MediaMetadata + Logger *log.Entry } // Validate validates the downloadRequest fields -func (r downloadRequest) Validate() *util.JSONResponse { +func (r *downloadRequest) Validate() *util.JSONResponse { // FIXME: the following errors aren't bad JSON, rather just a bad request path // maybe give the URL pattern in the routing, these are not even possible as the handler would not be hit...? if r.MediaMetadata.MediaID == "" { @@ -60,11 +61,11 @@ func (r downloadRequest) Validate() *util.JSONResponse { return nil } -func jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse, logger *log.Entry) { +func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse) { // Marshal JSON response into raw bytes to send as the HTTP body resBytes, err := json.Marshal(res.JSON) if err != nil { - logger.WithError(err).Error("Failed to marshal JSONResponse") + r.Logger.WithError(err).Error("Failed to marshal JSONResponse") // this should never fail to be marshalled so drop err to the floor res = util.MessageResponse(500, "Internal Server Error") resBytes, _ = json.Marshal(res.JSON) @@ -72,7 +73,7 @@ func jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse, logger *log // Set status code and write the body w.WriteHeader(res.Code) - logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes)) + r.Logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes)) w.Write(resBytes) } @@ -90,26 +91,25 @@ var nTries = 5 // If they are not present in the cache, they are obtained from the remote server and // simultaneously served back to the client and written into the cache. func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, mediaID types.MediaID, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) { - logger := util.GetLogger(req.Context()) - - // request validation - if req.Method != "GET" { - jsonErrorResponse(w, util.JSONResponse{ - Code: 405, - JSON: jsonerror.Unknown("request method must be GET"), - }, logger) - return - } - r := &downloadRequest{ MediaMetadata: &types.MediaMetadata{ MediaID: mediaID, Origin: origin, }, + Logger: util.GetLogger(req.Context()), + } + + // request validation + if req.Method != "GET" { + r.jsonErrorResponse(w, util.JSONResponse{ + Code: 405, + JSON: jsonerror.Unknown("request method must be GET"), + }) + return } if resErr := r.Validate(); resErr != nil { - jsonErrorResponse(w, *resErr, logger) + r.jsonErrorResponse(w, *resErr) return } @@ -118,7 +118,7 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, if err == nil { // If we have a record, we can respond from the local file - r.respondFromLocalFile(w, logger, cfg) + r.respondFromLocalFile(w, cfg) return } else if err == sql.ErrNoRows && r.MediaMetadata.Origin != cfg.ServerName { // If we do not have a record and the origin is remote, we need to fetch it and respond with that file @@ -133,31 +133,31 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, err = db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin, r.MediaMetadata) if err == nil { // If we have a record, we can respond from the local file - r.respondFromLocalFile(w, logger, cfg) + r.respondFromLocalFile(w, cfg) activeRemoteRequests.Unlock() return } if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok { if tries >= nTries { - logger.WithFields(log.Fields{ + r.Logger.WithFields(log.Fields{ "MediaID": r.MediaMetadata.MediaID, "Origin": r.MediaMetadata.Origin, }).Warnf("Other goroutines are trying to download the remote file and failing.") - jsonErrorResponse(w, util.JSONResponse{ + r.jsonErrorResponse(w, util.JSONResponse{ Code: 500, JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), - }, logger) + }) activeRemoteRequests.Unlock() return } - logger.WithFields(log.Fields{ + r.Logger.WithFields(log.Fields{ "Origin": r.MediaMetadata.Origin, "MediaID": r.MediaMetadata.MediaID, }).Infof("Waiting for another goroutine to fetch the remote file.") activeRemoteRequestCondition.Wait() activeRemoteRequests.Unlock() } else { - logger.WithFields(log.Fields{ + r.Logger.WithFields(log.Fields{ "MediaID": r.MediaMetadata.MediaID, "Origin": r.MediaMetadata.Origin, }).Infof("Fetching remote file") @@ -167,18 +167,18 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, } } - r.respondFromRemoteFile(w, logger, cfg, db, activeRemoteRequests) + r.respondFromRemoteFile(w, cfg, db, activeRemoteRequests) } else { // If we do not have a record and the origin is local, or if we have another error from the database, the file is not found - jsonErrorResponse(w, util.JSONResponse{ + r.jsonErrorResponse(w, util.JSONResponse{ Code: 404, JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), - }, logger) + }) } } -func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.MediaAPI) { - logger.WithFields(log.Fields{ +func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, cfg config.MediaAPI) { + r.Logger.WithFields(log.Fields{ "MediaID": r.MediaMetadata.MediaID, "Origin": r.MediaMetadata.Origin, "UploadName": r.MediaMetadata.UploadName, @@ -191,7 +191,7 @@ func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.M file, err := os.Open(filePath) if err != nil { // FIXME: Remove erroneous file from database? - jsonErrorResponse(w, util.JSONResponse{ + r.jsonErrorResponse(w, util.JSONResponse{ Code: 404, JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), }) @@ -201,15 +201,15 @@ func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.M stat, err := file.Stat() if err != nil { // FIXME: Remove erroneous file from database? - jsonErrorResponse(w, util.JSONResponse{ + r.jsonErrorResponse(w, util.JSONResponse{ Code: 404, JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), - }, logger) + }) return } if r.MediaMetadata.ContentLength > 0 && int64(r.MediaMetadata.ContentLength) != stat.Size() { - logger.Warnf("File size in database (%v) and on disk (%v) differ.", r.MediaMetadata.ContentLength, stat.Size()) + r.Logger.Warnf("File size in database (%v) and on disk (%v) differ.", r.MediaMetadata.ContentLength, stat.Size()) // FIXME: Remove erroneous file from database? } @@ -223,22 +223,22 @@ func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.M w.Header().Set("Content-Security-Policy", contentSecurityPolicy) if bytesResponded, err := io.Copy(w, file); err != nil { - logger.Warnf("Failed to copy from cache %v\n", err) + r.Logger.Warnf("Failed to copy from cache %v\n", err) if bytesResponded == 0 { - jsonErrorResponse(w, util.JSONResponse{ + r.jsonErrorResponse(w, util.JSONResponse{ Code: 500, JSON: jsonerror.NotFound(fmt.Sprintf("Failed to respond with file with media ID %q", r.MediaMetadata.MediaID)), - }, logger) + }) } // If we have written any data then we have already responded with 200 OK and all we can do is close the connection return } } -func (r *downloadRequest) createRemoteRequest(logger *log.Entry) (*http.Response, *util.JSONResponse) { +func (r *downloadRequest) createRemoteRequest() (*http.Response, *util.JSONResponse) { urls := getMatrixUrls(r.MediaMetadata.Origin) - logger.Printf("Connecting to remote %q\n", urls[0]) + r.Logger.Printf("Connecting to remote %q\n", urls[0]) remoteReqAddr := urls[0] + "/_matrix/media/v1/download/" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) remoteReq, err := http.NewRequest("GET", remoteReqAddr, nil) @@ -261,7 +261,7 @@ func (r *downloadRequest) createRemoteRequest(logger *log.Entry) (*http.Response } if resp.StatusCode != 200 { - logger.Printf("Server responded with %d\n", resp.StatusCode) + r.Logger.Printf("Server responded with %d\n", resp.StatusCode) if resp.StatusCode == 404 { return nil, &util.JSONResponse{ Code: 404, @@ -335,16 +335,16 @@ func copyToActiveAndPassive(r io.Reader, wActive io.Writer, wPassive io.Writer, return bytesResponded, bytesWritten, fetchError } -func closeConnection(w http.ResponseWriter, logger *log.Entry) { - logger.Println("Attempting to close the connection.") +func (r *downloadRequest) closeConnection(w http.ResponseWriter) { + r.Logger.Println("Attempting to close the connection.") hijacker, ok := w.(http.Hijacker) if ok { connection, _, hijackErr := hijacker.Hijack() if hijackErr == nil { - logger.Println("Closing") + r.Logger.Println("Closing") connection.Close() } else { - logger.Printf("Error trying to hijack: %v", hijackErr) + r.Logger.Printf("Error trying to hijack: %v", hijackErr) } } } @@ -357,10 +357,10 @@ func completeRemoteRequest(activeRemoteRequests *types.ActiveRemoteRequests, mxc activeRemoteRequests.Unlock() } -func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath types.Path, activeRemoteRequests *types.ActiveRemoteRequests, db *storage.Database, mxcURL string, logger *log.Entry) bool { +func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath types.Path, activeRemoteRequests *types.ActiveRemoteRequests, db *storage.Database, mxcURL string) bool { updateActiveRemoteRequests := true - logger.WithFields(log.Fields{ + r.Logger.WithFields(log.Fields{ "MediaID": r.MediaMetadata.MediaID, "Origin": r.MediaMetadata.Origin, "UploadName": r.MediaMetadata.UploadName, @@ -377,7 +377,7 @@ func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath type if err != nil { tmpDirErr := os.RemoveAll(string(tmpDir)) if tmpDirErr != nil { - logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + r.Logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) } return updateActiveRemoteRequests } @@ -396,12 +396,12 @@ func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath type finalDir := path.Dir(getPathFromMediaMetadata(r.MediaMetadata, basePath)) finalDirErr := os.RemoveAll(finalDir) if finalDirErr != nil { - logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr) + r.Logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr) } completeRemoteRequest(activeRemoteRequests, mxcURL) return updateActiveRemoteRequests } - logger.WithFields(log.Fields{ + r.Logger.WithFields(log.Fields{ "Origin": r.MediaMetadata.Origin, "MediaID": r.MediaMetadata.MediaID, }).Infof("Signalling other goroutines waiting for us to fetch the file.") @@ -409,8 +409,8 @@ func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath type return updateActiveRemoteRequests } -func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, cfg config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) { - logger.WithFields(log.Fields{ +func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, cfg config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) { + r.Logger.WithFields(log.Fields{ "MediaID": r.MediaMetadata.MediaID, "Origin": r.MediaMetadata.Origin, }).Infof("Fetching remote file") @@ -431,9 +431,9 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l }() // create request for remote file - resp, errorResponse := r.createRemoteRequest(logger) + resp, errorResponse := r.createRemoteRequest() if errorResponse != nil { - jsonErrorResponse(w, *errorResponse, logger) + r.jsonErrorResponse(w, *errorResponse) return } defer resp.Body.Close() @@ -441,7 +441,7 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l // get metadata from request and set metadata on response contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) if err != nil { - logger.Warn("Failed to parse content length") + r.Logger.Warn("Failed to parse content length") } r.MediaMetadata.ContentLength = types.ContentLength(contentLength) @@ -450,7 +450,7 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l // FIXME: parse from Content-Disposition header if possible, else fall back //r.MediaMetadata.UploadName = types.Filename() - logger.WithFields(log.Fields{ + r.Logger.WithFields(log.Fields{ "MediaID": r.MediaMetadata.MediaID, "Origin": r.MediaMetadata.Origin, }).Infof("Connected to remote") @@ -465,39 +465,39 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l w.Header().Set("Content-Security-Policy", contentSecurityPolicy) // create the temporary file writer - tmpFileWriter, tmpFile, tmpDir, errorResponse := createTempFileWriter(cfg.BasePath, logger) + tmpFileWriter, tmpFile, tmpDir, errorResponse := createTempFileWriter(cfg.BasePath, r.Logger) if errorResponse != nil { - jsonErrorResponse(w, *errorResponse, logger) + r.jsonErrorResponse(w, *errorResponse) return } defer tmpFile.Close() // read the remote request's response body // simultaneously write it to the incoming request's response body and the temporary file - logger.WithFields(log.Fields{ + r.Logger.WithFields(log.Fields{ "MediaID": r.MediaMetadata.MediaID, "Origin": r.MediaMetadata.Origin, }).Infof("Proxying and caching remote file") // bytesResponded is the total number of bytes written to the response to the client request // bytesWritten is the total number of bytes written to disk - bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, cfg.MaxFileSize, r.MediaMetadata, logger) + bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, cfg.MaxFileSize, r.MediaMetadata, r.Logger) tmpFileWriter.Flush() if fetchError != nil { tmpDirErr := os.RemoveAll(string(tmpDir)) if tmpDirErr != nil { - logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + r.Logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) } // Note: if we have responded with any data in the body at all then we have already sent 200 OK and we can only abort at this point if bytesResponded < 1 { - jsonErrorResponse(w, util.JSONResponse{ + r.jsonErrorResponse(w, util.JSONResponse{ Code: 502, JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), - }, logger) + }) } else { // We attempt to bluntly close the connection because that is the // best thing we can do after we've sent a 200 OK - closeConnection(w, logger) + r.closeConnection(w) } return } @@ -514,11 +514,11 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l r.MediaMetadata.ContentLength = types.ContentLength(bytesWritten) r.MediaMetadata.UserID = types.MatrixUserID("@:" + string(r.MediaMetadata.Origin)) - updateActiveRemoteRequests = r.commitFileAndMetadata(tmpDir, cfg.BasePath, activeRemoteRequests, db, mxcURL, logger) + updateActiveRemoteRequests = r.commitFileAndMetadata(tmpDir, cfg.BasePath, activeRemoteRequests, db, mxcURL) // TODO: generate thumbnails - logger.WithFields(log.Fields{ + r.Logger.WithFields(log.Fields{ "MediaID": r.MediaMetadata.MediaID, "Origin": r.MediaMetadata.Origin, "UploadName": r.MediaMetadata.UploadName,