Add support for authenticated federation media requests

This commit is contained in:
Till Faelligen 2024-06-18 12:36:35 +02:00
parent ccd337e1d4
commit 29ee5401ee
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
5 changed files with 183 additions and 17 deletions

View file

@ -16,6 +16,7 @@ package routing
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
@ -678,6 +679,53 @@ func MakeFedAPI(
return httputil.MakeExternalAPI(metricsName, h)
}
// MakeFedAPI makes an http.Handler that checks matrix federation authentication.
func MakeFedAPIHTML(
serverName spec.ServerName,
isLocalServerName func(spec.ServerName) bool,
keyRing gomatrixserverlib.JSONVerifier,
f func(http.ResponseWriter, *http.Request),
) http.Handler {
h := func(w http.ResponseWriter, req *http.Request) {
fedReq, errResp := fclient.VerifyHTTPRequest(
req, time.Now(), serverName, isLocalServerName, keyRing,
)
enc := json.NewEncoder(w)
logger := util.GetLogger(req.Context())
if fedReq == nil {
logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, errResp.Code)
w.WriteHeader(errResp.Code)
if err := enc.Encode(errResp); err != nil {
logger.WithError(err).Error("failed to encode JSON response")
}
return
}
// add the user to Sentry, if enabled
hub := sentry.GetHubFromContext(req.Context())
if hub != nil {
// clone the hub, so we don't send garbage events with e.g. mismatching rooms/event_ids
hub = hub.Clone()
hub.Scope().SetTag("origin", string(fedReq.Origin()))
hub.Scope().SetTag("uri", fedReq.RequestURI())
}
defer func() {
if r := recover(); r != nil {
if hub != nil {
hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
}
// re-panic to return the 500
panic(r)
}
}()
f(w, req)
}
return http.HandlerFunc(h)
}
type FederationWakeups struct {
FsAPI *fedInternal.FederationInternalAPI
origins sync.Map

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/sirupsen/logrus"
)
@ -32,6 +33,7 @@ func AddPublicRoutes(
cfg *config.Dendrite,
userAPI userapi.MediaUserAPI,
client *fclient.Client,
keyRing gomatrixserverlib.JSONVerifier,
) {
mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database)
if err != nil {
@ -39,6 +41,6 @@ func AddPublicRoutes(
}
routing.Setup(
routers, cfg, mediaDB, userAPI, client,
routers, cfg, mediaDB, userAPI, client, keyRing,
)
}

View file

@ -21,7 +21,9 @@ import (
"io"
"io/fs"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"net/url"
"os"
"path/filepath"
@ -31,6 +33,7 @@ import (
"sync"
"unicode"
"github.com/google/uuid"
"github.com/matrix-org/dendrite/mediaapi/fileutils"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/thumbnailer"
@ -61,6 +64,7 @@ type downloadRequest struct {
ThumbnailSize types.ThumbnailSize
Logger *log.Entry
DownloadFilename string
forFederation bool // whether we need to return a multipart/mixed response
}
// Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84
@ -115,7 +119,12 @@ func Download(
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
isThumbnailRequest bool,
customFilename string,
forFederation bool,
) {
// This happens if we call Download for a federation request
if forFederation && origin == "" {
origin = cfg.Matrix.ServerName
}
dReq := &downloadRequest{
MediaMetadata: &types.MediaMetadata{
MediaID: mediaID,
@ -127,6 +136,7 @@ func Download(
"MediaID": mediaID,
}),
DownloadFilename: customFilename,
forFederation: forFederation,
}
if dReq.IsThumbnailRequest {
@ -369,8 +379,49 @@ func (r *downloadRequest) respondFromLocalFile(
" object-src 'self';"
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
if _, err := io.Copy(w, responseFile); err != nil {
return nil, fmt.Errorf("io.Copy: %w", err)
if !r.forFederation {
if _, err := io.Copy(w, responseFile); err != nil {
return nil, fmt.Errorf("io.Copy: %w", err)
}
} else {
// Update the header to be multipart/mixed; boundary=$randomBoundary
boundary := uuid.NewString()
w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary)
w.Header().Del("Content-Length") // let Go handle the content length
w.Header().Del("Content-Security-Policy") // S-S request, so does not really matter?
mw := multipart.NewWriter(w)
defer func() {
if err = mw.Close(); err != nil {
r.Logger.WithError(err).Error("Failed to close multipart writer")
}
}()
if err = mw.SetBoundary(boundary); err != nil {
return nil, fmt.Errorf("failed to set multipart boundary: %w", err)
}
// JSON object part
jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"application/json"},
})
if err != nil {
return nil, fmt.Errorf("failed to create json writer: %w", err)
}
if _, err = jsonWriter.Write([]byte("{}")); err != nil {
return nil, fmt.Errorf("failed to write to json writer: %w", err)
}
// media part
mediaWriter, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {string(responseMetadata.ContentType)},
})
if err != nil {
return nil, fmt.Errorf("failed to create media writer: %w", err)
}
if _, err = io.Copy(mediaWriter, responseFile); err != nil {
return nil, fmt.Errorf("failed to write to media writer: %w", err)
}
}
return responseMetadata, nil
}

View file

@ -20,11 +20,13 @@ import (
"strings"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
@ -50,11 +52,13 @@ func Setup(
db storage.Database,
userAPI userapi.MediaUserAPI,
client *fclient.Client,
keyRing gomatrixserverlib.JSONVerifier,
) {
rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting)
v3mux := routers.Media.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter()
v1mux := routers.Client.PathPrefix("/v1/media/").Subrouter()
v1fedMux := routers.Federation.PathPrefix("/v1/media/").Subrouter()
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
PathToResult: map[string]*types.ThumbnailGenerationResult{},
@ -91,24 +95,75 @@ func Setup(
MXCToResult: map[string]*types.RemoteRequestResult{},
}
downloadHandler := makeDownloadAPI("download", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration)
downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false)
v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
makeDownloadAPI("thumbnail", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration),
makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false),
).Methods(http.MethodGet, http.MethodOptions)
// v1 client endpoints requiring auth
downloadHandlerAuthed := httputil.MakeHTMLAPI("download_authed_client", userAPI, cfg.Global.Metrics.Enabled, downloadHandler, httputil.WithAuth())
v1mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}", httputil.MakeHTMLAPI("download", userAPI, cfg.Global.Metrics.Enabled, downloadHandler, httputil.WithAuth())).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", httputil.MakeHTMLAPI("download", userAPI, cfg.Global.Metrics.Enabled, downloadHandler, httputil.WithAuth())).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/thumbnail/{serverName}/{mediaId}",
httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration), httputil.WithAuth()),
httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
).Methods(http.MethodGet, http.MethodOptions)
// same, but for federation
v1fedMux.Handle("/download/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true),
)).Methods(http.MethodGet, http.MethodOptions)
v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true),
)).Methods(http.MethodGet, http.MethodOptions)
}
var thumbnailCounter = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "thumbnail",
Help: "Total number of media_api requests for thumbnails",
},
[]string{"code", "type"},
)
var thumbnailSize = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "thumbnail_size_bytes",
Help: "Total number of media_api requests for thumbnails",
Buckets: []float64{50, 100, 200, 500, 900, 1500, 3000, 6000},
},
[]string{"code", "type"},
)
var downloadCounter = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "download",
Help: "Total number of media_api requests for full downloads",
},
[]string{"code", "type"},
)
var downloadSize = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "download_size_bytes",
Help: "Total number of media_api requests for full downloads",
Buckets: []float64{200, 500, 900, 1500, 3000, 6000, 10_000, 50_000, 100_000},
},
[]string{"code", "type"},
)
func makeDownloadAPI(
name string,
cfg *config.MediaAPI,
@ -117,16 +172,22 @@ func makeDownloadAPI(
client *fclient.Client,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
forFederation bool,
) http.HandlerFunc {
var counterVec *prometheus.CounterVec
var sizeVec *prometheus.HistogramVec
var requestType string
if cfg.Matrix.Metrics.Enabled {
counterVec = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: name,
Help: "Total number of media_api requests for either thumbnails or full downloads",
},
[]string{"code"},
)
split := strings.Split(name, "_")
name = split[0]
requestType = strings.Join(split[1:], "_")
counterVec = thumbnailCounter
sizeVec = thumbnailSize
if name != "thumbnail" {
counterVec = downloadCounter
sizeVec = downloadSize
}
}
httpHandler := func(w http.ResponseWriter, req *http.Request) {
req = util.RequestWithLogging(req)
@ -176,14 +237,18 @@ func makeDownloadAPI(
client,
activeRemoteRequests,
activeThumbnailGeneration,
name == "thumbnail",
strings.HasPrefix(name, "thumbnail"),
vars["downloadName"],
forFederation,
)
}
var handlerFunc http.HandlerFunc
if counterVec != nil {
counterVec = counterVec.MustCurryWith(prometheus.Labels{"type": requestType})
sizeVec2 := sizeVec.MustCurryWith(prometheus.Labels{"type": requestType})
handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler))
handlerFunc = promhttp.InstrumentHandlerResponseSize(sizeVec2, handlerFunc).ServeHTTP
} else {
handlerFunc = http.HandlerFunc(httpHandler)
}

View file

@ -78,7 +78,7 @@ func (m *Monolith) AddAllPublicRoutes(
federationapi.AddPublicRoutes(
processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics,
)
mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client)
mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.KeyRing)
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics)
if m.RelayAPI != nil {