Add support for authenticated federation media requests

This commit is contained in:
Till Faelligen 2024-06-18 12:36:35 +02:00
parent ccd337e1d4
commit 29ee5401ee
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
5 changed files with 183 additions and 17 deletions

View file

@ -16,6 +16,7 @@ package routing
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
@ -678,6 +679,53 @@ func MakeFedAPI(
return httputil.MakeExternalAPI(metricsName, h) return httputil.MakeExternalAPI(metricsName, h)
} }
// MakeFedAPI makes an http.Handler that checks matrix federation authentication.
func MakeFedAPIHTML(
serverName spec.ServerName,
isLocalServerName func(spec.ServerName) bool,
keyRing gomatrixserverlib.JSONVerifier,
f func(http.ResponseWriter, *http.Request),
) http.Handler {
h := func(w http.ResponseWriter, req *http.Request) {
fedReq, errResp := fclient.VerifyHTTPRequest(
req, time.Now(), serverName, isLocalServerName, keyRing,
)
enc := json.NewEncoder(w)
logger := util.GetLogger(req.Context())
if fedReq == nil {
logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, errResp.Code)
w.WriteHeader(errResp.Code)
if err := enc.Encode(errResp); err != nil {
logger.WithError(err).Error("failed to encode JSON response")
}
return
}
// add the user to Sentry, if enabled
hub := sentry.GetHubFromContext(req.Context())
if hub != nil {
// clone the hub, so we don't send garbage events with e.g. mismatching rooms/event_ids
hub = hub.Clone()
hub.Scope().SetTag("origin", string(fedReq.Origin()))
hub.Scope().SetTag("uri", fedReq.RequestURI())
}
defer func() {
if r := recover(); r != nil {
if hub != nil {
hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
}
// re-panic to return the 500
panic(r)
}
}()
f(w, req)
}
return http.HandlerFunc(h)
}
type FederationWakeups struct { type FederationWakeups struct {
FsAPI *fedInternal.FederationInternalAPI FsAPI *fedInternal.FederationInternalAPI
origins sync.Map origins sync.Map

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -32,6 +33,7 @@ func AddPublicRoutes(
cfg *config.Dendrite, cfg *config.Dendrite,
userAPI userapi.MediaUserAPI, userAPI userapi.MediaUserAPI,
client *fclient.Client, client *fclient.Client,
keyRing gomatrixserverlib.JSONVerifier,
) { ) {
mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database) mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database)
if err != nil { if err != nil {
@ -39,6 +41,6 @@ func AddPublicRoutes(
} }
routing.Setup( routing.Setup(
routers, cfg, mediaDB, userAPI, client, routers, cfg, mediaDB, userAPI, client, keyRing,
) )
} }

View file

@ -21,7 +21,9 @@ import (
"io" "io"
"io/fs" "io/fs"
"mime" "mime"
"mime/multipart"
"net/http" "net/http"
"net/textproto"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
@ -31,6 +33,7 @@ import (
"sync" "sync"
"unicode" "unicode"
"github.com/google/uuid"
"github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/fileutils"
"github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/thumbnailer" "github.com/matrix-org/dendrite/mediaapi/thumbnailer"
@ -61,6 +64,7 @@ type downloadRequest struct {
ThumbnailSize types.ThumbnailSize ThumbnailSize types.ThumbnailSize
Logger *log.Entry Logger *log.Entry
DownloadFilename string DownloadFilename string
forFederation bool // whether we need to return a multipart/mixed response
} }
// Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84 // Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84
@ -115,7 +119,12 @@ func Download(
activeThumbnailGeneration *types.ActiveThumbnailGeneration, activeThumbnailGeneration *types.ActiveThumbnailGeneration,
isThumbnailRequest bool, isThumbnailRequest bool,
customFilename string, customFilename string,
forFederation bool,
) { ) {
// This happens if we call Download for a federation request
if forFederation && origin == "" {
origin = cfg.Matrix.ServerName
}
dReq := &downloadRequest{ dReq := &downloadRequest{
MediaMetadata: &types.MediaMetadata{ MediaMetadata: &types.MediaMetadata{
MediaID: mediaID, MediaID: mediaID,
@ -127,6 +136,7 @@ func Download(
"MediaID": mediaID, "MediaID": mediaID,
}), }),
DownloadFilename: customFilename, DownloadFilename: customFilename,
forFederation: forFederation,
} }
if dReq.IsThumbnailRequest { if dReq.IsThumbnailRequest {
@ -369,8 +379,49 @@ func (r *downloadRequest) respondFromLocalFile(
" object-src 'self';" " object-src 'self';"
w.Header().Set("Content-Security-Policy", contentSecurityPolicy) w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
if _, err := io.Copy(w, responseFile); err != nil { if !r.forFederation {
return nil, fmt.Errorf("io.Copy: %w", err) if _, err := io.Copy(w, responseFile); err != nil {
return nil, fmt.Errorf("io.Copy: %w", err)
}
} else {
// Update the header to be multipart/mixed; boundary=$randomBoundary
boundary := uuid.NewString()
w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary)
w.Header().Del("Content-Length") // let Go handle the content length
w.Header().Del("Content-Security-Policy") // S-S request, so does not really matter?
mw := multipart.NewWriter(w)
defer func() {
if err = mw.Close(); err != nil {
r.Logger.WithError(err).Error("Failed to close multipart writer")
}
}()
if err = mw.SetBoundary(boundary); err != nil {
return nil, fmt.Errorf("failed to set multipart boundary: %w", err)
}
// JSON object part
jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"application/json"},
})
if err != nil {
return nil, fmt.Errorf("failed to create json writer: %w", err)
}
if _, err = jsonWriter.Write([]byte("{}")); err != nil {
return nil, fmt.Errorf("failed to write to json writer: %w", err)
}
// media part
mediaWriter, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {string(responseMetadata.ContentType)},
})
if err != nil {
return nil, fmt.Errorf("failed to create media writer: %w", err)
}
if _, err = io.Copy(mediaWriter, responseFile); err != nil {
return nil, fmt.Errorf("failed to write to media writer: %w", err)
}
} }
return responseMetadata, nil return responseMetadata, nil
} }

View file

@ -20,11 +20,13 @@ import (
"strings" "strings"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -50,11 +52,13 @@ func Setup(
db storage.Database, db storage.Database,
userAPI userapi.MediaUserAPI, userAPI userapi.MediaUserAPI,
client *fclient.Client, client *fclient.Client,
keyRing gomatrixserverlib.JSONVerifier,
) { ) {
rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting) rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting)
v3mux := routers.Media.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter() v3mux := routers.Media.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter()
v1mux := routers.Client.PathPrefix("/v1/media/").Subrouter() v1mux := routers.Client.PathPrefix("/v1/media/").Subrouter()
v1fedMux := routers.Federation.PathPrefix("/v1/media/").Subrouter()
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{ activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
PathToResult: map[string]*types.ThumbnailGenerationResult{}, PathToResult: map[string]*types.ThumbnailGenerationResult{},
@ -91,24 +95,75 @@ func Setup(
MXCToResult: map[string]*types.RemoteRequestResult{}, MXCToResult: map[string]*types.RemoteRequestResult{},
} }
downloadHandler := makeDownloadAPI("download", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration) downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false)
v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thumbnail/{serverName}/{mediaId}", v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
makeDownloadAPI("thumbnail", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration), makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// v1 client endpoints requiring auth // v1 client endpoints requiring auth
downloadHandlerAuthed := httputil.MakeHTMLAPI("download_authed_client", userAPI, cfg.Global.Metrics.Enabled, downloadHandler, httputil.WithAuth())
v1mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}", httputil.MakeHTMLAPI("download", userAPI, cfg.Global.Metrics.Enabled, downloadHandler, httputil.WithAuth())).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", httputil.MakeHTMLAPI("download", userAPI, cfg.Global.Metrics.Enabled, downloadHandler, httputil.WithAuth())).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/thumbnail/{serverName}/{mediaId}", v1mux.Handle("/thumbnail/{serverName}/{mediaId}",
httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration), httputil.WithAuth()), httputil.MakeHTMLAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// same, but for federation
v1fedMux.Handle("/download/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true),
)).Methods(http.MethodGet, http.MethodOptions)
v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedAPIHTML(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration, true),
)).Methods(http.MethodGet, http.MethodOptions)
} }
var thumbnailCounter = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "thumbnail",
Help: "Total number of media_api requests for thumbnails",
},
[]string{"code", "type"},
)
var thumbnailSize = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "thumbnail_size_bytes",
Help: "Total number of media_api requests for thumbnails",
Buckets: []float64{50, 100, 200, 500, 900, 1500, 3000, 6000},
},
[]string{"code", "type"},
)
var downloadCounter = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "download",
Help: "Total number of media_api requests for full downloads",
},
[]string{"code", "type"},
)
var downloadSize = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "download_size_bytes",
Help: "Total number of media_api requests for full downloads",
Buckets: []float64{200, 500, 900, 1500, 3000, 6000, 10_000, 50_000, 100_000},
},
[]string{"code", "type"},
)
func makeDownloadAPI( func makeDownloadAPI(
name string, name string,
cfg *config.MediaAPI, cfg *config.MediaAPI,
@ -117,16 +172,22 @@ func makeDownloadAPI(
client *fclient.Client, client *fclient.Client,
activeRemoteRequests *types.ActiveRemoteRequests, activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration, activeThumbnailGeneration *types.ActiveThumbnailGeneration,
forFederation bool,
) http.HandlerFunc { ) http.HandlerFunc {
var counterVec *prometheus.CounterVec var counterVec *prometheus.CounterVec
var sizeVec *prometheus.HistogramVec
var requestType string
if cfg.Matrix.Metrics.Enabled { if cfg.Matrix.Metrics.Enabled {
counterVec = promauto.NewCounterVec( split := strings.Split(name, "_")
prometheus.CounterOpts{ name = split[0]
Name: name, requestType = strings.Join(split[1:], "_")
Help: "Total number of media_api requests for either thumbnails or full downloads",
}, counterVec = thumbnailCounter
[]string{"code"}, sizeVec = thumbnailSize
) if name != "thumbnail" {
counterVec = downloadCounter
sizeVec = downloadSize
}
} }
httpHandler := func(w http.ResponseWriter, req *http.Request) { httpHandler := func(w http.ResponseWriter, req *http.Request) {
req = util.RequestWithLogging(req) req = util.RequestWithLogging(req)
@ -176,14 +237,18 @@ func makeDownloadAPI(
client, client,
activeRemoteRequests, activeRemoteRequests,
activeThumbnailGeneration, activeThumbnailGeneration,
name == "thumbnail", strings.HasPrefix(name, "thumbnail"),
vars["downloadName"], vars["downloadName"],
forFederation,
) )
} }
var handlerFunc http.HandlerFunc var handlerFunc http.HandlerFunc
if counterVec != nil { if counterVec != nil {
counterVec = counterVec.MustCurryWith(prometheus.Labels{"type": requestType})
sizeVec2 := sizeVec.MustCurryWith(prometheus.Labels{"type": requestType})
handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler))
handlerFunc = promhttp.InstrumentHandlerResponseSize(sizeVec2, handlerFunc).ServeHTTP
} else { } else {
handlerFunc = http.HandlerFunc(httpHandler) handlerFunc = http.HandlerFunc(httpHandler)
} }

View file

@ -78,7 +78,7 @@ func (m *Monolith) AddAllPublicRoutes(
federationapi.AddPublicRoutes( federationapi.AddPublicRoutes(
processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics, processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics,
) )
mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client) mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.KeyRing)
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics) syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics)
if m.RelayAPI != nil { if m.RelayAPI != nil {