Skip to content

Commit

Permalink
Get streams back, unsure of underlying intent
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Jan 22, 2025
1 parent 9d7fb9e commit ab1a135
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,21 +464,14 @@ async def rag(
No pipeline classes necessary.
"""

#FIXME: Unsure of the intention here, but we need to review

# Convert any UUID filters to string
for f, val in list(search_settings.filters.items()):
if isinstance(val, UUID):
search_settings.filters[f] = str(val)

try:
if rag_generation_config.stream:
# For streaming, handle separately
return await self.stream_rag_response(
query,
rag_generation_config,
search_settings,
**kwargs,
)

# 1) Do the search
search_results_dict = await self.search(query, search_settings)
aggregated = AggregateSearchResult.from_dict(search_results_dict)
Expand All @@ -501,6 +494,14 @@ async def rag(
task_prompt_override=task_prompt_override,
)

if rag_generation_config.stream:
return await self.stream_rag_response(
messages=messages,
rag_generation_config=rag_generation_config,
aggregated_results=aggregated,
**kwargs
)

# 4) LLM completion
response = await self.providers.llm.aget_completion(
messages=messages, generation_config=rag_generation_config
Expand Down Expand Up @@ -573,26 +574,30 @@ def _build_rag_context(

async def stream_rag_response(
self,
query,
messages,
rag_generation_config,
search_settings,
*args,
aggregated_results,
**kwargs,
):
#FIXME: We need to yield aggregated_results as well
async def stream_response():
merged_kwargs = {
"input": to_async_generator([query]),
"state": None,
"search_settings": search_settings,
"rag_generation_config": rag_generation_config,
**kwargs,
}

async for chunk in await self.pipelines.streaming_rag_pipeline.run(
*args,
**merged_kwargs,
):
yield chunk
try:
async for chunk in self.providers.llm.aget_completion_stream(
messages=messages,
generation_config=rag_generation_config
):
yield chunk.choices[0].delta.content or ""
except Exception as e:
logger.error(f"Error in streaming RAG: {e}")
if "NoneType" in str(e):
raise HTTPException(
status_code=502,
detail="Server not reachable or returned an invalid response"
)
raise HTTPException(
status_code=500,
detail=f"Internal RAG Error - {str(e)}"
)

return stream_response()

Expand Down

0 comments on commit ab1a135

Please sign in to comment.