diff --git a/README.md b/README.md index 0f19824..3211e6c 100644 --- a/README.md +++ b/README.md @@ -321,6 +321,13 @@ We are grateful for all the help we got from our contributors! Mert Yuksekgonul + + + Cerrix +
+ Francesco +
+ sugatoray @@ -342,6 +349,8 @@ We are grateful for all the help we got from our contributors! David Ruan + + sanowl @@ -349,8 +358,6 @@ We are grateful for all the help we got from our contributors! San - - huangzhii diff --git a/textgrad/engine/__init__.py b/textgrad/engine/__init__.py index 25b3ebd..a78dac3 100644 --- a/textgrad/engine/__init__.py +++ b/textgrad/engine/__init__.py @@ -41,6 +41,11 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM: elif (("gpt-4" in engine_name) or ("gpt-3.5" in engine_name)): from .openai import ChatOpenAI return ChatOpenAI(model_string=engine_name, is_multimodal=_check_if_multimodal(engine_name), **kwargs) + # bedrock incluedes most of the models so first check if the request is for it + elif "bedrock" in engine_name: + from .bedrock import ChatBedrock + engine_name = engine_name.replace("bedrock-", "") + return ChatBedrock(model_string=engine_name, **kwargs) elif "claude" in engine_name: from .anthropic import ChatAnthropic return ChatAnthropic(model_string=engine_name, is_multimodal=_check_if_multimodal(engine_name), **kwargs) diff --git a/textgrad/engine/bedrock.py b/textgrad/engine/bedrock.py new file mode 100644 index 0000000..80b9d14 --- /dev/null +++ b/textgrad/engine/bedrock.py @@ -0,0 +1,157 @@ +try: + import boto3 + from botocore.config import Config + +except ImportError: + raise ImportError("If you'd like to use Amazon Bedrock models, please install the boto3 package by running `pip install boto3`") + +import os +import platformdirs +from tenacity import ( + retry, + stop_after_attempt, + wait_random_exponential, +) +from .base import EngineLM, CachedEngine + + +class ChatBedrock(EngineLM, CachedEngine): + SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." + + def __init__( + self, + model_string="anthropic.claude-3-sonnet-20240229-v1:0", + system_prompt=SYSTEM_PROMPT, + **kwargs + ): + self.system_prompt_supported = True + if "anthropic" in model_string: + self.system_prompt_supported = True + if "meta" in model_string: + self.system_prompt_supported = True + if "cohere" in model_string: + self.system_prompt_supported = True + if "mistral" in model_string: + if "instruct" in model_string: + self.system_prompt_supported = False + else: + self.system_prompt_supported = True + if "amazon" in model_string: + self.system_prompt_supported = False + if "premier" in model_string: + raise ValueError("amazon-titan-premier not supported yet") + if "ai21" in model_string: + self.system_prompt_supported = False + raise ValueError("ai21 not supported yet") + + self.max_tokens = kwargs.get("max_tokens", None) + self.aws_region = kwargs.get("region", None) + + # handle both AWS interaction options: with default credential or providing AWS ACCESS KEY and SECRET KEY + if boto3._get_default_session().get_credentials() is not None: + if self.aws_region: + self.my_config = Config(region_name = self.aws_region) + self.client = boto3.client(service_name='bedrock-runtime', config=self.my_config) + else: + self.client = boto3.client(service_name='bedrock-runtime') + else: + access_key_id = os.getenv("AWS_ACCESS_KEY_ID", None) + secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY", None) + session_token = os.getenv("AWS_SESSION_TOKEN", None) + if self.aws_region is None: + self.aws_region = os.getenv("AWS_DEFAULT_REGION", None) + if self.aws_region is None: + raise ValueError("AWS region not specified. Please add it in get_engine parameters or as AWS_DEFAULT_REGION env var. You can also provide an AWS role to this environment to use default session credentials") + if access_key_id is None: + raise ValueError("AWS access key ID cannot be 'None'. You can also provide an AWS role to this environment to use default session credentials") + if secret_access_key is None: + raise ValueError("AWS secret access key cannot be 'None'. You can also provide an AWS role to this environment to use default session credentials") + session = boto3.Session( + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token + ) + self.my_config = Config(region_name = self.aws_region) + self.client = session.client(service_name='bedrock-runtime', config=self.my_config) + + root = platformdirs.user_cache_dir("textgrad") + cache_path = os.path.join(root, f"cache_bedrock_{model_string}.db") + super().__init__(cache_path=cache_path) + + self.model_string = model_string + self.system_prompt = system_prompt + + assert isinstance(self.system_prompt, str) + + @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5)) + def __call__(self, prompt, **kwargs): + return self.generate(prompt, **kwargs) + + def generate_conversation(self, model_id="", system_prompts=[], messages=[], temperature=0.5, top_k=200, top_p=0.99, max_tokens=2048): + """ + Sends messages to a model. + Args: + bedrock_client: The Boto3 Bedrock runtime client. + model_id (str): The model ID to use. + system_prompts (JSON) : The system prompts for the model to use. + messages (JSON) : The messages to send to the model. + + Returns: + response (JSON): The conversation that the model generated. + + """ + + # Base inference parameters to use. + inference_config = {"temperature": temperature, "topP": top_p, "maxTokens": self.max_tokens if self.max_tokens else max_tokens} + if("anthropic" in model_id): + # Additional inference parameters to use. + additional_model_fields = {"top_k": top_k} + else: + additional_model_fields = {} + + # Send the message. + if self.system_prompt_supported: + response = self.client.converse( + modelId=model_id, + messages=messages, + system=system_prompts, + inferenceConfig=inference_config, + additionalModelRequestFields=additional_model_fields + ) + else: + response = self.client.converse( + modelId=model_id, + messages=messages, + inferenceConfig=inference_config, + additionalModelRequestFields=additional_model_fields + ) + + return response + + def generate( + self, prompt, system_prompt=None, temperature=0, max_tokens=2048, top_p=0.99 + ): + + sys_prompt_arg = system_prompt if system_prompt else self.system_prompt + sys_prompt_args = [{"text": sys_prompt_arg}] + cache_or_none = self._check_cache(sys_prompt_arg + prompt) + if cache_or_none is not None: + return cache_or_none + + if self.system_prompt_supported: + messages = [{ + "role": "user", + "content": [{"text": prompt}] + }] + else: + messages = [ + { + "role": "user", + "content": [{"text": sys_prompt_arg + "\n\n" + prompt}] + }] + + response = self.generate_conversation(self.model_string, system_prompts=sys_prompt_args, messages=messages, temperature=temperature, top_p=top_p, max_tokens=max_tokens) + + response = response["output"]["message"]["content"][0]["text"] + self._save_cache(sys_prompt_arg + prompt, response) + return response