Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Amazon Bedrock models #35

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ platformdirs>=3.11.0
datasets>=2.14.6
diskcache>=5.6.3
graphviz>=0.20.3
gdown>=5.2.0
gdown>=5.2.0
boto3>=1.34.133
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this optional? i.e., without adding this to requirements.txt, can we just leave it to the import error handler (like we have for e.g., anthropic.py / openai.py etc., saying people who want to use bedrock models should install boto?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure we can, I'm already managing the import error :) let me change that

7 changes: 6 additions & 1 deletion textgrad/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"opus": "claude-3-opus-20240229",
"haiku": "claude-3-haiku-20240307",
"sonnet": "claude-3-sonnet-20240229",
"together-llama-3-70b": "together-meta-llama/Llama-3-70b-chat-hf",
"together-llama-3-70b": "together-meta-llama/Llama-3-70b-chat-hf"
}

def get_engine(engine_name: str, **kwargs) -> EngineLM:
Expand All @@ -17,6 +17,11 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM:
if (("gpt-4" in engine_name) or ("gpt-3.5" in engine_name)):
from .openai import ChatOpenAI
return ChatOpenAI(model_string=engine_name, **kwargs)
# bedrock incluedes most of the models so first check if the request is for it
elif "bedrock" in engine_name:
from .bedrock import ChatBedrock
engine_name = engine_name.replace("bedrock-", "")
return ChatBedrock(model_string=engine_name, **kwargs)
elif "claude" in engine_name:
from .anthropic import ChatAnthropic
return ChatAnthropic(model_string=engine_name, **kwargs)
Expand Down
157 changes: 157 additions & 0 deletions textgrad/engine/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
try:
import boto3
from botocore.config import Config

except ImportError:
raise ImportError("If you'd like to use Amazon Bedrock models, please install the boto3 package by running `pip install boto3`")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add the related environment variables that would be needed here? i.e. AWS_ACCESS_KEY_ID etc,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 63/68 handles the variables needed. But if you prefer I can move them also here


import os
import platformdirs
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
from .base import EngineLM, CachedEngine


class ChatBedrock(EngineLM, CachedEngine):
SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."

def __init__(
self,
model_string="anthropic.claude-3-sonnet-20240229-v1:0",
system_prompt=SYSTEM_PROMPT,
**kwargs
):
self.system_prompt_supported = True
if "anthropic" in model_string:
self.system_prompt_supported = True
if "meta" in model_string:
self.system_prompt_supported = True
if "cohere" in model_string:
self.system_prompt_supported = True
if "mistral" in model_string:
if "instruct" in model_string:
self.system_prompt_supported = False
else:
self.system_prompt_supported = True
if "amazon" in model_string:
self.system_prompt_supported = False
if "premier" in model_string:
raise ValueError("amazon-titan-premier not supported yet")
if "ai21" in model_string:
self.system_prompt_supported = False
raise ValueError("ai21 not supported yet")

self.max_tokens = kwargs.get("max_tokens", None)
self.aws_region = kwargs.get("region", None)

# handle both AWS interaction options: with default credential or providing AWS ACCESS KEY and SECRET KEY
if boto3._get_default_session().get_credentials() is not None:
if self.aws_region:
self.my_config = Config(region_name = self.aws_region)
self.client = boto3.client(service_name='bedrock-runtime', config=self.my_config)
else:
self.client = boto3.client(service_name='bedrock-runtime')
else:
access_key_id = os.getenv("AWS_ACCESS_KEY_ID", None)
secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY", None)
session_token = os.getenv("AWS_SESSION_TOKEN", None)
if self.aws_region is None:
self.aws_region = os.getenv("AWS_DEFAULT_REGION", None)
if self.aws_region is None:
raise ValueError("AWS region not specified. Please add it in get_engine parameters or has AWS_DEFAULT_REGION var")
if access_key_id is None:
raise ValueError("AWS access key ID cannot be 'None'.")
if secret_access_key is None:
raise ValueError("AWS secret access key cannot be 'None'.")
session = boto3.Session(
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token
)
self.my_config = Config(region_name = self.aws_region)
self.client = session.client(service_name='bedrock-runtime', config=self.my_config)

root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_bedrock_{model_string}.db")
super().__init__(cache_path=cache_path)

self.model_string = model_string
self.system_prompt = system_prompt

assert isinstance(self.system_prompt, str)

@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5))
def __call__(self, prompt, **kwargs):
return self.generate(prompt, **kwargs)

def generate_conversation(self, model_id="", system_prompts=[], messages=[], temperature=0.5, top_k=200, top_p=0.99, max_tokens=2048):
"""
Sends messages to a model.
Args:
bedrock_client: The Boto3 Bedrock runtime client.
model_id (str): The model ID to use.
system_prompts (JSON) : The system prompts for the model to use.
messages (JSON) : The messages to send to the model.

Returns:
response (JSON): The conversation that the model generated.

"""

# Base inference parameters to use.
inference_config = {"temperature": temperature, "topP": top_p, "maxTokens": self.max_tokens if self.max_tokens else max_tokens}
if("anthropic" in model_id):
# Additional inference parameters to use.
additional_model_fields = {"top_k": top_k}
else:
additional_model_fields = {}

# Send the message.
if self.system_prompt_supported:
response = self.client.converse(
modelId=model_id,
messages=messages,
system=system_prompts,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
else:
response = self.client.converse(
modelId=model_id,
messages=messages,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)

return response

def generate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great! Do you want to add support for multiple inputs to this engine? Like here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I can do! Can you provide me a simple example to test the implementation out?

Regarding to images only Anthropic models in Bedrock support vision so I'll enable that only for them!

self, prompt, system_prompt=None, temperature=0, max_tokens=2048, top_p=0.99
):

sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
sys_prompt_args = [{"text": sys_prompt_arg}]
cache_or_none = self._check_cache(sys_prompt_arg + prompt)
if cache_or_none is not None:
return cache_or_none

if self.system_prompt_supported:
messages = [{
"role": "user",
"content": [{"text": prompt}]
}]
else:
messages = [
{
"role": "user",
"content": [{"text": sys_prompt_arg + "\n\n" + prompt}]
}]

response = self.generate_conversation(self.model_string, system_prompts=sys_prompt_args, messages=messages, temperature=temperature, top_p=top_p, max_tokens=max_tokens)

response = response["output"]["message"]["content"][0]["text"]
self._save_cache(sys_prompt_arg + prompt, response)
return response