From c199b8efa7eae5ec02b3de9346f0125a036a4b6e Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Tue, 21 Jan 2025 21:58:48 -0800 Subject: [PATCH] up --- py/core/agent/rag.py | 4 +- py/core/main/api/v3/chunks_router.py | 1 + py/core/main/api/v3/collections_router.py | 1 + py/core/main/api/v3/conversations_router.py | 1 + py/core/main/api/v3/documents_router.py | 6 +- py/core/main/api/v3/graph_router.py | 1 + py/core/main/api/v3/indices_router.py | 2 +- py/core/main/api/v3/logs_router.py | 1 + py/core/main/api/v3/prompts_router.py | 2 + py/core/main/api/v3/retrieval_router.py | 2 + py/core/main/api/v3/system_router.py | 6 +- py/core/main/api/v3/users_router.py | 4 +- .../hatchet/ingestion_workflow.py | 6 +- py/core/main/services/graph_service.py | 1007 +---------------- .../3efc7b3b1b3d_add_total_tokens_count.py | 142 +++ py/r2r/compose.full.yaml | 2 +- 16 files changed, 167 insertions(+), 1021 deletions(-) create mode 100644 py/migrations/versions/3efc7b3b1b3d_add_total_tokens_count.py diff --git a/py/core/agent/rag.py b/py/core/agent/rag.py index 5b661c81c..ec9e1c298 100644 --- a/py/core/agent/rag.py +++ b/py/core/agent/rag.py @@ -1,4 +1,3 @@ -# rag_agent.py from typing import Any, Callable, Optional import tiktoken @@ -314,6 +313,8 @@ def __init__( rag_generation_config: GenerationConfig, local_search_method: Callable, content_method: Optional[Callable] = None, + max_tool_context_length: int = 10_000, + ): # Initialize base R2RAgent R2RAgent.__init__( @@ -331,6 +332,7 @@ def __init__( config=config, search_settings=search_settings, rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, local_search_method=local_search_method, content_method=content_method, ) diff --git a/py/core/main/api/v3/chunks_router.py b/py/core/main/api/v3/chunks_router.py index db3e3ef77..4d64488cf 100644 --- a/py/core/main/api/v3/chunks_router.py +++ b/py/core/main/api/v3/chunks_router.py @@ -40,6 +40,7 @@ def __init__( providers: R2RProviders, services: R2RServices, ): + logging.info("Initializing ChunksRouter") super().__init__(providers, services) def _setup_routes(self): diff --git a/py/core/main/api/v3/collections_router.py b/py/core/main/api/v3/collections_router.py index 726e5525e..ab6baf9d4 100644 --- a/py/core/main/api/v3/collections_router.py +++ b/py/core/main/api/v3/collections_router.py @@ -87,6 +87,7 @@ async def authorize_collection_action( class CollectionsRouter(BaseRouterV3): def __init__(self, providers: R2RProviders, services: R2RServices): + logging.info("Initializing CollectionsRouter") super().__init__(providers, services) def _setup_routes(self): diff --git a/py/core/main/api/v3/conversations_router.py b/py/core/main/api/v3/conversations_router.py index c24123354..0b4f57f59 100644 --- a/py/core/main/api/v3/conversations_router.py +++ b/py/core/main/api/v3/conversations_router.py @@ -29,6 +29,7 @@ def __init__( providers: R2RProviders, services: R2RServices, ): + logging.info("Initializing ConversationsRouter") super().__init__(providers, services) def _setup_routes(self): diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index 0c1b1f18c..cc4316587 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -81,6 +81,7 @@ def __init__( providers: R2RProviders, services: R2RServices, ): + logging.info("Initializing DocumentsRouter") super().__init__(providers, services) self._register_workflows() @@ -130,11 +131,6 @@ def _register_workflows(self): if self.providers.orchestration.config.provider != "simple" else "Document created and ingested successfully." ), - "update-files": ( - "Update file task queued successfully." - if self.providers.orchestration.config.provider != "simple" - else "Update task queued successfully." - ), "update-chunk": ( "Update chunk task queued successfully." if self.providers.orchestration.config.provider != "simple" diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 072285ab8..c636dcf79 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -38,6 +38,7 @@ def __init__( providers: R2RProviders, services: R2RServices, ): + logging.info("Initializing GraphRouter") super().__init__(providers, services) self._register_workflows() diff --git a/py/core/main/api/v3/indices_router.py b/py/core/main/api/v3/indices_router.py index a2f4f2a25..fc2392432 100644 --- a/py/core/main/api/v3/indices_router.py +++ b/py/core/main/api/v3/indices_router.py @@ -11,7 +11,6 @@ from core.base import IndexConfig, R2RException from core.base.abstractions import VectorTableName from core.base.api.models import ( - GenericMessageResponse, WrappedGenericMessageResponse, WrappedListVectorIndicesResponse, ) @@ -28,6 +27,7 @@ def __init__( providers: R2RProviders, services: R2RServices, ): + logging.info("Initializing IndicesRouter") super().__init__(providers, services) def _setup_routes(self): diff --git a/py/core/main/api/v3/logs_router.py b/py/core/main/api/v3/logs_router.py index e8faf31b6..2eb5ea13e 100644 --- a/py/core/main/api/v3/logs_router.py +++ b/py/core/main/api/v3/logs_router.py @@ -18,6 +18,7 @@ def __init__( providers: R2RProviders, services: R2RServices, ): + logging.info("Initializing LogsRouter") super().__init__(providers, services) CURRENT_DIR = Path(__file__).resolve().parent TEMPLATES_DIR = CURRENT_DIR.parent / "templates" diff --git a/py/core/main/api/v3/prompts_router.py b/py/core/main/api/v3/prompts_router.py index d7c214ca5..1d7ff6efc 100644 --- a/py/core/main/api/v3/prompts_router.py +++ b/py/core/main/api/v3/prompts_router.py @@ -1,3 +1,4 @@ +import logging import textwrap from typing import Optional @@ -23,6 +24,7 @@ def __init__( providers: R2RProviders, services: R2RServices, ): + logging.info("Initializing PromptsRouter") super().__init__(providers, services) def _setup_routes(self): diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index fdcaca50c..558e27a43 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -1,3 +1,4 @@ +import logging import textwrap from typing import Any, Optional from uuid import UUID @@ -46,6 +47,7 @@ def __init__( providers: R2RProviders, services: R2RServices, ): + logging.info("Initializing RetrievalRouterV3") super().__init__(providers, services) def _register_workflows(self): diff --git a/py/core/main/api/v3/system_router.py b/py/core/main/api/v3/system_router.py index 5bf4513df..0ee604525 100644 --- a/py/core/main/api/v3/system_router.py +++ b/py/core/main/api/v3/system_router.py @@ -1,15 +1,14 @@ +import logging import textwrap from datetime import datetime, timezone -from typing import Optional import psutil -from fastapi import Depends, Query +from fastapi import Depends from core.base import R2RException from core.base.api.models import ( GenericMessageResponse, WrappedGenericMessageResponse, - WrappedLogsResponse, WrappedServerStatsResponse, WrappedSettingsResponse, ) @@ -24,6 +23,7 @@ def __init__( providers: R2RProviders, services: R2RServices, ): + logging.info("Initializing SystemRouter") super().__init__(providers, services) self.start_time = datetime.now(timezone.utc) diff --git a/py/core/main/api/v3/users_router.py b/py/core/main/api/v3/users_router.py index 4dd4c6c24..061deab76 100644 --- a/py/core/main/api/v3/users_router.py +++ b/py/core/main/api/v3/users_router.py @@ -1,3 +1,4 @@ +import logging import os import textwrap import urllib.parse @@ -5,7 +6,7 @@ from uuid import UUID import requests -from fastapi import Body, Depends, HTTPException, Path, Query, Request +from fastapi import Body, Depends, HTTPException, Path, Query from fastapi.background import BackgroundTasks from fastapi.responses import FileResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm @@ -38,6 +39,7 @@ class UsersRouter(BaseRouterV3): def __init__(self, providers: R2RProviders, services: R2RServices): + logging.info("Initializing UsersRouter") super().__init__(providers, services) self.google_client_id = os.environ.get("GOOGLE_CLIENT_ID") self.google_client_secret = os.environ.get("GOOGLE_CLIENT_SECRET") diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index 0147aefc5..ef9763cbd 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -101,7 +101,7 @@ async def parse(self, context: Context) -> dict: ingestion_config = parsed_data["ingestion_config"] or {} extractions_generator = ( - await self.ingestion_service.parse_file( + self.ingestion_service.parse_file( document_info, ingestion_config ) ) @@ -147,7 +147,7 @@ async def parse(self, context: Context) -> dict: status=IngestionStatus.STORING, ) - storage_generator = await self.ingestion_service.store_embeddings( # type: ignore + storage_generator = self.ingestion_service.store_embeddings( # type: ignore embeddings ) @@ -413,7 +413,7 @@ async def embed(self, context: Context) -> dict: document_info, status=IngestionStatus.STORING ) - storage_generator = await self.ingestion_service.store_embeddings( + storage_generator = self.ingestion_service.store_embeddings( embeddings ) async for _ in storage_generator: diff --git a/py/core/main/services/graph_service.py b/py/core/main/services/graph_service.py index 584e5073d..0265d7a13 100644 --- a/py/core/main/services/graph_service.py +++ b/py/core/main/services/graph_service.py @@ -1432,1009 +1432,4 @@ async def deduplicate_document_entities( store_type=StoreType.DOCUMENTS, description=new_description, description_embedding=str(new_embedding), - ) - - -# import asyncio -# import json -# import logging -# import math -# import re -# import time -# import xml.etree.ElementTree as ET -# from typing import Any, AsyncGenerator, Optional -# from uuid import UUID - -# from core.base import ( -# DocumentChunk, -# KGExtraction, -# KGExtractionStatus, -# R2RDocumentProcessingError, -# ) -# from core.base.abstractions import ( -# Community, -# Entity, -# GenerationConfig, -# KGEnrichmentStatus, -# R2RException, -# Relationship, -# StoreType, -# ) -# from core.base.api.models import GraphResponse -# from core.telemetry.telemetry_decorator import telemetry_event - -# from ..abstractions import R2RProviders -# from ..config import R2RConfig -# from .base import Service - -# logger = logging.getLogger() - - -# MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH = 128 - - -# async def _collect_results(result_gen: AsyncGenerator) -> list[dict]: -# results = [] -# async for res in result_gen: -# results.append(res.json() if hasattr(res, "json") else res) -# return results - - -# # TODO - Fix naming convention to read `KGService` instead of `GraphService` -# # this will require a minor change in how services are registered. -# class GraphService(Service): -# def __init__( -# self, -# config: R2RConfig, -# providers: R2RProviders, -# ): -# super().__init__( -# config, -# providers, -# ) - -# @telemetry_event("create_entity") -# async def create_entity( -# self, -# name: str, -# description: str, -# parent_id: UUID, -# category: Optional[str] = None, -# metadata: Optional[dict] = None, -# ) -> Entity: -# description_embedding = str( -# await self.providers.embedding.async_get_embedding(description) -# ) - -# return await self.providers.database.graphs_handler.entities.create( -# name=name, -# parent_id=parent_id, -# store_type=StoreType.GRAPHS, -# category=category, -# description=description, -# description_embedding=description_embedding, -# metadata=metadata, -# ) - -# @telemetry_event("update_entity") -# async def update_entity( -# self, -# entity_id: UUID, -# name: Optional[str] = None, -# description: Optional[str] = None, -# category: Optional[str] = None, -# metadata: Optional[dict] = None, -# ) -> Entity: -# description_embedding = None -# if description is not None: -# description_embedding = str( -# await self.providers.embedding.async_get_embedding(description) -# ) - -# return await self.providers.database.graphs_handler.entities.update( -# entity_id=entity_id, -# store_type=StoreType.GRAPHS, -# name=name, -# description=description, -# description_embedding=description_embedding, -# category=category, -# metadata=metadata, -# ) - -# @telemetry_event("delete_entity") -# async def delete_entity( -# self, -# parent_id: UUID, -# entity_id: UUID, -# ): -# return await self.providers.database.graphs_handler.entities.delete( -# parent_id=parent_id, -# entity_ids=[entity_id], -# store_type=StoreType.GRAPHS, -# ) - -# @telemetry_event("get_entities") -# async def get_entities( -# self, -# parent_id: UUID, -# offset: int, -# limit: int, -# entity_ids: Optional[list[UUID]] = None, -# entity_names: Optional[list[str]] = None, -# include_embeddings: bool = False, -# ): -# return await self.providers.database.graphs_handler.get_entities( -# parent_id=parent_id, -# offset=offset, -# limit=limit, -# entity_ids=entity_ids, -# entity_names=entity_names, -# include_embeddings=include_embeddings, -# ) - -# @telemetry_event("create_relationship") -# async def create_relationship( -# self, -# subject: str, -# subject_id: UUID, -# predicate: str, -# object: str, -# object_id: UUID, -# parent_id: UUID, -# description: str | None = None, -# weight: float | None = 1.0, -# metadata: Optional[dict[str, Any] | str] = None, -# ) -> Relationship: -# description_embedding = None -# if description: -# description_embedding = str( -# await self.providers.embedding.async_get_embedding(description) -# ) - -# return ( -# await self.providers.database.graphs_handler.relationships.create( -# subject=subject, -# subject_id=subject_id, -# predicate=predicate, -# object=object, -# object_id=object_id, -# parent_id=parent_id, -# description=description, -# description_embedding=description_embedding, -# weight=weight, -# metadata=metadata, -# store_type=StoreType.GRAPHS, -# ) -# ) - -# @telemetry_event("delete_relationship") -# async def delete_relationship( -# self, -# parent_id: UUID, -# relationship_id: UUID, -# ): -# return ( -# await self.providers.database.graphs_handler.relationships.delete( -# parent_id=parent_id, -# relationship_ids=[relationship_id], -# store_type=StoreType.GRAPHS, -# ) -# ) - -# @telemetry_event("update_relationship") -# async def update_relationship( -# self, -# relationship_id: UUID, -# subject: Optional[str] = None, -# subject_id: Optional[UUID] = None, -# predicate: Optional[str] = None, -# object: Optional[str] = None, -# object_id: Optional[UUID] = None, -# description: Optional[str] = None, -# weight: Optional[float] = None, -# metadata: Optional[dict[str, Any] | str] = None, -# ) -> Relationship: -# description_embedding = None -# if description is not None: -# description_embedding = str( -# await self.providers.embedding.async_get_embedding(description) -# ) - -# return ( -# await self.providers.database.graphs_handler.relationships.update( -# relationship_id=relationship_id, -# subject=subject, -# subject_id=subject_id, -# predicate=predicate, -# object=object, -# object_id=object_id, -# description=description, -# description_embedding=description_embedding, -# weight=weight, -# metadata=metadata, -# store_type=StoreType.GRAPHS, -# ) -# ) - -# @telemetry_event("get_relationships") -# async def get_relationships( -# self, -# parent_id: UUID, -# offset: int, -# limit: int, -# relationship_ids: Optional[list[UUID]] = None, -# entity_names: Optional[list[str]] = None, -# ): -# return await self.providers.database.graphs_handler.relationships.get( -# parent_id=parent_id, -# store_type=StoreType.GRAPHS, -# offset=offset, -# limit=limit, -# relationship_ids=relationship_ids, -# entity_names=entity_names, -# ) - -# @telemetry_event("create_community") -# async def create_community( -# self, -# parent_id: UUID, -# name: str, -# summary: str, -# findings: Optional[list[str]], -# rating: Optional[float], -# rating_explanation: Optional[str], -# ) -> Community: -# description_embedding = str( -# await self.providers.embedding.async_get_embedding(summary) -# ) -# return await self.providers.database.graphs_handler.communities.create( -# parent_id=parent_id, -# store_type=StoreType.GRAPHS, -# name=name, -# summary=summary, -# description_embedding=description_embedding, -# findings=findings, -# rating=rating, -# rating_explanation=rating_explanation, -# ) - -# @telemetry_event("update_community") -# async def update_community( -# self, -# community_id: UUID, -# name: Optional[str], -# summary: Optional[str], -# findings: Optional[list[str]], -# rating: Optional[float], -# rating_explanation: Optional[str], -# ) -> Community: -# summary_embedding = None -# if summary is not None: -# summary_embedding = str( -# await self.providers.embedding.async_get_embedding(summary) -# ) - -# return await self.providers.database.graphs_handler.communities.update( -# community_id=community_id, -# store_type=StoreType.GRAPHS, -# name=name, -# summary=summary, -# summary_embedding=summary_embedding, -# findings=findings, -# rating=rating, -# rating_explanation=rating_explanation, -# ) - -# @telemetry_event("delete_community") -# async def delete_community( -# self, -# parent_id: UUID, -# community_id: UUID, -# ) -> None: -# await self.providers.database.graphs_handler.communities.delete( -# parent_id=parent_id, -# community_id=community_id, -# ) - -# @telemetry_event("list_communities") -# async def list_communities( -# self, -# collection_id: UUID, -# offset: int, -# limit: int, -# ): -# return await self.providers.database.graphs_handler.communities.get( -# parent_id=collection_id, -# store_type=StoreType.GRAPHS, -# offset=offset, -# limit=limit, -# ) - -# @telemetry_event("get_communities") -# async def get_communities( -# self, -# parent_id: UUID, -# offset: int, -# limit: int, -# community_ids: Optional[list[UUID]] = None, -# community_names: Optional[list[str]] = None, -# include_embeddings: bool = False, -# ): -# return await self.providers.database.graphs_handler.get_communities( -# parent_id=parent_id, -# offset=offset, -# limit=limit, -# community_ids=community_ids, -# include_embeddings=include_embeddings, -# ) - -# async def list_graphs( -# self, -# offset: int, -# limit: int, -# # user_ids: Optional[list[UUID]] = None, -# graph_ids: Optional[list[UUID]] = None, -# collection_id: Optional[UUID] = None, -# ) -> dict[str, list[GraphResponse] | int]: -# return await self.providers.database.graphs_handler.list_graphs( -# offset=offset, -# limit=limit, -# # filter_user_ids=user_ids, -# filter_graph_ids=graph_ids, -# filter_collection_id=collection_id, -# ) - -# @telemetry_event("update_graph") -# async def update_graph( -# self, -# collection_id: UUID, -# name: Optional[str] = None, -# description: Optional[str] = None, -# ) -> GraphResponse: -# return await self.providers.database.graphs_handler.update( -# collection_id=collection_id, -# name=name, -# description=description, -# ) - -# @telemetry_event("reset_graph_v3") -# async def reset_graph_v3(self, id: UUID) -> bool: -# await self.providers.database.graphs_handler.reset( -# parent_id=id, -# ) -# await self.providers.database.documents_handler.set_workflow_status( -# id=id, -# status_type="graph_cluster_status", -# status=KGEnrichmentStatus.PENDING, -# ) -# return True - -# @telemetry_event("get_document_ids_for_create_graph") -# async def get_document_ids_for_create_graph( -# self, -# collection_id: UUID, -# **kwargs, -# ): -# document_status_filter = [ -# KGExtractionStatus.PENDING, -# KGExtractionStatus.FAILED, -# ] - -# return await self.providers.database.documents_handler.get_document_ids_by_status( -# status_type="extraction_status", -# status=[str(ele) for ele in document_status_filter], -# collection_id=collection_id, -# ) - -# @telemetry_event("kg_entity_description") -# async def kg_entity_description( -# self, -# document_id: UUID, -# max_description_input_length: int, -# **kwargs, -# ): -# start_time = time.time() - -# logger.info( -# f"KGService: Running kg_entity_description for document {document_id}" -# ) - -# entity_count = ( -# await self.providers.database.graphs_handler.get_entity_count( -# document_id=document_id, -# distinct=True, -# entity_table_name="documents_entities", -# ) -# ) - -# logger.info( -# f"KGService: Found {entity_count} entities in document {document_id}" -# ) - -# # TODO - Do not hardcode the batch size, -# # make it a configurable parameter at runtime & server-side defaults - -# # process 256 entities at a time -# num_batches = math.ceil(entity_count / 256) -# logger.info( -# f"Calling `kg_entity_description` on document {document_id} with an entity count of {entity_count} and total batches of {num_batches}" -# ) -# all_results = [] -# for i in range(num_batches): -# logger.info( -# f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}" -# ) - -# node_descriptions = await self.pipes.graph_description_pipe.run( -# input=self.pipes.graph_description_pipe.Input( -# message={ -# "offset": i * 256, -# "limit": 256, -# "max_description_input_length": max_description_input_length, -# "document_id": document_id, -# "logger": logger, -# } -# ), -# state=None, -# ) - -# all_results.append(await _collect_results(node_descriptions)) - -# logger.info( -# f"KGService: Completed kg_entity_description for batch {i+1}/{num_batches} for document {document_id}" -# ) - -# await self.providers.database.documents_handler.set_workflow_status( -# id=document_id, -# status_type="extraction_status", -# status=KGExtractionStatus.SUCCESS, -# ) - -# logger.info( -# f"KGService: Completed kg_entity_description for document {document_id} in {time.time() - start_time:.2f} seconds", -# ) - -# return all_results - -# @telemetry_event("kg_clustering") -# async def kg_clustering( -# self, -# collection_id: UUID, -# # graph_id: UUID, -# generation_config: GenerationConfig, -# leiden_params: dict, -# **kwargs, -# ): -# logger.info( -# f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}" -# ) - -# clustering_result = await self.pipes.graph_clustering_pipe.run( -# input=self.pipes.graph_clustering_pipe.Input( -# message={ -# "collection_id": collection_id, -# "generation_config": generation_config, -# "leiden_params": leiden_params, -# "logger": logger, -# "clustering_mode": self.config.database.graph_creation_settings.clustering_mode, -# } -# ), -# state=None, -# ) -# return await _collect_results(clustering_result) - -# @telemetry_event("kg_community_summary") -# async def kg_community_summary( -# self, -# offset: int, -# limit: int, -# max_summary_input_length: int, -# generation_config: GenerationConfig, -# collection_id: UUID | None, -# # graph_id: UUID | None, -# **kwargs, -# ): -# summary_results = await self.pipes.graph_community_summary_pipe.run( -# input=self.pipes.graph_community_summary_pipe.Input( -# message={ -# "offset": offset, -# "limit": limit, -# "generation_config": generation_config, -# "max_summary_input_length": max_summary_input_length, -# "collection_id": collection_id, -# # "graph_id": graph_id, -# "logger": logger, -# } -# ), -# state=None, -# ) -# return await _collect_results(summary_results) - -# @telemetry_event("delete_graph_for_documents") -# async def delete_graph_for_documents( -# self, -# document_ids: list[UUID], -# **kwargs, -# ): -# # TODO: Implement this, as it needs some checks. -# raise NotImplementedError - -# @telemetry_event("delete_graph") -# async def delete_graph( -# self, -# collection_id: UUID, -# ): -# return await self.delete(collection_id=collection_id) - -# @telemetry_event("delete") -# async def delete( -# self, -# collection_id: UUID, -# **kwargs, -# ): -# return await self.providers.database.graphs_handler.delete( -# collection_id=collection_id, -# ) - -# async def kg_extraction( # type: ignore -# self, -# document_id: UUID, -# generation_config: GenerationConfig, -# max_knowledge_relationships: int, -# entity_types: list[str], -# relation_types: list[str], -# chunk_merge_count: int, -# filter_out_existing_chunks: bool = True, -# total_tasks: Optional[int] = None, -# *args: Any, -# **kwargs: Any, -# ) -> AsyncGenerator[KGExtraction | R2RDocumentProcessingError, None]: -# start_time = time.time() - -# logger.info( -# f"Graph Extraction: Processing document {document_id} for KG extraction", -# ) - -# # Then create the extractions from the results -# limit = 100 -# offset = 0 -# chunks = [] -# while True: -# chunk_req = await self.providers.database.chunks_handler.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. -# document_id=document_id, -# offset=offset, -# limit=limit, -# ) - -# chunks.extend( -# [ -# DocumentChunk( -# id=chunk["id"], -# document_id=chunk["document_id"], -# owner_id=chunk["owner_id"], -# collection_ids=chunk["collection_ids"], -# data=chunk["text"], -# metadata=chunk["metadata"], -# ) -# for chunk in chunk_req["results"] -# ] -# ) -# if len(chunk_req["results"]) < limit: -# break -# offset += limit - -# logger.info(f"Found {len(chunks)} chunks for document {document_id}") -# if len(chunks) == 0: -# logger.info(f"No chunks found for document {document_id}") -# raise R2RException( -# message="No chunks found for document", -# status_code=404, -# ) - -# if filter_out_existing_chunks: -# existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids( -# document_id=document_id -# ) -# chunks = [ -# chunk for chunk in chunks if chunk.id not in existing_chunk_ids -# ] -# logger.info( -# f"Filtered out {len(existing_chunk_ids)} existing chunks, remaining {len(chunks)} chunks for document {document_id}" -# ) - -# if len(chunks) == 0: -# logger.info(f"No extractions left for document {document_id}") -# return - -# logger.info( -# f"Graph Extraction: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds", -# ) - -# # sort the extractions accroding to chunk_order field in metadata in ascending order -# chunks = sorted( -# chunks, -# key=lambda x: x.metadata.get("chunk_order", float("inf")), -# ) - -# # group these extractions into groups of chunk_merge_count -# grouped_chunks = [ -# chunks[i : i + chunk_merge_count] -# for i in range(0, len(chunks), chunk_merge_count) -# ] - -# logger.info( -# f"Graph Extraction: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds", -# ) - -# tasks = [ -# asyncio.create_task( -# self._extract_kg( -# chunks=chunk_group, -# generation_config=generation_config, -# max_knowledge_relationships=max_knowledge_relationships, -# entity_types=entity_types, -# relation_types=relation_types, -# task_id=task_id, -# total_tasks=len(grouped_chunks), -# ) -# ) -# for task_id, chunk_group in enumerate(grouped_chunks) -# ] - -# completed_tasks = 0 -# total_tasks = len(tasks) - -# logger.info( -# f"Graph Extraction: Waiting for {total_tasks} KG extraction tasks to complete", -# ) - -# for completed_task in asyncio.as_completed(tasks): -# try: -# yield await completed_task -# completed_tasks += 1 -# if completed_tasks % 100 == 0: -# logger.info( -# f"Graph Extraction: Completed {completed_tasks}/{total_tasks} KG extraction tasks", -# ) -# except Exception as e: -# logger.error(f"Error in Extracting KG Relationships: {e}") -# yield R2RDocumentProcessingError( -# document_id=document_id, -# error_message=str(e), -# ) - -# logger.info( -# f"Graph Extraction: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds", -# ) - -# async def _extract_kg( -# self, -# chunks: list[DocumentChunk], -# generation_config: GenerationConfig, -# max_knowledge_relationships: int, -# entity_types: list[str], -# relation_types: list[str], -# retries: int = 5, -# delay: int = 2, -# task_id: Optional[int] = None, -# total_tasks: Optional[int] = None, -# ) -> KGExtraction: -# """ -# Extracts NER relationships from a extraction with retries. -# """ - -# # combine all extractions into a single string -# combined_extraction: str = " ".join([chunk.data for chunk in chunks]) # type: ignore - -# response = await self.providers.database.documents_handler.get_documents_overview( # type: ignore -# offset=0, -# limit=1, -# filter_document_ids=[chunks[0].document_id], -# ) -# document_summary = ( -# response["results"][0].summary if response["results"] else None -# ) - -# messages = await self.providers.database.prompts_handler.get_message_payload( -# task_prompt_name=self.providers.database.config.graph_creation_settings.graphrag_relationships_extraction_few_shot, -# task_inputs={ -# "document_summary": document_summary, -# "input": combined_extraction, -# "max_knowledge_relationships": max_knowledge_relationships, -# "entity_types": "\n".join(entity_types), -# "relation_types": "\n".join(relation_types), -# }, -# ) - -# for attempt in range(retries): -# try: -# response = await self.providers.llm.aget_completion( -# messages, -# generation_config=generation_config, -# ) - -# kg_extraction = response.choices[0].message.content - -# if not kg_extraction: -# raise R2RException( -# "No knowledge graph extraction found in the response string, the selected LLM likely failed to format it's response correctly.", -# 400, -# ) - -# def sanitize_xml(response_str: str) -> str: -# """Attempts to sanitize the XML response string by""" -# # Strip any markdown -# response_str = re.sub(r"```xml|```", "", response_str) - -# # Remove any XML processing instructions or style tags -# response_str = re.sub(r"<\?.*?\?>", "", response_str) -# response_str = re.sub( -# r".*?", "", response_str -# ) - -# # Only replace & if it's not already part of an escape sequence -# response_str = re.sub( -# r"&(?!amp;|quot;|apos;|lt;|gt;)", "&", response_str -# ) - -# # Remove any root tags since we'll add them in parse_fn -# response_str = response_str.replace("", "").replace( -# "", "" -# ) - -# # Find and track all opening/closing tags -# opened_tags = [] -# for match in re.finditer( -# r"<(\w+)(?:\s+[^>]*)?>", response_str -# ): -# tag = match.group(1) -# if tag != "root": # Don't track root tag -# opened_tags.append(tag) - -# for match in re.finditer(r"", response_str): -# tag = match.group(1) -# if tag in opened_tags: -# opened_tags.remove(tag) - -# # Close any unclosed tags -# for tag in reversed(opened_tags): -# response_str += f"" - -# return response_str.strip() - -# async def parse_fn(response_str: str) -> Any: -# # Wrap the response in a root element to ensure it is valid XML -# cleaned_xml = sanitize_xml(response_str) -# wrapped_xml = f"{cleaned_xml}" - -# try: -# root = ET.fromstring(wrapped_xml) -# except ET.ParseError as e: -# raise R2RException( -# f"Failed to parse XML response: {e}. Response: {wrapped_xml}", -# 400, -# ) - -# entities = root.findall(".//entity") -# if ( -# len(kg_extraction) -# > MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH -# and len(entities) == 0 -# ): -# raise R2RException( -# f"No entities found in the response string, the selected LLM likely failed to format it's response correctly. {response_str}", -# 400, -# ) - -# entities_arr = [] -# for entity_elem in entities: -# entity_value = entity_elem.get("name") -# entity_category = entity_elem.find("type").text -# entity_description = entity_elem.find( -# "description" -# ).text - -# description_embedding = ( -# await self.providers.embedding.async_get_embedding( -# entity_description -# ) -# ) - -# entities_arr.append( -# Entity( -# category=entity_category, -# description=entity_description, -# name=entity_value, -# parent_id=chunks[0].document_id, -# chunk_ids=[chunk.id for chunk in chunks], -# description_embedding=description_embedding, -# attributes={}, -# ) -# ) - -# relations_arr = [] -# for rel_elem in root.findall(".//relationship"): -# if rel_elem is not None: -# source_elem = rel_elem.find("source") -# target_elem = rel_elem.find("target") -# type_elem = rel_elem.find("type") -# desc_elem = rel_elem.find("description") -# weight_elem = rel_elem.find("weight") - -# if all( -# [ -# elem is not None -# for elem in [ -# source_elem, -# target_elem, -# type_elem, -# desc_elem, -# weight_elem, -# ] -# ] -# ): -# assert source_elem is not None -# assert target_elem is not None -# assert type_elem is not None -# assert desc_elem is not None -# assert weight_elem is not None - -# subject = source_elem.text -# object = target_elem.text -# predicate = type_elem.text -# description = desc_elem.text -# weight = float(weight_elem.text) - -# relationship_embedding = await self.providers.embedding.async_get_embedding( -# description -# ) - -# relations_arr.append( -# Relationship( -# subject=subject, -# predicate=predicate, -# object=object, -# description=description, -# weight=weight, -# parent_id=chunks[0].document_id, -# chunk_ids=[ -# chunk.id for chunk in chunks -# ], -# attributes={}, -# description_embedding=relationship_embedding, -# ) -# ) - -# return entities_arr, relations_arr - -# entities, relationships = await parse_fn(kg_extraction) -# return KGExtraction( -# entities=entities, -# relationships=relationships, -# ) - -# except ( -# Exception, -# json.JSONDecodeError, -# KeyError, -# IndexError, -# R2RException, -# ) as e: -# if attempt < retries - 1: -# await asyncio.sleep(delay) -# else: -# logger.warning( -# f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}" -# ) - -# logger.info( -# f"Graph Extraction: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}", -# ) - -# return KGExtraction( -# entities=[], -# relationships=[], -# ) - -# async def store_kg_extractions( -# self, -# kg_extractions: list[KGExtraction], -# ): -# """ -# Stores a batch of knowledge graph extractions in the graph database. -# """ - -# for extraction in kg_extractions: -# entities_id_map = {} -# for entity in extraction.entities: -# result = await self.providers.database.graphs_handler.entities.create( -# name=entity.name, -# parent_id=entity.parent_id, -# store_type=StoreType.DOCUMENTS, -# category=entity.category, -# description=entity.description, -# description_embedding=entity.description_embedding, -# chunk_ids=entity.chunk_ids, -# metadata=entity.metadata, -# ) -# entities_id_map[entity.name] = result.id - -# if extraction.relationships: -# for relationship in extraction.relationships: -# await self.providers.database.graphs_handler.relationships.create( -# subject=relationship.subject, -# subject_id=entities_id_map.get(relationship.subject), -# predicate=relationship.predicate, -# object=relationship.object, -# object_id=entities_id_map.get(relationship.object), -# parent_id=relationship.parent_id, -# description=relationship.description, -# description_embedding=relationship.description_embedding, -# weight=relationship.weight, -# metadata=relationship.metadata, -# store_type=StoreType.DOCUMENTS, -# ) - -# @telemetry_event("deduplicate_document_entities") -# async def deduplicate_document_entities( -# self, -# document_id: UUID, -# ): -# """ -# Deduplicate entities in a document. -# """ - -# merged_results = await self.providers.database.entities_handler.merge_duplicate_name_blocks( -# parent_id=document_id, -# store_type=StoreType.DOCUMENTS, -# ) - -# response = await self.providers.database.documents_handler.get_documents_overview( -# offset=0, -# limit=1, -# filter_document_ids=[document_id], -# ) -# document_summary = ( -# response["results"][0].summary if response["results"] else None -# ) - -# for original_entities, merged_entity in merged_results: -# # Generate new consolidated description using the LLM -# messages = await self.providers.database.prompts_handler.get_message_payload( -# task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt, -# task_inputs={ -# "document_summary": document_summary, -# "entity_info": f"{merged_entity.name}\n".join( -# [ -# desc -# for desc in { -# e.description for e in original_entities -# } -# if desc is not None -# ] -# ), -# "relationships_txt": "", -# }, -# ) - -# generation_config = ( -# self.config.database.graph_creation_settings.generation_config -# ) -# response = await self.providers.llm.aget_completion( -# messages, -# generation_config=generation_config, -# ) -# new_description = response.choices[0].message.content - -# # Generate new embedding for the consolidated description -# new_embedding = await self.providers.embedding.async_get_embedding( -# new_description -# ) - -# # Update the entity with new description and embedding -# await self.providers.database.graphs_handler.entities.update( -# entity_id=merged_entity.id, -# store_type=StoreType.DOCUMENTS, -# description=new_description, -# description_embedding=str(new_embedding), -# ) + ) \ No newline at end of file diff --git a/py/migrations/versions/3efc7b3b1b3d_add_total_tokens_count.py b/py/migrations/versions/3efc7b3b1b3d_add_total_tokens_count.py new file mode 100644 index 000000000..6ee5c237f --- /dev/null +++ b/py/migrations/versions/3efc7b3b1b3d_add_total_tokens_count.py @@ -0,0 +1,142 @@ +"""add_total_tokens_to_documents + +Revision ID: 123456789abc +Revises: 7eb70560f406 +Create Date: 2025-01-21 14:59:00.000000 + +""" + +import os +import math +import tiktoken +import logging +from alembic import op +import sqlalchemy as sa +from sqlalchemy import text + +# revision identifiers, used by Alembic. +revision = "123456789abc" +down_revision = "7eb70560f406" # Make sure this matches your newest migration +branch_labels = None +depends_on = None + +logger = logging.getLogger("alembic.runtime.migration") + + +def count_tokens_for_text(text: str, model: str = "gpt-3.5-turbo") -> int: + """ + Count the number of tokens in the given text using tiktoken. + Default model is set to "gpt-3.5-turbo". Adjust if you prefer a different model. + """ + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # Fallback to a known encoding if model not recognized + encoding = tiktoken.get_encoding("cl100k_base") + return len(encoding.encode(text)) + + +def upgrade() -> None: + connection = op.get_bind() + + # 1) Check if column 'total_tokens' already exists in 'documents' + # If not, we'll create it with a default of 0. + # (If you want the default to be NULL instead of 0, adjust as needed.) + insp = sa.inspect(connection) + columns = insp.get_columns("documents") # uses default schema or your schema + col_names = [col["name"] for col in columns] + if "total_tokens" not in col_names: + logger.info("Adding 'total_tokens' column to 'documents' table...") + op.add_column( + "documents", sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0") + ) + else: + logger.info("Column 'total_tokens' already exists in 'documents' table, skipping add-column step.") + + # 2) Fill in 'total_tokens' for each document by summing the tokens from all chunks + # We do this in batches to avoid loading too much data at once. + + BATCH_SIZE = 500 + + # a) Count how many documents we have + logger.info("Determining how many documents need updating...") + doc_count_query = text("SELECT COUNT(*) FROM documents") + total_docs = connection.execute(doc_count_query).scalar() or 0 + logger.info(f"Total documents found: {total_docs}") + + if total_docs == 0: + logger.info("No documents found, nothing to update.") + return + + # b) We'll iterate over documents in pages of size BATCH_SIZE + pages = math.ceil(total_docs / BATCH_SIZE) + logger.info(f"Updating total_tokens in {pages} batches of up to {BATCH_SIZE} documents...") + + # Optionally choose a Tiktoken model via environment variable + # or just default if none is set + default_model = os.getenv("R2R_TOKCOUNT_MODEL", "gpt-3.5-turbo") + + offset = 0 + for page_idx in range(pages): + logger.info(f"Processing batch {page_idx + 1} / {pages} (OFFSET={offset}, LIMIT={BATCH_SIZE})") + + # c) Fetch the IDs of the next batch of documents + batch_docs_query = text( + f""" + SELECT id + FROM documents + ORDER BY id -- or ORDER BY created_at, if you prefer chronological + LIMIT :limit_val + OFFSET :offset_val + """ + ) + batch_docs = connection.execute( + batch_docs_query, {"limit_val": BATCH_SIZE, "offset_val": offset} + ).fetchall() + + # If no results, break early + if not batch_docs: + break + + doc_ids = [row["id"] for row in batch_docs] + offset += BATCH_SIZE + + # d) For each document in this batch, sum up tokens from the chunks table + for doc_id in doc_ids: + # Get all chunk text for this doc + chunks_query = text( + """ + SELECT data + FROM chunks + WHERE document_id = :doc_id + """ + ) + chunk_rows = connection.execute(chunks_query, {"doc_id": doc_id}).fetchall() + + total_tokens = 0 + for c_row in chunk_rows: + chunk_text = c_row["data"] or "" + total_tokens += count_tokens_for_text(chunk_text, model=default_model) + + # e) Update total_tokens for this doc + update_query = text( + """ + UPDATE documents + SET total_tokens = :tokcount + WHERE id = :doc_id + """ + ) + connection.execute(update_query, {"tokcount": total_tokens, "doc_id": doc_id}) + + logger.info(f"Finished batch {page_idx + 1}") + + logger.info("Done updating total_tokens.") + + +def downgrade() -> None: + """ + If you want to remove the total_tokens column on downgrade, do so here. + Otherwise, you can leave it in place. + """ + logger.info("Dropping column 'total_tokens' from 'documents' table (downgrade).") + op.drop_column("documents", "total_tokens") diff --git a/py/r2r/compose.full.yaml b/py/r2r/compose.full.yaml index 951c714b7..634dcdbc7 100644 --- a/py/r2r/compose.full.yaml +++ b/py/r2r/compose.full.yaml @@ -203,7 +203,7 @@ services: echo 'Starting token creation process...' # Attempt to create token and capture both stdout and stderr - TOKEN_OUTPUT=$$(/hatchet/hatchet-admin token create --config /hatchet/config --tenant-id 707d0855-80ab-4e1f-a156-f1c4546cbf52 --expiresIn 86400 2>&1) + TOKEN_OUTPUT=$$(/hatchet/hatchet-admin token create --config /hatchet/config --tenant-id 707d0855-80ab-4e1f-a156-f1c4546cbf52 2>&1) # Extract the token (assuming it's the only part that looks like a JWT) TOKEN=$$(echo \"$$TOKEN_OUTPUT\" | grep -Eo 'eyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*')