Skip to content

Commit

Permalink
chore(sdk): consolidate http calls in one function (#850)
Browse files Browse the repository at this point in the history
Signed-off-by: Volodymyr Kit <[email protected]>
  • Loading branch information
justakit authored Dec 9, 2024
1 parent 28e85e1 commit 285627d
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 138 deletions.
45 changes: 38 additions & 7 deletions pkg/internal/httprequest/httprequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"fmt"
"io"
"net/http"
"slices"
"time"

"github.com/trustbloc/wallet-sdk/pkg/api"
Expand Down Expand Up @@ -42,11 +43,28 @@ func New(httpClient httpClient, metricsLogger api.MetricsLogger) *Request {
func (r *Request) Do(method, endpointURL, contentType string, body io.Reader,
event, parentEvent string, errorResponseHandler func(statusCode int, responseBody []byte) error,
) ([]byte, error) {
req, err := http.NewRequestWithContext(context.Background(), method, endpointURL, body)
return r.DoContext(context.Background(), method, endpointURL, contentType,
nil, body, event, parentEvent, nil, errorResponseHandler)
}

var defaultAcceptableStatuses = []int{http.StatusOK}

// DoContext is the same as Do, but also accept context and headers.
func (r *Request) DoContext(ctx context.Context, method, endpointURL, contentType string,
additionalHeaders http.Header, body io.Reader, event, parentEvent string, acceptableStatuses []int,
errorResponseHandler func(statusCode int, responseBody []byte) error,
) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, method, endpointURL, body)
if err != nil {
return nil, err
}

for header, values := range additionalHeaders {
for _, value := range values {
req.Header.Add(header, value)
}
}

if contentType != "" {
req.Header.Add("Content-Type", contentType)
}
Expand Down Expand Up @@ -79,9 +97,14 @@ func (r *Request) Do(method, endpointURL, contentType string, body io.Reader,
return nil, err
}

if resp.StatusCode != http.StatusOK {
statuses := acceptableStatuses
if statuses == nil {
statuses = defaultAcceptableStatuses
}

if !slices.Contains(statuses, resp.StatusCode) {
if errorResponseHandler == nil {
errorResponseHandler = genericErrorResponseHandler
errorResponseHandler = genericErrorResponseHandler(statuses)
}

return nil, errorResponseHandler(resp.StatusCode, respBytes)
Expand All @@ -106,8 +129,16 @@ func (r *Request) DoAndParse(method, endpointURL, contentType string, body io.Re
return json.Unmarshal(respBytes, response)
}

func genericErrorResponseHandler(statusCode int, respBytes []byte) error {
return fmt.Errorf(
"expected status code %d but got status code %d with response body %s instead",
http.StatusOK, statusCode, respBytes)
func genericErrorResponseHandler(expectedStatusCodes []int) func(statusCode int, respBytes []byte) error {
return func(statusCode int, respBytes []byte) error {
if len(expectedStatusCodes) == 1 {
return fmt.Errorf(
"expected status code %d but got status code %d with response body %s instead",
expectedStatusCodes[0], statusCode, respBytes)
}

return fmt.Errorf(
"expected status codes %v but got status code %d with response body %s instead",
expectedStatusCodes, statusCode, respBytes)
}
}
47 changes: 15 additions & 32 deletions pkg/oauth2/clientregistration.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@ package oauth2

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"

"github.com/trustbloc/wallet-sdk/pkg/internal/httprequest"
)

const (
newRegisterClientEventText = "Register client"
fetchRequestObjectEventText = "Fetch request object via an HTTP GET request to %s"
)

// RegisterClient registers a new client at the given registration endpoint.
Expand Down Expand Up @@ -55,39 +62,15 @@ func RegisterClient(registrationEndpoint string, clientMetadata *ClientMetadata,
}

func getRawResponse(requestBytes []byte, registrationEndpoint string, opts *opts) ([]byte, error) {
httpReq, err := http.NewRequest( //nolint: noctx // Timeout expected to be set in HTTP client already
http.MethodPost, registrationEndpoint, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}

httpReq.Header.Set("Content-Type", "application/json")

headers := http.Header{}
if opts.initialAccessBearerToken != "" {
httpReq.Header.Set("Authorization", "Bearer "+opts.initialAccessBearerToken)
}

resp, err := opts.httpClient.Do(httpReq)
if err != nil {
return nil, err
headers.Set("Authorization", "Bearer "+opts.initialAccessBearerToken)
}

defer func() {
errClose := resp.Body.Close()
if errClose != nil {
println(fmt.Sprintf("failed to close response body: %s", errClose.Error()))
}
}()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

if resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("server returned status code %d with body [%s]", resp.StatusCode,
string(respBody))
}
metricsEvent := fmt.Sprintf(fetchRequestObjectEventText, registrationEndpoint)

return respBody, nil
return httprequest.New(opts.httpClient, opts.metricsLogger).DoContext(context.TODO(),
http.MethodPost, registrationEndpoint, "application/json", headers,
bytes.NewReader(requestBytes), metricsEvent, newRegisterClientEventText,
[]int{http.StatusCreated}, nil)
}
2 changes: 1 addition & 1 deletion pkg/oauth2/clientregistration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestRegisterClient(t *testing.T) {
defer server.Close()

response, err := oauth2.RegisterClient(server.URL, nil)
require.EqualError(t, err, "server returned status code 500 with body []")
require.ErrorContains(t, err, "expected status code 201 but got status code 500 with response body instead")
require.Nil(t, response)
})
t.Run("Server returns empty body, resulting in a JSON unmarshal failure", func(t *testing.T) {
Expand Down
15 changes: 15 additions & 0 deletions pkg/oauth2/opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"net/http"

"github.com/trustbloc/wallet-sdk/pkg/api"
"github.com/trustbloc/wallet-sdk/pkg/metricslogger/noop"
)

type opts struct {
initialAccessBearerToken string
httpClient *http.Client
metricsLogger api.MetricsLogger
}

// An Opt is a single option for a call to RegisterClient.
Expand All @@ -29,6 +31,15 @@ func WithHTTPClient(httpClient *http.Client) Opt {
}
}

// WithMetricsLogger is an option for a call to RegisterClient that allows a caller to specify their MetricsLogger.
// If used, then performance metrics events will be pushed to the given MetricsLogger implementation.
// If this option is not used, then metrics logging will be disabled.
func WithMetricsLogger(metricsLogger api.MetricsLogger) Opt {
return func(opts *opts) {
opts.metricsLogger = metricsLogger
}
}

func processOpts(options []Opt) *opts {
opts := mergeOpts(options)

Expand All @@ -48,5 +59,9 @@ func mergeOpts(options []Opt) *opts {
}
}

if resolveOpts.metricsLogger == nil {
resolveOpts.metricsLogger = noop.NewMetricsLogger()
}

return resolveOpts
}
73 changes: 11 additions & 62 deletions pkg/openid4ci/interaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
Expand All @@ -33,6 +32,7 @@ import (
"github.com/trustbloc/wallet-sdk/pkg/common"
diderrors "github.com/trustbloc/wallet-sdk/pkg/did"
"github.com/trustbloc/wallet-sdk/pkg/did/wellknown"
"github.com/trustbloc/wallet-sdk/pkg/internal/httprequest"
metadatafetcher "github.com/trustbloc/wallet-sdk/pkg/internal/issuermetadata"
"github.com/trustbloc/wallet-sdk/pkg/models/issuer"
"github.com/trustbloc/wallet-sdk/pkg/walleterror"
Expand Down Expand Up @@ -380,19 +380,21 @@ func (i *interaction) getCredentialResponse(signer api.JWTSigner, nonce any,
oAuthHTTPClient := createOAuthHTTPClient(i.oAuth2Config, i.authToken, i.httpClient)

for index := range credentialTypes {
request, err := i.createCredentialRequestWithoutAccessToken(proofJWT, credentialFormats[index],
requestBody, err := i.createCredentialRequestBody(proofJWT, credentialFormats[index],
credentialTypes[index], credentialContexts[index])
if err != nil {
return nil, err
}

// The access token header will be injected automatically by the OAuth HTTP client, so there's no need to
// explicitly set it on the request object generated by the method call above.

fetchCredentialResponseEventText := fmt.Sprintf(fetchCredentialViaGETReqEventText, index+1,
len(credentialTypes), i.issuerMetadata.CredentialEndpoint)

responseBytes, err := i.getRawCredentialResponse(request, fetchCredentialResponseEventText, oAuthHTTPClient)
// The access token header will be injected automatically by the OAuth HTTP client, so there's no need to
// explicitly set it on the request object generated by the method call above.
responseBytes, err := httprequest.New(oAuthHTTPClient, i.metricsLogger).DoContext(context.TODO(),
http.MethodPost, i.issuerMetadata.CredentialEndpoint, "application/json", nil,
bytes.NewReader(requestBody), fetchCredentialResponseEventText, requestCredentialEventText,
[]int{http.StatusOK, http.StatusCreated}, processCredentialErrorResponse)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -461,12 +463,9 @@ func createOAuthHTTPClient(
return oAuthHTTPClient
}

// The returned *http.Request will not have the access token set on it. The caller must ensure that it's set
// before sending the request to the server.
func (i *interaction) createCredentialRequestWithoutAccessToken(proofJWT, credentialFormat string,
func (i *interaction) createCredentialRequestBody(proofJWT, credentialFormat string,
credentialTypes, credentialContext []string,
) (*http.Request, error) {

) ([]byte, error) {
var credentialContextToSend *[]string

if len(credentialContext) > 0 {
Expand All @@ -485,57 +484,7 @@ func (i *interaction) createCredentialRequestWithoutAccessToken(proofJWT, creden
},
}

credentialReqBytes, err := json.Marshal(credentialReq)
if err != nil {
return nil, err
}

request, err := http.NewRequest(http.MethodPost, //nolint: noctx
i.issuerMetadata.CredentialEndpoint, bytes.NewReader(credentialReqBytes))
if err != nil {
return nil, err
}

request.Header.Add("Content-Type", "application/json")

return request, nil
}

func (i *interaction) getRawCredentialResponse(credentialReq *http.Request, eventText string, httpClient *http.Client,
) ([]byte, error) {
timeStartHTTPRequest := time.Now()

response, err := httpClient.Do(credentialReq)
if err != nil {
return nil, err
}

err = i.metricsLogger.Log(&api.MetricsEvent{
Event: eventText,
ParentEvent: requestCredentialEventText,
Duration: time.Since(timeStartHTTPRequest),
})
if err != nil {
return nil, err
}

responseBytes, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}

if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusCreated {
return nil, processCredentialErrorResponse(response.StatusCode, responseBytes)
}

defer func() {
errClose := response.Body.Close()
if errClose != nil {
println(fmt.Sprintf("failed to close response body: %s", errClose.Error()))
}
}()

return responseBytes, nil
return json.Marshal(credentialReq)
}

func (i *interaction) getVCsFromCredentialResponses(
Expand Down
29 changes: 13 additions & 16 deletions pkg/openid4ci/issuerinitiatedinteraction.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,19 +443,22 @@ func (i *IssuerInitiatedInteraction) getCredentialResponse(
credentialResponses := make([]CredentialResponse, len(i.credentialTypes))

for index := range i.credentialTypes {
request, err := i.interaction.createCredentialRequestWithoutAccessToken(proofJWT, i.credentialFormats[index],
requestBody, err := i.interaction.createCredentialRequestBody(proofJWT, i.credentialFormats[index],
i.credentialTypes[index], i.credentialContexts[index])
if err != nil {
return nil, err
}

request.Header.Add("Authorization", "Bearer "+tokenResponse.AccessToken)
headers := http.Header{}
headers.Add("Authorization", "Bearer "+tokenResponse.AccessToken)

fetchCredentialResponseEventText := fmt.Sprintf(fetchCredentialViaGETReqEventText, index+1,
len(i.credentialTypes), i.interaction.issuerMetadata.CredentialEndpoint)

responseBytes, err := i.interaction.getRawCredentialResponse(request, fetchCredentialResponseEventText,
i.interaction.httpClient)
responseBytes, err := httprequest.New(i.interaction.httpClient, i.interaction.metricsLogger).DoContext(context.TODO(),
http.MethodPost, i.interaction.issuerMetadata.CredentialEndpoint, "application/json", headers,
bytes.NewReader(requestBody), fetchCredentialResponseEventText, requestCredentialEventText,
[]int{http.StatusOK, http.StatusCreated}, processCredentialErrorResponse)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -505,22 +508,16 @@ func (i *IssuerInitiatedInteraction) getCredentialResponsesBatch(
return nil, err
}

request, err := http.NewRequestWithContext(context.Background(),
http.MethodPost,
i.interaction.issuerMetadata.BatchCredentialEndpoint,
bytes.NewReader(b),
)
if err != nil {
return nil, err
}

request.Header.Add("Content-Type", "application/json")
request.Header.Add("Authorization", "Bearer "+tokenResponse.AccessToken)
headers := http.Header{}
headers.Add("Authorization", "Bearer "+tokenResponse.AccessToken)

fetchCredentialResponseEventText := fmt.Sprintf(fetchCredentialViaGETReqEventText, numberOfCredentials,
numberOfCredentials, i.interaction.issuerMetadata.BatchCredentialEndpoint)

b, err = i.interaction.getRawCredentialResponse(request, fetchCredentialResponseEventText, i.interaction.httpClient)
b, err = httprequest.New(i.interaction.httpClient, i.interaction.metricsLogger).DoContext(context.TODO(),
http.MethodPost, i.interaction.issuerMetadata.BatchCredentialEndpoint, "application/json", headers,
bytes.NewReader(b), fetchCredentialResponseEventText, requestCredentialEventText,
[]int{http.StatusOK, http.StatusCreated}, processCredentialErrorResponse)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 285627d

Please sign in to comment.