From 90e5401ee4cc7bb4f75a3431b56ac3a1eff337f4 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Thu, 9 Jan 2025 10:47:48 -0800 Subject: [PATCH] Extract `lf.llms.OpenAICompatible` for OpenAI-compatible LLMs. With this CL, we have a centralized file for adding common features/fixing bugs for OpenAI-compatible LLMs. As a result, adding a new OpenAI-compatible LLM is just to configure the endpoint and overriding a few methods. PiperOrigin-RevId: 713724557 --- langfun/core/llms/deepseek.py | 160 +------ langfun/core/llms/deepseek_test.py | 401 +--------------- langfun/core/llms/groq.py | 111 +---- langfun/core/llms/groq_test.py | 168 ++----- langfun/core/llms/llama_cpp.py | 71 +-- langfun/core/llms/llama_cpp_test.py | 36 +- langfun/core/llms/openai.py | 156 +------ langfun/core/llms/openai_compatible.py | 186 ++++++++ langfun/core/llms/openai_compatible_test.py | 480 ++++++++++++++++++++ langfun/core/llms/openai_test.py | 436 +----------------- 10 files changed, 770 insertions(+), 1435 deletions(-) create mode 100644 langfun/core/llms/openai_compatible.py create mode 100644 langfun/core/llms/openai_compatible_test.py diff --git a/langfun/core/llms/deepseek.py b/langfun/core/llms/deepseek.py index 040b1e6c..afd301ce 100644 --- a/langfun/core/llms/deepseek.py +++ b/langfun/core/llms/deepseek.py @@ -17,8 +17,7 @@ from typing import Annotated, Any import langfun.core as lf -from langfun.core import modalities as lf_modalities -from langfun.core.llms import rest +from langfun.core.llms import openai_compatible import pyglove as pg SUPPORTED_MODELS_AND_SETTINGS = { @@ -39,7 +38,7 @@ # DeepSeek API uses an API format compatible with OpenAI. # Reference: https://api-docs.deepseek.com/ @lf.use_init_args(['model']) -class DeepSeek(rest.REST): +class DeepSeek(openai_compatible.OpenAICompatible): """DeepSeek model.""" model: pg.typing.Annotated[ @@ -51,10 +50,6 @@ class DeepSeek(rest.REST): api_endpoint: str = 'https://api.deepseek.com/chat/completions' - multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = ( - False - ) - api_key: Annotated[ str | None, ( @@ -63,25 +58,18 @@ class DeepSeek(rest.REST): ), ] = None - def _on_bound(self): - super()._on_bound() - self._api_key = None - - def _initialize(self): + @property + def headers(self) -> dict[str, Any]: api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None) if not api_key: raise ValueError( 'Please specify `api_key` during `__init__` or set environment ' 'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.' ) - self._api_key = api_key - - @property - def headers(self) -> dict[str, Any]: - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self._api_key}', - } + headers = super().headers + headers.update({ + 'Authorization': f'Bearer {api_key}', + }) return headers @property @@ -118,138 +106,6 @@ def estimate_cost( def dir(cls): return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service] - def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]: - # Reference: - # https://platform.openai.com/docs/api-reference/completions/create - # NOTE(daiyip): options.top_k is not applicable. - args = dict( - model=self.model, - n=options.n, - top_logprobs=options.top_logprobs, - ) - if options.logprobs: - args['logprobs'] = options.logprobs - - if options.temperature is not None: - args['temperature'] = options.temperature - if options.max_tokens is not None: - args['max_completion_tokens'] = options.max_tokens - if options.top_p is not None: - args['top_p'] = options.top_p - if options.stop: - args['stop'] = options.stop - if options.random_seed is not None: - args['seed'] = options.random_seed - return args - - def _content_from_message(self, message: lf.Message): - """Returns a OpenAI content object from a Langfun message.""" - - def _uri_from(chunk: lf.Modality) -> str: - if chunk.uri and chunk.uri.lower().startswith( - ('http:', 'https:', 'ftp:') - ): - return chunk.uri - return chunk.content_uri - - content = [] - for chunk in message.chunk(): - if isinstance(chunk, str): - item = dict(type='text', text=chunk) - elif isinstance(chunk, lf_modalities.Image) and self.multimodal: - item = dict(type='image_url', image_url=dict(url=_uri_from(chunk))) - else: - raise ValueError(f'Unsupported modality: {chunk!r}.') - content.append(item) - return content - - def request( - self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions - ) -> dict[str, Any]: - """Returns the JSON input for a message.""" - request_args = self._request_args(sampling_options) - - # Users could use `metadata_json_schema` to pass additional - # request arguments. - json_schema = prompt.metadata.get('json_schema') - if json_schema is not None: - if not isinstance(json_schema, dict): - raise ValueError(f'`json_schema` must be a dict, got {json_schema!r}.') - if 'title' not in json_schema: - raise ValueError( - 'The root of `json_schema` must have a `title` field, ' - f'got {json_schema!r}.' - ) - request_args.update( - response_format=dict( - type='json_schema', - json_schema=dict( - schema=json_schema, - name=json_schema['title'], - strict=True, - ), - ) - ) - prompt.metadata.formatted_text = ( - prompt.text - + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' - + pg.to_json_str(request_args['response_format'], json_indent=2) - ) - - # Prepare messages. - messages = [] - # Users could use `metadata_system_message` to pass system message. - system_message = prompt.metadata.get('system_message') - if system_message: - system_message = lf.SystemMessage.from_value(system_message) - messages.append( - dict( - role='system', content=self._content_from_message(system_message) - ) - ) - messages.append( - dict(role='user', content=self._content_from_message(prompt)) - ) - request = dict() - request.update(request_args) - request['messages'] = messages - return request - - def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample: - # Reference: - # https://platform.openai.com/docs/api-reference/chat/object - logprobs = None - choice_logprobs = choice.get('logprobs') - if choice_logprobs: - logprobs = [ - ( - t['token'], - t['logprob'], - [(tt['token'], tt['logprob']) for tt in t['top_logprobs']], - ) - for t in choice_logprobs['content'] - ] - return lf.LMSample( - choice['message']['content'], - score=0.0, - logprobs=logprobs, - ) - - def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: - usage = json['usage'] - return lf.LMSamplingResult( - samples=[self._parse_choice(choice) for choice in json['choices']], - usage=lf.LMSamplingUsage( - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'], - total_tokens=usage['total_tokens'], - estimated_cost=self.estimate_cost( - num_input_tokens=usage['prompt_tokens'], - num_output_tokens=usage['completion_tokens'], - ), - ), - ) - class DeepSeekChat(DeepSeek): """DeepSeek Chat model. diff --git a/langfun/core/llms/deepseek_test.py b/langfun/core/llms/deepseek_test.py index 2db1f97e..3323d601 100644 --- a/langfun/core/llms/deepseek_test.py +++ b/langfun/core/llms/deepseek_test.py @@ -11,72 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for OpenAI models.""" - -from typing import Any import unittest -from unittest import mock - -import langfun.core as lf from langfun.core.llms import deepseek -import pyglove as pg -import requests - - -def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs): - del url, kwargs - messages = json['messages'] - if len(messages) > 1: - system_message = f' system={messages[0]["content"]}' - else: - system_message = '' - - if 'response_format' in json: - response_format = f' format={json["response_format"]["type"]}' - else: - response_format = '' - - choices = [] - for k in range(json['n']): - if json.get('logprobs'): - logprobs = dict( - content=[ - dict( - token='chosen_token', - logprob=0.5, - top_logprobs=[ - dict( - token=f'alternative_token_{i + 1}', - logprob=0.1 - ) for i in range(3) - ] - ) - ] - ) - else: - logprobs = None - - choices.append(dict( - message=dict( - content=( - f'Sample {k} for message.{system_message}{response_format}' - ) - ), - logprobs=logprobs, - )) - response = requests.Response() - response.status_code = 200 - response._content = pg.to_json_str( - dict( - choices=choices, - usage=lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - ), - ) - ).encode() - return response class DeepSeekTest(unittest.TestCase): @@ -87,7 +23,14 @@ def test_dir(self): def test_key(self): with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'): - deepseek.DeepSeekChat()('hi') + _ = deepseek.DeepSeekChat().headers + self.assertEqual( + deepseek.DeepSeekChat(api_key='test_key').headers, + { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer test_key', + } + ) def test_model_id(self): self.assertEqual( @@ -106,333 +49,13 @@ def test_max_concurrency(self): deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0 ) - def test_request_args(self): - self.assertEqual( - deepseek.DeepSeekChat(api_key='test_key')._request_args( - lf.LMSamplingOptions( - temperature=1.0, stop=['\n'], n=1, random_seed=123 - ) - ), - dict( - model='deepseek-chat', - top_logprobs=None, - n=1, - temperature=1.0, - stop=['\n'], - seed=123, - ), - ) - - def test_call_chat_completion(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - lm = deepseek.DeepSeek(model='deepseek-chat', api_key='test_key') - self.assertEqual( - lm('hello', sampling_options=lf.LMSamplingOptions(n=2)), - 'Sample 0 for message.', - ) - - def test_call_chat_completion_with_logprobs(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - lm = deepseek.DeepSeek(model='deepseek-chat', api_key='test_key') - results = lm.sample(['hello'], logprobs=True) - self.assertEqual(len(results), 1) - expected = lf.LMSamplingResult( - [ - lf.LMSample( - response=lf.AIMessage( - text='Sample 0 for message.', - metadata={ - 'score': 0.0, - 'logprobs': [( - 'chosen_token', - 0.5, - [ - ('alternative_token_1', 0.1), - ('alternative_token_2', 0.1), - ('alternative_token_3', 0.1), - ], - )], - 'is_cached': False, - 'usage': lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - estimated_cost=4.2e-05, - ), - }, - tags=['lm-response'], - ), - logprobs=[( - 'chosen_token', - 0.5, - [ - ('alternative_token_1', 0.1), - ('alternative_token_2', 0.1), - ('alternative_token_3', 0.1), - ], - )], - ) - ], - usage=lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - estimated_cost=4.2e-05, - ), - ) - self.assertTrue(pg.eq(results[0], expected)) - - def test_sample_chat_completion(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - deepseek.SUPPORTED_MODELS_AND_SETTINGS['deepseek-chat'].update({ - 'cost_per_1k_input_tokens': 1.0, - 'cost_per_1k_output_tokens': 1.0, - }) - lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat') - results = lm.sample( - ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3) - ) - - self.assertEqual(len(results), 2) - print(results[0]) - self.assertEqual( - results[0], - lf.LMSamplingResult( - [ - lf.LMSample( - lf.AIMessage( - 'Sample 0 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 1 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 2 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - ], - usage=lf.LMSamplingUsage( - prompt_tokens=100, completion_tokens=100, total_tokens=200, - estimated_cost=0.2, - ), - ), - ) + def test_estimate_cost(self): self.assertEqual( - results[1], - lf.LMSamplingResult( - [ - lf.LMSample( - lf.AIMessage( - 'Sample 0 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 1 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 2 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - ], - usage=lf.LMSamplingUsage( - prompt_tokens=100, completion_tokens=100, total_tokens=200, - estimated_cost=0.2, - ), + deepseek.DeepSeekChat(api_key='test_key').estimate_cost( + num_input_tokens=100, num_output_tokens=100 ), + 4.2e-5 ) - def test_sample_with_contextual_options(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat') - with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)): - results = lm.sample(['hello']) - - self.assertEqual(len(results), 1) - expected = lf.LMSamplingResult( - samples=[ - lf.LMSample( - response=lf.AIMessage( - text='Sample 0 for message.', - sender='AI', - metadata=pg.Dict( - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=50, - completion_tokens=50, - total_tokens=100, - num_requests=1, - estimated_cost=0.1, - ), - ), - tags=['lm-response'], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - response=lf.AIMessage( - text='Sample 1 for message.', - sender='AI', - metadata=pg.Dict( - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=50, - completion_tokens=50, - total_tokens=100, - num_requests=1, - estimated_cost=0.1, - ), - ), - tags=['lm-response'], - ), - score=0.0, - logprobs=None, - ), - ], - usage=lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - num_requests=1, - estimated_cost=0.2, - ), - is_cached=False, - ) - self.assertTrue(pg.eq(results[0], expected)) - - def test_call_with_system_message(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat') - self.assertEqual( - lm( - lf.UserMessage( - 'hello', - system_message='hi', - ), - sampling_options=lf.LMSamplingOptions(n=2) - ), - '''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''', - ) - - def test_call_with_json_schema(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat') - self.assertEqual( - lm( - lf.UserMessage( - 'hello', - json_schema={ - 'type': 'object', - 'properties': { - 'name': {'type': 'string'}, - }, - 'required': ['name'], - 'title': 'Person', - } - ), - sampling_options=lf.LMSamplingOptions(n=2) - ), - 'Sample 0 for message. format=json_schema', - ) - - # Test bad json schema. - with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'): - lm(lf.UserMessage('hello', json_schema='foo')) - - with self.assertRaisesRegex( - ValueError, 'The root of `json_schema` must have a `title` field' - ): - lm(lf.UserMessage('hello', json_schema={})) - - if __name__ == '__main__': unittest.main() diff --git a/langfun/core/llms/groq.py b/langfun/core/llms/groq.py index 6610bb1e..2146190c 100644 --- a/langfun/core/llms/groq.py +++ b/langfun/core/llms/groq.py @@ -17,8 +17,7 @@ from typing import Annotated, Any import langfun.core as lf -from langfun.core import modalities as lf_modalities -from langfun.core.llms import rest +from langfun.core.llms import openai_compatible import pyglove as pg @@ -95,7 +94,7 @@ @lf.use_init_args(['model']) -class Groq(rest.REST): +class Groq(openai_compatible.OpenAICompatible): """Groq LLMs through REST APIs (OpenAI compatible). See https://platform.openai.com/docs/api-reference/chat @@ -108,10 +107,6 @@ class Groq(rest.REST): 'The name of the model to use.', ] - multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = ( - False - ) - api_key: Annotated[ str | None, ( @@ -122,25 +117,19 @@ class Groq(rest.REST): api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions' - def _on_bound(self): - super()._on_bound() - self._api_key = None - - def _initialize(self): + @property + def headers(self) -> dict[str, Any]: api_key = self.api_key or os.environ.get('GROQ_API_KEY', None) if not api_key: raise ValueError( 'Please specify `api_key` during `__init__` or set environment ' 'variable `GROQ_API_KEY` with your Groq API key.' ) - self._api_key = api_key - - @property - def headers(self) -> dict[str, Any]: - return { - 'Authorization': f'Bearer {self._api_key}', - 'Content-Type': 'application/json', - } + headers = super().headers + headers.update({ + 'Authorization': f'Bearer {api_key}', + }) + return headers @property def model_id(self) -> str: @@ -170,90 +159,14 @@ def estimate_cost( + cost_per_1k_output_tokens * num_output_tokens ) / 1000 - def request( - self, - prompt: lf.Message, - sampling_options: lf.LMSamplingOptions - ) -> dict[str, Any]: - """Returns the JSON input for a message.""" - request = dict() - request.update(self._request_args(sampling_options)) - request.update( - dict( - messages=[ - dict(role='user', content=self._content_from_message(prompt)) - ] - ) - ) - return request - def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]: """Returns a dict as request arguments.""" # `logprobs` and `top_logprobs` flags are not supported on Groq yet. - args = dict( - model=self.model, - n=options.n, - stream=False, - ) - - if options.temperature is not None: - args['temperature'] = options.temperature - if options.max_tokens is not None: - args['max_tokens'] = options.max_tokens - if options.top_p is not None: - args['top_p'] = options.top_p - if options.stop: - args['stop'] = options.stop + args = super()._request_args(options) + args.pop('logprobs', None) + args.pop('top_logprobs', None) return args - def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]: - """Converts an message to Groq's content protocol (list of dicts).""" - # Refer: https://platform.openai.com/docs/api-reference/chat/create - content = [] - for chunk in prompt.chunk(): - if isinstance(chunk, str): - item = dict(type='text', text=chunk) - elif ( - self.multimodal - and isinstance(chunk, lf_modalities.Image) - and chunk.uri - ): - # NOTE(daiyip): Groq only support image URL. - item = dict(type='image_url', image_url=chunk.uri) - else: - raise ValueError(f'Unsupported modality object: {chunk!r}.') - content.append(item) - return content - - def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: - samples = [ - lf.LMSample(self._message_from_choice(choice), score=0.0) - for choice in json['choices'] - ] - usage = json['usage'] - return lf.LMSamplingResult( - samples, - usage=lf.LMSamplingUsage( - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'], - total_tokens=usage['total_tokens'], - estimated_cost=self.estimate_cost( - num_input_tokens=usage['prompt_tokens'], - num_output_tokens=usage['completion_tokens'], - ), - ), - ) - - def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message: - """Converts Groq's content protocol to message.""" - # Refer: https://platform.openai.com/docs/api-reference/chat/create - content = choice['message']['content'] - if isinstance(content, str): - return lf.AIMessage(content) - return lf.AIMessage.from_chunks( - [x['text'] for x in content if x['type'] == 'text'] - ) - class GroqLlama3_2_3B(Groq): # pylint: disable=invalid-name """Llama3.2-3B with 8K context window. diff --git a/langfun/core/llms/groq_test.py b/langfun/core/llms/groq_test.py index f42e2fff..3e93ff03 100644 --- a/langfun/core/llms/groq_test.py +++ b/langfun/core/llms/groq_test.py @@ -11,89 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Groq models.""" - import os -from typing import Any import unittest -from unittest import mock -from langfun.core import modalities as lf_modalities +import langfun.core as lf from langfun.core.llms import groq -import pyglove as pg -import requests - - -def mock_requests_post(url: str, json: dict[str, Any], **kwargs): - del url, kwargs - - response = requests.Response() - response.status_code = 200 - response._content = pg.to_json_str({ - 'choices': [{ - 'message': { - 'content': [{ - 'type': 'text', - 'text': ( - f'hello with temperature={json.get("temperature")}, ' - f'top_p={json.get("top_p")}, ' - f'max_tokens={json.get("max_tokens")}, ' - f'stop={json.get("stop")}.' - ), - }], - } - }], - 'usage': { - 'prompt_tokens': 2, - 'completion_tokens': 1, - 'total_tokens': 3, - }, - }).encode() - return response - - -def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs): - del url, kwargs - v = json['messages'][0]['content'][0] - image = lf_modalities.Image.from_uri(v['image_url']) - - response = requests.Response() - response.status_code = 200 - response._content = pg.to_json_str({ - 'choices': [ - { - 'message': { - 'content': [{ - 'type': 'text', - 'text': image.uri, - }], - } - } - ], - 'usage': { - 'prompt_tokens': 2, - 'completion_tokens': 1, - 'total_tokens': 3, - }, - }).encode() - return response - - -def mock_requests_post_error(status_code, error_type, error_message): - def _mock_requests(url: str, json: dict[str, Any], **kwargs): - del url, json, kwargs - response = requests.Response() - response.status_code = status_code - response._content = pg.to_json_str( - { - 'error': { - 'type': error_type, - 'message': error_message, - } - } - ).encode() - return response - - return _mock_requests class AuthropicTest(unittest.TestCase): @@ -101,69 +22,42 @@ class AuthropicTest(unittest.TestCase): def test_basics(self): self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768') self.assertEqual(groq.GroqMistral_8x7B().max_concurrency, 16) + self.assertEqual(groq.GroqMistral_8x7B().estimate_cost(100, 100), 4.8e-5) + + def test_request_args(self): + args = groq.GroqMistral_8x7B()._request_args( + lf.LMSamplingOptions( + temperature=1.0, stop=['\n'], n=1, random_seed=123, + logprobs=True, top_logprobs=True + ) + ) + self.assertNotIn('logprobs', args) + self.assertNotIn('top_logprobs', args) def test_api_key(self): lm = groq.GroqMistral_8x7B() with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'): - lm('hi') - - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_requests_post - - lm = groq.GroqMistral_8x7B(api_key='fake key') - self.assertRegex(lm('hi').text, 'hello.*') - - os.environ['GROQ_API_KEY'] = 'abc' - lm = groq.GroqMistral_8x7B() - self.assertRegex(lm('hi').text, 'hello.*') - del os.environ['GROQ_API_KEY'] - - def test_call(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_requests_post - lm = groq.GroqLlama3_70B(api_key='fake_key') - response = lm( - 'hello', - temperature=0.0, - max_tokens=1024, - top_k=0.1, - top_p=0.2, - stop=['\n'], - ) - self.assertEqual( - response.text, - ( - 'hello with temperature=0.0, top_p=0.2, ' - "max_tokens=1024, stop=['\\n']." - ), - ) - self.assertIsNotNone(response.usage) - self.assertIsNotNone(response.usage.prompt_tokens, 2) - self.assertIsNotNone(response.usage.completion_tokens, 1) - self.assertIsNotNone(response.usage.total_tokens, 3) + _ = lm.headers - def test_mm_call(self): - with mock.patch('requests.Session.post') as mock_mm_request: - mock_mm_request.side_effect = mock_mm_requests_post - lm = groq.GroqLlama3_70B(multimodal=True, api_key='fake_key') - response = lm(lf_modalities.Image.from_uri('https://fake/image.jpg')) - self.assertEqual(response.text, 'https://fake/image.jpg') + lm = groq.GroqMistral_8x7B(api_key='fake key') + self.assertEqual( + lm.headers, + { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer fake key', + } + ) - def test_call_errors(self): - for status_code, error_type, error_message in [ - (429, 'rate_limit', 'Rate limit exceeded.'), - (503, 'service_unavailable', 'Service unavailable.'), - (500, 'bad_request', 'Bad request.'), - ]: - with mock.patch('requests.Session.post') as mock_mm_request: - mock_mm_request.side_effect = mock_requests_post_error( - status_code, error_type, error_message - ) - lm = groq.GroqLlama3_70B(api_key='fake_key') - with self.assertRaisesRegex( - Exception, f'{status_code}:.*{error_type}' - ): - lm('hello', max_attempts=1) + os.environ['GROQ_API_KEY'] = 'abc' + lm = groq.GroqMistral_8x7B() + self.assertEqual( + lm.headers, + { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer abc', + } + ) + del os.environ['GROQ_API_KEY'] if __name__ == '__main__': diff --git a/langfun/core/llms/llama_cpp.py b/langfun/core/llms/llama_cpp.py index c6a7b7c2..3c45b559 100644 --- a/langfun/core/llms/llama_cpp.py +++ b/langfun/core/llms/llama_cpp.py @@ -13,72 +13,35 @@ # limitations under the License. """Language models from llama.cpp.""" -from typing import Any - -import langfun.core as lf -from langfun.core.llms import rest +from typing import Annotated +from langfun.core.llms import openai_compatible import pyglove as pg -class LlamaCppRemote(rest.REST): +@pg.use_init_args(['url', 'model']) +@pg.members([('api_endpoint', pg.typing.Str().freeze(''))]) +class LlamaCppRemote(openai_compatible.OpenAICompatible): """The remote LLaMA C++ model. The Remote LLaMA C++ models can be launched via https://github.com/ggerganov/llama.cpp/tree/master/examples/server """ + url: Annotated[ + str, + 'The URL of the LLaMA C++ server.', + ] + + model: Annotated[ + str, + 'The name of the model to use.', + ] = '' - @pg.explicit_method_override - def __init__(self, url: str, model: str | None = None, **kwargs): - super().__init__(api_endpoint=f'{url}/completion', model=model, **kwargs) + @property + def api_endpoint(self) -> str: + return self.url + '/completion' @property def model_id(self) -> str: """Returns a string to identify the model.""" return f'LLaMAC++({self.model or ""})' - def request( - self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions - ) -> dict[str, Any]: - """Returns the JSON input for a message.""" - request = dict() - request.update(self._request_args(sampling_options)) - # NOTE(daiyip): multi-modal is current not supported. - request['prompt'] = prompt.text - return request - - def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]: - """Returns a dict as request arguments.""" - args = dict( - n_predict=options.max_tokens or 1024, - top_k=options.top_k or 50, - top_p=options.top_p or 0.95, - ) - if options.temperature is not None: - args['temperature'] = options.temperature - return args - - def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: - return lf.LMSamplingResult( - [lf.LMSample(item['content'], score=0.0) for item in json['items']] - ) - - def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult: - request = self.request(prompt, self.sampling_options) - - def _sample_one_example(request): - response = self._session.post( - self.api_endpoint, - json=request, - timeout=self.timeout, - ) - if response.status_code == 200: - return response.json() - else: - error_cls = self._error_cls_from_status(response.status_code) - raise error_cls(f'{response.status_code}: {response.content}') - - items = self._parallel_execute_with_currency_control( - _sample_one_example, - [request] * (self.sampling_options.n or 1), - ) - return self.result(dict(items=items)) diff --git a/langfun/core/llms/llama_cpp_test.py b/langfun/core/llms/llama_cpp_test.py index 9d3914d2..879db8f5 100644 --- a/langfun/core/llms/llama_cpp_test.py +++ b/langfun/core/llms/llama_cpp_test.py @@ -11,48 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for llama cpp models.""" - -import typing import unittest -from unittest import mock - from langfun.core.llms import llama_cpp -def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs): - del kwargs - - class TEMP: - @property - def status_code(self): - return 200 - - def json(self): - return {"content": json["prompt"] + "\n" + url} - - return TEMP() - - class LlamaCppRemoteTest(unittest.TestCase): """Tests for the LlamaCppRemote model.""" - def test_call_completion(self): - with mock.patch("requests.Session.post") as mock_request: - mock_request.side_effect = mock_requests_post - lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080") - [result] = lm.sample(["hello"], n=2) - self.assertEqual( - len(result.samples), - 2 - ) - self.assertEqual( - str(result.samples[0].response), - "hello\nhttp://127.0.0.1:8080/completion", - ) - - def test_model_id(self): + def test_basics(self): lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080") + self.assertEqual(lm.api_endpoint, "http://127.0.0.1:8080/completion") self.assertEqual(lm.model_id, "LLaMAC++()") lm = llama_cpp.LlamaCppRemote("xxx", model="x") self.assertEqual(lm.model_id, "LLaMAC++(x)") diff --git a/langfun/core/llms/openai.py b/langfun/core/llms/openai.py index aefe8377..f6171b82 100644 --- a/langfun/core/llms/openai.py +++ b/langfun/core/llms/openai.py @@ -17,8 +17,7 @@ from typing import Annotated, Any import langfun.core as lf -from langfun.core import modalities as lf_modalities -from langfun.core.llms import rest +from langfun.core.llms import openai_compatible import pyglove as pg @@ -299,7 +298,7 @@ @lf.use_init_args(['model']) -class OpenAI(rest.REST): +class OpenAI(openai_compatible.OpenAICompatible): """OpenAI model.""" model: pg.typing.Annotated[ @@ -311,11 +310,6 @@ class OpenAI(rest.REST): api_endpoint: str = 'https://api.openai.com/v1/chat/completions' - multimodal: Annotated[ - bool, - 'Whether this model has multimodal support.' - ] = False - api_key: Annotated[ str | None, ( @@ -363,10 +357,9 @@ def _initialize(self): @property def headers(self) -> dict[str, Any]: - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self._api_key}', - } + assert self._api_initialized + headers = super().headers + headers['Authorization'] = f'Bearer {self._api_key}' if self._organization: headers['OpenAI-Organization'] = self._organization if self._project: @@ -411,141 +404,10 @@ def dir(cls): def _request_args( self, options: lf.LMSamplingOptions) -> dict[str, Any]: - # Reference: - # https://platform.openai.com/docs/api-reference/completions/create - # NOTE(daiyip): options.top_k is not applicable. - args = dict( - model=self.model, - n=options.n, - top_logprobs=options.top_logprobs, - ) - if options.logprobs: - # Reasoning models (o1 series) does not support `logprobs` by 2024/09/12. - if self.model.startswith('o1-'): - raise RuntimeError('`logprobs` is not supported on {self.model!r}.') - args['logprobs'] = options.logprobs - - if options.temperature is not None: - args['temperature'] = options.temperature - if options.max_tokens is not None: - args['max_completion_tokens'] = options.max_tokens - if options.top_p is not None: - args['top_p'] = options.top_p - if options.stop: - args['stop'] = options.stop - if options.random_seed is not None: - args['seed'] = options.random_seed - return args - - def _content_from_message(self, message: lf.Message): - """Returns a OpenAI content object from a Langfun message.""" - def _uri_from(chunk: lf.Modality) -> str: - if chunk.uri and chunk.uri.lower().startswith( - ('http:', 'https:', 'ftp:') - ): - return chunk.uri - return chunk.content_uri - - content = [] - for chunk in message.chunk(): - if isinstance(chunk, str): - item = dict(type='text', text=chunk) - elif isinstance(chunk, lf_modalities.Image) and self.multimodal: - item = dict(type='image_url', image_url=dict(url=_uri_from(chunk))) - else: - raise ValueError(f'Unsupported modality: {chunk!r}.') - content.append(item) - return content - - def request( - self, - prompt: lf.Message, - sampling_options: lf.LMSamplingOptions - ) -> dict[str, Any]: - """Returns the JSON input for a message.""" - request_args = self._request_args(sampling_options) - - # Users could use `metadata_json_schema` to pass additional - # request arguments. - json_schema = prompt.metadata.get('json_schema') - if json_schema is not None: - if not isinstance(json_schema, dict): - raise ValueError( - f'`json_schema` must be a dict, got {json_schema!r}.' - ) - if 'title' not in json_schema: - raise ValueError( - f'The root of `json_schema` must have a `title` field, ' - f'got {json_schema!r}.' - ) - request_args.update( - response_format=dict( - type='json_schema', - json_schema=dict( - schema=json_schema, - name=json_schema['title'], - strict=True, - ) - ) - ) - prompt.metadata.formatted_text = ( - prompt.text - + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' - + pg.to_json_str(request_args['response_format'], json_indent=2) - ) - - # Prepare messages. - messages = [] - # Users could use `metadata_system_message` to pass system message. - system_message = prompt.metadata.get('system_message') - if system_message: - system_message = lf.SystemMessage.from_value(system_message) - messages.append( - dict(role='system', - content=self._content_from_message(system_message)) - ) - messages.append( - dict(role='user', content=self._content_from_message(prompt)) - ) - request = dict() - request.update(request_args) - request['messages'] = messages - return request - - def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample: - # Reference: - # https://platform.openai.com/docs/api-reference/chat/object - logprobs = None - choice_logprobs = choice.get('logprobs') - if choice_logprobs: - logprobs = [ - ( - t['token'], - t['logprob'], - [(tt['token'], tt['logprob']) for tt in t['top_logprobs']], - ) - for t in choice_logprobs['content'] - ] - return lf.LMSample( - choice['message']['content'], - score=0.0, - logprobs=logprobs, - ) - - def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: - usage = json['usage'] - return lf.LMSamplingResult( - samples=[self._parse_choice(choice) for choice in json['choices']], - usage=lf.LMSamplingUsage( - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'], - total_tokens=usage['total_tokens'], - estimated_cost=self.estimate_cost( - num_input_tokens=usage['prompt_tokens'], - num_output_tokens=usage['completion_tokens'], - ) - ), - ) + # Reasoning models (o1 series) does not support `logprobs` by 2024/09/12. + if options.logprobs and self.model.startswith(('o1-', 'o3-')): + raise RuntimeError('`logprobs` is not supported on {self.model!r}.') + return super()._request_args(options) class GptO1(OpenAI): diff --git a/langfun/core/llms/openai_compatible.py b/langfun/core/llms/openai_compatible.py new file mode 100644 index 00000000..9f7ce81e --- /dev/null +++ b/langfun/core/llms/openai_compatible.py @@ -0,0 +1,186 @@ +# Copyright 2025 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base for OpenAI compatible models (including OpenAI).""" + +from typing import Annotated, Any + +import langfun.core as lf +from langfun.core import modalities as lf_modalities +from langfun.core.llms import rest +import pyglove as pg + + +@lf.use_init_args(['model']) +class OpenAICompatible(rest.REST): + """Base for OpenAI compatible models.""" + + model: Annotated[ + str, 'The name of the model to use.', + ] + + multimodal: Annotated[ + bool, 'Whether this model has multimodal support.' + ] = False + + @property + def headers(self) -> dict[str, Any]: + return { + 'Content-Type': 'application/json' + } + + def _request_args( + self, options: lf.LMSamplingOptions) -> dict[str, Any]: + """Returns a dict as request arguments.""" + # Reference: + # https://platform.openai.com/docs/api-reference/completions/create + # NOTE(daiyip): options.top_k is not applicable. + args = dict( + model=self.model, + n=options.n, + top_logprobs=options.top_logprobs, + ) + if options.logprobs: + args['logprobs'] = options.logprobs + + if options.temperature is not None: + args['temperature'] = options.temperature + if options.max_tokens is not None: + args['max_completion_tokens'] = options.max_tokens + if options.top_p is not None: + args['top_p'] = options.top_p + if options.stop: + args['stop'] = options.stop + if options.random_seed is not None: + args['seed'] = options.random_seed + return args + + def _content_from_message(self, message: lf.Message) -> list[dict[str, Any]]: + """Returns a OpenAI content object from a Langfun message.""" + def _uri_from(chunk: lf.Modality) -> str: + if chunk.uri and chunk.uri.lower().startswith( + ('http:', 'https:', 'ftp:') + ): + return chunk.uri + return chunk.content_uri + + content = [] + for chunk in message.chunk(): + if isinstance(chunk, str): + item = dict(type='text', text=chunk) + elif isinstance(chunk, lf_modalities.Image) and self.multimodal: + item = dict(type='image_url', image_url=dict(url=_uri_from(chunk))) + else: + raise ValueError(f'Unsupported modality: {chunk!r}.') + content.append(item) + return content + + def request( + self, + prompt: lf.Message, + sampling_options: lf.LMSamplingOptions + ) -> dict[str, Any]: + """Returns the JSON input for a message.""" + request_args = self._request_args(sampling_options) + + # Users could use `metadata_json_schema` to pass additional + # request arguments. + json_schema = prompt.metadata.get('json_schema') + if json_schema is not None: + if not isinstance(json_schema, dict): + raise ValueError( + f'`json_schema` must be a dict, got {json_schema!r}.' + ) + if 'title' not in json_schema: + raise ValueError( + f'The root of `json_schema` must have a `title` field, ' + f'got {json_schema!r}.' + ) + request_args.update( + response_format=dict( + type='json_schema', + json_schema=dict( + schema=json_schema, + name=json_schema['title'], + strict=True, + ) + ) + ) + prompt.metadata.formatted_text = ( + prompt.text + + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' + + pg.to_json_str(request_args['response_format'], json_indent=2) + ) + + # Prepare messages. + messages = [] + # Users could use `metadata_system_message` to pass system message. + system_message = prompt.metadata.get('system_message') + if system_message: + system_message = lf.SystemMessage.from_value(system_message) + messages.append( + dict(role='system', + content=self._content_from_message(system_message)) + ) + messages.append( + dict(role='user', content=self._content_from_message(prompt)) + ) + request = dict() + request.update(request_args) + request['messages'] = messages + return request + + def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample: + # Reference: + # https://platform.openai.com/docs/api-reference/chat/object + logprobs = None + choice_logprobs = choice.get('logprobs') + if choice_logprobs: + logprobs = [ + ( + t['token'], + t['logprob'], + [(tt['token'], tt['logprob']) for tt in t['top_logprobs']], + ) + for t in choice_logprobs['content'] + ] + return lf.LMSample( + choice['message']['content'], + score=0.0, + logprobs=logprobs, + ) + + def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: + """Returns a LMSamplingResult from a JSON response.""" + usage = json['usage'] + return lf.LMSamplingResult( + samples=[self._parse_choice(choice) for choice in json['choices']], + usage=lf.LMSamplingUsage( + prompt_tokens=usage['prompt_tokens'], + completion_tokens=usage['completion_tokens'], + total_tokens=usage['total_tokens'], + estimated_cost=self.estimate_cost( + num_input_tokens=usage['prompt_tokens'], + num_output_tokens=usage['completion_tokens'], + ) + ), + ) + + def estimate_cost( + self, + num_input_tokens: int, + num_output_tokens: int + ) -> float | None: + """Estimate the cost based on usage.""" + del num_input_tokens, num_output_tokens + return None diff --git a/langfun/core/llms/openai_compatible_test.py b/langfun/core/llms/openai_compatible_test.py new file mode 100644 index 00000000..14e11897 --- /dev/null +++ b/langfun/core/llms/openai_compatible_test.py @@ -0,0 +1,480 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for OpenAI models.""" + +from typing import Any +import unittest +from unittest import mock + +import langfun.core as lf +from langfun.core import modalities as lf_modalities +from langfun.core.llms import openai_compatible +import pyglove as pg +import requests + + +def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs): + del url, kwargs + messages = json['messages'] + if len(messages) > 1: + system_message = f' system={messages[0]["content"]}' + else: + system_message = '' + + if 'response_format' in json: + response_format = f' format={json["response_format"]["type"]}' + else: + response_format = '' + + choices = [] + for k in range(json['n']): + if json.get('logprobs'): + logprobs = dict( + content=[ + dict( + token='chosen_token', + logprob=0.5, + top_logprobs=[ + dict( + token=f'alternative_token_{i + 1}', + logprob=0.1 + ) for i in range(3) + ] + ) + ] + ) + else: + logprobs = None + + choices.append(dict( + message=dict( + content=( + f'Sample {k} for message.{system_message}{response_format}' + ) + ), + logprobs=logprobs, + )) + response = requests.Response() + response.status_code = 200 + response._content = pg.to_json_str( + dict( + choices=choices, + usage=lf.LMSamplingUsage( + prompt_tokens=100, + completion_tokens=100, + total_tokens=200, + ), + ) + ).encode() + return response + + +def mock_chat_completion_request_vision( + url: str, json: dict[str, Any], **kwargs +): + del url, kwargs + choices = [] + urls = [ + c['image_url']['url'] + for c in json['messages'][0]['content'] if c['type'] == 'image_url' + ] + for k in range(json['n']): + choices.append(pg.Dict( + message=pg.Dict( + content=f'Sample {k} for message: {"".join(urls)}' + ), + logprobs=None, + )) + response = requests.Response() + response.status_code = 200 + response._content = pg.to_json_str( + dict( + choices=choices, + usage=lf.LMSamplingUsage( + prompt_tokens=100, + completion_tokens=100, + total_tokens=200, + ), + ) + ).encode() + return response + + +class OpenAIComptibleTest(unittest.TestCase): + """Tests for OpenAI compatible language model.""" + + def test_request_args(self): + self.assertEqual( + openai_compatible.OpenAICompatible( + api_endpoint='https://test-server', + model='test-model' + )._request_args( + lf.LMSamplingOptions( + temperature=1.0, stop=['\n'], n=1, random_seed=123 + ) + ), + dict( + model='test-model', + top_logprobs=None, + n=1, + temperature=1.0, + stop=['\n'], + seed=123, + ), + ) + + def test_call_chat_completion(self): + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request + lm = openai_compatible.OpenAICompatible( + api_endpoint='https://test-server', model='test-model', + ) + self.assertEqual( + lm('hello', sampling_options=lf.LMSamplingOptions(n=2)), + 'Sample 0 for message.', + ) + + def test_call_chat_completion_with_logprobs(self): + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request + lm = openai_compatible.OpenAICompatible( + api_endpoint='https://test-server', model='test-model', + ) + results = lm.sample(['hello'], logprobs=True) + self.assertEqual(len(results), 1) + self.assertEqual( + results[0], + lf.LMSamplingResult( + [ + lf.LMSample( + response=lf.AIMessage( + text='Sample 0 for message.', + metadata={ + 'score': 0.0, + 'logprobs': [( + 'chosen_token', + 0.5, + [ + ('alternative_token_1', 0.1), + ('alternative_token_2', 0.1), + ('alternative_token_3', 0.1), + ], + )], + 'is_cached': False, + 'usage': lf.LMSamplingUsage( + prompt_tokens=100, + completion_tokens=100, + total_tokens=200, + estimated_cost=None, + ), + }, + tags=['lm-response'], + ), + logprobs=[( + 'chosen_token', + 0.5, + [ + ('alternative_token_1', 0.1), + ('alternative_token_2', 0.1), + ('alternative_token_3', 0.1), + ], + )], + ) + ], + usage=lf.LMSamplingUsage( + prompt_tokens=100, + completion_tokens=100, + total_tokens=200, + estimated_cost=None, + ), + ), + ) + + def test_call_chat_completion_vision(self): + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request_vision + lm_1 = openai_compatible.OpenAICompatible( + api_endpoint='https://test-server', + model='test-model1', + multimodal=True + ) + lm_2 = openai_compatible.OpenAICompatible( + api_endpoint='https://test-server', + model='test-model2', + multimodal=True + ) + for lm in (lm_1, lm_2): + self.assertEqual( + lm( + lf.UserMessage( + 'hello <<[[image]]>>', + image=lf_modalities.Image.from_uri('https://fake/image') + ), + sampling_options=lf.LMSamplingOptions(n=2) + ), + 'Sample 0 for message: https://fake/image', + ) + lm_3 = openai_compatible.OpenAICompatible( + api_endpoint='https://test-server', model='test-model3' + ) + with self.assertRaisesRegex(ValueError, 'Unsupported modality'): + lm_3( + lf.UserMessage( + 'hello <<[[image]]>>', + image=lf_modalities.Image.from_uri('https://fake/image') + ), + ) + + def test_sample_chat_completion(self): + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request + lm = openai_compatible.OpenAICompatible( + api_endpoint='https://test-server', model='test-model' + ) + results = lm.sample( + ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3) + ) + + self.assertEqual(len(results), 2) + self.assertEqual( + results[0], + lf.LMSamplingResult( + [ + lf.LMSample( + lf.AIMessage( + 'Sample 0 for message.', + score=0.0, + logprobs=None, + is_cached=False, + usage=lf.LMSamplingUsage( + prompt_tokens=33, + completion_tokens=33, + total_tokens=66, + estimated_cost=None, + ), + tags=[lf.Message.TAG_LM_RESPONSE], + ), + score=0.0, + logprobs=None, + ), + lf.LMSample( + lf.AIMessage( + 'Sample 1 for message.', + score=0.0, + logprobs=None, + is_cached=False, + usage=lf.LMSamplingUsage( + prompt_tokens=33, + completion_tokens=33, + total_tokens=66, + estimated_cost=None, + ), + tags=[lf.Message.TAG_LM_RESPONSE], + ), + score=0.0, + logprobs=None, + ), + lf.LMSample( + lf.AIMessage( + 'Sample 2 for message.', + score=0.0, + logprobs=None, + is_cached=False, + usage=lf.LMSamplingUsage( + prompt_tokens=33, + completion_tokens=33, + total_tokens=66, + estimated_cost=None, + ), + tags=[lf.Message.TAG_LM_RESPONSE], + ), + score=0.0, + logprobs=None, + ), + ], + usage=lf.LMSamplingUsage( + prompt_tokens=100, completion_tokens=100, total_tokens=200, + estimated_cost=None, + ), + ), + ) + self.assertEqual( + results[1], + lf.LMSamplingResult( + [ + lf.LMSample( + lf.AIMessage( + 'Sample 0 for message.', + score=0.0, + logprobs=None, + is_cached=False, + usage=lf.LMSamplingUsage( + prompt_tokens=33, + completion_tokens=33, + total_tokens=66, + estimated_cost=None, + ), + tags=[lf.Message.TAG_LM_RESPONSE], + ), + score=0.0, + logprobs=None, + ), + lf.LMSample( + lf.AIMessage( + 'Sample 1 for message.', + score=0.0, + logprobs=None, + is_cached=False, + usage=lf.LMSamplingUsage( + prompt_tokens=33, + completion_tokens=33, + total_tokens=66, + estimated_cost=None, + ), + tags=[lf.Message.TAG_LM_RESPONSE], + ), + score=0.0, + logprobs=None, + ), + lf.LMSample( + lf.AIMessage( + 'Sample 2 for message.', + score=0.0, + logprobs=None, + is_cached=False, + usage=lf.LMSamplingUsage( + prompt_tokens=33, + completion_tokens=33, + total_tokens=66, + estimated_cost=None, + ), + tags=[lf.Message.TAG_LM_RESPONSE], + ), + score=0.0, + logprobs=None, + ), + ], + usage=lf.LMSamplingUsage( + prompt_tokens=100, completion_tokens=100, total_tokens=200, + estimated_cost=None, + ), + ), + ) + + def test_sample_with_contextual_options(self): + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request + lm = openai_compatible.OpenAICompatible( + api_endpoint='https://test-server', model='test-model' + ) + with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)): + results = lm.sample(['hello']) + + self.assertEqual(len(results), 1) + self.assertEqual( + results[0], + lf.LMSamplingResult( + [ + lf.LMSample( + lf.AIMessage( + 'Sample 0 for message.', + score=0.0, + logprobs=None, + is_cached=False, + usage=lf.LMSamplingUsage( + prompt_tokens=50, + completion_tokens=50, + total_tokens=100, + ), + tags=[lf.Message.TAG_LM_RESPONSE], + ), + score=0.0, + logprobs=None, + ), + lf.LMSample( + lf.AIMessage( + 'Sample 1 for message.', + score=0.0, + logprobs=None, + is_cached=False, + usage=lf.LMSamplingUsage( + prompt_tokens=50, + completion_tokens=50, + total_tokens=100, + ), + tags=[lf.Message.TAG_LM_RESPONSE], + ), + score=0.0, + logprobs=None, + ), + ], + usage=lf.LMSamplingUsage( + prompt_tokens=100, completion_tokens=100, total_tokens=200 + ), + ) + ) + + def test_call_with_system_message(self): + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request + lm = openai_compatible.OpenAICompatible( + api_endpoint='https://test-server', model='test-model' + ) + self.assertEqual( + lm( + lf.UserMessage( + 'hello', + system_message='hi', + ), + sampling_options=lf.LMSamplingOptions(n=2) + ), + '''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''', + ) + + def test_call_with_json_schema(self): + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request + lm = openai_compatible.OpenAICompatible( + api_endpoint='https://test-server', model='test-model' + ) + self.assertEqual( + lm( + lf.UserMessage( + 'hello', + json_schema={ + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + }, + 'required': ['name'], + 'title': 'Person', + } + ), + sampling_options=lf.LMSamplingOptions(n=2) + ), + 'Sample 0 for message. format=json_schema', + ) + + # Test bad json schema. + with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'): + lm(lf.UserMessage('hello', json_schema='foo')) + + with self.assertRaisesRegex( + ValueError, 'The root of `json_schema` must have a `title` field' + ): + lm(lf.UserMessage('hello', json_schema={})) + + +if __name__ == '__main__': + unittest.main() diff --git a/langfun/core/llms/openai_test.py b/langfun/core/llms/openai_test.py index 1dd36c45..67cfe7a3 100644 --- a/langfun/core/llms/openai_test.py +++ b/langfun/core/llms/openai_test.py @@ -13,102 +13,9 @@ # limitations under the License. """Tests for OpenAI models.""" -from typing import Any import unittest -from unittest import mock - import langfun.core as lf -from langfun.core import modalities as lf_modalities from langfun.core.llms import openai -import pyglove as pg -import requests - - -def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs): - del url, kwargs - messages = json['messages'] - if len(messages) > 1: - system_message = f' system={messages[0]["content"]}' - else: - system_message = '' - - if 'response_format' in json: - response_format = f' format={json["response_format"]["type"]}' - else: - response_format = '' - - choices = [] - for k in range(json['n']): - if json.get('logprobs'): - logprobs = dict( - content=[ - dict( - token='chosen_token', - logprob=0.5, - top_logprobs=[ - dict( - token=f'alternative_token_{i + 1}', - logprob=0.1 - ) for i in range(3) - ] - ) - ] - ) - else: - logprobs = None - - choices.append(dict( - message=dict( - content=( - f'Sample {k} for message.{system_message}{response_format}' - ) - ), - logprobs=logprobs, - )) - response = requests.Response() - response.status_code = 200 - response._content = pg.to_json_str( - dict( - choices=choices, - usage=lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - ), - ) - ).encode() - return response - - -def mock_chat_completion_request_vision( - url: str, json: dict[str, Any], **kwargs -): - del url, kwargs - choices = [] - urls = [ - c['image_url']['url'] - for c in json['messages'][0]['content'] if c['type'] == 'image_url' - ] - for k in range(json['n']): - choices.append(pg.Dict( - message=pg.Dict( - content=f'Sample {k} for message: {"".join(urls)}' - ), - logprobs=None, - )) - response = requests.Response() - response.status_code = 200 - response._content = pg.to_json_str( - dict( - choices=choices, - usage=lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - ), - ) - ).encode() - return response class OpenAITest(unittest.TestCase): @@ -130,6 +37,15 @@ def test_resource_id(self): openai.Gpt35(api_key='test_key').resource_id, 'OpenAI(text-davinci-003)' ) + def test_headers(self): + self.assertEqual( + openai.Gpt35(api_key='test_key').headers, + { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer test_key', + }, + ) + def test_max_concurrency(self): self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0) @@ -156,340 +72,14 @@ def test_request_args(self): ) ) - def test_call_chat_completion(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - lm = openai.OpenAI( - model='gpt-4', - api_key='test_key', - organization='my_org', - project='my_project' - ) - self.assertEqual( - lm('hello', sampling_options=lf.LMSamplingOptions(n=2)), - 'Sample 0 for message.', - ) - - def test_call_chat_completion_with_logprobs(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - lm = openai.OpenAI( - model='gpt-4', - api_key='test_key', - organization='my_org', - project='my_project' - ) - results = lm.sample(['hello'], logprobs=True) - self.assertEqual(len(results), 1) - self.assertEqual( - results[0], - lf.LMSamplingResult( - [ - lf.LMSample( - response=lf.AIMessage( - text='Sample 0 for message.', - metadata={ - 'score': 0.0, - 'logprobs': [( - 'chosen_token', - 0.5, - [ - ('alternative_token_1', 0.1), - ('alternative_token_2', 0.1), - ('alternative_token_3', 0.1), - ], - )], - 'is_cached': False, - 'usage': lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - estimated_cost=0.009, - ), - }, - tags=['lm-response'], - ), - logprobs=[( - 'chosen_token', - 0.5, - [ - ('alternative_token_1', 0.1), - ('alternative_token_2', 0.1), - ('alternative_token_3', 0.1), - ], - )], - ) - ], - usage=lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - estimated_cost=0.009, - ), - ), - ) - - def test_call_chat_completion_vision(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request_vision - lm_1 = openai.Gpt4Turbo(api_key='test_key') - lm_2 = openai.Gpt4VisionPreview(api_key='test_key') - for lm in (lm_1, lm_2): - self.assertEqual( - lm( - lf.UserMessage( - 'hello <<[[image]]>>', - image=lf_modalities.Image.from_uri('https://fake/image') - ), - sampling_options=lf.LMSamplingOptions(n=2) - ), - 'Sample 0 for message: https://fake/image', - ) - lm_3 = openai.Gpt35Turbo(api_key='test_key') - with self.assertRaisesRegex(ValueError, 'Unsupported modality'): - lm_3( - lf.UserMessage( - 'hello <<[[image]]>>', - image=lf_modalities.Image.from_uri('https://fake/image') - ), - ) - - def test_sample_chat_completion(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - openai.SUPPORTED_MODELS_AND_SETTINGS['gpt-4'].update({ - 'cost_per_1k_input_tokens': 1.0, - 'cost_per_1k_output_tokens': 1.0, - }) - lm = openai.OpenAI(api_key='test_key', model='gpt-4') - results = lm.sample( - ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3) - ) - - self.assertEqual(len(results), 2) - print(results[0]) - self.assertEqual( - results[0], - lf.LMSamplingResult( - [ - lf.LMSample( - lf.AIMessage( - 'Sample 0 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 1 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 2 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - ], - usage=lf.LMSamplingUsage( - prompt_tokens=100, completion_tokens=100, total_tokens=200, - estimated_cost=0.2, - ), - ), - ) + def test_estimate_cost(self): self.assertEqual( - results[1], - lf.LMSamplingResult( - [ - lf.LMSample( - lf.AIMessage( - 'Sample 0 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 1 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 2 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=33, - completion_tokens=33, - total_tokens=66, - estimated_cost=0.2 / 3, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - ], - usage=lf.LMSamplingUsage( - prompt_tokens=100, completion_tokens=100, total_tokens=200, - estimated_cost=0.2, - ), + openai.Gpt4(api_key='test_key').estimate_cost( + num_input_tokens=100, num_output_tokens=100 ), + 0.009 ) - def test_sample_with_contextual_options(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - lm = openai.OpenAI(api_key='test_key', model='text-davinci-003') - with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)): - results = lm.sample(['hello']) - - self.assertEqual(len(results), 1) - self.assertEqual( - results[0], - lf.LMSamplingResult( - [ - lf.LMSample( - lf.AIMessage( - 'Sample 0 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=50, - completion_tokens=50, - total_tokens=100, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 1 for message.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=50, - completion_tokens=50, - total_tokens=100, - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - ], - usage=lf.LMSamplingUsage( - prompt_tokens=100, completion_tokens=100, total_tokens=200 - ), - ) - ) - - def test_call_with_system_message(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - lm = openai.OpenAI(api_key='test_key', model='gpt-4') - self.assertEqual( - lm( - lf.UserMessage( - 'hello', - system_message='hi', - ), - sampling_options=lf.LMSamplingOptions(n=2) - ), - '''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''', - ) - - def test_call_with_json_schema(self): - with mock.patch('requests.Session.post') as mock_request: - mock_request.side_effect = mock_chat_completion_request - lm = openai.OpenAI(api_key='test_key', model='gpt-4') - self.assertEqual( - lm( - lf.UserMessage( - 'hello', - json_schema={ - 'type': 'object', - 'properties': { - 'name': {'type': 'string'}, - }, - 'required': ['name'], - 'title': 'Person', - } - ), - sampling_options=lf.LMSamplingOptions(n=2) - ), - 'Sample 0 for message. format=json_schema', - ) - - # Test bad json schema. - with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'): - lm(lf.UserMessage('hello', json_schema='foo')) - - with self.assertRaisesRegex( - ValueError, 'The root of `json_schema` must have a `title` field' - ): - lm(lf.UserMessage('hello', json_schema={})) - if __name__ == '__main__': unittest.main()