Implement MSC3916 (#3397)

Needs https://github.com/matrix-org/gomatrixserverlib/pull/437
This commit is contained in:
Till 2024-08-16 12:37:59 +02:00 committed by GitHub
parent 8c6cf51b8f
commit 7a4ef240fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 364 additions and 45 deletions

View file

@ -15,23 +15,26 @@
package mediaapi
import (
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/routing"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/sirupsen/logrus"
)
// AddPublicRoutes sets up and registers HTTP handlers for the MediaAPI component.
func AddPublicRoutes(
mediaRouter *mux.Router,
routers httputil.Routers,
cm *sqlutil.Connections,
cfg *config.Dendrite,
userAPI userapi.MediaUserAPI,
client *fclient.Client,
fedClient fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier,
) {
mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database)
if err != nil {
@ -39,6 +42,6 @@ func AddPublicRoutes(
}
routing.Setup(
mediaRouter, cfg, mediaDB, userAPI, client,
routers, cfg, mediaDB, userAPI, client, fedClient, keyRing,
)
}

View file

@ -21,7 +21,9 @@ import (
"io"
"io/fs"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"net/url"
"os"
"path/filepath"
@ -31,6 +33,7 @@ import (
"sync"
"unicode"
"github.com/google/uuid"
"github.com/matrix-org/dendrite/mediaapi/fileutils"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/thumbnailer"
@ -61,6 +64,9 @@ type downloadRequest struct {
ThumbnailSize types.ThumbnailSize
Logger *log.Entry
DownloadFilename string
multipartResponse bool // whether we need to return a multipart/mixed response (for requests coming in over federation)
fedClient fclient.FederationClient
origin spec.ServerName
}
// Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84
@ -111,11 +117,17 @@ func Download(
cfg *config.MediaAPI,
db storage.Database,
client *fclient.Client,
fedClient fclient.FederationClient,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
isThumbnailRequest bool,
customFilename string,
federationRequest bool,
) {
// This happens if we call Download for a federation request
if federationRequest && origin == "" {
origin = cfg.Matrix.ServerName
}
dReq := &downloadRequest{
MediaMetadata: &types.MediaMetadata{
MediaID: mediaID,
@ -126,7 +138,10 @@ func Download(
"Origin": origin,
"MediaID": mediaID,
}),
DownloadFilename: customFilename,
DownloadFilename: customFilename,
multipartResponse: federationRequest,
origin: cfg.Matrix.ServerName,
fedClient: fedClient,
}
if dReq.IsThumbnailRequest {
@ -355,7 +370,7 @@ func (r *downloadRequest) respondFromLocalFile(
}).Trace("Responding with file")
responseFile = file
responseMetadata = r.MediaMetadata
if err := r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil {
if err = r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil {
return nil, err
}
}
@ -367,14 +382,61 @@ func (r *downloadRequest) respondFromLocalFile(
" plugin-types application/pdf;" +
" style-src 'unsafe-inline';" +
" object-src 'self';"
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
if _, err := io.Copy(w, responseFile); err != nil {
return nil, fmt.Errorf("io.Copy: %w", err)
if !r.multipartResponse {
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
if _, err = io.Copy(w, responseFile); err != nil {
return nil, fmt.Errorf("io.Copy: %w", err)
}
} else {
var written int64
written, err = multipartResponse(w, r, string(responseMetadata.ContentType), responseFile)
if err != nil {
return nil, err
}
responseMetadata.FileSizeBytes = types.FileSizeBytes(written)
}
return responseMetadata, nil
}
func multipartResponse(w http.ResponseWriter, r *downloadRequest, contentType string, responseFile io.Reader) (int64, error) {
// Update the header to be multipart/mixed; boundary=$randomBoundary
boundary := uuid.NewString()
w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary)
w.Header().Del("Content-Length") // let Go handle the content length
mw := multipart.NewWriter(w)
defer func() {
if err := mw.Close(); err != nil {
r.Logger.WithError(err).Error("Failed to close multipart writer")
}
}()
if err := mw.SetBoundary(boundary); err != nil {
return 0, fmt.Errorf("failed to set multipart boundary: %w", err)
}
// JSON object part
jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"application/json"},
})
if err != nil {
return 0, fmt.Errorf("failed to create json writer: %w", err)
}
if _, err = jsonWriter.Write([]byte("{}")); err != nil {
return 0, fmt.Errorf("failed to write to json writer: %w", err)
}
// media part
mediaWriter, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {contentType},
})
if err != nil {
return 0, fmt.Errorf("failed to create media writer: %w", err)
}
return io.Copy(mediaWriter, responseFile)
}
func (r *downloadRequest) addDownloadFilenameToHeaders(
w http.ResponseWriter,
responseMetadata *types.MediaMetadata,
@ -722,8 +784,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
return nil
}
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
reader := *body
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, reader io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
var contentLength int64
if contentLengthHeader != "" {
@ -742,7 +803,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
// We successfully parsed the Content-Length, so we'll return a limited
// reader that restricts us to reading only up to this size.
reader = io.NopCloser(io.LimitReader(*body, parsedLength))
reader = io.NopCloser(io.LimitReader(reader, parsedLength))
contentLength = parsedLength
} else {
// Content-Length header is missing. If we have a maximum file size
@ -751,7 +812,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
// ultimately it will get rewritten later when the temp file is written
// to disk.
if maxFileSizeBytes > 0 {
reader = io.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
reader = io.NopCloser(io.LimitReader(reader, int64(maxFileSizeBytes)))
}
contentLength = 0
}
@ -759,6 +820,11 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
return contentLength, reader, nil
}
// mediaMeta contains information about a multipart media response.
// TODO: extend once something is defined.
type mediaMeta struct{}
// nolint: gocyclo
func (r *downloadRequest) fetchRemoteFile(
ctx context.Context,
client *fclient.Client,
@ -767,19 +833,38 @@ func (r *downloadRequest) fetchRemoteFile(
) (types.Path, bool, error) {
r.Logger.Debug("Fetching remote file")
// create request for remote file
resp, err := client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
// Attempt to download via authenticated media endpoint
isAuthed := true
resp, err := r.fedClient.DownloadMedia(ctx, r.origin, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
if resp != nil && resp.StatusCode == http.StatusNotFound {
return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
isAuthed = false
// try again on the unauthed endpoint
// create request for remote file
resp, err = client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
if resp != nil && resp.StatusCode == http.StatusNotFound {
return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
}
return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s: %w", r.MediaMetadata.MediaID, r.MediaMetadata.Origin, err)
}
return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
}
defer resp.Body.Close() // nolint: errcheck
// The reader returned here will be limited either by the Content-Length
// and/or the configured maximum media size.
contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes)
// If this wasn't a multipart response, set the Content-Type now. Will be overwritten
// by the multipart Content-Type below.
r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type"))
var contentLength int64
var reader io.Reader
var parseErr error
if isAuthed {
parseErr, contentLength, reader = parseMultipartResponse(r, resp, maxFileSizeBytes)
} else {
// The reader returned here will be limited either by the Content-Length
// and/or the configured maximum media size.
contentLength, reader, parseErr = r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes)
}
if parseErr != nil {
return "", false, parseErr
}
@ -790,7 +875,6 @@ func (r *downloadRequest) fetchRemoteFile(
}
r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength)
r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type"))
dispositionHeader := resp.Header.Get("Content-Disposition")
if _, params, e := mime.ParseMediaType(dispositionHeader); e == nil {
@ -844,6 +928,50 @@ func (r *downloadRequest) fetchRemoteFile(
return types.Path(finalPath), duplicate, nil
}
func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSizeBytes config.FileSizeBytes) (error, int64, io.Reader) {
_, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil {
return err, 0, nil
}
if params["boundary"] == "" {
return fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin), 0, nil
}
mr := multipart.NewReader(resp.Body, params["boundary"])
// Get the first, JSON, part
p, err := mr.NextPart()
if err != nil {
return err, 0, nil
}
defer p.Close() // nolint: errcheck
if p.Header.Get("Content-Type") != "application/json" {
return fmt.Errorf("first part of the response must be application/json"), 0, nil
}
// Try to parse media meta information
meta := mediaMeta{}
if err = json.NewDecoder(p).Decode(&meta); err != nil {
return err, 0, nil
}
defer p.Close() // nolint: errcheck
// Get the actual media content
p, err = mr.NextPart()
if err != nil {
return err, 0, nil
}
redirect := p.Header.Get("Location")
if redirect != "" {
return fmt.Errorf("Location header is not yet supported"), 0, nil
}
contentLength, reader, err := r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes)
// For multipart requests, we need to get the Content-Type of the second part, which is the actual media
r.MediaMetadata.ContentType = types.ContentType(p.Header.Get("Content-Type"))
return err, contentLength, reader
}
// contentDispositionFor returns the Content-Disposition for a given
// content type.
func contentDispositionFor(contentType types.ContentType) string {

View file

@ -1,8 +1,13 @@
package routing
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/stretchr/testify/assert"
)
@ -11,3 +16,28 @@ func Test_dispositionFor(t *testing.T) {
assert.Equal(t, "attachment", contentDispositionFor("image/svg"), "image/svg")
assert.Equal(t, "inline", contentDispositionFor("image/jpeg"), "image/jpg")
}
func Test_Multipart(t *testing.T) {
r := &downloadRequest{
MediaMetadata: &types.MediaMetadata{},
}
data := bytes.Buffer{}
responseBody := "This media is plain text. Maybe somebody used it as a paste bin."
data.WriteString(responseBody)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, err := multipartResponse(w, r, "text/plain", &data)
assert.NoError(t, err)
}))
defer srv.Close()
resp, err := srv.Client().Get(srv.URL)
assert.NoError(t, err)
defer resp.Body.Close()
// contentLength is always 0, since there's no Content-Length header on the multipart part.
err, _, reader := parseMultipartResponse(r, resp, 1000)
assert.NoError(t, err)
gotResponse, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, responseBody, string(gotResponse))
}

View file

@ -20,11 +20,13 @@ import (
"strings"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
@ -45,15 +47,19 @@ type configResponse struct {
// applied:
// nolint: gocyclo
func Setup(
publicAPIMux *mux.Router,
routers httputil.Routers,
cfg *config.Dendrite,
db storage.Database,
userAPI userapi.MediaUserAPI,
client *fclient.Client,
federationClient fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier,
) {
rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting)
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter()
v3mux := routers.Media.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter()
v1mux := routers.Client.PathPrefix("/v1/media/").Subrouter()
v1fedMux := routers.Federation.PathPrefix("/v1/media/").Subrouter()
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
PathToResult: map[string]*types.ThumbnailGenerationResult{},
@ -90,33 +96,103 @@ func Setup(
MXCToResult: map[string]*types.RemoteRequestResult{},
}
downloadHandler := makeDownloadAPI("download", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration)
downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false)
v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
makeDownloadAPI("thumbnail", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration),
makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false),
).Methods(http.MethodGet, http.MethodOptions)
// v1 client endpoints requiring auth
downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth())
v1mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/thumbnail/{serverName}/{mediaId}",
httputil.MakeHTTPAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
).Methods(http.MethodGet, http.MethodOptions)
// same, but for federation
v1fedMux.Handle("/download/{mediaId}", routing.MakeFedHTTPAPI(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true),
)).Methods(http.MethodGet, http.MethodOptions)
v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedHTTPAPI(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true),
)).Methods(http.MethodGet, http.MethodOptions)
}
var thumbnailCounter = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "thumbnail",
Help: "Total number of media_api requests for thumbnails",
},
[]string{"code", "type"},
)
var thumbnailSize = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "thumbnail_size_bytes",
Help: "Total size of media_api requests for thumbnails",
Buckets: []float64{50, 100, 200, 500, 900, 1500, 3000, 6000},
},
[]string{"code", "type"},
)
var downloadCounter = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "download",
Help: "Total size of media_api requests for full downloads",
},
[]string{"code", "type"},
)
var downloadSize = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "mediaapi",
Name: "download_size_bytes",
Help: "Total size of media_api requests for full downloads",
Buckets: []float64{1500, 3000, 6000, 10_000, 50_000, 100_000},
},
[]string{"code", "type"},
)
func makeDownloadAPI(
name string,
cfg *config.MediaAPI,
rateLimits *httputil.RateLimits,
db storage.Database,
client *fclient.Client,
fedClient fclient.FederationClient,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
forFederation bool,
) http.HandlerFunc {
var counterVec *prometheus.CounterVec
var sizeVec *prometheus.HistogramVec
var requestType string
if cfg.Matrix.Metrics.Enabled {
counterVec = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: name,
Help: "Total number of media_api requests for either thumbnails or full downloads",
},
[]string{"code"},
)
split := strings.Split(name, "_")
// The first part of the split is either "download" or "thumbnail"
name = split[0]
// The remainder of the split is something like "authed_download" or "unauthed_thumbnail", etc.
// This is used to curry the metrics with the given types.
requestType = strings.Join(split[1:], "_")
counterVec = thumbnailCounter
sizeVec = thumbnailSize
if name != "thumbnail" {
counterVec = downloadCounter
sizeVec = downloadSize
}
}
httpHandler := func(w http.ResponseWriter, req *http.Request) {
req = util.RequestWithLogging(req)
@ -164,16 +240,21 @@ func makeDownloadAPI(
cfg,
db,
client,
fedClient,
activeRemoteRequests,
activeThumbnailGeneration,
name == "thumbnail",
strings.HasPrefix(name, "thumbnail"),
vars["downloadName"],
forFederation,
)
}
var handlerFunc http.HandlerFunc
if counterVec != nil {
counterVec = counterVec.MustCurryWith(prometheus.Labels{"type": requestType})
sizeVec2 := sizeVec.MustCurryWith(prometheus.Labels{"type": requestType})
handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler))
handlerFunc = promhttp.InstrumentHandlerResponseSize(sizeVec2, handlerFunc).ServeHTTP
} else {
handlerFunc = http.HandlerFunc(httpHandler)
}