mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-07 06:53:09 -06:00
mediaapi/writers/download: Add logger to downloadRequest and use it
This commit is contained in:
parent
6e24fb86cb
commit
13b1051a3e
|
|
@ -39,10 +39,11 @@ import (
|
|||
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-download
|
||||
type downloadRequest struct {
|
||||
MediaMetadata *types.MediaMetadata
|
||||
Logger *log.Entry
|
||||
}
|
||||
|
||||
// Validate validates the downloadRequest fields
|
||||
func (r downloadRequest) Validate() *util.JSONResponse {
|
||||
func (r *downloadRequest) Validate() *util.JSONResponse {
|
||||
// FIXME: the following errors aren't bad JSON, rather just a bad request path
|
||||
// maybe give the URL pattern in the routing, these are not even possible as the handler would not be hit...?
|
||||
if r.MediaMetadata.MediaID == "" {
|
||||
|
|
@ -60,11 +61,11 @@ func (r downloadRequest) Validate() *util.JSONResponse {
|
|||
return nil
|
||||
}
|
||||
|
||||
func jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse, logger *log.Entry) {
|
||||
func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse) {
|
||||
// Marshal JSON response into raw bytes to send as the HTTP body
|
||||
resBytes, err := json.Marshal(res.JSON)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Failed to marshal JSONResponse")
|
||||
r.Logger.WithError(err).Error("Failed to marshal JSONResponse")
|
||||
// this should never fail to be marshalled so drop err to the floor
|
||||
res = util.MessageResponse(500, "Internal Server Error")
|
||||
resBytes, _ = json.Marshal(res.JSON)
|
||||
|
|
@ -72,7 +73,7 @@ func jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse, logger *log
|
|||
|
||||
// Set status code and write the body
|
||||
w.WriteHeader(res.Code)
|
||||
logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes))
|
||||
r.Logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes))
|
||||
w.Write(resBytes)
|
||||
}
|
||||
|
||||
|
|
@ -90,26 +91,25 @@ var nTries = 5
|
|||
// If they are not present in the cache, they are obtained from the remote server and
|
||||
// simultaneously served back to the client and written into the cache.
|
||||
func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, mediaID types.MediaID, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) {
|
||||
logger := util.GetLogger(req.Context())
|
||||
|
||||
// request validation
|
||||
if req.Method != "GET" {
|
||||
jsonErrorResponse(w, util.JSONResponse{
|
||||
Code: 405,
|
||||
JSON: jsonerror.Unknown("request method must be GET"),
|
||||
}, logger)
|
||||
return
|
||||
}
|
||||
|
||||
r := &downloadRequest{
|
||||
MediaMetadata: &types.MediaMetadata{
|
||||
MediaID: mediaID,
|
||||
Origin: origin,
|
||||
},
|
||||
Logger: util.GetLogger(req.Context()),
|
||||
}
|
||||
|
||||
// request validation
|
||||
if req.Method != "GET" {
|
||||
r.jsonErrorResponse(w, util.JSONResponse{
|
||||
Code: 405,
|
||||
JSON: jsonerror.Unknown("request method must be GET"),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if resErr := r.Validate(); resErr != nil {
|
||||
jsonErrorResponse(w, *resErr, logger)
|
||||
r.jsonErrorResponse(w, *resErr)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -118,7 +118,7 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName,
|
|||
|
||||
if err == nil {
|
||||
// If we have a record, we can respond from the local file
|
||||
r.respondFromLocalFile(w, logger, cfg)
|
||||
r.respondFromLocalFile(w, cfg)
|
||||
return
|
||||
} else if err == sql.ErrNoRows && r.MediaMetadata.Origin != cfg.ServerName {
|
||||
// If we do not have a record and the origin is remote, we need to fetch it and respond with that file
|
||||
|
|
@ -133,31 +133,31 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName,
|
|||
err = db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin, r.MediaMetadata)
|
||||
if err == nil {
|
||||
// If we have a record, we can respond from the local file
|
||||
r.respondFromLocalFile(w, logger, cfg)
|
||||
r.respondFromLocalFile(w, cfg)
|
||||
activeRemoteRequests.Unlock()
|
||||
return
|
||||
}
|
||||
if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok {
|
||||
if tries >= nTries {
|
||||
logger.WithFields(log.Fields{
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"MediaID": r.MediaMetadata.MediaID,
|
||||
"Origin": r.MediaMetadata.Origin,
|
||||
}).Warnf("Other goroutines are trying to download the remote file and failing.")
|
||||
jsonErrorResponse(w, util.JSONResponse{
|
||||
r.jsonErrorResponse(w, util.JSONResponse{
|
||||
Code: 500,
|
||||
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
|
||||
}, logger)
|
||||
})
|
||||
activeRemoteRequests.Unlock()
|
||||
return
|
||||
}
|
||||
logger.WithFields(log.Fields{
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"Origin": r.MediaMetadata.Origin,
|
||||
"MediaID": r.MediaMetadata.MediaID,
|
||||
}).Infof("Waiting for another goroutine to fetch the remote file.")
|
||||
activeRemoteRequestCondition.Wait()
|
||||
activeRemoteRequests.Unlock()
|
||||
} else {
|
||||
logger.WithFields(log.Fields{
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"MediaID": r.MediaMetadata.MediaID,
|
||||
"Origin": r.MediaMetadata.Origin,
|
||||
}).Infof("Fetching remote file")
|
||||
|
|
@ -167,18 +167,18 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName,
|
|||
}
|
||||
}
|
||||
|
||||
r.respondFromRemoteFile(w, logger, cfg, db, activeRemoteRequests)
|
||||
r.respondFromRemoteFile(w, cfg, db, activeRemoteRequests)
|
||||
} else {
|
||||
// If we do not have a record and the origin is local, or if we have another error from the database, the file is not found
|
||||
jsonErrorResponse(w, util.JSONResponse{
|
||||
r.jsonErrorResponse(w, util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
|
||||
}, logger)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.MediaAPI) {
|
||||
logger.WithFields(log.Fields{
|
||||
func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, cfg config.MediaAPI) {
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"MediaID": r.MediaMetadata.MediaID,
|
||||
"Origin": r.MediaMetadata.Origin,
|
||||
"UploadName": r.MediaMetadata.UploadName,
|
||||
|
|
@ -191,7 +191,7 @@ func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.M
|
|||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
// FIXME: Remove erroneous file from database?
|
||||
jsonErrorResponse(w, util.JSONResponse{
|
||||
r.jsonErrorResponse(w, util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
|
||||
})
|
||||
|
|
@ -201,15 +201,15 @@ func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.M
|
|||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
// FIXME: Remove erroneous file from database?
|
||||
jsonErrorResponse(w, util.JSONResponse{
|
||||
r.jsonErrorResponse(w, util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
|
||||
}, logger)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if r.MediaMetadata.ContentLength > 0 && int64(r.MediaMetadata.ContentLength) != stat.Size() {
|
||||
logger.Warnf("File size in database (%v) and on disk (%v) differ.", r.MediaMetadata.ContentLength, stat.Size())
|
||||
r.Logger.Warnf("File size in database (%v) and on disk (%v) differ.", r.MediaMetadata.ContentLength, stat.Size())
|
||||
// FIXME: Remove erroneous file from database?
|
||||
}
|
||||
|
||||
|
|
@ -223,22 +223,22 @@ func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.M
|
|||
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
|
||||
|
||||
if bytesResponded, err := io.Copy(w, file); err != nil {
|
||||
logger.Warnf("Failed to copy from cache %v\n", err)
|
||||
r.Logger.Warnf("Failed to copy from cache %v\n", err)
|
||||
if bytesResponded == 0 {
|
||||
jsonErrorResponse(w, util.JSONResponse{
|
||||
r.jsonErrorResponse(w, util.JSONResponse{
|
||||
Code: 500,
|
||||
JSON: jsonerror.NotFound(fmt.Sprintf("Failed to respond with file with media ID %q", r.MediaMetadata.MediaID)),
|
||||
}, logger)
|
||||
})
|
||||
}
|
||||
// If we have written any data then we have already responded with 200 OK and all we can do is close the connection
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (r *downloadRequest) createRemoteRequest(logger *log.Entry) (*http.Response, *util.JSONResponse) {
|
||||
func (r *downloadRequest) createRemoteRequest() (*http.Response, *util.JSONResponse) {
|
||||
urls := getMatrixUrls(r.MediaMetadata.Origin)
|
||||
|
||||
logger.Printf("Connecting to remote %q\n", urls[0])
|
||||
r.Logger.Printf("Connecting to remote %q\n", urls[0])
|
||||
|
||||
remoteReqAddr := urls[0] + "/_matrix/media/v1/download/" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID)
|
||||
remoteReq, err := http.NewRequest("GET", remoteReqAddr, nil)
|
||||
|
|
@ -261,7 +261,7 @@ func (r *downloadRequest) createRemoteRequest(logger *log.Entry) (*http.Response
|
|||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
logger.Printf("Server responded with %d\n", resp.StatusCode)
|
||||
r.Logger.Printf("Server responded with %d\n", resp.StatusCode)
|
||||
if resp.StatusCode == 404 {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: 404,
|
||||
|
|
@ -335,16 +335,16 @@ func copyToActiveAndPassive(r io.Reader, wActive io.Writer, wPassive io.Writer,
|
|||
return bytesResponded, bytesWritten, fetchError
|
||||
}
|
||||
|
||||
func closeConnection(w http.ResponseWriter, logger *log.Entry) {
|
||||
logger.Println("Attempting to close the connection.")
|
||||
func (r *downloadRequest) closeConnection(w http.ResponseWriter) {
|
||||
r.Logger.Println("Attempting to close the connection.")
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if ok {
|
||||
connection, _, hijackErr := hijacker.Hijack()
|
||||
if hijackErr == nil {
|
||||
logger.Println("Closing")
|
||||
r.Logger.Println("Closing")
|
||||
connection.Close()
|
||||
} else {
|
||||
logger.Printf("Error trying to hijack: %v", hijackErr)
|
||||
r.Logger.Printf("Error trying to hijack: %v", hijackErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -357,10 +357,10 @@ func completeRemoteRequest(activeRemoteRequests *types.ActiveRemoteRequests, mxc
|
|||
activeRemoteRequests.Unlock()
|
||||
}
|
||||
|
||||
func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath types.Path, activeRemoteRequests *types.ActiveRemoteRequests, db *storage.Database, mxcURL string, logger *log.Entry) bool {
|
||||
func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath types.Path, activeRemoteRequests *types.ActiveRemoteRequests, db *storage.Database, mxcURL string) bool {
|
||||
updateActiveRemoteRequests := true
|
||||
|
||||
logger.WithFields(log.Fields{
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"MediaID": r.MediaMetadata.MediaID,
|
||||
"Origin": r.MediaMetadata.Origin,
|
||||
"UploadName": r.MediaMetadata.UploadName,
|
||||
|
|
@ -377,7 +377,7 @@ func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath type
|
|||
if err != nil {
|
||||
tmpDirErr := os.RemoveAll(string(tmpDir))
|
||||
if tmpDirErr != nil {
|
||||
logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
|
||||
r.Logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
|
||||
}
|
||||
return updateActiveRemoteRequests
|
||||
}
|
||||
|
|
@ -396,12 +396,12 @@ func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath type
|
|||
finalDir := path.Dir(getPathFromMediaMetadata(r.MediaMetadata, basePath))
|
||||
finalDirErr := os.RemoveAll(finalDir)
|
||||
if finalDirErr != nil {
|
||||
logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr)
|
||||
r.Logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr)
|
||||
}
|
||||
completeRemoteRequest(activeRemoteRequests, mxcURL)
|
||||
return updateActiveRemoteRequests
|
||||
}
|
||||
logger.WithFields(log.Fields{
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"Origin": r.MediaMetadata.Origin,
|
||||
"MediaID": r.MediaMetadata.MediaID,
|
||||
}).Infof("Signalling other goroutines waiting for us to fetch the file.")
|
||||
|
|
@ -409,8 +409,8 @@ func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath type
|
|||
return updateActiveRemoteRequests
|
||||
}
|
||||
|
||||
func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, cfg config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) {
|
||||
logger.WithFields(log.Fields{
|
||||
func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, cfg config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) {
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"MediaID": r.MediaMetadata.MediaID,
|
||||
"Origin": r.MediaMetadata.Origin,
|
||||
}).Infof("Fetching remote file")
|
||||
|
|
@ -431,9 +431,9 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l
|
|||
}()
|
||||
|
||||
// create request for remote file
|
||||
resp, errorResponse := r.createRemoteRequest(logger)
|
||||
resp, errorResponse := r.createRemoteRequest()
|
||||
if errorResponse != nil {
|
||||
jsonErrorResponse(w, *errorResponse, logger)
|
||||
r.jsonErrorResponse(w, *errorResponse)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
|
@ -441,7 +441,7 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l
|
|||
// get metadata from request and set metadata on response
|
||||
contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to parse content length")
|
||||
r.Logger.Warn("Failed to parse content length")
|
||||
}
|
||||
r.MediaMetadata.ContentLength = types.ContentLength(contentLength)
|
||||
|
||||
|
|
@ -450,7 +450,7 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l
|
|||
// FIXME: parse from Content-Disposition header if possible, else fall back
|
||||
//r.MediaMetadata.UploadName = types.Filename()
|
||||
|
||||
logger.WithFields(log.Fields{
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"MediaID": r.MediaMetadata.MediaID,
|
||||
"Origin": r.MediaMetadata.Origin,
|
||||
}).Infof("Connected to remote")
|
||||
|
|
@ -465,39 +465,39 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l
|
|||
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
|
||||
|
||||
// create the temporary file writer
|
||||
tmpFileWriter, tmpFile, tmpDir, errorResponse := createTempFileWriter(cfg.BasePath, logger)
|
||||
tmpFileWriter, tmpFile, tmpDir, errorResponse := createTempFileWriter(cfg.BasePath, r.Logger)
|
||||
if errorResponse != nil {
|
||||
jsonErrorResponse(w, *errorResponse, logger)
|
||||
r.jsonErrorResponse(w, *errorResponse)
|
||||
return
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
// read the remote request's response body
|
||||
// simultaneously write it to the incoming request's response body and the temporary file
|
||||
logger.WithFields(log.Fields{
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"MediaID": r.MediaMetadata.MediaID,
|
||||
"Origin": r.MediaMetadata.Origin,
|
||||
}).Infof("Proxying and caching remote file")
|
||||
|
||||
// bytesResponded is the total number of bytes written to the response to the client request
|
||||
// bytesWritten is the total number of bytes written to disk
|
||||
bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, cfg.MaxFileSize, r.MediaMetadata, logger)
|
||||
bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, cfg.MaxFileSize, r.MediaMetadata, r.Logger)
|
||||
tmpFileWriter.Flush()
|
||||
if fetchError != nil {
|
||||
tmpDirErr := os.RemoveAll(string(tmpDir))
|
||||
if tmpDirErr != nil {
|
||||
logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
|
||||
r.Logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
|
||||
}
|
||||
// Note: if we have responded with any data in the body at all then we have already sent 200 OK and we can only abort at this point
|
||||
if bytesResponded < 1 {
|
||||
jsonErrorResponse(w, util.JSONResponse{
|
||||
r.jsonErrorResponse(w, util.JSONResponse{
|
||||
Code: 502,
|
||||
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
|
||||
}, logger)
|
||||
})
|
||||
} else {
|
||||
// We attempt to bluntly close the connection because that is the
|
||||
// best thing we can do after we've sent a 200 OK
|
||||
closeConnection(w, logger)
|
||||
r.closeConnection(w)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -514,11 +514,11 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l
|
|||
r.MediaMetadata.ContentLength = types.ContentLength(bytesWritten)
|
||||
r.MediaMetadata.UserID = types.MatrixUserID("@:" + string(r.MediaMetadata.Origin))
|
||||
|
||||
updateActiveRemoteRequests = r.commitFileAndMetadata(tmpDir, cfg.BasePath, activeRemoteRequests, db, mxcURL, logger)
|
||||
updateActiveRemoteRequests = r.commitFileAndMetadata(tmpDir, cfg.BasePath, activeRemoteRequests, db, mxcURL)
|
||||
|
||||
// TODO: generate thumbnails
|
||||
|
||||
logger.WithFields(log.Fields{
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"MediaID": r.MediaMetadata.MediaID,
|
||||
"Origin": r.MediaMetadata.Origin,
|
||||
"UploadName": r.MediaMetadata.UploadName,
|
||||
|
|
|
|||
Loading…
Reference in a new issue