Skip to content

Commit

Permalink
Cerebras Inference Integration
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytwo committed Oct 17, 2024
1 parent 209cd3d commit b77aac4
Show file tree
Hide file tree
Showing 9 changed files with 912 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Package.resolved
*.pte
*.ipynb_checkpoints*
.idea
venv
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Cerebras | Hosted | :heavy_check_mark: | :heavy_check_mark: | | | |
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
Expand Down
10 changes: 10 additions & 0 deletions llama_stack/distribution/templates/local-cerebras-build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: local-cerebras
distribution_spec:
description: Like local, but use Cerebras for running LLM inference
providers:
inference: remote::cerebras
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda
20 changes: 20 additions & 0 deletions llama_stack/providers/adapters/inference/cerebras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .cerebras import CerebrasInferenceAdapter
from .config import CerebrasImplConfig


async def get_adapter_impl(config: CerebrasImplConfig, _deps):
assert isinstance(
config, CerebrasImplConfig
), f"Unexpected config type: {type(config)}"

impl = CerebrasInferenceAdapter(config)

await impl.initialize()

return impl
246 changes: 246 additions & 0 deletions llama_stack/providers/adapters/inference/cerebras/cerebras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json

from typing import AsyncGenerator

from cerebras.cloud.sdk import Cerebras
from cerebras.cloud.sdk.types.chat.completion_create_params import (
Message as CerebrasMessage,
MessageAssistantMessageRequestToolCallFunctionTyped,
MessageAssistantMessageRequestToolCallTyped,
MessageAssistantMessageRequestTyped,
MessageSystemMessageRequestTyped,
MessageToolMessageRequestTyped,
MessageUserMessageRequestTyped,
Tool,
ToolFunctionTyped,
ToolTyped,
)

from llama_models.llama3.api.chat_format import ChatFormat

from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer

from llama_stack.apis.inference import * # noqa: F403

from pydantic import BaseModel

from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)

from .config import CerebrasImplConfig


CEREBRAS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "llama3.1-8b",
"Llama3.1-70B-Instruct": "llama3.1-70b",
}


class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=CEREBRAS_SUPPORTED_MODELS
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())

self.client = Cerebras(
base_url=self.config.base_url, api_key=self.config.api_key
)

async def initialize(self) -> None:
return

async def shutdown(self) -> None:
pass

def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()

def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)

if stream:
return self._stream_chat_completion(request, self.client)
else:
return self._nonstream_chat_completion(request, self.client)

def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Cerebras
) -> ChatCompletionResponse:
params = self._get_params(request)

r = client.chat.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Cerebras
) -> AsyncGenerator:
params = self._get_params(request)

print(params)

async def _to_async_generator():
s = client.chat.completions.create(**params)
for chunk in s:
yield chunk

stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk

def _get_params(self, request: ChatCompletionRequest) -> dict:
if request.sampling_params and request.sampling_params.top_k:
raise ValueError("`top_k` not supported by Cerebras")

return {
"model": self.map_to_provider_model(request.model),
"messages": self._construct_cerebras_messages(request),
"tools": self._construct_cerebras_tools(request),
"tool_choice": request.tool_choice.value if request.tool_choice else None,
"stream": request.stream,
"logprobs": request.logprobs is not None,
"top_logprobs": request.logprobs,
**get_sampling_options(request),
}

@staticmethod
def _construct_cerebras_tools(request: ChatCompletionRequest) -> List[Tool]:
tools = []

for raw_tool in request.tools:
tools.append(
ToolTyped(
function=ToolFunctionTyped(
name=__class__._parse_tool_name(raw_tool.tool_name),
description=raw_tool.description,
parameters=(
{
k: v.model_dump() if isinstance(v, BaseModel) else v
for k, v in raw_tool.parameters.items()
}
if raw_tool.parameters
else None
),
),
type="object",
)
)

return tools

@staticmethod
def _construct_cerebras_messages(
request: ChatCompletionRequest,
) -> List[CerebrasMessage]:
messages = []

for raw_message in request.messages:
content = raw_message.content

assert isinstance(
content, str
), f"Message content must be of type `str` but got `{type(content)}`"

if isinstance(raw_message, UserMessage):
messages.append(
MessageUserMessageRequestTyped(
content=content,
role="user",
)
)
elif isinstance(raw_message, SystemMessage):
messages.append(
MessageSystemMessageRequestTyped(
content=content,
role="system",
)
)
elif isinstance(raw_message, ToolResponseMessage):
messages.append(
MessageToolMessageRequestTyped(
role="tool",
tool_call_id=raw_message.call_id,
name=__class__._parse_tool_name(raw_message.tool_name),
content=content,
)
)
elif isinstance(raw_message, CompletionMessage):
messages.append(
MessageAssistantMessageRequestTyped(
role="assistant",
content=content,
tool_calls=__class__._construct_cerebras_tool_calls(
raw_message.tool_calls
),
)
)

return messages

@staticmethod
def _construct_cerebras_tool_calls(
raw_tool_calls: List[ToolCall],
) -> List[MessageAssistantMessageRequestToolCallTyped]:
return [
MessageAssistantMessageRequestToolCallTyped(
id=tool_call.call_id,
type="function",
function=MessageAssistantMessageRequestToolCallFunctionTyped(
arguments=json.dumps(tool_call.arguments),
# Handle BuiltinTool using enum name names.
name=__class__._parse_tool_name(tool_call.tool_name),
),
)
for tool_call in raw_tool_calls
]

@staticmethod
def _parse_tool_name(raw_tool_name: Union[str, Enum]) -> str:
return raw_tool_name if isinstance(raw_tool_name, str) else raw_tool_name.value

async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
23 changes: 23 additions & 0 deletions llama_stack/providers/adapters/inference/cerebras/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import os
from typing import Optional

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field


@json_schema_type
class CerebrasImplConfig(BaseModel):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", "https://api.cerebras.ai"),
description="Base URL for the Cerebras API",
)
api_key: Optional[str] = Field(
default=None,
description="Cerebras API Key",
)
11 changes: 11 additions & 0 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
module="llama_stack.providers.adapters.inference.cerebras",
config_class="llama_stack.providers.adapters.inference.cerebras.CerebrasImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
Expand Down
Loading

0 comments on commit b77aac4

Please sign in to comment.