Add support for getting authenticated federation media requests
This commit is contained in:
parent
29ee5401ee
commit
89d64c6dca
|
@ -33,6 +33,7 @@ func AddPublicRoutes(
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
userAPI userapi.MediaUserAPI,
|
userAPI userapi.MediaUserAPI,
|
||||||
client *fclient.Client,
|
client *fclient.Client,
|
||||||
|
fedClient fclient.FederationClient,
|
||||||
keyRing gomatrixserverlib.JSONVerifier,
|
keyRing gomatrixserverlib.JSONVerifier,
|
||||||
) {
|
) {
|
||||||
mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database)
|
mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database)
|
||||||
|
@ -41,6 +42,6 @@ func AddPublicRoutes(
|
||||||
}
|
}
|
||||||
|
|
||||||
routing.Setup(
|
routing.Setup(
|
||||||
routers, cfg, mediaDB, userAPI, client, keyRing,
|
routers, cfg, mediaDB, userAPI, client, fedClient, keyRing,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,6 +65,8 @@ type downloadRequest struct {
|
||||||
Logger *log.Entry
|
Logger *log.Entry
|
||||||
DownloadFilename string
|
DownloadFilename string
|
||||||
forFederation bool // whether we need to return a multipart/mixed response
|
forFederation bool // whether we need to return a multipart/mixed response
|
||||||
|
fedClient fclient.FederationClient
|
||||||
|
origin spec.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
// Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84
|
// Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84
|
||||||
|
@ -115,6 +117,7 @@ func Download(
|
||||||
cfg *config.MediaAPI,
|
cfg *config.MediaAPI,
|
||||||
db storage.Database,
|
db storage.Database,
|
||||||
client *fclient.Client,
|
client *fclient.Client,
|
||||||
|
fedClient fclient.FederationClient,
|
||||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||||
isThumbnailRequest bool,
|
isThumbnailRequest bool,
|
||||||
|
@ -137,6 +140,8 @@ func Download(
|
||||||
}),
|
}),
|
||||||
DownloadFilename: customFilename,
|
DownloadFilename: customFilename,
|
||||||
forFederation: forFederation,
|
forFederation: forFederation,
|
||||||
|
origin: cfg.Matrix.ServerName,
|
||||||
|
fedClient: fedClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
if dReq.IsThumbnailRequest {
|
if dReq.IsThumbnailRequest {
|
||||||
|
@ -773,8 +778,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
|
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, reader io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
|
||||||
reader := *body
|
|
||||||
var contentLength int64
|
var contentLength int64
|
||||||
|
|
||||||
if contentLengthHeader != "" {
|
if contentLengthHeader != "" {
|
||||||
|
@ -793,7 +797,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
|
||||||
|
|
||||||
// We successfully parsed the Content-Length, so we'll return a limited
|
// We successfully parsed the Content-Length, so we'll return a limited
|
||||||
// reader that restricts us to reading only up to this size.
|
// reader that restricts us to reading only up to this size.
|
||||||
reader = io.NopCloser(io.LimitReader(*body, parsedLength))
|
reader = io.NopCloser(io.LimitReader(reader, parsedLength))
|
||||||
contentLength = parsedLength
|
contentLength = parsedLength
|
||||||
} else {
|
} else {
|
||||||
// Content-Length header is missing. If we have a maximum file size
|
// Content-Length header is missing. If we have a maximum file size
|
||||||
|
@ -802,7 +806,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
|
||||||
// ultimately it will get rewritten later when the temp file is written
|
// ultimately it will get rewritten later when the temp file is written
|
||||||
// to disk.
|
// to disk.
|
||||||
if maxFileSizeBytes > 0 {
|
if maxFileSizeBytes > 0 {
|
||||||
reader = io.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
|
reader = io.NopCloser(io.LimitReader(reader, int64(maxFileSizeBytes)))
|
||||||
}
|
}
|
||||||
contentLength = 0
|
contentLength = 0
|
||||||
}
|
}
|
||||||
|
@ -818,19 +822,61 @@ func (r *downloadRequest) fetchRemoteFile(
|
||||||
) (types.Path, bool, error) {
|
) (types.Path, bool, error) {
|
||||||
r.Logger.Debug("Fetching remote file")
|
r.Logger.Debug("Fetching remote file")
|
||||||
|
|
||||||
|
// Attempt to download via authenticated media endpoint
|
||||||
|
isMultiPart := true
|
||||||
|
resp, err := r.fedClient.DownloadMedia(ctx, r.origin, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
|
||||||
|
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
|
||||||
|
isMultiPart = false
|
||||||
|
// try again on the unauthed endpoint
|
||||||
// create request for remote file
|
// create request for remote file
|
||||||
resp, err := client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
|
resp, err = client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
|
||||||
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
|
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
|
||||||
if resp != nil && resp.StatusCode == http.StatusNotFound {
|
if resp != nil && resp.StatusCode == http.StatusNotFound {
|
||||||
return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
|
return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
|
||||||
}
|
}
|
||||||
return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
|
return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
defer resp.Body.Close() // nolint: errcheck
|
defer resp.Body.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
var contentLength int64
|
||||||
|
var reader io.Reader
|
||||||
|
var parseErr error
|
||||||
|
if isMultiPart {
|
||||||
|
r.Logger.Info("Downloaded file using authenticated endpoint")
|
||||||
|
_, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if params["boundary"] == "" {
|
||||||
|
return "", false, fmt.Errorf("no boundary header found on %s", r.MediaMetadata.Origin)
|
||||||
|
}
|
||||||
|
mr := multipart.NewReader(resp.Body, params["boundary"])
|
||||||
|
|
||||||
|
first := true
|
||||||
|
for {
|
||||||
|
p, err := mr.NextPart()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !first {
|
||||||
|
readCloser := io.NopCloser(p)
|
||||||
|
contentLength, reader, parseErr = r.GetContentLengthAndReader(p.Header.Get("Content-Length"), readCloser, maxFileSizeBytes)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
first = false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
// The reader returned here will be limited either by the Content-Length
|
// The reader returned here will be limited either by the Content-Length
|
||||||
// and/or the configured maximum media size.
|
// and/or the configured maximum media size.
|
||||||
contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes)
|
contentLength, reader, parseErr = r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes)
|
||||||
|
}
|
||||||
|
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
return "", false, parseErr
|
return "", false, parseErr
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,6 +52,7 @@ func Setup(
|
||||||
db storage.Database,
|
db storage.Database,
|
||||||
userAPI userapi.MediaUserAPI,
|
userAPI userapi.MediaUserAPI,
|
||||||
client *fclient.Client,
|
client *fclient.Client,
|
||||||
|
federationClient fclient.FederationClient,
|
||||||
keyRing gomatrixserverlib.JSONVerifier,
|
keyRing gomatrixserverlib.JSONVerifier,
|
||||||
) {
|
) {
|
||||||
rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting)
|
rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting)
|
||||||
|
@ -95,12 +96,12 @@ func Setup(
|
||||||
MXCToResult: map[string]*types.RemoteRequestResult{},
|
MXCToResult: map[string]*types.RemoteRequestResult{},
|
||||||
}
|
}
|
||||||
|
|
||||||
downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false)
|
downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false)
|
||||||
v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||||
v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
||||||
makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false),
|
makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
// v1 client endpoints requiring auth
|
// v1 client endpoints requiring auth
|
||||||
|
@ -110,15 +111,15 @@ func Setup(
|
||||||
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
|
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v1mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
v1mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
||||||
httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
|
httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
// same, but for federation
|
// same, but for federation
|
||||||
v1fedMux.Handle("/download/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
|
v1fedMux.Handle("/download/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
|
||||||
makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true),
|
makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true),
|
||||||
)).Methods(http.MethodGet, http.MethodOptions)
|
)).Methods(http.MethodGet, http.MethodOptions)
|
||||||
v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
|
v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
|
||||||
makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true),
|
makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true),
|
||||||
)).Methods(http.MethodGet, http.MethodOptions)
|
)).Methods(http.MethodGet, http.MethodOptions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,6 +171,7 @@ func makeDownloadAPI(
|
||||||
rateLimits *httputil.RateLimits,
|
rateLimits *httputil.RateLimits,
|
||||||
db storage.Database,
|
db storage.Database,
|
||||||
client *fclient.Client,
|
client *fclient.Client,
|
||||||
|
fedClient fclient.FederationClient,
|
||||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||||
forFederation bool,
|
forFederation bool,
|
||||||
|
@ -235,6 +237,7 @@ func makeDownloadAPI(
|
||||||
cfg,
|
cfg,
|
||||||
db,
|
db,
|
||||||
client,
|
client,
|
||||||
|
fedClient,
|
||||||
activeRemoteRequests,
|
activeRemoteRequests,
|
||||||
activeThumbnailGeneration,
|
activeThumbnailGeneration,
|
||||||
strings.HasPrefix(name, "thumbnail"),
|
strings.HasPrefix(name, "thumbnail"),
|
||||||
|
|
|
@ -78,7 +78,7 @@ func (m *Monolith) AddAllPublicRoutes(
|
||||||
federationapi.AddPublicRoutes(
|
federationapi.AddPublicRoutes(
|
||||||
processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics,
|
processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics,
|
||||||
)
|
)
|
||||||
mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.KeyRing)
|
mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.FedClient, m.KeyRing)
|
||||||
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics)
|
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics)
|
||||||
|
|
||||||
if m.RelayAPI != nil {
|
if m.RelayAPI != nil {
|
||||||
|
|
Loading…
Reference in a new issue