From 09fc3800b9238c45e0e60b9cc65004643a2df970 Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Sat, 18 Jan 2025 19:35:17 +1100 Subject: [PATCH] Add vllm completions --- .../providers/remote/inference/vllm/vllm.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 1dbb4ecfa3..0cf16f0133 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -41,6 +41,8 @@ get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, @@ -92,7 +94,19 @@ async def completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - raise NotImplementedError("Completion not implemented for vLLM") + model = await self.model_store.get_model(model_id) + request = CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) async def chat_completion( self, @@ -154,6 +168,26 @@ async def _to_async_generator(): ): yield chunk + async def _nonstream_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + params = await self._get_params(request) + r = self.client.completions.create(**params) + return process_completion_response(r, self.formatter) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = await self._get_params(request) + + # Wrapper for async generator similar + async def _to_async_generator(): + stream = self.client.completions.create(**params) + for chunk in stream: + yield chunk + + stream = _to_async_generator() + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + async def register_model(self, model: Model) -> Model: model = await self.register_helper.register_model(model) res = self.client.models.list()