diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java index 6ab69cb24a92..cd1b16d6efa1 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java @@ -300,14 +300,23 @@ private MethodSpec generateTrySelectAuthScheme() { } builder.addStatement("$T identity", namedIdentityFuture()); builder.addStatement("$T metric = getIdentityMetric(identityProvider)", durationSdkMetric()); + builder.addStatement("$T resolveIdentityRequest = identityRequestBuilder.build()" , ResolveIdentityRequest.class); + builder.beginControlFlow("if (metric == null)") - .addStatement("identity = identityProvider.resolveIdentity(identityRequestBuilder.build())") + .addStatement("identity = identityProvider.resolveIdentity(resolveIdentityRequest)") .nextControlFlow("else") .addStatement("identity = $T.reportDuration(" - + "() -> identityProvider.resolveIdentity(identityRequestBuilder.build()), metricCollector, metric)", + + "() -> identityProvider.resolveIdentity(resolveIdentityRequest), metricCollector, metric)", MetricUtils.class) .endControlFlow(); + if (endpointRulesSpecUtils.isS3()) { + builder.addStatement("executionAttributes.putAttribute($T.RESOLVE_IDENTITY_REQUEST, resolveIdentityRequest)", + SdkInternalExecutionAttribute.class); + builder.addStatement("executionAttributes.putAttribute($T.SELECTED_IDENTITY_PROVIDER, identityProvider)", + SdkInternalExecutionAttribute.class); + } + builder.addStatement("return new $T<>(identity, signer, authOption)", SelectedAuthScheme.class); return builder.build(); } diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-auth-scheme-interceptor.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-auth-scheme-interceptor.java index 48edb00b1855..ca5523c9fe13 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-auth-scheme-interceptor.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-auth-scheme-interceptor.java @@ -113,10 +113,11 @@ private SelectedAuthScheme trySelectAuthScheme(AuthSchem authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); CompletableFuture identity; SdkMetric metric = getIdentityMetric(identityProvider); + ResolveIdentityRequest resolveIdentityRequest = identityRequestBuilder.build(); if (metric == null) { - identity = identityProvider.resolveIdentity(identityRequestBuilder.build()); + identity = identityProvider.resolveIdentity(resolveIdentityRequest); } else { - identity = MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(identityRequestBuilder.build()), + identity = MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(resolveIdentityRequest), metricCollector, metric); } return new SelectedAuthScheme<>(identity, signer, authOption); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-with-allowlist-auth-scheme-interceptor.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-with-allowlist-auth-scheme-interceptor.java index 942aa43d9aee..10234167d218 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-with-allowlist-auth-scheme-interceptor.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-with-allowlist-auth-scheme-interceptor.java @@ -130,10 +130,11 @@ private SelectedAuthScheme trySelectAuthScheme(AuthSchem authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); CompletableFuture identity; SdkMetric metric = getIdentityMetric(identityProvider); + ResolveIdentityRequest resolveIdentityRequest = identityRequestBuilder.build(); if (metric == null) { - identity = identityProvider.resolveIdentity(identityRequestBuilder.build()); + identity = identityProvider.resolveIdentity(resolveIdentityRequest); } else { - identity = MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(identityRequestBuilder.build()), + identity = MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(resolveIdentityRequest), metricCollector, metric); } return new SelectedAuthScheme<>(identity, signer, authOption); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-without-allowlist-auth-scheme-interceptor.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-without-allowlist-auth-scheme-interceptor.java index 498681559a0a..79aca96892b6 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-without-allowlist-auth-scheme-interceptor.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-without-allowlist-auth-scheme-interceptor.java @@ -137,10 +137,11 @@ private SelectedAuthScheme trySelectAuthScheme(AuthSchem authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); CompletableFuture identity; SdkMetric metric = getIdentityMetric(identityProvider); + ResolveIdentityRequest resolveIdentityRequest = identityRequestBuilder.build(); if (metric == null) { - identity = identityProvider.resolveIdentity(identityRequestBuilder.build()); + identity = identityProvider.resolveIdentity(resolveIdentityRequest); } else { - identity = MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(identityRequestBuilder.build()), + identity = MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(resolveIdentityRequest), metricCollector, metric); } return new SelectedAuthScheme<>(identity, signer, authOption); diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java index 37ef66a5d717..0718a07f5268 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java @@ -31,7 +31,10 @@ import software.amazon.awssdk.http.SdkHttpExecutionAttributes; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeProvider; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; import software.amazon.awssdk.utils.AttributeMap; /** @@ -153,6 +156,18 @@ public final class SdkInternalExecutionAttribute extends SdkExecutionAttribute { public static final ExecutionAttribute> SELECTED_AUTH_SCHEME = new ExecutionAttribute<>("SelectedAuthScheme"); + /** + * The selected identity provider for a request. + */ + public static final ExecutionAttribute> SELECTED_IDENTITY_PROVIDER = + new ExecutionAttribute<>("SelectedIdentityProvider"); + + /** + * The resolve identity request used by the identity provider. + */ + public static final ExecutionAttribute RESOLVE_IDENTITY_REQUEST = + new ExecutionAttribute<>("ResolveIdentityRequest"); + /** * The supported compression algorithms for an operation, and whether the operation is streaming or not. */ diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncRetryableStage2.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncRetryableStage2.java index ff598d27832e..e0b91f64007f 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncRetryableStage2.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncRetryableStage2.java @@ -95,6 +95,7 @@ private void attemptExecute(CompletableFuture> future) { try { retryableStageHelper.startingAttempt(); retryableStageHelper.logSendingRequest(); + retryableStageHelper.resolveCredentialsIfS3ExpressRetry(context); responseFuture = requestPipeline.execute(retryableStageHelper.requestToSend(), context); // If the result future fails, go ahead and fail the response future. diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/RetryableStage2.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/RetryableStage2.java index d12c444afa73..f0f4b03a19d6 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/RetryableStage2.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/RetryableStage2.java @@ -53,6 +53,7 @@ public Response execute(SdkHttpFullRequest request, RequestExecutionCon while (true) { try { retryableStageHelper.startingAttempt(); + retryableStageHelper.resolveCredentialsIfS3ExpressRetry(context); Response response = executeRequest(retryableStageHelper, context); retryableStageHelper.recordAttemptSucceeded(); return response; diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/utils/RetryableStageHelper2.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/utils/RetryableStageHelper2.java index 13cbb601789b..fc888dd47049 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/utils/RetryableStageHelper2.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/utils/RetryableStageHelper2.java @@ -24,13 +24,16 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.Response; import software.amazon.awssdk.core.SdkStandardLogger; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.internal.http.HttpClientDependencies; import software.amazon.awssdk.core.internal.http.RequestExecutionContext; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage; @@ -41,6 +44,8 @@ import software.amazon.awssdk.core.retry.RetryPolicyContext; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; import software.amazon.awssdk.retries.AdaptiveRetryStrategy; import software.amazon.awssdk.retries.api.AcquireInitialTokenRequest; import software.amazon.awssdk.retries.api.AcquireInitialTokenResponse; @@ -236,6 +241,40 @@ public void setLastResponse(SdkHttpResponse lastResponse) { this.lastResponse = lastResponse; } + /** + * Re-resolve the credentials upon a retry, if S3Express request. + */ + public void resolveCredentialsIfS3ExpressRetry(RequestExecutionContext requestExecutionContext) { + if (isInitialAttempt()) { + return; + } + + IdentityProvider identityProvider = + requestExecutionContext.executionAttributes().getAttribute(SdkInternalExecutionAttribute.SELECTED_IDENTITY_PROVIDER); + + if (identityProvider == null || !isS3Express(identityProvider)) { + return; + } + + ResolveIdentityRequest resolveIdentityRequest = + requestExecutionContext.executionAttributes() + .getAttribute(SdkInternalExecutionAttribute.RESOLVE_IDENTITY_REQUEST); + + SelectedAuthScheme authScheme = + requestExecutionContext.executionAttributes().getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + + CompletableFuture newlyResolvedIdentity = identityProvider.resolveIdentity(resolveIdentityRequest); + SelectedAuthScheme updatedAuthScheme = new SelectedAuthScheme(newlyResolvedIdentity, authScheme.signer(), + authScheme.authSchemeOption()); + requestExecutionContext.executionAttributes().putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, + updatedAuthScheme); + } + + private boolean isS3Express(IdentityProvider identityProvider) { + String className = identityProvider.identityType().getSimpleName(); + return "S3ExpressSessionCredentials".equals(className); + } + /** * Returns true if this is the first attempt. */ diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressRetryResolveCredentialsTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressRetryResolveCredentialsTest.java new file mode 100644 index 000000000000..e9c27e351f26 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressRetryResolveCredentialsTest.java @@ -0,0 +1,190 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.s3express; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.put; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlMatching; +import static java.lang.Boolean.TRUE; +import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.http.SdkHttpConfigurationOption.TRUST_ALL_CERTIFICATES; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ClientBuilder; +import software.amazon.awssdk.utils.AttributeMap; +import software.amazon.awssdk.utils.http.SdkHttpUtils; + +@WireMockTest(httpsEnabled = true) +public class S3ExpressRetryResolveCredentialsTest { + + private static final Function WM_HTTPS_ENDPOINT = wm -> URI.create(wm.getHttpsBaseUrl()); + private static final PathStyleEnforcingInterceptor PATH_STYLE_INTERCEPTOR = new PathStyleEnforcingInterceptor(); + private static final String S3EXPRESS_BUCKET = "s3express-cache-1--use1-az1--x-s3"; + private static final String REGULAR_S3__BUCKET = "my-test-bucket"; + private static final int RETRYABLE_ERROR_STATUS_CODE = 429; + private static final int NON_RETRYABLE_ERROR_STATUS_CODE = 400; + + private S3Client s3; + private S3AsyncClient s3Async; + private TrackingCredentialsProvider trackingCredentialsProvider; + + private static final String CREATE_SESSION_RESPONSE = String.format( + "\n" + + "\n" + + "\n" + + "%s\n" + + "%s\n" + + "%s" + + "\n" + + "", "TheToken", "TheSecret", "TheAccessKey"); + + @BeforeEach + public void methodSetup(WireMockRuntimeInfo wm) { + AwsBasicCredentials credentials = AwsBasicCredentials.create("akid_client", "skid_client"); + AwsCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(credentials); + trackingCredentialsProvider = new TrackingCredentialsProvider(credentialsProvider); + s3 = getS3ClientBuilder(wm).build(); + s3Async = getS3AsyncClientBuilder(wm).build(); + + stubFor(get(urlMatching("/.*session")).atPriority(1).willReturn(aResponse() + .withStatus(200) + .withBody(CREATE_SESSION_RESPONSE))); + } + + private static List testParams() { + return Arrays.asList( + Arguments.of(S3EXPRESS_BUCKET, RETRYABLE_ERROR_STATUS_CODE, 4), // + 3 retries + Arguments.of(S3EXPRESS_BUCKET, NON_RETRYABLE_ERROR_STATUS_CODE, 1), + Arguments.of(REGULAR_S3__BUCKET, RETRYABLE_ERROR_STATUS_CODE, 1), + Arguments.of(REGULAR_S3__BUCKET, NON_RETRYABLE_ERROR_STATUS_CODE, 1) + ); + } + + @ParameterizedTest + @MethodSource("testParams") + void syncClient_resolvesIdentityProperNumberOfTimes(String bucket, int statusCode, int resolveIdentityCount) { + stubFor(put(anyUrl()).willReturn(aResponse().withStatus(statusCode))); + try { + s3.putObject(r -> r.bucket(bucket).key("key"), RequestBody.fromString("tmp")); + } catch (Exception e) { + assertThat(trackingCredentialsProvider.resolveIdentityCount()).isEqualTo(resolveIdentityCount); + } + } + + @ParameterizedTest + @MethodSource("testParams") + void asyncClient_resolvesIdentityProperNumberOfTimes(String bucket, int statusCode, int resolveIdentityCount) { + stubFor(put(anyUrl()).willReturn(aResponse().withStatus(statusCode))); + try { + s3Async.putObject(r -> r.bucket(bucket).key("key"), AsyncRequestBody.fromString("tmp")).join(); + } catch (Exception e) { + assertThat(trackingCredentialsProvider.resolveIdentityCount()).isEqualTo(resolveIdentityCount); + } + } + + private S3ClientBuilder getS3ClientBuilder(WireMockRuntimeInfo wm) { + return S3Client.builder() + .region(Region.US_EAST_1) + .overrideConfiguration(c -> c.addExecutionInterceptor(PATH_STYLE_INTERCEPTOR)) + .credentialsProvider(trackingCredentialsProvider) + .endpointOverride(WM_HTTPS_ENDPOINT.apply(wm)) + .httpClient(ApacheHttpClient.builder() + .buildWithDefaults(AttributeMap.builder() + .put(TRUST_ALL_CERTIFICATES, TRUE) + .build())); + } + + private S3AsyncClientBuilder getS3AsyncClientBuilder(WireMockRuntimeInfo wm) { + return S3AsyncClient.builder() + .region(Region.US_EAST_1) + .overrideConfiguration(c -> c.addExecutionInterceptor(PATH_STYLE_INTERCEPTOR)) + .credentialsProvider(trackingCredentialsProvider) + .endpointOverride(WM_HTTPS_ENDPOINT.apply(wm)) + .httpClient(NettyNioAsyncHttpClient.builder() + .buildWithDefaults(AttributeMap.builder() + .put(TRUST_ALL_CERTIFICATES, TRUE) + .build())); + } + + private static final class PathStyleEnforcingInterceptor implements ExecutionInterceptor { + + @Override + public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { + SdkHttpRequest sdkHttpRequest = context.httpRequest(); + String host = sdkHttpRequest.host(); + String bucket = host.substring(0, host.indexOf(".localhost")); + + return sdkHttpRequest.toBuilder().host("localhost") + .encodedPath(SdkHttpUtils.appendUri(bucket, sdkHttpRequest.encodedPath())) + .build(); + } + } + + private static final class TrackingCredentialsProvider implements AwsCredentialsProvider { + private final AwsCredentialsProvider delegate; + private int resolveIdentityCount; + + TrackingCredentialsProvider(AwsCredentialsProvider delegate) { + this.delegate = delegate; + } + + @Override + public AwsCredentials resolveCredentials() { + return delegate.resolveCredentials(); + } + + @Override + public CompletableFuture resolveIdentity(ResolveIdentityRequest resolveIdentityRequest) { + resolveIdentityCount++; + return delegate.resolveIdentity(resolveIdentityRequest); + } + + public int resolveIdentityCount() { + return resolveIdentityCount; + } + } +}