diff --git a/.gitignore b/.gitignore index d0a5f00563..cfbf971bcd 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ Package.resolved *.pte *.ipynb_checkpoints* .idea +venv diff --git a/README.md b/README.md index 238475840c..99b83d5f8c 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | | Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Cerebras | Hosted | :heavy_check_mark: | :heavy_check_mark: | | | | | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | | Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | diff --git a/llama_stack/distribution/templates/local-cerebras-build.yaml b/llama_stack/distribution/templates/local-cerebras-build.yaml new file mode 100644 index 0000000000..75150132bf --- /dev/null +++ b/llama_stack/distribution/templates/local-cerebras-build.yaml @@ -0,0 +1,10 @@ +name: local-cerebras +distribution_spec: + description: Like local, but use Cerebras for running LLM inference + providers: + inference: remote::cerebras + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference +image_type: conda \ No newline at end of file diff --git a/llama_stack/providers/adapters/inference/cerebras/__init__.py b/llama_stack/providers/adapters/inference/cerebras/__init__.py new file mode 100644 index 0000000000..67b6bcb822 --- /dev/null +++ b/llama_stack/providers/adapters/inference/cerebras/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .cerebras import CerebrasInferenceAdapter +from .config import CerebrasImplConfig + + +async def get_adapter_impl(config: CerebrasImplConfig, _deps): + assert isinstance( + config, CerebrasImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = CerebrasInferenceAdapter(config) + + await impl.initialize() + + return impl diff --git a/llama_stack/providers/adapters/inference/cerebras/cerebras.py b/llama_stack/providers/adapters/inference/cerebras/cerebras.py new file mode 100644 index 0000000000..33ade81451 --- /dev/null +++ b/llama_stack/providers/adapters/inference/cerebras/cerebras.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +from typing import AsyncGenerator + +from cerebras.cloud.sdk import Cerebras +from cerebras.cloud.sdk.types.chat.completion_create_params import ( + Message as CerebrasMessage, + MessageAssistantMessageRequestToolCallFunctionTyped, + MessageAssistantMessageRequestToolCallTyped, + MessageAssistantMessageRequestTyped, + MessageSystemMessageRequestTyped, + MessageToolMessageRequestTyped, + MessageUserMessageRequestTyped, + Tool, + ToolFunctionTyped, + ToolTyped, +) + +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_stack.apis.inference import * # noqa: F403 + +from pydantic import BaseModel + +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, +) + +from .config import CerebrasImplConfig + + +CEREBRAS_SUPPORTED_MODELS = { + "Llama3.1-8B-Instruct": "llama3.1-8b", + "Llama3.1-70B-Instruct": "llama3.1-70b", +} + + +class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): + def __init__(self, config: CerebrasImplConfig) -> None: + ModelRegistryHelper.__init__( + self, stack_to_provider_models_map=CEREBRAS_SUPPORTED_MODELS + ) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + self.client = Cerebras( + base_url=self.config.base_url, api_key=self.config.api_key + ) + + async def initialize(self) -> None: + return + + async def shutdown(self) -> None: + pass + + def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + raise NotImplementedError() + + def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + + if stream: + return self._stream_chat_completion(request, self.client) + else: + return self._nonstream_chat_completion(request, self.client) + + def _nonstream_chat_completion( + self, request: ChatCompletionRequest, client: Cerebras + ) -> ChatCompletionResponse: + params = self._get_params(request) + + r = client.chat.completions.create(**params) + return process_chat_completion_response(request, r, self.formatter) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: Cerebras + ) -> AsyncGenerator: + params = self._get_params(request) + + print(params) + + async def _to_async_generator(): + s = client.chat.completions.create(**params) + for chunk in s: + yield chunk + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response( + request, stream, self.formatter + ): + yield chunk + + def _get_params(self, request: ChatCompletionRequest) -> dict: + if request.sampling_params and request.sampling_params.top_k: + raise ValueError("`top_k` not supported by Cerebras") + + return { + "model": self.map_to_provider_model(request.model), + "messages": self._construct_cerebras_messages(request), + "tools": self._construct_cerebras_tools(request), + "tool_choice": request.tool_choice.value if request.tool_choice else None, + "stream": request.stream, + "logprobs": request.logprobs is not None, + "top_logprobs": request.logprobs, + **get_sampling_options(request), + } + + @staticmethod + def _construct_cerebras_tools(request: ChatCompletionRequest) -> List[Tool]: + tools = [] + + for raw_tool in request.tools: + tools.append( + ToolTyped( + function=ToolFunctionTyped( + name=__class__._parse_tool_name(raw_tool.tool_name), + description=raw_tool.description, + parameters=( + { + k: v.model_dump() if isinstance(v, BaseModel) else v + for k, v in raw_tool.parameters.items() + } + if raw_tool.parameters + else None + ), + ), + type="object", + ) + ) + + return tools + + @staticmethod + def _construct_cerebras_messages( + request: ChatCompletionRequest, + ) -> List[CerebrasMessage]: + messages = [] + + for raw_message in request.messages: + content = raw_message.content + + assert isinstance( + content, str + ), f"Message content must be of type `str` but got `{type(content)}`" + + if isinstance(raw_message, UserMessage): + messages.append( + MessageUserMessageRequestTyped( + content=content, + role="user", + ) + ) + elif isinstance(raw_message, SystemMessage): + messages.append( + MessageSystemMessageRequestTyped( + content=content, + role="system", + ) + ) + elif isinstance(raw_message, ToolResponseMessage): + messages.append( + MessageToolMessageRequestTyped( + role="tool", + tool_call_id=raw_message.call_id, + name=__class__._parse_tool_name(raw_message.tool_name), + content=content, + ) + ) + elif isinstance(raw_message, CompletionMessage): + messages.append( + MessageAssistantMessageRequestTyped( + role="assistant", + content=content, + tool_calls=__class__._construct_cerebras_tool_calls( + raw_message.tool_calls + ), + ) + ) + + return messages + + @staticmethod + def _construct_cerebras_tool_calls( + raw_tool_calls: List[ToolCall], + ) -> List[MessageAssistantMessageRequestToolCallTyped]: + return [ + MessageAssistantMessageRequestToolCallTyped( + id=tool_call.call_id, + type="function", + function=MessageAssistantMessageRequestToolCallFunctionTyped( + arguments=json.dumps(tool_call.arguments), + # Handle BuiltinTool using enum name names. + name=__class__._parse_tool_name(tool_call.tool_name), + ), + ) + for tool_call in raw_tool_calls + ] + + @staticmethod + def _parse_tool_name(raw_tool_name: Union[str, Enum]) -> str: + return raw_tool_name if isinstance(raw_tool_name, str) else raw_tool_name.value + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/adapters/inference/cerebras/config.py b/llama_stack/providers/adapters/inference/cerebras/config.py new file mode 100644 index 0000000000..6cb79211f2 --- /dev/null +++ b/llama_stack/providers/adapters/inference/cerebras/config.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class CerebrasImplConfig(BaseModel): + base_url: str = Field( + default=os.environ.get("CEREBRAS_BASE_URL", "https://api.cerebras.ai"), + description="Base URL for the Cerebras API", + ) + api_key: Optional[str] = Field( + default=None, + description="Cerebras API Key", + ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 686fc273b8..fed57154ea 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -51,6 +51,17 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.inference.sample.SampleConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="cerebras", + pip_packages=[ + "cerebras_cloud_sdk", + ], + module="llama_stack.providers.adapters.inference.cerebras", + config_class="llama_stack.providers.adapters.inference.cerebras.CerebrasImplConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 118880b29b..606cf1efb5 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json + from typing import AsyncGenerator, Optional from llama_models.llama3.api.chat_format import ChatFormat @@ -44,9 +46,50 @@ def get_sampling_options(request: ChatCompletionRequest) -> dict: def text_from_choice(choice) -> str: if hasattr(choice, "delta") and choice.delta: - return choice.delta.content + return choice.delta.content or "" + + if hasattr(choice, "message") and choice.message: + return choice.message.content or "" + + if hasattr(choice, "text"): + return choice.text or "" + + return "" + + +def tool_calls_from_choice(choice) -> List[ToolCall]: + tool_calls = [] - return choice.text + if choice.message and choice.message.tool_calls: + for tool_call in choice.message.tool_calls: + tool_calls.append( + ToolCall( + call_id=tool_call.id, + tool_name=tool_call.function.name, + arguments=json.loads(tool_call.function.arguments), + ) + ) + + return tool_calls + + +def tool_call_deltas_from_choice(choice) -> List[ToolCallDelta]: + tool_call_deltas = [] + + if choice.delta and choice.delta.tool_calls: + for tool_call in choice.delta.tool_calls: + tool_call_deltas.append( + ToolCallDelta( + content=ToolCall( + call_id=tool_call.id, + tool_name=tool_call.function.name, + arguments=json.loads(tool_call.function.arguments), + ), + parse_status=ToolCallParseStatus.in_progress, + ) + ) + + return tool_call_deltas def process_chat_completion_response( @@ -58,7 +101,7 @@ def process_chat_completion_response( stop_reason = None if reason := choice.finish_reason: - if reason in ["stop", "eos"]: + if reason in ["stop", "eos", "tool_calls"]: stop_reason = StopReason.end_of_turn elif reason == "eom": stop_reason = StopReason.end_of_message @@ -71,6 +114,10 @@ def process_chat_completion_response( completion_message = formatter.decode_assistant_message_from_content( text_from_choice(choice), stop_reason ) + + # According to the OpenAI spec, tool calls are embedded as a field in the response object. + completion_message.tool_calls += tool_calls_from_choice(choice) + return ChatCompletionResponse( completion_message=completion_message, logprobs=None, @@ -98,7 +145,12 @@ async def process_chat_completion_stream_response( finish_reason = choice.finish_reason if finish_reason: - if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]: + if stop_reason is None and finish_reason in [ + "stop", + "eos", + "eos_token", + "tool_calls", + ]: stop_reason = StopReason.end_of_turn elif stop_reason is None and finish_reason == "length": stop_reason = StopReason.out_of_tokens @@ -129,8 +181,18 @@ async def process_chat_completion_stream_response( text = "" continue - if ipython: - buffer += text + buffer += text + if tool_call_deltas := tool_call_deltas_from_choice(choice): + for delta in tool_call_deltas: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + stop_reason=stop_reason, + ) + ) + + elif ipython: delta = ToolCallDelta( content=text, parse_status=ToolCallParseStatus.in_progress, @@ -144,7 +206,6 @@ async def process_chat_completion_stream_response( ) ) else: - buffer += text yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -155,6 +216,7 @@ async def process_chat_completion_stream_response( # parse tool calls and report errors message = formatter.decode_assistant_message_from_content(buffer, stop_reason) + parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( diff --git a/tests/test_cerebras_inference.py b/tests/test_cerebras_inference.py new file mode 100644 index 0000000000..7ee916a67c --- /dev/null +++ b/tests/test_cerebras_inference.py @@ -0,0 +1,531 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import unittest +from unittest import mock + +from cerebras.cloud.sdk.types.chat.completion_create_response import ( + ChatChunkResponse, + ChatCompletion, +) +from llama_models.llama3.api.datatypes import ( + BuiltinTool, + CompletionMessage, + StopReason, + ToolCall, + ToolChoice, + ToolDefinition, + ToolParamDefinition, + ToolResponseMessage, + UserMessage, +) +from llama_stack.apis.inference.inference import ( + ChatCompletionRequest, + ChatCompletionResponseEventType, +) +from llama_stack.providers.adapters.inference.cerebras import get_adapter_impl +from llama_stack.providers.adapters.inference.cerebras.config import CerebrasImplConfig + + +class CerebrasInferenceTests(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + cerebras_config = CerebrasImplConfig(api_key="foobar") + + # setup Cerebras + self.api = await get_adapter_impl(cerebras_config, {}) + await self.api.initialize() + + self.custom_tool_defn = ToolDefinition( + tool_name="get_boiling_point", + description="Get the boiling point of a imaginary liquids (eg. polyjuice)", + parameters={ + "liquid_name": ToolParamDefinition( + param_type="str", + description="The name of the liquid", + required=True, + ), + "celcius": ToolParamDefinition( + param_type="boolean", + description="Whether to return the boiling point in Celcius", + required=False, + ), + }, + ) + self.valid_supported_model = "Llama3.1-70B-Instruct" + + async def asyncTearDown(self): + await self.api.shutdown() + + async def test_text(self): + with mock.patch.object( + self.api.client.chat.completions, "create" + ) as mock_completion: + mock_completion.return_value = ChatCompletion( + **{ + "id": "chatcmpl-1f0a31de-a615-4ef5-a355-c72b840710d0", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "The capital of France is Paris.", + "role": "assistant", + }, + } + ], + "created": 1729026294, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion", + "usage": { + "prompt_tokens": 17, + "completion_tokens": 8, + "total_tokens": 25, + }, + "time_info": { + "queue_time": 2.702e-05, + "prompt_time": 0.0013021605714285715, + "completion_time": 0.0039899714285714285, + "total_time": 0.01815319061279297, + "created": 1729026294, + }, + } + ) + + request = ChatCompletionRequest( + model=self.valid_supported_model, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=False, + ) + response = self.api.chat_completion( + request.model, + request.messages, + request.sampling_params, + request.tools, + request.tool_choice, + request.tool_prompt_format, + request.stream, + request.logprobs, + ) + + result = response.completion_message.content + self.assertTrue("Paris" in result, result) + + async def test_text_streaming(self): + events = [ + { + "id": "chatcmpl-f908bda3-eaa6-4148-ada8-689631b1e7c7", + "choices": [{"delta": {"role": "assistant"}, "index": 0}], + "created": 1729094696, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + }, + { + "id": "chatcmpl-f908bda3-eaa6-4148-ada8-689631b1e7c7", + "choices": [{"delta": {"content": "The"}, "index": 0}], + "created": 1729094696, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + }, + { + "id": "chatcmpl-f908bda3-eaa6-4148-ada8-689631b1e7c7", + "choices": [{"delta": {"content": " capital"}, "index": 0}], + "created": 1729094696, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + }, + { + "id": "chatcmpl-f908bda3-eaa6-4148-ada8-689631b1e7c7", + "choices": [{"delta": {"content": " of"}, "index": 0}], + "created": 1729094696, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + }, + { + "id": "chatcmpl-f908bda3-eaa6-4148-ada8-689631b1e7c7", + "choices": [{"delta": {"content": " France"}, "index": 0}], + "created": 1729094696, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + }, + { + "id": "chatcmpl-f908bda3-eaa6-4148-ada8-689631b1e7c7", + "choices": [{"delta": {"content": " is"}, "index": 0}], + "created": 1729094696, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + }, + { + "id": "chatcmpl-f908bda3-eaa6-4148-ada8-689631b1e7c7", + "choices": [{"delta": {"content": " Paris"}, "index": 0}], + "created": 1729094696, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + }, + { + "id": "chatcmpl-f908bda3-eaa6-4148-ada8-689631b1e7c7", + "choices": [{"delta": {"content": "."}, "index": 0}], + "created": 1729094696, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + }, + { + "id": "chatcmpl-f908bda3-eaa6-4148-ada8-689631b1e7c7", + "choices": [{"delta": {}, "finish_reason": "stop", "index": 0}], + "created": 1729094696, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + "usage": { + "prompt_tokens": 17, + "completion_tokens": 8, + "total_tokens": 25, + }, + "time_info": { + "queue_time": 2.568e-05, + "prompt_time": 0.001300800857142857, + "completion_time": 0.0039863531428571426, + "total_time": 0.020612478256225586, + "created": 1729094696, + }, + }, + ] + + with mock.patch.object( + self.api.client.chat.completions, "create" + ) as mock_completion_stream: + mock_completion_stream.return_value = [ + ChatChunkResponse(**event) for event in events + ] + + request = ChatCompletionRequest( + model=self.valid_supported_model, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=True, + ) + iterator = self.api.chat_completion( + request.model, + request.messages, + request.sampling_params, + request.tools, + request.tool_choice, + request.tool_prompt_format, + request.stream, + request.logprobs, + ) + + events = [] + async for chunk in iterator: + events.append(chunk.event) + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + + self.assertEqual( + events[0].event_type, ChatCompletionResponseEventType.start + ) + self.assertEqual( + events[-1].event_type, ChatCompletionResponseEventType.complete + ) + + response = "" + for e in events[1:-1]: + response += e.delta + + self.assertTrue("Paris" in response, response) + + async def test_custom_tool_call(self): + with mock.patch.object( + self.api.client.chat.completions, "create" + ) as mock_completion: + mock_completion.return_value = ChatCompletion( + **{ + "id": "chatcmpl-f673ee3b-2598-4cd4-8952-35f183564441", + "choices": [ + { + "finish_reason": "tool_calls", + "index": 0, + "message": { + "tool_calls": [ + { + "id": "253ac93a8", + "type": "function", + "function": { + "name": "get_boiling_point", + "arguments": '{"liquid_name": "polyjuice", "celcius": "True"}', + }, + } + ], + "role": "assistant", + }, + } + ], + "created": 1729095512, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion", + "usage": { + "prompt_tokens": 193, + "completion_tokens": 14, + "total_tokens": 207, + }, + "time_info": { + "queue_time": 2.651e-05, + "prompt_time": 0.008480784000000002, + "completion_time": 0.024744930000000002, + "total_time": 0.03530120849609375, + "created": 1729095512, + }, + } + ) + + request = ChatCompletionRequest( + tool_choice=ToolChoice.required, + model=self.valid_supported_model, + messages=[ + UserMessage( + content="Use provided function to find the boiling point of polyjuice?", + ), + ], + stream=False, + tools=[self.custom_tool_defn], + ) + response = self.api.chat_completion( + request.model, + request.messages, + request.sampling_params, + request.tools, + request.tool_choice, + request.tool_prompt_format, + request.stream, + request.logprobs, + ) + + completion_message = response.completion_message + + self.assertEqual(completion_message.content, "") + + self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) + + self.assertEqual( + len(completion_message.tool_calls), 1, completion_message.tool_calls + ) + self.assertEqual( + completion_message.tool_calls[0].tool_name, "get_boiling_point" + ) + + args = completion_message.tool_calls[0].arguments + self.assertTrue(isinstance(args, dict)) + self.assertTrue(args["liquid_name"], "polyjuice") + + async def test_tool_call_streaming(self): + events = [ + { + "id": "chatcmpl-1e573d82-bd76-496b-aa01-18faed024a1d", + "choices": [{"delta": {"role": "assistant"}, "index": 0}], + "created": 1729101621, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + }, + { + "id": "chatcmpl-1e573d82-bd76-496b-aa01-18faed024a1d", + "choices": [ + { + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "df0a5e087", + "type": "function", + "function": { + "name": "brave_search", + "arguments": '{"query": "current US President"}', + }, + } + ] + }, + "index": 0, + } + ], + "created": 1729101621, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + }, + { + "id": "chatcmpl-1e573d82-bd76-496b-aa01-18faed024a1d", + "choices": [{"delta": {}, "finish_reason": "tool_calls", "index": 0}], + "created": 1729101621, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion.chunk", + "usage": { + "prompt_tokens": 193, + "completion_tokens": 14, + "total_tokens": 207, + }, + "time_info": { + "queue_time": 0.00001926, + "prompt_time": 0.008447145307692307, + "completion_time": 0.024725255692307692, + "total_time": 0.04976296424865723, + "created": 1729101621, + }, + }, + ] + with mock.patch.object( + self.api.client.chat.completions, "create" + ) as mock_completion_stream: + mock_completion_stream.return_value = [ + ChatChunkResponse(**event) for event in events + ] + + request = ChatCompletionRequest( + model=self.valid_supported_model, + tool_choice=ToolChoice.required, + messages=[ + UserMessage( + content="Who is the current US President?", + ), + ], + stream=True, + tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], + ) + iterator = self.api.chat_completion( + request.model, + request.messages, + request.sampling_params, + request.tools, + request.tool_choice, + request.tool_prompt_format, + request.stream, + request.logprobs, + ) + + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + self.assertEqual( + events[0].event_type, ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, ChatCompletionResponseEventType.complete + ) + # last but one event should be eom with tool call + self.assertEqual( + events[-2].event_type, ChatCompletionResponseEventType.progress + ) + self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn) + self.assertEqual( + events[-2].delta.content.tool_name, BuiltinTool.brave_search + ) + + async def test_multi_turn_non_streaming(self): + with mock.patch.object( + self.api.client.chat.completions, "create" + ) as mock_completion: + mock_completion.return_value = ChatCompletion( + **{ + "id": "chatcmpl-1f0a31de-a615-4ef5-a355-c72b840710d0", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "The 44th president of the United States was Barack Obama.", + "role": "assistant", + }, + } + ], + "created": 1729026294, + "model": "llama3.1-70b", + "system_fingerprint": "fp_97b75e13af", + "object": "chat.completion", + "usage": { + "prompt_tokens": 17, + "completion_tokens": 8, + "total_tokens": 25, + }, + "time_info": { + "queue_time": 2.702e-05, + "prompt_time": 0.0013021605714285715, + "completion_time": 0.0039899714285714285, + "total_time": 0.01815319061279297, + "created": 1729026294, + }, + } + ) + + request = ChatCompletionRequest( + model=self.valid_supported_model, + messages=[ + UserMessage( + content="Search the web and tell me who the " + "44th president of the United States was", + ), + CompletionMessage( + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + call_id="1", + tool_name=BuiltinTool.brave_search, + arguments={ + "query": "44th president of the United States" + }, + ) + ], + ), + ToolResponseMessage( + call_id="1", + tool_name=BuiltinTool.brave_search, + content="Barack Obama", + ), + ], + stream=False, + tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], + ) + response = self.api.chat_completion( + request.model, + request.messages, + request.sampling_params, + request.tools, + request.tool_choice, + request.tool_prompt_format, + request.stream, + request.logprobs, + ) + + completion_message = response.completion_message + + self.assertTrue( + completion_message.stop_reason + in { + StopReason.end_of_turn, + StopReason.end_of_message, + } + ) + + self.assertTrue("obama" in completion_message.content.lower())