Skip to content

Commit

Permalink
Add support for MSC3916
Browse files Browse the repository at this point in the history
  • Loading branch information
turt2live committed Nov 25, 2023
1 parent b2f3c30 commit 16628b1
Show file tree
Hide file tree
Showing 15 changed files with 685 additions and 17 deletions.
4 changes: 4 additions & 0 deletions api/_apimeta/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ type UserInfo struct {
IsShared bool
}

type ServerInfo struct {
ServerName string
}

func GetRequestUserAdminStatus(r *http.Request, rctx rcontext.RequestContext, user UserInfo) (bool, bool) {
isGlobalAdmin := util.IsGlobalAdmin(user.UserId) || user.IsShared
isLocalAdmin, err := matrix.IsUserAdmin(rctx, r.Host, user.AccessToken, r.RemoteAddr)
Expand Down
2 changes: 1 addition & 1 deletion api/_auth_cache/auth_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/turt2live/matrix-media-repo/matrix"
)

var tokenCache = cache.New(0*time.Second, 30*time.Second)
var tokenCache = cache.New(cache.NoExpiration, 30*time.Second)
var rwLock = &sync.RWMutex{}
var regexCache = make(map[string]*regexp.Regexp)

Expand Down
30 changes: 30 additions & 0 deletions api/_routers/97-require-server-auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package _routers

import (
"net/http"

"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/matrix"
)

type GeneratorWithServerFn = func(r *http.Request, ctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{}

func RequireServerAuth(generator GeneratorWithServerFn) GeneratorFn {
return func(r *http.Request, ctx rcontext.RequestContext) interface{} {
serverName, err := matrix.ValidateXMatrixAuth(r, true)
if err != nil {
ctx.Log.Debug("Error with X-Matrix auth: ", err)
return &_responses.ErrorResponse{
Code: common.ErrCodeForbidden,
Message: "no auth provided (required)",
InternalCode: common.ErrCodeMissingToken,
}
}
return generator(r, ctx, _apimeta.ServerInfo{
ServerName: serverName,
})
}
}
32 changes: 18 additions & 14 deletions api/_routers/98-use-rcontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,24 @@ func (c *RContextRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
beforeParseDownload:
log.Infof("Replying with result: %T %+v", res, res)
if downloadRes, isDownload := res.(*_responses.DownloadResponse); isDownload {
ranges, err := http_range.ParseRange(r.Header.Get("Range"), downloadRes.SizeBytes, rctx.Config.Downloads.DefaultRangeChunkSizeBytes)
if errors.Is(err, http_range.ErrInvalid) {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("invalid range header")
goto beforeParseDownload // reprocess `res`
} else if errors.Is(err, http_range.ErrNoOverlap) {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("out of range")
goto beforeParseDownload // reprocess `res`
}
if len(ranges) > 1 {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("only 1 range is supported")
goto beforeParseDownload // reprocess `res`
var ranges []http_range.Range
var err error
if downloadRes.SizeBytes > 0 {
ranges, err = http_range.ParseRange(r.Header.Get("Range"), downloadRes.SizeBytes, rctx.Config.Downloads.DefaultRangeChunkSizeBytes)
if errors.Is(err, http_range.ErrInvalid) {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("invalid range header")
goto beforeParseDownload // reprocess `res`
} else if errors.Is(err, http_range.ErrNoOverlap) {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("out of range")
goto beforeParseDownload // reprocess `res`
}
if len(ranges) > 1 {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("only 1 range is supported")
goto beforeParseDownload // reprocess `res`
}
}

contentType = "application/octet-stream"
Expand Down
3 changes: 3 additions & 0 deletions api/custom/federation.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ func GetFederationInfo(r *http.Request, rctx rcontext.RequestContext, user _apim

versionUrl := url + "/_matrix/federation/v1/version"
versionResponse, err := matrix.FederatedGet(versionUrl, hostname, rctx)
if versionResponse != nil {
defer versionResponse.Body.Close()
}
if err != nil {
rctx.Log.Error(err)
sentry.CaptureException(err)
Expand Down
18 changes: 16 additions & 2 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

const PrefixMedia = "/_matrix/media"
const PrefixClient = "/_matrix/client"
const PrefixFederation = "/_matrix/federation"

func buildRoutes() http.Handler {
counter := &_routers.RequestCounter{}
Expand All @@ -36,13 +37,25 @@ func buildRoutes() http.Handler {
register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId/:filename", mxSpecV3Transition, router, downloadRoute)
register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId", mxSpecV3Transition, router, downloadRoute)
register([]string{"GET"}, PrefixMedia, "thumbnail/:server/:mediaId", mxSpecV3Transition, router, makeRoute(_routers.OptionalAccessToken(r0.ThumbnailMedia), "thumbnail", counter))
register([]string{"GET"}, PrefixMedia, "preview_url", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.PreviewUrl), "url_preview", counter))
previewUrlRoute := makeRoute(_routers.RequireAccessToken(r0.PreviewUrl), "url_preview", counter)
register([]string{"GET"}, PrefixMedia, "preview_url", mxSpecV3TransitionCS, router, previewUrlRoute)
register([]string{"GET"}, PrefixMedia, "identicon/*seed", mxR0, router, makeRoute(_routers.OptionalAccessToken(r0.Identicon), "identicon", counter))
register([]string{"GET"}, PrefixMedia, "config", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.PublicConfig), "config", counter))
configRoute := makeRoute(_routers.RequireAccessToken(r0.PublicConfig), "config", counter)
register([]string{"GET"}, PrefixMedia, "config", mxSpecV3TransitionCS, router, configRoute)
register([]string{"POST"}, PrefixClient, "logout", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.Logout), "logout", counter))
register([]string{"POST"}, PrefixClient, "logout/all", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.LogoutAll), "logout_all", counter))
register([]string{"POST"}, PrefixMedia, "create", mxV1, router, makeRoute(_routers.RequireAccessToken(v1.CreateMedia), "create", counter))

// MSC3916 - Authentication & endpoint API separation
register([]string{"GET"}, PrefixClient, "media/preview_url", msc3916, router, previewUrlRoute)
register([]string{"GET"}, PrefixClient, "media/config", msc3916, router, configRoute)
authedDownloadRoute := makeRoute(_routers.RequireAccessToken(unstable.ClientDownloadMedia), "download", counter)
register([]string{"GET"}, PrefixClient, "media/download/:server/:mediaId/:filename", msc3916, router, authedDownloadRoute)
register([]string{"GET"}, PrefixClient, "media/download/:server/:mediaId", msc3916, router, authedDownloadRoute)
register([]string{"GET"}, PrefixClient, "media/thumbnail/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireAccessToken(r0.ThumbnailMedia), "thumbnail", counter))
register([]string{"GET"}, PrefixFederation, "media/download/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationDownloadMedia), "download", counter))
register([]string{"GET"}, PrefixFederation, "media/thumbnail/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationThumbnailMedia), "thumbnail", counter))

// Custom features
register([]string{"GET"}, PrefixMedia, "local_copy/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.LocalCopy), "local_copy", counter))
register([]string{"GET"}, PrefixMedia, "info/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.MediaInfo), "info", counter))
Expand Down Expand Up @@ -129,6 +142,7 @@ var (
//mxAllSpec matrixVersions = []string{"r0", "v1", "v3", "unstable", "unstable/io.t2bot.media" /* and MSC routes */}
mxUnstable matrixVersions = []string{"unstable", "unstable/io.t2bot.media"}
msc4034 matrixVersions = []string{"unstable/org.matrix.msc4034"}
msc3916 matrixVersions = []string{"unstable/org.matrix.msc3916"}
mxSpecV3Transition matrixVersions = []string{"r0", "v1", "v3"}
mxSpecV3TransitionCS matrixVersions = []string{"r0", "v3"}
mxR0 matrixVersions = []string{"r0"}
Expand Down
37 changes: 37 additions & 0 deletions api/unstable/msc3916_download.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package unstable

import (
"bytes"
"net/http"

"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/r0"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util/readers"
)

func ClientDownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
r.URL.Query().Set("allow_remote", "true")
return r0.DownloadMedia(r, rctx, user)
}

func FederationDownloadMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} {
r.URL.Query().Set("allow_remote", "false")

res := r0.DownloadMedia(r, rctx, _apimeta.UserInfo{})
if dl, ok := res.(*_responses.DownloadResponse); ok {
return &_responses.DownloadResponse{
ContentType: "multipart/mixed",
Filename: "",
SizeBytes: 0,
Data: readers.NewMultipartReader(
&readers.MultipartPart{ContentType: "application/json", Reader: readers.MakeCloser(bytes.NewReader([]byte("{}")))},
&readers.MultipartPart{ContentType: dl.ContentType, FileName: dl.Filename, Reader: dl.Data},
),
TargetDisposition: "attachment",
}
} else {
return res
}
}
32 changes: 32 additions & 0 deletions api/unstable/msc3916_thumbnail.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package unstable

import (
"bytes"
"net/http"

"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/r0"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util/readers"
)

func FederationThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} {
r.URL.Query().Set("allow_remote", "false")

res := r0.ThumbnailMedia(r, rctx, _apimeta.UserInfo{})
if dl, ok := res.(*_responses.DownloadResponse); ok {
return &_responses.DownloadResponse{
ContentType: "multipart/mixed",
Filename: "",
SizeBytes: 0,
Data: readers.NewMultipartReader(
&readers.MultipartPart{ContentType: "application/json", Reader: readers.MakeCloser(bytes.NewReader([]byte("{}")))},
&readers.MultipartPart{ContentType: dl.ContentType, FileName: dl.Filename, Reader: dl.Data},
),
TargetDisposition: "attachment",
}
} else {
return res
}
}
150 changes: 150 additions & 0 deletions matrix/requests_signing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package matrix

import (
"crypto/ed25519"
"encoding/json"
"errors"
"fmt"
"sync"
"time"

"github.com/patrickmn/go-cache"
"github.com/sirupsen/logrus"
"github.com/t2bot/go-typed-singleflight"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/util"
)

type signingKey struct {
Key string `json:"key"`
}

type serverKeyResult struct {
ServerName string `json:"server_name"`
ValidUntilTs int64 `json:"valid_until_ts"`
VerifyKeys map[string]signingKey `json:"verify_keys"` // unpadded base64
OldVerifyKeys map[string]signingKey `json:"old_verify_keys"` // unpadded base64
Signatures map[string]map[string]string `json:"signatures"` // unpadded base64; <name, <keyId, sig>>
}

type ServerSigningKeys map[string]ed25519.PublicKey

var signingKeySf = new(typedsf.Group[*ServerSigningKeys])
var signingKeyCache = cache.New(cache.NoExpiration, 30*time.Second)
var signingKeyRWLock = new(sync.RWMutex)

func querySigningKeyCache(serverName string) *ServerSigningKeys {
if val, ok := signingKeyCache.Get(serverName); ok {
return val.(*ServerSigningKeys)
}
return nil
}

func QuerySigningKeys(serverName string) (*ServerSigningKeys, error) {
signingKeyRWLock.RLock()
keys := querySigningKeyCache(serverName)
signingKeyRWLock.RUnlock()
if keys != nil {
return keys, nil
}

keys, err, _ := signingKeySf.Do(serverName, func() (*ServerSigningKeys, error) {
ctx := rcontext.Initial().LogWithFields(logrus.Fields{
"keysForServer": serverName,
})

signingKeyRWLock.Lock()
defer signingKeyRWLock.Unlock()

// check cache once more, just in case the locks overlapped
cachedKeys := querySigningKeyCache(serverName)
if keys != nil {
return cachedKeys, nil
}

// now we can try to get the keys from the source
url, hostname, err := GetServerApiUrl(serverName)
if err != nil {
return nil, err
}

keysUrl := url + "/_matrix/key/v2/server"
keysResponse, err := FederatedGet(keysUrl, hostname, ctx)
if keysResponse != nil {
defer keysResponse.Body.Close()
}
if err != nil {
return nil, err
}

decoder := json.NewDecoder(keysResponse.Body)
raw := database.AnonymousJson{}
if err = decoder.Decode(&raw); err != nil {
return nil, err
}
keyInfo := new(serverKeyResult)
if err = raw.ApplyTo(keyInfo); err != nil {
return nil, err
}

// Check validity before we go much further
if keyInfo.ServerName != serverName {
return nil, fmt.Errorf("got keys for '%s' but expected '%s'", keyInfo.ServerName, serverName)
}
if keyInfo.ValidUntilTs <= util.NowMillis() {
return nil, errors.New("returned server keys are expired")
}
cacheUntil := time.Until(time.UnixMilli(keyInfo.ValidUntilTs)) / 2
if cacheUntil <= (6 * time.Second) {
return nil, errors.New("returned server keys would expire too quickly")
}

// Convert to something useful
serverKeys := make(ServerSigningKeys)
for keyId, keyObj := range keyInfo.VerifyKeys {
b, err := util.DecodeUnpaddedBase64String(keyObj.Key)
if err != nil {
return nil, errors.Join(fmt.Errorf("bad base64 for key ID '%s' for '%s'", keyId, serverName), err)
}

serverKeys[keyId] = b
}

// Check signatures
if len(keyInfo.Signatures) == 0 || len(keyInfo.Signatures[serverName]) == 0 {
return nil, fmt.Errorf("missing signatures from '%s'", serverName)
}
delete(raw, "signatures")
canonical, err := util.EncodeCanonicalJson(raw)
if err != nil {
return nil, err
}
for domain, sig := range keyInfo.Signatures {
if domain != serverName {
return nil, fmt.Errorf("unexpected signature from '%s' (expected '%s')", domain, serverName)
}

for keyId, b64 := range sig {
signatureBytes, err := util.DecodeUnpaddedBase64String(b64)
if err != nil {
return nil, errors.Join(fmt.Errorf("bad base64 signature for key ID '%s' for '%s'", keyId, serverName), err)
}

key, ok := serverKeys[keyId]
if !ok {
return nil, fmt.Errorf("unknown key ID '%s' for signature from '%s'", keyId, serverName)
}

if !ed25519.Verify(key, canonical, signatureBytes) {
return nil, fmt.Errorf("invalid signature '%s' from key ID '%s' for '%s'", b64, keyId, serverName)
}
}
}

// Cache & return (unlock was deferred)
signingKeyCache.Set(serverName, &serverKeys, cacheUntil)
return &serverKeys, nil
})
return keys, err
}
Loading

0 comments on commit 16628b1

Please sign in to comment.