Skip to content

Commit

Permalink
JWT Token example refactored with the latest github.com/cristalhq/jwt…
Browse files Browse the repository at this point in the history
…/v5. (#351)

Co-authored-by: Brian Royer <[email protected]>
  • Loading branch information
shyce and Brian Royer authored Jan 21, 2024
1 parent 6cd1903 commit 72c3b87
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 45 deletions.
4 changes: 2 additions & 2 deletions _examples/jwt_token/jwt/token_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ package jwt
import (
"time"

"github.com/cristalhq/jwt/v3"
"github.com/cristalhq/jwt/v5"
)

func BuildUserToken(secret string, userID string, expireAt int64) (string, error) {
key := []byte(secret)
signer, _ := jwt.NewSignerHS(jwt.HS256, key)
builder := jwt.NewBuilder(signer)
claims := &connectTokenClaims{
StandardClaims: jwt.StandardClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
},
}
Expand Down
110 changes: 75 additions & 35 deletions _examples/jwt_token/jwt/token_verifier_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"sync"
"time"

"github.com/cristalhq/jwt/v3"
"github.com/cristalhq/jwt/v5"
)

type ConnectToken struct {
Expand Down Expand Up @@ -83,7 +83,7 @@ type connectTokenClaims struct {
Info json.RawMessage `json:"info,omitempty"`
Base64Info string `json:"b64info,omitempty"`
Channels []string `json:"channels,omitempty"`
jwt.StandardClaims
jwt.RegisteredClaims
}

type subscribeTokenClaims struct {
Expand All @@ -92,7 +92,7 @@ type subscribeTokenClaims struct {
Info json.RawMessage `json:"info,omitempty"`
Base64Info string `json:"b64info,omitempty"`
ExpireTokenOnly bool `json:"eto,omitempty"`
jwt.StandardClaims
jwt.RegisteredClaims
}

type algorithms struct {
Expand Down Expand Up @@ -148,60 +148,67 @@ func newAlgorithms(tokenHMACSecretKey string, pubKey *rsa.PublicKey) (*algorithm
return alg, nil
}

func (s *algorithms) verify(token *jwt.Token) error {
var verifier jwt.Verifier
func (verifier *TokenVerifier) verifySignature(token *jwt.Token) error {
verifier.mu.RLock()
defer verifier.mu.RUnlock()

var verifierFunc jwt.Verifier
switch token.Header().Algorithm {
case jwt.HS256:
verifier = s.HS256
verifierFunc = verifier.algorithms.HS256
case jwt.HS384:
verifier = s.HS384
verifierFunc = verifier.algorithms.HS384
case jwt.HS512:
verifier = s.HS512
verifierFunc = verifier.algorithms.HS512
case jwt.RS256:
verifier = s.RS256
verifierFunc = verifier.algorithms.RS256
case jwt.RS384:
verifier = s.RS384
verifierFunc = verifier.algorithms.RS384
case jwt.RS512:
verifier = s.RS512
verifierFunc = verifier.algorithms.RS512
default:
return fmt.Errorf("%w: %s", errUnsupportedAlgorithm, string(token.Header().Algorithm))
}
if verifier == nil {

if verifierFunc == nil {
return fmt.Errorf("%w: %s", errDisabledAlgorithm, string(token.Header().Algorithm))
}
return verifier.Verify(token.Payload(), token.Signature())
}

func (verifier *TokenVerifier) verifySignature(token *jwt.Token) error {
verifier.mu.RLock()
defer verifier.mu.RUnlock()
return verifier.algorithms.verify(token)
return verifierFunc.Verify(token)
}

func (verifier *TokenVerifier) VerifyConnectToken(t string) (ConnectToken, error) {
token, err := jwt.Parse([]byte(t))
token, err := jwt.ParseNoVerify([]byte(t))
if err != nil {
return ConnectToken{}, err
return ConnectToken{}, fmt.Errorf("error parsing connect token: %w", err)
}

err = verifier.verifySignature(token)
if err != nil {
return ConnectToken{}, err
return ConnectToken{}, fmt.Errorf("error verifying connect token signature: %w", err)
}

token, err = jwt.Parse([]byte(t), verifier.selectVerifier(token.Header().Algorithm))
if err != nil {
return ConnectToken{}, fmt.Errorf("error verifying connect token: %w", err)
}

claims := &connectTokenClaims{}
err = json.Unmarshal(token.RawClaims(), claims)
err = json.Unmarshal(token.Claims(), claims)
if err != nil {
return ConnectToken{}, err
return ConnectToken{}, fmt.Errorf("error unmarshalling connect token claims: %w", err)
}

now := time.Now()
if !claims.IsValidExpiresAt(now) || !claims.IsValidNotBefore(now) {
if !claims.IsValidExpiresAt(now) {
return ConnectToken{}, ErrTokenExpired
}
if !claims.IsValidNotBefore(now) {
return ConnectToken{}, errors.New("token not valid yet")
}

ct := ConnectToken{
UserID: claims.StandardClaims.Subject,
UserID: claims.RegisteredClaims.Subject,
Info: claims.Info,
Channels: claims.Channels,
}
Expand All @@ -211,15 +218,37 @@ func (verifier *TokenVerifier) VerifyConnectToken(t string) (ConnectToken, error
if claims.Base64Info != "" {
byteInfo, err := base64.StdEncoding.DecodeString(claims.Base64Info)
if err != nil {
return ConnectToken{}, err
return ConnectToken{}, fmt.Errorf("error decoding base64 info in connect token: %w", err)
}
ct.Info = byteInfo
}
return ct, nil
}

func (verifier *TokenVerifier) selectVerifier(alg jwt.Algorithm) jwt.Verifier {
verifier.mu.RLock()
defer verifier.mu.RUnlock()

switch alg {
case jwt.HS256:
return verifier.algorithms.HS256
case jwt.HS384:
return verifier.algorithms.HS384
case jwt.HS512:
return verifier.algorithms.HS512
case jwt.RS256:
return verifier.algorithms.RS256
case jwt.RS384:
return verifier.algorithms.RS384
case jwt.RS512:
return verifier.algorithms.RS512
default:
return nil
}
}

func (verifier *TokenVerifier) VerifySubscribeToken(t string) (SubscribeToken, error) {
token, err := jwt.Parse([]byte(t))
token, err := jwt.ParseNoVerify([]byte(t))
if err != nil {
return SubscribeToken{}, err
}
Expand All @@ -229,33 +258,44 @@ func (verifier *TokenVerifier) VerifySubscribeToken(t string) (SubscribeToken, e
return SubscribeToken{}, err
}

token, err = jwt.Parse([]byte(t), verifier.selectVerifier(token.Header().Algorithm))
if err != nil {
return SubscribeToken{}, fmt.Errorf("error verifying subscribe token: %w", err)
}

claims := &subscribeTokenClaims{}
err = json.Unmarshal(token.RawClaims(), claims)
err = json.Unmarshal(token.Claims(), claims)
if err != nil {
return SubscribeToken{}, err
}

now := time.Now()
if !claims.IsValidExpiresAt(now) || !claims.IsValidNotBefore(now) {
if !claims.IsValidExpiresAt(now) {
return SubscribeToken{}, ErrTokenExpired
}
if !claims.IsValidNotBefore(now) {
return SubscribeToken{}, errors.New("token not valid yet")
}

st := SubscribeToken{
Client: claims.Client,
Info: claims.Info,
Channel: claims.Channel,
ExpireAt: claims.ExpiresAt.Unix(),
ExpireTokenOnly: claims.ExpireTokenOnly,
}
if claims.ExpiresAt != nil {
st.ExpireAt = claims.ExpiresAt.Unix()
}
if claims.Base64Info != "" {

// Decode the Info field if it's present
if len(claims.Info) > 0 {
st.Info = claims.Info
} else if claims.Base64Info != "" {
// If Info is not present, but Base64Info is, decode it
byteInfo, err := base64.StdEncoding.DecodeString(claims.Base64Info)
if err != nil {
return SubscribeToken{}, err
return SubscribeToken{}, fmt.Errorf("error decoding base64 info in subscribe token: %w", err)
}
st.Info = byteInfo
}

return st, nil
}

Expand Down
16 changes: 8 additions & 8 deletions _examples/jwt_token/jwt/token_verifier_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"testing"
"time"

"github.com/cristalhq/jwt/v3"
"github.com/cristalhq/jwt/v5"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -48,7 +48,7 @@ func getConnToken(user string, exp int64, rsaPrivateKey *rsa.PrivateKey) string
builder := getTokenBuilder(rsaPrivateKey)
claims := &connectTokenClaims{
Base64Info: "e30=",
StandardClaims: jwt.StandardClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: user,
},
}
Expand All @@ -59,16 +59,16 @@ func getConnToken(user string, exp int64, rsaPrivateKey *rsa.PrivateKey) string
if err != nil {
panic(err)
}
return string(token.Raw())
return string(token.Bytes())
}

func getSubscribeToken(channel string, client string, exp int64, rsaPrivateKey *rsa.PrivateKey) string {
builder := getTokenBuilder(rsaPrivateKey)
claims := &subscribeTokenClaims{
Base64Info: "e30=",
Channel: channel,
Client: client,
StandardClaims: jwt.StandardClaims{},
Base64Info: "e30=",
Channel: channel,
Client: client,
RegisteredClaims: jwt.RegisteredClaims{},
}
if exp > 0 {
claims.ExpiresAt = jwt.NewNumericDate(time.Unix(exp, 0))
Expand All @@ -77,7 +77,7 @@ func getSubscribeToken(channel string, client string, exp int64, rsaPrivateKey *
if err != nil {
panic(err)
}
return string(token.Raw())
return string(token.Bytes())
}

func Test_tokenVerifierJWT_Signer(t *testing.T) {
Expand Down

0 comments on commit 72c3b87

Please sign in to comment.