diff --git a/distributions/dependencies.json b/distributions/dependencies.json index bd363ea40c..225fb45ff3 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -405,5 +405,37 @@ "uvicorn", "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" - ] + ], + "centml": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "chromadb-client", + "datasets", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "requests", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ] } diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 55924a1e93..e5dd209c10 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -195,4 +195,16 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="centml", + pip_packages=[ + "openai", + ], + module="llama_stack.providers.remote.inference.centml", + config_class="llama_stack.providers.remote.inference.centml.CentMLImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.centml.CentMLProviderDataValidator", + ), + ), ] diff --git a/llama_stack/providers/remote/inference/centml/__init__.py b/llama_stack/providers/remote/inference/centml/__init__.py new file mode 100644 index 0000000000..4bfc27b9ec --- /dev/null +++ b/llama_stack/providers/remote/inference/centml/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import CentMLImplConfig + + +class CentMLProviderDataValidator(BaseModel): + centml_api_key: str + + +async def get_adapter_impl(config: CentMLImplConfig, _deps): + """ + Factory function to construct and initialize the CentML adapter. + + :param config: Instance of CentMLImplConfig, containing `url`, `api_key`, etc. + :param _deps: Additional dependencies provided by llama-stack (unused here). + """ + from .centml import CentMLInferenceAdapter + + # Ensure the provided config is indeed a CentMLImplConfig + assert isinstance(config, CentMLImplConfig), ( + f"Unexpected config type: {type(config)}" + ) + + # Instantiate and initialize the adapter + adapter = CentMLInferenceAdapter(config) + await adapter.initialize() + return adapter diff --git a/llama_stack/providers/remote/inference/centml/centml.py b/llama_stack/providers/remote/inference/centml/centml.py new file mode 100644 index 0000000000..aacc738045 --- /dev/null +++ b/llama_stack/providers/remote/inference/centml/centml.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator, List, Optional, Union + +from openai import OpenAI + +from llama_models.datatypes import CoreModelId +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) +from llama_stack.providers.utils.inference.openai_compat import ( + convert_message_to_openai_dict, + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + completion_request_to_prompt, + content_has_media, + interleaved_content_as_str, + request_has_media, +) + +from .config import CentMLImplConfig + +# Example model aliases that map from CentML’s +# published model identifiers to llama-stack's `CoreModelId`. +MODEL_ALIASES = [ + build_model_alias( + "meta-llama/Llama-3.3-70B-Instruct", + CoreModelId.llama3_3_70b_instruct.value, + ), + build_model_alias( + "meta-llama/Llama-3.1-405B-Instruct-FP8", + CoreModelId.llama3_1_405b_instruct.value, + ), +] + + +class CentMLInferenceAdapter( + ModelRegistryHelper, Inference, NeedsRequestProviderData +): + """ + Adapter to use CentML's serverless inference endpoints, + which adhere to the OpenAI chat/completions API spec, + inside llama-stack. + """ + + def __init__(self, config: CentMLImplConfig) -> None: + super().__init__(MODEL_ALIASES) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + def _get_api_key(self) -> str: + """ + Obtain the CentML API key either from the adapter config + or from the dynamic provider data in request headers. + """ + if self.config.api_key is not None: + return self.config.api_key.get_secret_value() + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.centml_api_key: + raise ValueError( + 'Pass CentML API Key in the header X-LlamaStack-ProviderData as { "centml_api_key": "" }' + ) + return provider_data.centml_api_key + + def _get_client(self) -> OpenAI: + """ + Creates an OpenAI-compatible client pointing to CentML's base URL, + using the user's CentML API key. + """ + api_key = self._get_api_key() + return OpenAI(api_key=api_key, base_url=self.config.url) + + # + # COMPLETION (non-chat) + # + + async def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + """ + For "completion" style requests (non-chat). + """ + model = await self.model_store.get_model(model_id) + request = CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + async def _nonstream_completion( + self, request: CompletionRequest + ) -> ChatCompletionResponse: + params = await self._get_params(request) + # Using the older "completions" route for non-chat + response = self._get_client().completions.create(**params) + return process_completion_response(response, self.formatter) + + async def _stream_completion( + self, request: CompletionRequest + ) -> AsyncGenerator: + params = await self._get_params(request) + + async def _to_async_generator(): + stream = self._get_client().completions.create(**params) + for chunk in stream: + yield chunk + + stream = _to_async_generator() + async for chunk in process_completion_stream_response( + stream, self.formatter + ): + yield chunk + + # + # CHAT COMPLETION + # + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + """ + For "chat completion" style requests. + """ + model = await self.model_store.get_model(model_id) + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: + params = await self._get_params(request) + + # For chat requests, if "messages" is in params -> .chat.completions + if "messages" in params: + response = self._get_client().chat.completions.create(**params) + else: + # fallback if we ended up only with "prompt" + response = self._get_client().completions.create(**params) + + return process_chat_completion_response(response, self.formatter) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + params = await self._get_params(request) + + async def _to_async_generator(): + if "messages" in params: + stream = self._get_client().chat.completions.create(**params) + else: + stream = self._get_client().completions.create(**params) + for chunk in stream: + yield chunk + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + # + # HELPER METHODS + # + + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + """ + Build the 'params' dict that the OpenAI (CentML) client expects. + For chat requests, we always prefer "messages" so that it calls + the chat endpoint properly. + """ + input_dict = {} + media_present = request_has_media(request) + + if isinstance(request, ChatCompletionRequest): + # For chat requests, always build "messages" from the user messages + input_dict["messages"] = [ + await convert_message_to_openai_dict(m) + for m in request.messages + ] + + else: + # Non-chat (CompletionRequest) + assert not media_present, ( + "CentML does not support media for completions" + ) + input_dict["prompt"] = await completion_request_to_prompt( + request, self.formatter + ) + + return { + "model": request.model, + **input_dict, + "stream": request.stream, + **self._build_options( + request.sampling_params, request.response_format + ), + } + + def _build_options( + self, + sampling_params: Optional[SamplingParams], + fmt: Optional[ResponseFormat], + ) -> dict: + """ + Build temperature, max_tokens, top_p, etc., plus any response format data. + """ + options = get_sampling_options(sampling_params) + options.setdefault("max_tokens", 512) + + if fmt: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.json_schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise NotImplementedError( + "Grammar response format not supported yet" + ) + else: + raise ValueError(f"Unknown response format {fmt.type}") + + return options + + # + # EMBEDDINGS + # + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedContent], + ) -> EmbeddingsResponse: + model = await self.model_store.get_model(model_id) + # CentML does not support media + assert all(not content_has_media(c) for c in contents), ( + "CentML does not support media for embeddings" + ) + + resp = self._get_client().embeddings.create( + model=model.provider_resource_id, + input=[interleaved_content_as_str(c) for c in contents], + ) + embeddings = [item.embedding for item in resp.data] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/centml/config.py b/llama_stack/providers/remote/inference/centml/config.py new file mode 100644 index 0000000000..bc9711bdbd --- /dev/null +++ b/llama_stack/providers/remote/inference/centml/config.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field, SecretStr + + +@json_schema_type +class CentMLImplConfig(BaseModel): + url: str = Field( + default="https://api.centml.com/openai/v1", + description="The CentML API server URL", + ) + api_key: Optional[SecretStr] = Field( + default=None, + description="The CentML API Key", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "url": "https://api.centml.com/openai/v1", + "api_key": "${env.CENTML_API_KEY}", + } diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index b6653b65d7..afb63ac789 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -18,6 +18,7 @@ from llama_stack.providers.inline.inference.vllm import VLLMConfig from llama_stack.providers.remote.inference.bedrock import BedrockConfig +from llama_stack.providers.remote.inference.centml import CentMLImplConfig from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.groq import GroqConfig @@ -231,6 +232,25 @@ def inference_tgi() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_centml() -> ProviderFixture: + api_key = os.getenv("CENTML_API_KEY") + if not api_key: + pytest.skip("Missing CENTML_API_KEY in environment; skipping CentML tests") + + return ProviderFixture( + providers=[ + Provider( + provider_id="centml", + provider_type="remote::centml", + config=CentMLImplConfig(api_key=api_key).model_dump(), + ) + ], + provider_data=dict(centml_api_key=api_key), + ) + + + @pytest.fixture(scope="session") def inference_sentence_transformers() -> ProviderFixture: return ProviderFixture( @@ -282,6 +302,7 @@ def model_id(inference_model) -> str: "cerebras", "nvidia", "tgi", + "centml", ] diff --git a/llama_stack/templates/centml/__init__.py b/llama_stack/templates/centml/__init__.py new file mode 100644 index 0000000000..b56599bfc9 --- /dev/null +++ b/llama_stack/templates/centml/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .centml import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/centml/build.yaml b/llama_stack/templates/centml/build.yaml new file mode 100644 index 0000000000..489b9f8fdf --- /dev/null +++ b/llama_stack/templates/centml/build.yaml @@ -0,0 +1,32 @@ +version: '2' +name: centml +distribution_spec: + description: Use CentML for running LLM inference + providers: + inference: + - remote::centml + memory: + - inline::faiss + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + eval: + - inline::meta-reference + datasetio: + - remote::huggingface + - inline::localfs + scoring: + - inline::basic + - inline::llm-as-judge + - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime +image_type: conda diff --git a/llama_stack/templates/centml/centml.py b/llama_stack/templates/centml/centml.py new file mode 100644 index 0000000000..0f8c13b7af --- /dev/null +++ b/llama_stack/templates/centml/centml.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_models.sku_list import all_registered_models +from llama_stack.apis.models.models import ModelType +from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.providers.inline.inference.sentence_transformers import ( + SentenceTransformersInferenceConfig, +) +from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig +from llama_stack.providers.remote.inference.centml.config import ( + CentMLImplConfig, +) + +# If your CentML adapter has a MODEL_ALIASES constant with known model mappings: +from llama_stack.providers.remote.inference.centml.centml import MODEL_ALIASES + +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, +) + + +def get_distribution_template() -> DistributionTemplate: + """ + Returns a distribution template for running Llama Stack with CentML inference. + """ + providers = { + "inference": ["remote::centml"], + "memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + "eval": ["inline::meta-reference"], + "datasetio": ["remote::huggingface", "inline::localfs"], + "scoring": [ + "inline::basic", + "inline::llm-as-judge", + "inline::braintrust", + ], + } + name = "centml" + + # Primary inference provider: CentML + inference_provider = Provider( + provider_id="centml", + provider_type="remote::centml", + config=CentMLImplConfig.sample_run_config(), + ) + + # Memory provider: Faiss + memory_provider = Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + ) + + # Embedding provider: SentenceTransformers + embedding_provider = Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + config=SentenceTransformersInferenceConfig.sample_run_config(), + ) + + # Map Llama Models to provider IDs if needed + core_model_to_hf_repo = { + m.descriptor(): m.huggingface_repo for m in all_registered_models() + } + default_models = [ + ModelInput( + model_id=core_model_to_hf_repo[m.llama_model], + provider_model_id=m.provider_model_id, + provider_id="centml", + ) + for m in MODEL_ALIASES + ] + + # Example embedding model + embedding_model = ModelInput( + model_id="all-MiniLM-L6-v2", + provider_id="sentence-transformers", + model_type=ModelType.embedding, + metadata={"embedding_dimension": 384}, + ) + + return DistributionTemplate( + name=name, + distro_type="self_hosted", + description="Use CentML for running LLM inference", + docker_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + default_models=default_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider, embedding_provider], + "memory": [memory_provider], + }, + default_models=default_models + [embedding_model], + default_shields=[ + ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B") + ], + ), + }, + run_config_env_vars={ + "LLAMASTACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "CENTML_API_KEY": ( + "", + "CentML API Key", + ), + }, + ) diff --git a/llama_stack/templates/centml/doc_template.md b/llama_stack/templates/centml/doc_template.md new file mode 100644 index 0000000000..fded0b4e12 --- /dev/null +++ b/llama_stack/templates/centml/doc_template.md @@ -0,0 +1,66 @@ +--- +orphan: true +--- +# CentML Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }}` +{% endfor %} +{% endif %} + +### Prerequisite: API Keys + +Make sure you have a valid **CentML API Key**. Sign up or access your credentials at [CentML.com](https://centml.com/). + +## Running Llama Stack with CentML + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + llamastack/distribution-{{ name }} \ + --port $LLAMA_STACK_PORT \ + --env CENTML_API_KEY=$CENTML_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template {{ name }} --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env CENTML_API_KEY=$CENTML_API_KEY +``` \ No newline at end of file diff --git a/llama_stack/templates/centml/run.yaml b/llama_stack/templates/centml/run.yaml new file mode 100644 index 0000000000..414dd9065f --- /dev/null +++ b/llama_stack/templates/centml/run.yaml @@ -0,0 +1,129 @@ +version: '2' +image_name: centml +conda_env: centml +apis: + - agents + - datasetio + - eval + - inference + - memory + - safety + - scoring + - telemetry + - tool_runtime +providers: + inference: + - provider_id: centml + provider_type: remote::centml + config: + url: https://api.centml.com/openai/v1 + api_key: "${env.CENTML_API_KEY}" + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + + memory: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/centml}/faiss_store.db + + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/centml}/agents_store.db + + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:llama-stack} + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/centml}/trace_store.db + + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: {} + - provider_id: localfs + provider_type: inline::localfs + config: {} + + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} + +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/centml}/registry.db + +models: +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: centml + provider_model_id: meta-llama/Llama-3.3-70B-Instruct + model_type: llm + +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: centml + provider_model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 + model_type: llm + +shields: + - shield_id: meta-llama/Llama-Guard-3-8B + +memory_banks: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] +tool_groups: + - toolgroup_id: builtin::websearch + provider_id: tavily-search + - toolgroup_id: builtin::memory + provider_id: memory-runtime + - toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/pyproject.toml b/pyproject.toml index 638dd9c54f..6e14777f9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,5 @@ [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" +[tool.ruff] +line-length = 80