Add support for getting authenticated federation media requests
This commit is contained in:
parent
29ee5401ee
commit
89d64c6dca
|
@ -33,6 +33,7 @@ func AddPublicRoutes(
|
|||
cfg *config.Dendrite,
|
||||
userAPI userapi.MediaUserAPI,
|
||||
client *fclient.Client,
|
||||
fedClient fclient.FederationClient,
|
||||
keyRing gomatrixserverlib.JSONVerifier,
|
||||
) {
|
||||
mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database)
|
||||
|
@ -41,6 +42,6 @@ func AddPublicRoutes(
|
|||
}
|
||||
|
||||
routing.Setup(
|
||||
routers, cfg, mediaDB, userAPI, client, keyRing,
|
||||
routers, cfg, mediaDB, userAPI, client, fedClient, keyRing,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -65,6 +65,8 @@ type downloadRequest struct {
|
|||
Logger *log.Entry
|
||||
DownloadFilename string
|
||||
forFederation bool // whether we need to return a multipart/mixed response
|
||||
fedClient fclient.FederationClient
|
||||
origin spec.ServerName
|
||||
}
|
||||
|
||||
// Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84
|
||||
|
@ -115,6 +117,7 @@ func Download(
|
|||
cfg *config.MediaAPI,
|
||||
db storage.Database,
|
||||
client *fclient.Client,
|
||||
fedClient fclient.FederationClient,
|
||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||
isThumbnailRequest bool,
|
||||
|
@ -137,6 +140,8 @@ func Download(
|
|||
}),
|
||||
DownloadFilename: customFilename,
|
||||
forFederation: forFederation,
|
||||
origin: cfg.Matrix.ServerName,
|
||||
fedClient: fedClient,
|
||||
}
|
||||
|
||||
if dReq.IsThumbnailRequest {
|
||||
|
@ -773,8 +778,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
|
||||
reader := *body
|
||||
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, reader io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
|
||||
var contentLength int64
|
||||
|
||||
if contentLengthHeader != "" {
|
||||
|
@ -793,7 +797,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
|
|||
|
||||
// We successfully parsed the Content-Length, so we'll return a limited
|
||||
// reader that restricts us to reading only up to this size.
|
||||
reader = io.NopCloser(io.LimitReader(*body, parsedLength))
|
||||
reader = io.NopCloser(io.LimitReader(reader, parsedLength))
|
||||
contentLength = parsedLength
|
||||
} else {
|
||||
// Content-Length header is missing. If we have a maximum file size
|
||||
|
@ -802,7 +806,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
|
|||
// ultimately it will get rewritten later when the temp file is written
|
||||
// to disk.
|
||||
if maxFileSizeBytes > 0 {
|
||||
reader = io.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
|
||||
reader = io.NopCloser(io.LimitReader(reader, int64(maxFileSizeBytes)))
|
||||
}
|
||||
contentLength = 0
|
||||
}
|
||||
|
@ -818,19 +822,61 @@ func (r *downloadRequest) fetchRemoteFile(
|
|||
) (types.Path, bool, error) {
|
||||
r.Logger.Debug("Fetching remote file")
|
||||
|
||||
// Attempt to download via authenticated media endpoint
|
||||
isMultiPart := true
|
||||
resp, err := r.fedClient.DownloadMedia(ctx, r.origin, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
|
||||
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
|
||||
isMultiPart = false
|
||||
// try again on the unauthed endpoint
|
||||
// create request for remote file
|
||||
resp, err := client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
|
||||
resp, err = client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
|
||||
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
|
||||
if resp != nil && resp.StatusCode == http.StatusNotFound {
|
||||
return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
|
||||
}
|
||||
return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close() // nolint: errcheck
|
||||
|
||||
var contentLength int64
|
||||
var reader io.Reader
|
||||
var parseErr error
|
||||
if isMultiPart {
|
||||
r.Logger.Info("Downloaded file using authenticated endpoint")
|
||||
_, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if params["boundary"] == "" {
|
||||
return "", false, fmt.Errorf("no boundary header found on %s", r.MediaMetadata.Origin)
|
||||
}
|
||||
mr := multipart.NewReader(resp.Body, params["boundary"])
|
||||
|
||||
first := true
|
||||
for {
|
||||
p, err := mr.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
if !first {
|
||||
readCloser := io.NopCloser(p)
|
||||
contentLength, reader, parseErr = r.GetContentLengthAndReader(p.Header.Get("Content-Length"), readCloser, maxFileSizeBytes)
|
||||
break
|
||||
}
|
||||
|
||||
first = false
|
||||
}
|
||||
} else {
|
||||
// The reader returned here will be limited either by the Content-Length
|
||||
// and/or the configured maximum media size.
|
||||
contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes)
|
||||
contentLength, reader, parseErr = r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes)
|
||||
}
|
||||
|
||||
if parseErr != nil {
|
||||
return "", false, parseErr
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@ func Setup(
|
|||
db storage.Database,
|
||||
userAPI userapi.MediaUserAPI,
|
||||
client *fclient.Client,
|
||||
federationClient fclient.FederationClient,
|
||||
keyRing gomatrixserverlib.JSONVerifier,
|
||||
) {
|
||||
rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting)
|
||||
|
@ -95,12 +96,12 @@ func Setup(
|
|||
MXCToResult: map[string]*types.RemoteRequestResult{},
|
||||
}
|
||||
|
||||
downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false)
|
||||
downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false)
|
||||
v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||
v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
||||
makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false),
|
||||
makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
// v1 client endpoints requiring auth
|
||||
|
@ -110,15 +111,15 @@ func Setup(
|
|||
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v1mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
||||
httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
|
||||
httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
// same, but for federation
|
||||
v1fedMux.Handle("/download/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
|
||||
makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true),
|
||||
makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true),
|
||||
)).Methods(http.MethodGet, http.MethodOptions)
|
||||
v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
|
||||
makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true),
|
||||
makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true),
|
||||
)).Methods(http.MethodGet, http.MethodOptions)
|
||||
}
|
||||
|
||||
|
@ -170,6 +171,7 @@ func makeDownloadAPI(
|
|||
rateLimits *httputil.RateLimits,
|
||||
db storage.Database,
|
||||
client *fclient.Client,
|
||||
fedClient fclient.FederationClient,
|
||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||
forFederation bool,
|
||||
|
@ -235,6 +237,7 @@ func makeDownloadAPI(
|
|||
cfg,
|
||||
db,
|
||||
client,
|
||||
fedClient,
|
||||
activeRemoteRequests,
|
||||
activeThumbnailGeneration,
|
||||
strings.HasPrefix(name, "thumbnail"),
|
||||
|
|
|
@ -78,7 +78,7 @@ func (m *Monolith) AddAllPublicRoutes(
|
|||
federationapi.AddPublicRoutes(
|
||||
processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics,
|
||||
)
|
||||
mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.KeyRing)
|
||||
mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.FedClient, m.KeyRing)
|
||||
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics)
|
||||
|
||||
if m.RelayAPI != nil {
|
||||
|
|
Loading…
Reference in a new issue