Skip to content

Commit

Permalink
[AUTO-BACKPORT release-0.38.0] 10211: chore: fix license check (#10215)
Browse files Browse the repository at this point in the history
Co-authored-by: Bradley Laney <[email protected]>
  • Loading branch information
github-actions[bot] and stoksc authored Nov 19, 2024
1 parent 332cefc commit 9619dcf
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 59 deletions.
17 changes: 3 additions & 14 deletions master/internal/api_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,12 @@ import (
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

var errAccessTokenRequiresEE = status.Error(
codes.FailedPrecondition,
"users cannot log in with an access token without a valid Enterprise Edition license set up.",
)

// PostAccessToken takes user id and optional lifespan, description and creates an
// access token for the given user.
func (a *apiServer) PostAccessToken(
ctx context.Context, req *apiv1.PostAccessTokenRequest,
) (*apiv1.PostAccessTokenResponse, error) {
if !license.IsEE() {
return nil, errAccessTokenRequiresEE
}
license.RequireLicense("access tokens")

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
Expand Down Expand Up @@ -85,9 +78,7 @@ func (a *apiServer) PostAccessToken(
func (a *apiServer) GetAccessTokens(
ctx context.Context, req *apiv1.GetAccessTokensRequest,
) (*apiv1.GetAccessTokensResponse, error) {
if !license.IsEE() {
return nil, errAccessTokenRequiresEE
}
license.RequireLicense("access tokens")

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
Expand Down Expand Up @@ -192,9 +183,7 @@ func (a *apiServer) GetAccessTokens(
func (a *apiServer) PatchAccessToken(
ctx context.Context, req *apiv1.PatchAccessTokenRequest,
) (*apiv1.PatchAccessTokenResponse, error) {
if !license.IsEE() {
return nil, errAccessTokenRequiresEE
}
license.RequireLicense("access tokens")

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
Expand Down
47 changes: 2 additions & 45 deletions master/internal/license/license.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
package license

import (
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"

"github.com/golang-jwt/jwt/v4"
)

const (
licenseRequiredMsg = "An enterprise license is required to use this feature"
errCheckingLicense = "error when validating license"
Expand All @@ -20,42 +11,8 @@ var licenseKey string
// publicKey stores the public key used to verify licenses. Defaults to empty.
var publicKey string

// decodedLicense contains the body of a decoded licenseKey.
type decodedLicense struct {
jwt.RegisteredClaims

LicenseVersion string `json:"licenseVersion"`
}

// RequireLicense panics if no licenseKey or an invalid licenseKey is used.
func RequireLicense(resource string) {
if publicKey == "" || licenseKey == "" {
// TODO: get better messaging for this
panic(fmt.Sprintf("%s: %s", licenseRequiredMsg, resource))
}
var claims decodedLicense
_, err := jwt.ParseWithClaims(licenseKey, &claims, func(token *jwt.Token) (interface{}, error) {
pemData, err := base64.StdEncoding.DecodeString(publicKey)
if err != nil {
return nil, err
}
blk, _ := pem.Decode(pemData)
if blk == nil {
return nil, fmt.Errorf("error decoding pem")
}
key, err := x509.ParsePKIXPublicKey(blk.Bytes)
if err != nil {
return nil, fmt.Errorf("error parsing public key: %w", err)
}
return key, nil
})
if err != nil {
panic(fmt.Sprintf("%s: %s", errCheckingLicense, err.Error()))
}
if claims.LicenseVersion != "1" {
panic("Specified licenseKey version is incompatible")
}
}
// RequireLicense is a no-op.
func RequireLicense(resource string) {}

// IsEE returns true if a license is detected.
func IsEE() bool {
Expand Down

0 comments on commit 9619dcf

Please sign in to comment.