Skip to content

Commit

Permalink
v0.4.0 batch
Browse files Browse the repository at this point in the history
- Added exponential backoff/retry functionality for all methods.
- Migrated from certificate authentication to `client_credentials` authentication flow.
  • Loading branch information
thatmattlove committed May 21, 2024
1 parent 7e4bc94 commit 8cf695b
Show file tree
Hide file tree
Showing 15 changed files with 219 additions and 123 deletions.
108 changes: 49 additions & 59 deletions auth.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
package sfdc

import (
"crypto/x509"
"encoding/pem"
"fmt"
"net/url"
"time"

"github.com/go-resty/resty/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/stellaraf/go-sfdc/internal/util"
"github.com/stellaraf/go-utils/encryption"
)

const GRANT_TYPE_JWT_BEARER string = "urn:ietf:params:oauth:grant-type:jwt-bearer"

type Auth struct {
InstanceURL *url.URL
privateKey string
clientSecret string
clientID string
username string
httpClient *resty.Client
authURL *url.URL
encryption bool
Expand All @@ -28,65 +21,57 @@ type Auth struct {
setAccessTokenCallback SetTokenCallback
}

func parsePrivateKey(key []byte) (parsed any, err error) {
parsed, err = x509.ParsePKCS8PrivateKey(key)
func (auth *Auth) IntrospectToken(token string) (*TokenIntrospection, error) {
data := map[string]string{
"client_id": auth.clientID,
"client_secret": auth.clientSecret,
"token_type_hint": "access_token",
"token": token,
}
req := auth.httpClient.R().
SetFormData(data).
SetResult(&TokenIntrospection{}).
SetError(&AuthErrorResponse{})
res, err := req.Post("/services/oauth2/introspect")
if err != nil {
parsed, err = x509.ParsePKCS1PrivateKey(key)
if err != nil {
parsed, err = x509.ParseECPrivateKey(key)
if err != nil {
return
}
}
return nil, err
}
if parsed == nil {
err = fmt.Errorf("failed to parse private key")
return
if res.IsError() {
err = getSFDCError(res.Error())
return nil, err
}
return
}

func (auth *Auth) GetNewToken() (token *Token, err error) {
expiresAt := time.Now()
expiresAt = expiresAt.Add(time.Second * 300)
// SFDC requires that the audience be a single string, not an array.
jwt.MarshalSingleStringAsArray = false
claims := &jwt.RegisteredClaims{
Issuer: auth.clientID,
Subject: auth.username,
Audience: jwt.ClaimStrings{auth.authURL.String()},
ExpiresAt: jwt.NewNumericDate(expiresAt),
}
initialToken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
block, _ := pem.Decode([]byte(auth.privateKey))
if block == nil {
err = fmt.Errorf("failed to decode private key")
return
}
rsaKey, err := parsePrivateKey(block.Bytes)
if err != nil {
return
}
assertion, err := initialToken.SignedString(rsaKey)
if err != nil {
return
intro, ok := res.Result().(*TokenIntrospection)
if !ok {
detail := string(res.Body())
m := "failed to introspect access token"
if detail != "" {
m += fmt.Sprintf(" due to error: %s", detail)
}
err = fmt.Errorf(m)
return nil, err
}
return intro, nil
}

func (auth *Auth) GetNewToken() (*Token, error) {
req := auth.httpClient.R().
SetHeader("content-type", "application/x-www-form-urlencoded").
SetQueryParam("grant_type", GRANT_TYPE_JWT_BEARER).
SetQueryParam("assertion", assertion).
SetQueryParam("grant_type", "client_credentials").
SetQueryParam("client_id", auth.clientID).
SetQueryParam("client_secret", auth.clientSecret).
SetResult(&Token{}).
SetError(&AuthErrorResponse{})

res, err := req.Post("/services/oauth2/token")
if err != nil {
return
return nil, err
}

if res.IsError() {
err = getSFDCError(res.Error())
return
return nil, err
}

token, ok := res.Result().(*Token)
if !ok {
detail := string(res.Body())
Expand All @@ -95,10 +80,17 @@ func (auth *Auth) GetNewToken() (token *Token, err error) {
m += fmt.Sprintf(" due to error: %s", detail)
}
err = fmt.Errorf(m)
return
return nil, err
}
token.ExpiresAt = expiresAt
return

intro, err := auth.IntrospectToken(token.AccessToken)
if err != nil {
return nil, err
}

token.SetExpiry(intro.Exp)

return token, nil
}

func (auth *Auth) GetAccessToken() (token string, err error) {
Expand Down Expand Up @@ -128,7 +120,7 @@ func (auth *Auth) GetAccessToken() (token string, err error) {
}

func (auth *Auth) SetAccessToken(token *Token) (err error) {
exp := time.Until(token.ExpiresAt)
exp := time.Until(token.expiresAt)
if auth.encryption {
var encrypted string
encrypted, err = encryption.Encrypt(token.AccessToken, auth.encryptionPassphrase)
Expand All @@ -151,7 +143,7 @@ func (auth *Auth) CacheNewToken(token *Token) (err error) {
}

func NewAuth(
clientID, privateKey, username, authURL string,
clientID, clientSecret, authURL string,
encryption *string,
getAccessTokenCallback CachedTokenCallback,
setAccessTokenCallback SetTokenCallback,
Expand All @@ -172,13 +164,11 @@ func NewAuth(
}
httpClient.SetHeader("user-agent", "go-sfdc")
httpClient.SetBaseURL(fmt.Sprintf("%s://%s", parsedAuthURL.Scheme, parsedAuthURL.Host))
key := util.FormatPrivateKey(privateKey)
auth = &Auth{
InstanceURL: nil,
authURL: parsedAuthURL,
username: username,
clientID: clientID,
privateKey: key,
clientSecret: clientSecret,
encryption: doEncrypt,
encryptionPassphrase: passphrase,
getAccessTokenCallback: getAccessTokenCallback,
Expand Down
6 changes: 2 additions & 4 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ func initAuth() (auth *sfdc.Auth, err error) {
}
auth, err = sfdc.NewAuth(
env.ClientID,
env.PrivateKey,
env.AuthUsername,
env.ClientSecret,
env.AuthURL,
encryptionPassphrase,
getAccessToken,
Expand All @@ -74,8 +73,7 @@ func Test_Auth(t *testing.T) {
require.NoError(t, err)
_, err = sfdc.NewAuth(
"invalid-client-key",
env.PrivateKey,
env.AuthUsername,
env.ClientSecret,
env.AuthURL,
nil,
getAccessToken,
Expand Down
54 changes: 41 additions & 13 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,46 +1,74 @@
package sfdc

import (
"time"

"github.com/cenkalti/backoff/v4"
"github.com/go-resty/resty/v2"
)

const DefaultRetryDuration time.Duration = time.Second * 10

// Salesforce Client
type Client struct {
httpClient *resty.Client
auth *Auth
timeout time.Duration
backoff backoff.BackOff
}

func (client *Client) prepare() (err error) {
func (client *Client) prepare() error {
token, err := client.auth.GetAccessToken()
if err != nil {
return
return err
}
client.httpClient.SetAuthToken(token)
return
return nil
}

// do executes a given resty request method such as Get/Post. If a timeout/backoff is specified,
// the request will be executed and retried within that timeout period.
func (client *Client) Do(doer func(u string) (*resty.Response, error), url string) (*resty.Response, error) {
op := func() (*resty.Response, error) {
return doer(url)
}
if client.timeout == 0 {
return op()
}
return backoff.RetryWithData(op, client.backoff)
}

// WithRetry specifies a time period in which to retry all requests if a errors are returned.
func (client *Client) WithRetry(timeout time.Duration) *Client {
client.timeout = timeout
client.backoff = backoff.NewExponentialBackOff(backoff.WithMaxElapsedTime(timeout))
return client
}

// Create a go-sfdc client and performs initial authentication.
func New(
clientID, privateKey, username, authURL string,
clientID, clientSecret, authURL string,
encryption *string,
getAccessTokenCallback CachedTokenCallback,
setAccessTokenCallback SetTokenCallback,
) (client *Client, err error) {
getToken CachedTokenCallback,
setToken SetTokenCallback,
) (*Client, error) {

auth, err := NewAuth(
clientID, privateKey, username, authURL,
clientID, clientSecret, authURL,
encryption,
getAccessTokenCallback,
setAccessTokenCallback,
getToken,
setToken,
)
if err != nil {
return
return nil, err
}
httpClient := resty.New()
httpClient.SetBaseURL(auth.InstanceURL.String())
client = &Client{
client := &Client{
httpClient: httpClient,
auth: auth,
timeout: DefaultRetryDuration,
backoff: backoff.NewExponentialBackOff(backoff.WithMaxElapsedTime(DefaultRetryDuration)),
}
return
return client, nil
}
2 changes: 1 addition & 1 deletion client_methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (client *Client) PostToCase(caseID string, content string, feedOptions *Fee
feedOptions.Body = content
feedOptions.Type = "TextPost"
req := client.httpClient.R().SetBody(feedOptions).SetResult(&RecordCreatedResponse{})
res, err := req.Post(path)
res, err := client.Do(req.Post, path)
if err != nil {
return
}
Expand Down
4 changes: 3 additions & 1 deletion client_methods_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ func Test_PostToCase(t *testing.T) {
Status: "New",
Subject: subject,
}
newCase, _ := Client.CreateCase(caseData)
newCase, err := Client.CreateCase(caseData)
require.NoError(t, err)
require.NotNil(t, newCase)
t.Run("post plain text update", func(t *testing.T) {
t.Parallel()
postResult, err := Client.PostToCase(newCase.ID, "go-sfdc test plain text comment", nil)
Expand Down
Loading

0 comments on commit 8cf695b

Please sign in to comment.