Add support for getting authenticated federation media requests

This commit is contained in:
Till Faelligen 2024-06-18 15:15:56 +02:00
parent 29ee5401ee
commit 89d64c6dca
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
4 changed files with 69 additions and 19 deletions

View file

@ -33,6 +33,7 @@ func AddPublicRoutes(
cfg *config.Dendrite, cfg *config.Dendrite,
userAPI userapi.MediaUserAPI, userAPI userapi.MediaUserAPI,
client *fclient.Client, client *fclient.Client,
fedClient fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, keyRing gomatrixserverlib.JSONVerifier,
) { ) {
mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database) mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database)
@ -41,6 +42,6 @@ func AddPublicRoutes(
} }
routing.Setup( routing.Setup(
routers, cfg, mediaDB, userAPI, client, keyRing, routers, cfg, mediaDB, userAPI, client, fedClient, keyRing,
) )
} }

View file

@ -65,6 +65,8 @@ type downloadRequest struct {
Logger *log.Entry Logger *log.Entry
DownloadFilename string DownloadFilename string
forFederation bool // whether we need to return a multipart/mixed response forFederation bool // whether we need to return a multipart/mixed response
fedClient fclient.FederationClient
origin spec.ServerName
} }
// Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84 // Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84
@ -115,6 +117,7 @@ func Download(
cfg *config.MediaAPI, cfg *config.MediaAPI,
db storage.Database, db storage.Database,
client *fclient.Client, client *fclient.Client,
fedClient fclient.FederationClient,
activeRemoteRequests *types.ActiveRemoteRequests, activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration, activeThumbnailGeneration *types.ActiveThumbnailGeneration,
isThumbnailRequest bool, isThumbnailRequest bool,
@ -137,6 +140,8 @@ func Download(
}), }),
DownloadFilename: customFilename, DownloadFilename: customFilename,
forFederation: forFederation, forFederation: forFederation,
origin: cfg.Matrix.ServerName,
fedClient: fedClient,
} }
if dReq.IsThumbnailRequest { if dReq.IsThumbnailRequest {
@ -773,8 +778,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
return nil return nil
} }
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) { func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, reader io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
reader := *body
var contentLength int64 var contentLength int64
if contentLengthHeader != "" { if contentLengthHeader != "" {
@ -793,7 +797,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
// We successfully parsed the Content-Length, so we'll return a limited // We successfully parsed the Content-Length, so we'll return a limited
// reader that restricts us to reading only up to this size. // reader that restricts us to reading only up to this size.
reader = io.NopCloser(io.LimitReader(*body, parsedLength)) reader = io.NopCloser(io.LimitReader(reader, parsedLength))
contentLength = parsedLength contentLength = parsedLength
} else { } else {
// Content-Length header is missing. If we have a maximum file size // Content-Length header is missing. If we have a maximum file size
@ -802,7 +806,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
// ultimately it will get rewritten later when the temp file is written // ultimately it will get rewritten later when the temp file is written
// to disk. // to disk.
if maxFileSizeBytes > 0 { if maxFileSizeBytes > 0 {
reader = io.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes))) reader = io.NopCloser(io.LimitReader(reader, int64(maxFileSizeBytes)))
} }
contentLength = 0 contentLength = 0
} }
@ -818,19 +822,61 @@ func (r *downloadRequest) fetchRemoteFile(
) (types.Path, bool, error) { ) (types.Path, bool, error) {
r.Logger.Debug("Fetching remote file") r.Logger.Debug("Fetching remote file")
// Attempt to download via authenticated media endpoint
isMultiPart := true
resp, err := r.fedClient.DownloadMedia(ctx, r.origin, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
isMultiPart = false
// try again on the unauthed endpoint
// create request for remote file // create request for remote file
resp, err := client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) resp, err = client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) { if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
if resp != nil && resp.StatusCode == http.StatusNotFound { if resp != nil && resp.StatusCode == http.StatusNotFound {
return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
} }
return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
} }
}
defer resp.Body.Close() // nolint: errcheck defer resp.Body.Close() // nolint: errcheck
var contentLength int64
var reader io.Reader
var parseErr error
if isMultiPart {
r.Logger.Info("Downloaded file using authenticated endpoint")
_, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil {
panic(err)
}
if params["boundary"] == "" {
return "", false, fmt.Errorf("no boundary header found on %s", r.MediaMetadata.Origin)
}
mr := multipart.NewReader(resp.Body, params["boundary"])
first := true
for {
p, err := mr.NextPart()
if err == io.EOF {
break
}
if err != nil {
return "", false, err
}
if !first {
readCloser := io.NopCloser(p)
contentLength, reader, parseErr = r.GetContentLengthAndReader(p.Header.Get("Content-Length"), readCloser, maxFileSizeBytes)
break
}
first = false
}
} else {
// The reader returned here will be limited either by the Content-Length // The reader returned here will be limited either by the Content-Length
// and/or the configured maximum media size. // and/or the configured maximum media size.
contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes) contentLength, reader, parseErr = r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes)
}
if parseErr != nil { if parseErr != nil {
return "", false, parseErr return "", false, parseErr
} }

View file

@ -52,6 +52,7 @@ func Setup(
db storage.Database, db storage.Database,
userAPI userapi.MediaUserAPI, userAPI userapi.MediaUserAPI,
client *fclient.Client, client *fclient.Client,
federationClient fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, keyRing gomatrixserverlib.JSONVerifier,
) { ) {
rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting) rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting)
@ -95,12 +96,12 @@ func Setup(
MXCToResult: map[string]*types.RemoteRequestResult{}, MXCToResult: map[string]*types.RemoteRequestResult{},
} }
downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false) downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false)
v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thumbnail/{serverName}/{mediaId}", v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false), makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// v1 client endpoints requiring auth // v1 client endpoints requiring auth
@ -110,15 +111,15 @@ func Setup(
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/thumbnail/{serverName}/{mediaId}", v1mux.Handle("/thumbnail/{serverName}/{mediaId}",
httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()), httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// same, but for federation // same, but for federation
v1fedMux.Handle("/download/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing, v1fedMux.Handle("/download/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true), makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true),
)).Methods(http.MethodGet, http.MethodOptions) )).Methods(http.MethodGet, http.MethodOptions)
v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing, v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true), makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true),
)).Methods(http.MethodGet, http.MethodOptions) )).Methods(http.MethodGet, http.MethodOptions)
} }
@ -170,6 +171,7 @@ func makeDownloadAPI(
rateLimits *httputil.RateLimits, rateLimits *httputil.RateLimits,
db storage.Database, db storage.Database,
client *fclient.Client, client *fclient.Client,
fedClient fclient.FederationClient,
activeRemoteRequests *types.ActiveRemoteRequests, activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration, activeThumbnailGeneration *types.ActiveThumbnailGeneration,
forFederation bool, forFederation bool,
@ -235,6 +237,7 @@ func makeDownloadAPI(
cfg, cfg,
db, db,
client, client,
fedClient,
activeRemoteRequests, activeRemoteRequests,
activeThumbnailGeneration, activeThumbnailGeneration,
strings.HasPrefix(name, "thumbnail"), strings.HasPrefix(name, "thumbnail"),

View file

@ -78,7 +78,7 @@ func (m *Monolith) AddAllPublicRoutes(
federationapi.AddPublicRoutes( federationapi.AddPublicRoutes(
processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics, processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics,
) )
mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.KeyRing) mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.FedClient, m.KeyRing)
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics) syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics)
if m.RelayAPI != nil { if m.RelayAPI != nil {