From a271355cf038c66ec31f0469ff2884138ab2cc8a Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Tue, 7 Jan 2025 15:41:50 -0800 Subject: [PATCH] Simplify credential management for Bedrock client (#3255) --- src/helm/clients/bedrock_client.py | 8 ++------ src/helm/clients/bedrock_utils.py | 27 +++++++++++---------------- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/helm/clients/bedrock_client.py b/src/helm/clients/bedrock_client.py index 8bc444f3f41..a74c612709c 100644 --- a/src/helm/clients/bedrock_client.py +++ b/src/helm/clients/bedrock_client.py @@ -39,17 +39,15 @@ def __init__( cache_config: CacheConfig, tokenizer: Tokenizer, tokenizer_name: str, - bedrock_model_id: Optional[str] = None, assumed_role: Optional[str] = None, region: Optional[str] = None, ): super().__init__(cache_config=cache_config) self.tokenizer = tokenizer self.tokenizer_name = tokenizer_name - self.bedrock_model_id = bedrock_model_id self.bedrock_client = get_bedrock_client( assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None), - region=region or os.environ.get("AWS_DEFAULT_REGION", None), + region=region, ) def make_request(self, request: Request) -> RequestResult: @@ -108,17 +106,15 @@ def __init__( cache_config: CacheConfig, tokenizer: Tokenizer, tokenizer_name: str, - bedrock_model_id: Optional[str] = None, assumed_role: Optional[str] = None, region: Optional[str] = None, ): super().__init__(cache_config=cache_config) self.tokenizer = tokenizer self.tokenizer_name = tokenizer_name - self.bedrock_model_id = bedrock_model_id self.bedrock_client = get_bedrock_client_v1( assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None), - region=region or os.environ.get("AWS_DEFAULT_REGION", None), + region=region, ) def convert_request_to_raw_request(self, request: Request) -> Dict: diff --git a/src/helm/clients/bedrock_utils.py b/src/helm/clients/bedrock_utils.py index f2a3c95bcb5..5b90564a48e 100644 --- a/src/helm/clients/bedrock_utils.py +++ b/src/helm/clients/bedrock_utils.py @@ -1,7 +1,7 @@ """Helper utilities for working with Amazon Bedrock.""" import os -from typing import Optional, Dict +from typing import Optional from helm.common.hierarchical_logger import hlog from helm.common.optional_dependencies import handle_module_not_found_error @@ -74,37 +74,32 @@ def get_bedrock_client( def get_bedrock_client_v1( - assumed_role: Optional[str] = None, + region: Optional[str] = None, service_name: str = "bedrock-runtime", - region: Optional[str] = "us-east-1", + assumed_role: Optional[str] = None, read_timeout: int = 5000, connect_timeout: int = 5000, - retries: Dict = {"max_attempts": 10}, + max_attempts: int = 10, ): - if region is None: - target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")) - else: - target_region = region - - boto_config = Config(read_timeout=read_timeout, connect_timeout=connect_timeout, retries=retries) - - if target_region is None: - raise ValueError("region environment variable is not set.") + boto_config = Config( + read_timeout=read_timeout, connect_timeout=connect_timeout, retries={"max_attempts": max_attempts} + ) if assumed_role: - session = boto3.Session(region_name=target_region) + session = boto3.Session(region_name=region) # Assume role and get credentials sts = session.client("sts") creds = sts.assume_role(RoleArn=str(assumed_role), RoleSessionName="crfm-helm")["Credentials"] session = Session( aws_access_key_id=creds["AccessKeyId"], aws_secret_access_key=creds["SecretAccessKey"], + aws_session_token=creds["SessionToken"], ) return session.client( service_name=service_name, - region_name=target_region, + region_name=region, config=boto_config, ) # default to instance role to get the aws credentials or aws configured credentials - return boto3.client(service_name=service_name, region_name=target_region, config=boto_config) + return boto3.client(service_name=service_name, region_name=region, config=boto_config)