Skip to content

Commit

Permalink
Add vllm completions
Browse files Browse the repository at this point in the history
  • Loading branch information
aidando73 committed Jan 18, 2025
1 parent 3a9468c commit 09fc380
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion llama_stack/providers/remote/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
Expand Down Expand Up @@ -92,7 +94,19 @@ async def completion(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError("Completion not implemented for vLLM")
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)

async def chat_completion(
self,
Expand Down Expand Up @@ -154,6 +168,26 @@ async def _to_async_generator():
):
yield chunk

async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = await self._get_params(request)
r = self.client.completions.create(**params)
return process_completion_response(r, self.formatter)

async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)

# Wrapper for async generator similar
async def _to_async_generator():
stream = self.client.completions.create(**params)
for chunk in stream:
yield chunk

stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk

async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
res = self.client.models.list()
Expand Down

0 comments on commit 09fc380

Please sign in to comment.