Skip to content

Commit

Permalink
Simplify credential management for Bedrock client (#3255)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Jan 7, 2025
1 parent bd4d6de commit a271355
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 22 deletions.
8 changes: 2 additions & 6 deletions src/helm/clients/bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,15 @@ def __init__(
cache_config: CacheConfig,
tokenizer: Tokenizer,
tokenizer_name: str,
bedrock_model_id: Optional[str] = None,
assumed_role: Optional[str] = None,
region: Optional[str] = None,
):
super().__init__(cache_config=cache_config)
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer_name
self.bedrock_model_id = bedrock_model_id
self.bedrock_client = get_bedrock_client(
assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None),
region=region or os.environ.get("AWS_DEFAULT_REGION", None),
region=region,
)

def make_request(self, request: Request) -> RequestResult:
Expand Down Expand Up @@ -108,17 +106,15 @@ def __init__(
cache_config: CacheConfig,
tokenizer: Tokenizer,
tokenizer_name: str,
bedrock_model_id: Optional[str] = None,
assumed_role: Optional[str] = None,
region: Optional[str] = None,
):
super().__init__(cache_config=cache_config)
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer_name
self.bedrock_model_id = bedrock_model_id
self.bedrock_client = get_bedrock_client_v1(
assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None),
region=region or os.environ.get("AWS_DEFAULT_REGION", None),
region=region,
)

def convert_request_to_raw_request(self, request: Request) -> Dict:
Expand Down
27 changes: 11 additions & 16 deletions src/helm/clients/bedrock_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Helper utilities for working with Amazon Bedrock."""

import os
from typing import Optional, Dict
from typing import Optional

from helm.common.hierarchical_logger import hlog
from helm.common.optional_dependencies import handle_module_not_found_error
Expand Down Expand Up @@ -74,37 +74,32 @@ def get_bedrock_client(


def get_bedrock_client_v1(
assumed_role: Optional[str] = None,
region: Optional[str] = None,
service_name: str = "bedrock-runtime",
region: Optional[str] = "us-east-1",
assumed_role: Optional[str] = None,
read_timeout: int = 5000,
connect_timeout: int = 5000,
retries: Dict = {"max_attempts": 10},
max_attempts: int = 10,
):
if region is None:
target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
else:
target_region = region

boto_config = Config(read_timeout=read_timeout, connect_timeout=connect_timeout, retries=retries)

if target_region is None:
raise ValueError("region environment variable is not set.")
boto_config = Config(
read_timeout=read_timeout, connect_timeout=connect_timeout, retries={"max_attempts": max_attempts}
)

if assumed_role:
session = boto3.Session(region_name=target_region)
session = boto3.Session(region_name=region)
# Assume role and get credentials
sts = session.client("sts")
creds = sts.assume_role(RoleArn=str(assumed_role), RoleSessionName="crfm-helm")["Credentials"]
session = Session(
aws_access_key_id=creds["AccessKeyId"],
aws_secret_access_key=creds["SecretAccessKey"],
aws_session_token=creds["SessionToken"],
)
return session.client(
service_name=service_name,
region_name=target_region,
region_name=region,
config=boto_config,
)

# default to instance role to get the aws credentials or aws configured credentials
return boto3.client(service_name=service_name, region_name=target_region, config=boto_config)
return boto3.client(service_name=service_name, region_name=region, config=boto_config)

0 comments on commit a271355

Please sign in to comment.