Skip to content

Commit

Permalink
fix(credential-providers): supply backup credentials to fromTemporary…
Browse files Browse the repository at this point in the history
…Credentials
  • Loading branch information
kuhe committed Jan 17, 2025
1 parent d35e4ad commit 32fd0ed
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import type { AssumeRoleCommandInput, STSClient, STSClientConfig } from "@aws-sdk/nested-clients/sts";
import type {
AwsIdentityProperties,
CredentialProviderOptions,
RuntimeConfigAwsCredentialIdentityProvider,
} from "@aws-sdk/types";
import { CredentialsProviderError } from "@smithy/property-provider";
import { AwsCredentialIdentity, AwsCredentialIdentityProvider, Pluggable } from "@smithy/types";

export interface FromTemporaryCredentialsOptions extends CredentialProviderOptions {
params: Omit<AssumeRoleCommandInput, "RoleSessionName"> & { RoleSessionName?: string };
masterCredentials?: AwsCredentialIdentity | AwsCredentialIdentityProvider;
clientConfig?: STSClientConfig;
clientPlugins?: Pluggable<any, any>[];
mfaCodeProvider?: (mfaSerial: string) => Promise<string>;
}

export const fromTemporaryCredentials = (
options: FromTemporaryCredentialsOptions,
credentialDefaultProvider?: () => AwsCredentialIdentityProvider
): RuntimeConfigAwsCredentialIdentityProvider => {
let stsClient: STSClient;
return async (awsIdentityProperties: AwsIdentityProperties = {}): Promise<AwsCredentialIdentity> => {
options.logger?.debug("@aws-sdk/credential-providers - fromTemporaryCredentials (STS)");
const params = { ...options.params, RoleSessionName: options.params.RoleSessionName ?? "aws-sdk-js-" + Date.now() };
if (params?.SerialNumber) {
if (!options.mfaCodeProvider) {
throw new CredentialsProviderError(
`Temporary credential requires multi-factor authentication,` + ` but no MFA code callback was provided.`,
{
tryNextLink: false,
logger: options.logger,
}
);
}
params.TokenCode = await options.mfaCodeProvider(params?.SerialNumber);
}

const { AssumeRoleCommand, STSClient } = await import("./loadSts");

if (!stsClient) {
const defaultCredentialsOrError =
typeof credentialDefaultProvider === "function" ? credentialDefaultProvider() : undefined;

const { callerClientConfig } = awsIdentityProperties;
stsClient = new STSClient({
...options.clientConfig,
credentials:
options.masterCredentials ??
options.clientConfig?.credentials ??
callerClientConfig?.credentialDefaultProvider?.() ??
defaultCredentialsOrError,
});
}
if (options.clientPlugins) {
for (const plugin of options.clientPlugins) {
stsClient.middlewareStack.use(plugin);
}
}
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new CredentialsProviderError(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`, {
logger: options.logger,
});
}
return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
};
};
};
98 changes: 81 additions & 17 deletions packages/credential-providers/src/fromTemporaryCredentials.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { AssumeRoleCommand, STSClient } from "@aws-sdk/nested-clients/sts";
import { beforeEach, describe, expect, test as it, vi } from "vitest";

import { fromTemporaryCredentials } from "./fromTemporaryCredentials";
import { fromTemporaryCredentials as fromTemporaryCredentialsNode } from "./fromTemporaryCredentials";
import { fromTemporaryCredentials } from "./fromTemporaryCredentials.browser";

const mockSend = vi.fn();
const mockUsePlugin = vi.fn();
Expand Down Expand Up @@ -55,7 +56,7 @@ describe("fromTemporaryCredentials", () => {
clientConfig: { region },
clientPlugins: [plugin],
};
const provider = fromTemporaryCredentials(options);
const provider = fromTemporaryCredentialsNode(options);
const credential = await provider();
expect(credential).toEqual({
accessKeyId: "ACCESS_KEY_ID",
Expand All @@ -77,7 +78,7 @@ describe("fromTemporaryCredentials", () => {

it("should create STS client if not supplied", async () => {
const plugin = { applyToStack: () => {} };
const provider = fromTemporaryCredentials({
const provider = fromTemporaryCredentialsNode({
params: {
RoleArn,
RoleSessionName,
Expand All @@ -93,19 +94,8 @@ describe("fromTemporaryCredentials", () => {
expect(mockUsePlugin).toHaveBeenNthCalledWith(1, plugin);
});

it("should resolve default credentials if master credential is not supplied", async () => {
const provider = fromTemporaryCredentials({
params: {
RoleArn,
RoleSessionName,
},
});
await provider();
expect(vi.mocked(STSClient as any)).toHaveBeenCalledWith({});
});

it("should create a role session name if none provided", async () => {
const provider = fromTemporaryCredentials({
const provider = fromTemporaryCredentialsNode({
params: { RoleArn },
});
await provider();
Expand All @@ -115,6 +105,80 @@ describe("fromTemporaryCredentials", () => {
});
});

describe("nested sts credential resolution order", () => {
const masterCredentials = vi.fn();
const clientConfigCredentials = vi.fn();
const callerClientCredentials = vi.fn();
const chainCredentials = vi.fn();

it("should use with 1st priority masterCredentials from the provider", async () => {
const provider = fromTemporaryCredentials(
{
params: { RoleArn },
masterCredentials: masterCredentials,
clientConfig: {
credentials: clientConfigCredentials,
},
},
chainCredentials
);
await provider({
callerClientConfig: {
region: async () => "us-west-2",
credentialDefaultProvider: callerClientCredentials,
},
});
expect(masterCredentials).toHaveBeenCalled();
});
it("should use with 2nd priority options.clientConfig.credentials", async () => {
const provider = fromTemporaryCredentials(
{
params: { RoleArn },
clientConfig: {
credentials: clientConfigCredentials,
},
},
chainCredentials
);
await provider({
callerClientConfig: {
region: async () => "us-west-2",
credentialDefaultProvider: callerClientCredentials,
},
});
expect(clientConfigCredentials).toHaveBeenCalled();
});
it("should use with 3rd priority caller client's credentialDefaultProvider", async () => {
const provider = fromTemporaryCredentials(
{
params: { RoleArn },
},
chainCredentials
);
await provider({
callerClientConfig: {
region: async () => "us-west-2",
credentialDefaultProvider: callerClientCredentials,
},
});
expect(callerClientCredentials).toHaveBeenCalled();
});
it("should use with 4th priority the node default provider chain (if in Node.js)", async () => {
const provider = fromTemporaryCredentials(
{
params: { RoleArn },
},
chainCredentials
);
await provider({
callerClientConfig: {
region: async () => "us-west-2",
},
});
expect(chainCredentials).toHaveBeenCalled();
});
});

it("should allow assume roles assuming roles assuming roles ad infinitum", async () => {
const roleArnOf = (id: string) => `arn:aws:iam::123456789:role/${id}`;
const idOf = (roleArn: string) => roleArn.split("/")?.[1] ?? "UNKNOWN";
Expand Down Expand Up @@ -176,7 +240,7 @@ describe("fromTemporaryCredentials", () => {
const SerialNumber = "SERIAL_NUMBER";
const mfaCode = "MFA_CODE";
const mfaCodeProvider = vi.fn().mockResolvedValue(mfaCode);
const provider = fromTemporaryCredentials({
const provider = fromTemporaryCredentialsNode({
params: { RoleArn, SerialNumber, RoleSessionName },
mfaCodeProvider,
});
Expand All @@ -197,7 +261,7 @@ describe("fromTemporaryCredentials", () => {
it("should reject the promise with a terminal error if a MFA serial presents but mfaCodeProvider is missing", async () => {
const SerialNumber = "SERIAL_NUMBER";
try {
await fromTemporaryCredentials({
await fromTemporaryCredentialsNode({
params: { RoleArn, SerialNumber, RoleSessionName },
})();
fail("this test must fail");
Expand Down
65 changes: 13 additions & 52 deletions packages/credential-providers/src/fromTemporaryCredentials.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import type { AssumeRoleCommandInput, STSClient, STSClientConfig } from "@aws-sdk/nested-clients/sts";
import type { CredentialProviderOptions } from "@aws-sdk/types";
import { CredentialsProviderError } from "@smithy/property-provider";
import { AwsCredentialIdentity, AwsCredentialIdentityProvider, Pluggable } from "@smithy/types";
import type { RuntimeConfigAwsCredentialIdentityProvider } from "@aws-sdk/types";

export interface FromTemporaryCredentialsOptions extends CredentialProviderOptions {
params: Omit<AssumeRoleCommandInput, "RoleSessionName"> & { RoleSessionName?: string };
masterCredentials?: AwsCredentialIdentity | AwsCredentialIdentityProvider;
clientConfig?: STSClientConfig;
clientPlugins?: Pluggable<any, any>[];
mfaCodeProvider?: (mfaSerial: string) => Promise<string>;
}
import { fromNodeProviderChain } from "./fromNodeProviderChain";
import type { FromTemporaryCredentialsOptions } from "./fromTemporaryCredentials.browser";
import { fromTemporaryCredentials as fromTemporaryCredentialsBase } from "./fromTemporaryCredentials.browser";

/**
* @public
*/
export { FromTemporaryCredentialsOptions };

/**
* Creates a credential provider function that retrieves temporary credentials from STS AssumeRole API.
Expand Down Expand Up @@ -53,45 +51,8 @@ export interface FromTemporaryCredentialsOptions extends CredentialProviderOptio
*
* @public
*/
export const fromTemporaryCredentials = (options: FromTemporaryCredentialsOptions): AwsCredentialIdentityProvider => {
let stsClient: STSClient;
return async (): Promise<AwsCredentialIdentity> => {
options.logger?.debug("@aws-sdk/credential-providers - fromTemporaryCredentials (STS)");
const params = { ...options.params, RoleSessionName: options.params.RoleSessionName ?? "aws-sdk-js-" + Date.now() };
if (params?.SerialNumber) {
if (!options.mfaCodeProvider) {
throw new CredentialsProviderError(
`Temporary credential requires multi-factor authentication,` + ` but no MFA code callback was provided.`,
{
tryNextLink: false,
logger: options.logger,
}
);
}
params.TokenCode = await options.mfaCodeProvider(params?.SerialNumber);
}

const { AssumeRoleCommand, STSClient } = await import("./loadSts");

if (!stsClient) stsClient = new STSClient({ ...options.clientConfig, credentials: options.masterCredentials });
if (options.clientPlugins) {
for (const plugin of options.clientPlugins) {
stsClient.middlewareStack.use(plugin);
}
}
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new CredentialsProviderError(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`, {
logger: options.logger,
});
}
return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
};
};
export const fromTemporaryCredentials = (
options: FromTemporaryCredentialsOptions
): RuntimeConfigAwsCredentialIdentityProvider => {
return fromTemporaryCredentialsBase(options, fromNodeProviderChain);
};
2 changes: 1 addition & 1 deletion packages/credential-providers/src/index.browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ export * from "./fromCognitoIdentity";
export * from "./fromCognitoIdentityPool";
export { fromHttp } from "@aws-sdk/credential-provider-http";
export type { FromHttpOptions, HttpProviderCredentials } from "@aws-sdk/credential-provider-http";
export * from "./fromTemporaryCredentials";
export * from "./fromTemporaryCredentials.browser";
export * from "./fromWebToken";
7 changes: 6 additions & 1 deletion packages/types/src/identity/AwsCredentialIdentity.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { AwsCredentialIdentity } from "@smithy/types";
import type { AwsCredentialIdentity, AwsCredentialIdentityProvider } from "@smithy/types";

import type { AwsSdkCredentialsFeatures } from "../feature-ids";

Expand All @@ -11,6 +11,11 @@ export interface AwsIdentityProperties {
callerClientConfig?: {
region(): Promise<string>;
profile?: string;
/**
* @internal
* @deprecated
*/
credentialDefaultProvider?: (input?: any) => AwsCredentialIdentityProvider;
};
}

Expand Down

0 comments on commit 32fd0ed

Please sign in to comment.