From 677312190467c7b8d047ec0fa8544094889eea27 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 28 Jul 2025 23:19:36 +0200 Subject: [PATCH 1/8] fix: datasets status without datasets parameter --- cognee/api/v1/datasets/routers/get_datasets_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/api/v1/datasets/routers/get_datasets_router.py b/cognee/api/v1/datasets/routers/get_datasets_router.py index 335306ba5..4de6feca1 100644 --- a/cognee/api/v1/datasets/routers/get_datasets_router.py +++ b/cognee/api/v1/datasets/routers/get_datasets_router.py @@ -353,7 +353,7 @@ def get_datasets_router() -> APIRouter: @router.get("/status", response_model=dict[str, PipelineRunStatus]) async def get_dataset_status( - datasets: Annotated[List[UUID], Query(alias="dataset")] = None, + datasets: Annotated[List[UUID], Query(alias="dataset")] = [], user: User = Depends(get_authenticated_user), ): """ From 9793cd56ad23064c8302f386f3961bc8dbfe8b24 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 28 Jul 2025 23:20:21 +0200 Subject: [PATCH 2/8] version: 0.2.2.dev0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 33e45d88c..b2d3e8e30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cognee" -version = "0.2.1" +version = "0.2.2.dev0" description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." authors = [ { name = "Vasilije Markovic" }, From 961fa5ec45d835783af46d86b512d41f71b5d42e Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 28 Jul 2025 23:23:35 +0200 Subject: [PATCH 3/8] chore: update uv.lock file --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 039f96ca8..6eb4a6aca 100644 --- a/uv.lock +++ b/uv.lock @@ -857,7 +857,7 @@ wheels = [ [[package]] name = "cognee" -version = "0.2.1" +version = "0.2.2.dev0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, From f78af0cec307510f3bc330604a37b7e248569871 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 29 Jul 2025 12:35:38 +0200 Subject: [PATCH 4/8] feature: solve edge embedding duplicates in edge collection + retriever optimization (#1151) ## Description feature: solve edge embedding duplicates in edge collection + retriever optimization ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com> --- .../modules/engine/utils/generate_edge_id.py | 5 ++++ .../modules/graph/cognee_graph/CogneeGraph.py | 27 +++++++------------ cognee/tasks/storage/index_graph_edges.py | 5 +++- 3 files changed, 18 insertions(+), 19 deletions(-) create mode 100644 cognee/modules/engine/utils/generate_edge_id.py diff --git a/cognee/modules/engine/utils/generate_edge_id.py b/cognee/modules/engine/utils/generate_edge_id.py new file mode 100644 index 000000000..00645284b --- /dev/null +++ b/cognee/modules/engine/utils/generate_edge_id.py @@ -0,0 +1,5 @@ +from uuid import NAMESPACE_OID, uuid5 + + +def generate_edge_id(edge_id: str) -> str: + return uuid5(NAMESPACE_OID, edge_id.lower().replace(" ", "_").replace("'", "")) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 4e3a2d15a..ca1984dfe 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -170,28 +170,19 @@ class CogneeGraph(CogneeAbstractGraph): for edge in self.edges: relationship_type = edge.attributes.get("relationship_type") - if relationship_type and relationship_type in embedding_map: - edge.attributes["vector_distance"] = embedding_map[relationship_type] + distance = embedding_map.get(relationship_type, None) + if distance is not None: + edge.attributes["vector_distance"] = distance except Exception as ex: logger.error(f"Error mapping vector distances to edges: {str(ex)}") raise ex async def calculate_top_triplet_importances(self, k: int) -> List: - min_heap = [] + def score(edge): + n1 = edge.node1.attributes.get("vector_distance", 1) + n2 = edge.node2.attributes.get("vector_distance", 1) + e = edge.attributes.get("vector_distance", 1) + return n1 + n2 + e - for i, edge in enumerate(self.edges): - source_node = self.get_node(edge.node1.id) - target_node = self.get_node(edge.node2.id) - - source_distance = source_node.attributes.get("vector_distance", 1) if source_node else 1 - target_distance = target_node.attributes.get("vector_distance", 1) if target_node else 1 - edge_distance = edge.attributes.get("vector_distance", 1) - - total_distance = source_distance + target_distance + edge_distance - - heapq.heappush(min_heap, (-total_distance, i, edge)) - if len(min_heap) > k: - heapq.heappop(min_heap) - - return [edge for _, _, edge in sorted(min_heap)] + return heapq.nsmallest(k, self.edges, key=score) diff --git a/cognee/tasks/storage/index_graph_edges.py b/cognee/tasks/storage/index_graph_edges.py index 54f72804b..2233ab99f 100644 --- a/cognee/tasks/storage/index_graph_edges.py +++ b/cognee/tasks/storage/index_graph_edges.py @@ -1,3 +1,4 @@ +from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from cognee.shared.logging_utils import get_logger, ERROR from collections import Counter @@ -49,7 +50,9 @@ async def index_graph_edges(batch_size: int = 1024): ) for text, count in edge_types.items(): - edge = EdgeType(relationship_name=text, number_of_edges=count) + edge = EdgeType( + id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count + ) data_point_type = type(edge) for field_name in edge.metadata["index_fields"]: From 14ba3e8829634bc856ed17c9c9e50fb290d2cf5d Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Tue, 29 Jul 2025 16:39:31 +0200 Subject: [PATCH 5/8] feat: Enable async execution of data items for incremental loading (#1092) ## Description Attempt at making incremental loading run async ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- cognee/api/v1/add/add.py | 2 + cognee/api/v1/add/routers/get_add_router.py | 3 + cognee/api/v1/cognify/code_graph_pipeline.py | 4 +- cognee/api/v1/cognify/cognify.py | 8 + .../v1/cognify/routers/get_cognify_router.py | 9 +- cognee/modules/data/models/Data.py | 7 +- .../processing/document_types/PdfDocument.py | 15 +- .../modules/pipelines/exceptions/__init__.py | 1 + .../pipelines/exceptions/exceptions.py | 12 + .../pipelines/models/DataItemStatus.py | 5 + .../pipelines/models/PipelineRunInfo.py | 6 + cognee/modules/pipelines/models/__init__.py | 1 + .../modules/pipelines/operations/pipeline.py | 7 +- .../modules/pipelines/operations/run_tasks.py | 246 ++++++++++++++++-- .../extract_chunks_from_documents.py | 20 +- cognee/tasks/ingestion/ingest_data.py | 3 +- .../ingestion/resolve_data_directories.py | 3 + .../get_repo_file_dependencies.py | 3 + cognee/tests/test_deduplication.py | 4 +- 19 files changed, 309 insertions(+), 50 deletions(-) create mode 100644 cognee/modules/pipelines/exceptions/__init__.py create mode 100644 cognee/modules/pipelines/exceptions/exceptions.py create mode 100644 cognee/modules/pipelines/models/DataItemStatus.py diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index 4f51729a3..3e4aaae49 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -15,6 +15,7 @@ async def add( vector_db_config: dict = None, graph_db_config: dict = None, dataset_id: Optional[UUID] = None, + incremental_loading: bool = True, ): """ Add data to Cognee for knowledge graph processing. @@ -153,6 +154,7 @@ async def add( pipeline_name="add_pipeline", vector_db_config=vector_db_config, graph_db_config=graph_db_config, + incremental_loading=incremental_loading, ): pipeline_run_info = run_info diff --git a/cognee/api/v1/add/routers/get_add_router.py b/cognee/api/v1/add/routers/get_add_router.py index 4519af728..66b165a38 100644 --- a/cognee/api/v1/add/routers/get_add_router.py +++ b/cognee/api/v1/add/routers/get_add_router.py @@ -11,6 +11,7 @@ from typing import List, Optional, Union, Literal from cognee.modules.users.models import User from cognee.modules.users.methods import get_authenticated_user from cognee.shared.utils import send_telemetry +from cognee.modules.pipelines.models import PipelineRunErrored from cognee.shared.logging_utils import get_logger logger = get_logger() @@ -100,6 +101,8 @@ def get_add_router() -> APIRouter: else: add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId) + if isinstance(add_run, PipelineRunErrored): + return JSONResponse(status_code=420, content=add_run.model_dump(mode="json")) return add_run.model_dump() except Exception as error: return JSONResponse(status_code=409, content={"error": str(error)}) diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 00a0d3dc9..0da286c4b 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -79,7 +79,9 @@ async def run_code_graph_pipeline(repo_path, include_docs=False): async for run_status in non_code_pipeline_run: yield run_status - async for run_status in run_tasks(tasks, dataset.id, repo_path, user, "cognify_code_pipeline"): + async for run_status in run_tasks( + tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False + ): yield run_status diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 7c7821460..c6508f3a7 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -39,6 +39,7 @@ async def cognify( vector_db_config: dict = None, graph_db_config: dict = None, run_in_background: bool = False, + incremental_loading: bool = True, ): """ Transform ingested data into a structured knowledge graph. @@ -194,6 +195,7 @@ async def cognify( datasets=datasets, vector_db_config=vector_db_config, graph_db_config=graph_db_config, + incremental_loading=incremental_loading, ) else: return await run_cognify_blocking( @@ -202,6 +204,7 @@ async def cognify( datasets=datasets, vector_db_config=vector_db_config, graph_db_config=graph_db_config, + incremental_loading=incremental_loading, ) @@ -211,6 +214,7 @@ async def run_cognify_blocking( datasets, graph_db_config: dict = None, vector_db_config: dict = False, + incremental_loading: bool = True, ): total_run_info = {} @@ -221,6 +225,7 @@ async def run_cognify_blocking( pipeline_name="cognify_pipeline", graph_db_config=graph_db_config, vector_db_config=vector_db_config, + incremental_loading=incremental_loading, ): if run_info.dataset_id: total_run_info[run_info.dataset_id] = run_info @@ -236,6 +241,7 @@ async def run_cognify_as_background_process( datasets, graph_db_config: dict = None, vector_db_config: dict = False, + incremental_loading: bool = True, ): # Convert dataset to list if it's a string if isinstance(datasets, str): @@ -246,6 +252,7 @@ async def run_cognify_as_background_process( async def handle_rest_of_the_run(pipeline_list): # Execute all provided pipelines one by one to avoid database write conflicts + # TODO: Convert to async gather task instead of for loop when Queue mechanism for database is created for pipeline in pipeline_list: while True: try: @@ -270,6 +277,7 @@ async def run_cognify_as_background_process( pipeline_name="cognify_pipeline", graph_db_config=graph_db_config, vector_db_config=vector_db_config, + incremental_loading=incremental_loading, ) # Save dataset Pipeline run started info diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index ecfceec52..b63238966 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -16,7 +16,11 @@ from cognee.modules.graph.methods import get_formatted_graph_data from cognee.modules.users.get_user_manager import get_user_manager_context from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.users.authentication.default.default_jwt_strategy import DefaultJWTStrategy -from cognee.modules.pipelines.models.PipelineRunInfo import PipelineRunCompleted, PipelineRunInfo +from cognee.modules.pipelines.models.PipelineRunInfo import ( + PipelineRunCompleted, + PipelineRunInfo, + PipelineRunErrored, +) from cognee.modules.pipelines.queues.pipeline_run_info_queues import ( get_from_queue, initialize_queue, @@ -105,6 +109,9 @@ def get_cognify_router() -> APIRouter: datasets, user, run_in_background=payload.run_in_background ) + # If any cognify run errored return JSONResponse with proper error status code + if any(isinstance(v, PipelineRunErrored) for v in cognify_run.values()): + return JSONResponse(status_code=420, content=cognify_run) return cognify_run except Exception as error: return JSONResponse(status_code=409, content={"error": str(error)}) diff --git a/cognee/modules/data/models/Data.py b/cognee/modules/data/models/Data.py index c22cc338e..dc918c2ed 100644 --- a/cognee/modules/data/models/Data.py +++ b/cognee/modules/data/models/Data.py @@ -1,6 +1,7 @@ from datetime import datetime, timezone from uuid import uuid4 from sqlalchemy import UUID, Column, DateTime, String, JSON, Integer +from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import relationship from cognee.infrastructure.databases.relational import Base @@ -21,7 +22,11 @@ class Data(Base): tenant_id = Column(UUID, index=True, nullable=True) content_hash = Column(String) external_metadata = Column(JSON) - node_set = Column(JSON, nullable=True) # Store NodeSet as JSON list of strings + # Store NodeSet as JSON list of strings + node_set = Column(JSON, nullable=True) + # MutableDict allows SQLAlchemy to notice key-value pair changes, without it changing a value for a key + # wouldn't be noticed when commiting a database session + pipeline_status = Column(MutableDict.as_mutable(JSON)) token_count = Column(Integer) data_size = Column(Integer, nullable=True) # File size in bytes created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) diff --git a/cognee/modules/data/processing/document_types/PdfDocument.py b/cognee/modules/data/processing/document_types/PdfDocument.py index e92868c2e..dc90899eb 100644 --- a/cognee/modules/data/processing/document_types/PdfDocument.py +++ b/cognee/modules/data/processing/document_types/PdfDocument.py @@ -5,7 +5,6 @@ from cognee.modules.chunking.Chunker import Chunker from cognee.infrastructure.files.utils.open_data_file import open_data_file from .Document import Document -from .exceptions.exceptions import PyPdfInternalError logger = get_logger("PDFDocument") @@ -17,18 +16,12 @@ class PdfDocument(Document): async with open_data_file(self.raw_data_location, mode="rb") as stream: logger.info(f"Reading PDF: {self.raw_data_location}") - try: - file = PdfReader(stream, strict=False) - except Exception: - raise PyPdfInternalError() + file = PdfReader(stream, strict=False) async def get_text(): - try: - for page in file.pages: - page_text = page.extract_text() - yield page_text - except Exception: - raise PyPdfInternalError() + for page in file.pages: + page_text = page.extract_text() + yield page_text chunker = chunker_cls(self, get_text=get_text, max_chunk_size=max_chunk_size) diff --git a/cognee/modules/pipelines/exceptions/__init__.py b/cognee/modules/pipelines/exceptions/__init__.py new file mode 100644 index 000000000..f4e296be3 --- /dev/null +++ b/cognee/modules/pipelines/exceptions/__init__.py @@ -0,0 +1 @@ +from .exceptions import PipelineRunFailedError diff --git a/cognee/modules/pipelines/exceptions/exceptions.py b/cognee/modules/pipelines/exceptions/exceptions.py new file mode 100644 index 000000000..0a4863075 --- /dev/null +++ b/cognee/modules/pipelines/exceptions/exceptions.py @@ -0,0 +1,12 @@ +from cognee.exceptions import CogneeApiError +from fastapi import status + + +class PipelineRunFailedError(CogneeApiError): + def __init__( + self, + message: str = "Pipeline run failed.", + name: str = "PipelineRunFailedError", + status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY, + ): + super().__init__(message, name, status_code) diff --git a/cognee/modules/pipelines/models/DataItemStatus.py b/cognee/modules/pipelines/models/DataItemStatus.py new file mode 100644 index 000000000..c9be26255 --- /dev/null +++ b/cognee/modules/pipelines/models/DataItemStatus.py @@ -0,0 +1,5 @@ +import enum + + +class DataItemStatus(str, enum.Enum): + DATA_ITEM_PROCESSING_COMPLETED = "DATA_ITEM_PROCESSING_COMPLETED" diff --git a/cognee/modules/pipelines/models/PipelineRunInfo.py b/cognee/modules/pipelines/models/PipelineRunInfo.py index d910f4fc8..5f5a91c34 100644 --- a/cognee/modules/pipelines/models/PipelineRunInfo.py +++ b/cognee/modules/pipelines/models/PipelineRunInfo.py @@ -9,6 +9,7 @@ class PipelineRunInfo(BaseModel): dataset_id: UUID dataset_name: str payload: Optional[Any] = None + data_ingestion_info: Optional[list] = None model_config = { "arbitrary_types_allowed": True, @@ -30,6 +31,11 @@ class PipelineRunCompleted(PipelineRunInfo): pass +class PipelineRunAlreadyCompleted(PipelineRunInfo): + status: str = "PipelineRunAlreadyCompleted" + pass + + class PipelineRunErrored(PipelineRunInfo): status: str = "PipelineRunErrored" pass diff --git a/cognee/modules/pipelines/models/__init__.py b/cognee/modules/pipelines/models/__init__.py index f109d7196..ed81f1398 100644 --- a/cognee/modules/pipelines/models/__init__.py +++ b/cognee/modules/pipelines/models/__init__.py @@ -6,3 +6,4 @@ from .PipelineRunInfo import ( PipelineRunCompleted, PipelineRunErrored, ) +from .DataItemStatus import DataItemStatus diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index e58c15254..b08f8e3bb 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -52,6 +52,7 @@ async def cognee_pipeline( pipeline_name: str = "custom_pipeline", vector_db_config: dict = None, graph_db_config: dict = None, + incremental_loading: bool = True, ): # Note: These context variables allow different value assignment for databases in Cognee # per async task, thread, process and etc. @@ -106,6 +107,7 @@ async def cognee_pipeline( data=data, pipeline_name=pipeline_name, context={"dataset": dataset}, + incremental_loading=incremental_loading, ): yield run_info @@ -117,6 +119,7 @@ async def run_pipeline( data=None, pipeline_name: str = "custom_pipeline", context: dict = None, + incremental_loading=True, ): check_dataset_name(dataset.name) @@ -184,7 +187,9 @@ async def run_pipeline( if not isinstance(task, Task): raise ValueError(f"Task {task} is not an instance of Task") - pipeline_run = run_tasks(tasks, dataset_id, data, user, pipeline_name, context) + pipeline_run = run_tasks( + tasks, dataset_id, data, user, pipeline_name, context, incremental_loading + ) async for pipeline_run_info in pipeline_run: yield pipeline_run_info diff --git a/cognee/modules/pipelines/operations/run_tasks.py b/cognee/modules/pipelines/operations/run_tasks.py index 926d433fe..1f503f7d2 100644 --- a/cognee/modules/pipelines/operations/run_tasks.py +++ b/cognee/modules/pipelines/operations/run_tasks.py @@ -1,21 +1,31 @@ import os + +import asyncio from uuid import UUID from typing import Any from functools import wraps +from sqlalchemy import select +import cognee.modules.ingestion as ingestion from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed from cognee.modules.users.models import User +from cognee.modules.data.models import Data +from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.shared.logging_utils import get_logger from cognee.modules.users.methods import get_default_user from cognee.modules.pipelines.utils import generate_pipeline_id +from cognee.modules.pipelines.exceptions import PipelineRunFailedError +from cognee.tasks.ingestion import save_data_item_to_storage, resolve_data_directories from cognee.modules.pipelines.models.PipelineRunInfo import ( PipelineRunCompleted, PipelineRunErrored, PipelineRunStarted, PipelineRunYield, + PipelineRunAlreadyCompleted, ) +from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus from cognee.modules.pipelines.operations import ( log_pipeline_run_start, @@ -56,34 +66,116 @@ async def run_tasks( user: User = None, pipeline_name: str = "unknown_pipeline", context: dict = None, + incremental_loading: bool = True, ): - if not user: - user = await get_default_user() + async def _run_tasks_data_item_incremental( + data_item, + dataset, + tasks, + pipeline_name, + pipeline_id, + pipeline_run_id, + context, + user, + ): + db_engine = get_relational_engine() + # If incremental_loading of data is set to True don't process documents already processed by pipeline + # If data is being added to Cognee for the first time calculate the id of the data + if not isinstance(data_item, Data): + file_path = await save_data_item_to_storage(data_item) + # Ingest data and add metadata + async with open_data_file(file_path) as file: + classified_data = ingestion.classify(file) + # data_id is the hash of file contents + owner id to avoid duplicate data + data_id = ingestion.identify(classified_data, user) + else: + # If data was already processed by Cognee get data id + data_id = data_item.id - # Get Dataset object - db_engine = get_relational_engine() - async with db_engine.get_async_session() as session: - from cognee.modules.data.models import Dataset + # Check pipeline status, if Data already processed for pipeline before skip current processing + async with db_engine.get_async_session() as session: + data_point = ( + await session.execute(select(Data).filter(Data.id == data_id)) + ).scalar_one_or_none() + if data_point: + if ( + data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id)) + == DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED + ): + yield { + "run_info": PipelineRunAlreadyCompleted( + pipeline_run_id=pipeline_run_id, + dataset_id=dataset.id, + dataset_name=dataset.name, + ), + "data_id": data_id, + } + return - dataset = await session.get(Dataset, dataset_id) + try: + # Process data based on data_item and list of tasks + async for result in run_tasks_with_telemetry( + tasks=tasks, + data=[data_item], + user=user, + pipeline_name=pipeline_id, + context=context, + ): + yield PipelineRunYield( + pipeline_run_id=pipeline_run_id, + dataset_id=dataset.id, + dataset_name=dataset.name, + payload=result, + ) - pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name) + # Update pipeline status for Data element + async with db_engine.get_async_session() as session: + data_point = ( + await session.execute(select(Data).filter(Data.id == data_id)) + ).scalar_one_or_none() + data_point.pipeline_status[pipeline_name] = { + str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED + } + await session.merge(data_point) + await session.commit() - pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data) + yield { + "run_info": PipelineRunCompleted( + pipeline_run_id=pipeline_run_id, + dataset_id=dataset.id, + dataset_name=dataset.name, + ), + "data_id": data_id, + } - pipeline_run_id = pipeline_run.pipeline_run_id + except Exception as error: + # Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline + logger.error( + f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}." + ) + yield { + "run_info": PipelineRunErrored( + pipeline_run_id=pipeline_run_id, + payload=repr(error), + dataset_id=dataset.id, + dataset_name=dataset.name, + ), + "data_id": data_id, + } - yield PipelineRunStarted( - pipeline_run_id=pipeline_run_id, - dataset_id=dataset.id, - dataset_name=dataset.name, - payload=data, - ) - - try: + async def _run_tasks_data_item_regular( + data_item, + dataset, + tasks, + pipeline_id, + pipeline_run_id, + context, + user, + ): + # Process data based on data_item and list of tasks async for result in run_tasks_with_telemetry( tasks=tasks, - data=data, + data=[data_item], user=user, pipeline_name=pipeline_id, context=context, @@ -95,6 +187,112 @@ async def run_tasks( payload=result, ) + yield { + "run_info": PipelineRunCompleted( + pipeline_run_id=pipeline_run_id, + dataset_id=dataset.id, + dataset_name=dataset.name, + ) + } + + async def _run_tasks_data_item( + data_item, + dataset, + tasks, + pipeline_name, + pipeline_id, + pipeline_run_id, + context, + user, + incremental_loading, + ): + # Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped, + # PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues + result = None + if incremental_loading: + async for result in _run_tasks_data_item_incremental( + data_item=data_item, + dataset=dataset, + tasks=tasks, + pipeline_name=pipeline_name, + pipeline_id=pipeline_id, + pipeline_run_id=pipeline_run_id, + context=context, + user=user, + ): + pass + else: + async for result in _run_tasks_data_item_regular( + data_item=data_item, + dataset=dataset, + tasks=tasks, + pipeline_id=pipeline_id, + pipeline_run_id=pipeline_run_id, + context=context, + user=user, + ): + pass + + return result + + if not user: + user = await get_default_user() + + # Get Dataset object + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + from cognee.modules.data.models import Dataset + + dataset = await session.get(Dataset, dataset_id) + + pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name) + pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data) + pipeline_run_id = pipeline_run.pipeline_run_id + + yield PipelineRunStarted( + pipeline_run_id=pipeline_run_id, + dataset_id=dataset.id, + dataset_name=dataset.name, + payload=data, + ) + + try: + if not isinstance(data, list): + data = [data] + + if incremental_loading: + data = await resolve_data_directories(data) + + # Create async tasks per data item that will run the pipeline for the data item + data_item_tasks = [ + asyncio.create_task( + _run_tasks_data_item( + data_item, + dataset, + tasks, + pipeline_name, + pipeline_id, + pipeline_run_id, + context, + user, + incremental_loading, + ) + ) + for data_item in data + ] + results = await asyncio.gather(*data_item_tasks) + # Remove skipped data items from results + results = [result for result in results if result] + + # If any data item could not be processed propagate error + errored_results = [ + result for result in results if isinstance(result["run_info"], PipelineRunErrored) + ] + if errored_results: + raise PipelineRunFailedError( + message="Pipeline run failed. Data item could not be processed." + ) + await log_pipeline_run_complete( pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data ) @@ -103,6 +301,7 @@ async def run_tasks( pipeline_run_id=pipeline_run_id, dataset_id=dataset.id, dataset_name=dataset.name, + data_ingestion_info=results, ) graph_engine = await get_graph_engine() @@ -120,9 +319,14 @@ async def run_tasks( yield PipelineRunErrored( pipeline_run_id=pipeline_run_id, - payload=error, + payload=repr(error), dataset_id=dataset.id, dataset_name=dataset.name, + data_ingestion_info=locals().get( + "results" + ), # Returns results if they exist or returns None ) - raise error + # In case of error during incremental loading of data just let the user know the pipeline Errored, don't raise error + if not isinstance(error, PipelineRunFailedError): + raise error diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index 1d1870d98..216185495 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -8,7 +8,6 @@ from cognee.modules.data.models import Data from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.chunking.Chunker import Chunker -from cognee.modules.data.processing.document_types.exceptions.exceptions import PyPdfInternalError async def update_document_token_count(document_id: UUID, token_count: int) -> None: @@ -40,15 +39,14 @@ async def extract_chunks_from_documents( """ for document in documents: document_token_count = 0 - try: - async for document_chunk in document.read( - max_chunk_size=max_chunk_size, chunker_cls=chunker - ): - document_token_count += document_chunk.chunk_size - document_chunk.belongs_to_set = document.belongs_to_set - yield document_chunk - await update_document_token_count(document.id, document_token_count) - except PyPdfInternalError: - pass + async for document_chunk in document.read( + max_chunk_size=max_chunk_size, chunker_cls=chunker + ): + document_token_count += document_chunk.chunk_size + document_chunk.belongs_to_set = document.belongs_to_set + yield document_chunk + + await update_document_token_count(document.id, document_token_count) + # todo rita diff --git a/cognee/tasks/ingestion/ingest_data.py b/cognee/tasks/ingestion/ingest_data.py index 846c183d4..429e04e5d 100644 --- a/cognee/tasks/ingestion/ingest_data.py +++ b/cognee/tasks/ingestion/ingest_data.py @@ -5,12 +5,12 @@ from uuid import UUID from typing import Union, BinaryIO, Any, List, Optional import cognee.modules.ingestion as ingestion -from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.data.models import Data from cognee.modules.users.models import User from cognee.modules.users.methods import get_default_user from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets +from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.modules.data.methods import ( get_authorized_existing_datasets, get_dataset_data, @@ -134,6 +134,7 @@ async def ingest_data( node_set=json.dumps(node_set) if node_set else None, data_size=file_metadata["file_size"], tenant_id=user.tenant_id if user.tenant_id else None, + pipeline_status={}, token_count=-1, ) diff --git a/cognee/tasks/ingestion/resolve_data_directories.py b/cognee/tasks/ingestion/resolve_data_directories.py index dfabcea0b..0f2f2a85f 100644 --- a/cognee/tasks/ingestion/resolve_data_directories.py +++ b/cognee/tasks/ingestion/resolve_data_directories.py @@ -40,6 +40,9 @@ async def resolve_data_directories( if include_subdirectories: base_path = item if item.endswith("/") else item + "/" s3_keys = fs.glob(base_path + "**") + # If path is not directory attempt to add item directly + if not s3_keys: + s3_keys = fs.ls(item) else: s3_keys = fs.ls(item) # Filter out keys that represent directories using fs.isdir diff --git a/cognee/tasks/repo_processor/get_repo_file_dependencies.py b/cognee/tasks/repo_processor/get_repo_file_dependencies.py index 232850936..b0cdb4402 100644 --- a/cognee/tasks/repo_processor/get_repo_file_dependencies.py +++ b/cognee/tasks/repo_processor/get_repo_file_dependencies.py @@ -103,6 +103,9 @@ async def get_repo_file_dependencies( extraction of dependencies (default is False). (default False) """ + if isinstance(repo_path, list) and len(repo_path) == 1: + repo_path = repo_path[0] + if not os.path.exists(repo_path): raise FileNotFoundError(f"Repository path {repo_path} does not exist.") diff --git a/cognee/tests/test_deduplication.py b/cognee/tests/test_deduplication.py index c449719c7..bef813317 100644 --- a/cognee/tests/test_deduplication.py +++ b/cognee/tests/test_deduplication.py @@ -26,8 +26,8 @@ async def test_deduplication(): explanation_file_path2 = os.path.join( pathlib.Path(__file__).parent, "test_data/Natural_language_processing_copy.txt" ) - await cognee.add([explanation_file_path], dataset_name) - await cognee.add([explanation_file_path2], dataset_name2) + await cognee.add([explanation_file_path], dataset_name, incremental_loading=False) + await cognee.add([explanation_file_path2], dataset_name2, incremental_loading=False) result = await relational_engine.get_all_data_from_table("data") assert len(result) == 1, "More than one data entity was found." From 5b6e946c436c12e010f37f8f65ba795535dadf31 Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Fri, 1 Aug 2025 15:12:04 +0200 Subject: [PATCH 6/8] fix: Add async lock for dynamic vector table creation (#1175) ## Description Add async lock for dynamic table creation ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- cognee/tasks/storage/index_data_points.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 2813d9c54..452e7f2ac 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -1,3 +1,5 @@ +import asyncio + from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException @@ -6,6 +8,9 @@ from cognee.infrastructure.engine import DataPoint logger = get_logger("index_data_points") +# A single lock shared by all coroutines +vector_index_lock = asyncio.Lock() + async def index_data_points(data_points: list[DataPoint]): created_indexes = {} @@ -22,9 +27,11 @@ async def index_data_points(data_points: list[DataPoint]): index_name = f"{data_point_type.__name__}_{field_name}" - if index_name not in created_indexes: - await vector_engine.create_vector_index(data_point_type.__name__, field_name) - created_indexes[index_name] = True + # Add async lock to make sure two different coroutines won't create a table at the same time + async with vector_index_lock: + if index_name not in created_indexes: + await vector_engine.create_vector_index(data_point_type.__name__, field_name) + created_indexes[index_name] = True if index_name not in index_points: index_points[index_name] = [] From 9faa47fc5a2d380a8d0c4e148cef55a26017d4a7 Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Fri, 1 Aug 2025 16:37:53 +0200 Subject: [PATCH 7/8] feat: add default tokenizer in case hugging face is not available (#1177) ## Description Add default tokenizer for custom models not available on HuggingFace ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- .../vector/embeddings/LiteLLMEmbeddingEngine.py | 7 ++++++- .../infrastructure/llm/tokenizer/TikToken/adapter.py | 10 +++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index b51d397ed..54f319be3 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -177,7 +177,12 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): elif "mistral" in self.provider.lower(): tokenizer = MistralTokenizer(model=model, max_tokens=self.max_tokens) else: - tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens) + try: + tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens) + except Exception as e: + logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}") + logger.info("Switching to TikToken default tokenizer.") + tokenizer = TikTokenTokenizer(model=None, max_tokens=self.max_tokens) logger.debug(f"Tokenizer loaded for model: {self.model}") return tokenizer diff --git a/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py index 881ffaba7..8806112c3 100644 --- a/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py @@ -1,4 +1,4 @@ -from typing import List, Any +from typing import List, Any, Optional import tiktoken from ..tokenizer_interface import TokenizerInterface @@ -12,13 +12,17 @@ class TikTokenTokenizer(TokenizerInterface): def __init__( self, - model: str, + model: Optional[str] = None, max_tokens: int = 8191, ): self.model = model self.max_tokens = max_tokens # Initialize TikToken for GPT based on model - self.tokenizer = tiktoken.encoding_for_model(self.model) + if model: + self.tokenizer = tiktoken.encoding_for_model(self.model) + else: + # Use default if model not provided + self.tokenizer = tiktoken.get_encoding("cl100k_base") def extract_tokens(self, text: str) -> List[Any]: """ From fc7a91d99178c0b166e7d6c4e7a29e08040cceca Mon Sep 17 00:00:00 2001 From: EricXiao <7250816+EricXiao95@users.noreply.github.com> Date: Sat, 2 Aug 2025 22:30:08 +0800 Subject: [PATCH 8/8] feature: implement FEELING_LUCKY search type (#1178) ## Description This PR implements the 'FEELING_LUCKY' search type, which intelligently routes user queries to the most appropriate search retriever, addressing [#1162](https://github.com/topoteretes/cognee/issues/1162). - implement new search type FEELING_LUCKY - Add the select_search_type function to analyze queries and choose the proper search type - Integrate with an LLM for intelligent search type determination - Add logging for the search type selection process - Support fallback to RAG_COMPLETION when the LLM selection fails - Add tests for the new search type ## How it works When a user selects the 'FEELING_LUCKY' search type, the system first sends their natural language query to an LLM-based classifier. This classifier analyzes the query's intent (e.g., is it asking for a relationship, a summary, or a factual answer?) and selects the optimal SearchType, such as 'INSIGHTS' or 'GRAPH_COMPLETION'. The main search function then proceeds using this dynamically selected type. If the classification process fails, it gracefully falls back to the default 'RAG_COMPLETION' type. ## Testing Tests can be run with: ```bash python -m pytest cognee/tests/unit/modules/search/search_methods_test.py -k "feeling_lucky" -v ``` ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. Signed-off-by: EricXiao --- cognee/api/v1/search/search.py | 10 ++ .../prompts/search_type_selector_prompt.txt | 130 ++++++++++++++++++ cognee/modules/search/methods/search.py | 6 +- cognee/modules/search/operations/__init__.py | 1 + .../search/operations/select_search_type.py | 43 ++++++ cognee/modules/search/types/SearchType.py | 1 + .../modules/search/search_methods_test.py | 55 ++++++++ 7 files changed, 245 insertions(+), 1 deletion(-) create mode 100644 cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt create mode 100644 cognee/modules/search/operations/select_search_type.py diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index eb245f545..66ce48cc2 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -71,6 +71,12 @@ async def search( Best for: Advanced users, specific graph traversals, debugging. Returns: Raw graph query results. + **FEELING_LUCKY**: + Intelligently selects and runs the most appropriate search type. + Best for: General-purpose queries or when you're unsure which search type is best. + Returns: The results from the automatically selected search type. + + Args: query_text: Your question or search query in natural language. Examples: @@ -119,6 +125,9 @@ async def search( **CODE**: [List of structured code information with context] + **FEELING_LUCKY**: + [List of results in the format of the search type that is automatically selected] + @@ -130,6 +139,7 @@ async def search( - **CHUNKS**: Fastest, pure vector similarity search without LLM - **SUMMARIES**: Fast, returns pre-computed summaries - **CODE**: Medium speed, specialized for code understanding + - **FEELING_LUCKY**: Variable speed, uses LLM + search type selection intelligently - **top_k**: Start with 10, increase for comprehensive analysis (max 100) - **datasets**: Specify datasets to improve speed and relevance diff --git a/cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt b/cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt new file mode 100644 index 000000000..7ed2e72fc --- /dev/null +++ b/cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt @@ -0,0 +1,130 @@ +You are an expert query analyzer for a **GraphRAG system**. Your primary goal is to analyze a user's query and select the single most appropriate `SearchType` tool to answer it. + +Here are the available `SearchType` tools and their specific functions: + +- **`SUMMARIES`**: The `SUMMARIES` search type retrieves summarized information from the knowledge graph. + + **Best for:** + + - Getting concise overviews of topics + - Summarizing large amounts of information + - Quick understanding of complex subjects + +* **`INSIGHTS`**: The `INSIGHTS` search type discovers connections and relationships between entities in the knowledge graph. + + **Best for:** + + - Discovering how entities are connected + - Understanding relationships between concepts + - Exploring the structure of your knowledge graph + +* **`CHUNKS`**: The `CHUNKS` search type retrieves specific facts and information chunks from the knowledge graph. + + **Best for:** + + - Finding specific facts + - Getting direct answers to questions + - Retrieving precise information + +* **`RAG_COMPLETION`**: Use for direct factual questions that can likely be answered by retrieving a specific text passage from a document. It does not use the graph's relationship structure. + + **Best for:** + + - Getting detailed explanations or comprehensive answers + - Combining multiple pieces of information + - Getting a single, coherent answer that is generated from relevant text passages + +* **`GRAPH_COMPLETION`**: The `GRAPH_COMPLETION` search type leverages the graph structure to provide more contextually aware completions. + + **Best for:** + + - Complex queries requiring graph traversal + - Questions that benefit from understanding relationships + - Queries where context from connected entities matters + +* **`GRAPH_SUMMARY_COMPLETION`**: The `GRAPH_SUMMARY_COMPLETION` search type combines graph traversal with summarization to provide concise but comprehensive answers. + + **Best for:** + + - Getting summarized information that requires understanding relationships + - Complex topics that need concise explanations + - Queries that benefit from both graph structure and summarization + +* **`GRAPH_COMPLETION_COT`**: The `GRAPH_COMPLETION_COT` search type combines graph traversal with chain of thought to provide answers to complex multi hop questions. + + **Best for:** + + - Multi-hop questions that require following several linked concepts or entities + - Tracing relational paths in a knowledge graph while also getting clear step-by-step reasoning + - Summarizing completx linkages into a concise, human-readable answer once all hops have been explored + +* **`GRAPH_COMPLETION_CONTEXT_EXTENSION`**: The `GRAPH_COMPLETION_CONTEXT_EXTENSION` search type combines graph traversal with multi-round context extension. + + **Best for:** + + - Iterative, multi-hop queries where intermediate facts aren’t all present upfront + - Complex linkages that benefit from multi-round “search → extend context → reason” loops to uncover deep connections. + - Sparse or evolving graphs that require on-the-fly expansion—issuing follow-up searches to discover missing nodes or properties + +* **`CODE`**: The `CODE` search type is specialized for retrieving and understanding code-related information from the knowledge graph. + + **Best for:** + + - Code-related queries + - Programming examples and patterns + - Technical documentation searches + +* **`CYPHER`**: The `CYPHER` search type allows user to execute raw Cypher queries directly against your graph database. + + **Best for:** + + - Executing precise graph queries with full control + - Leveraging Cypher features and functions + - Getting raw data directly from the graph database + +* **`NATURAL_LANGUAGE`**: The `NATURAL_LANGUAGE` search type translates a natural language question into a precise Cypher query that is executed directly against the graph database. + + **Best for:** + + - Getting precise, structured answers from the graph using natural language. + - Performing advanced graph operations like filtering and aggregating data using natural language. + - Asking precise, database-style questions without needing to write Cypher. + +**Examples:** + +Query: "Summarize the key findings from these research papers" +Response: `SUMMARIES` + +Query: "What is the relationship between the methodologies used in these papers?" +Response: `INSIGHTS` + +Query: "When was Einstein born?" +Response: `CHUNKS` + +Query: "Explain Einstein's contributions to physics" +Response: `RAG_COMPLETION` + +Query: "Provide a comprehensive analysis of how these papers contribute to the field" +Response: `GRAPH_COMPLETION` + +Query: "Explain the overall architecture of this codebase" +Response: `GRAPH_SUMMARY_COMPLETION` + +Query: "Who was the father of the person who invented the lightbulb" +Response: `GRAPH_COMPLETION_COT` + +Query: "What county was XY born in" +Response: `GRAPH_COMPLETION_CONTEXT_EXTENSION` + +Query: "How to implement authentication in this codebase" +Response: `CODE` + +Query: "MATCH (n) RETURN labels(n) as types, n.name as name LIMIT 10" +Response: `CYPHER` + +Query: "Get all nodes connected to John" +Response: `NATURAL_LANGUAGE` + + + +Your response MUST be a single word, consisting of only the chosen `SearchType` name. Do not provide any explanation. \ No newline at end of file diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 1eff23c4a..365920019 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -27,7 +27,7 @@ from cognee.modules.users.models import User from cognee.modules.data.models import Dataset from cognee.shared.utils import send_telemetry from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets -from cognee.modules.search.operations import log_query, log_result +from cognee.modules.search.operations import log_query, log_result, select_search_type async def search( @@ -129,6 +129,10 @@ async def specific_search( SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion, } + # If the query type is FEELING_LUCKY, select the search type intelligently + if query_type is SearchType.FEELING_LUCKY: + query_type = await select_search_type(query) + search_task = search_tasks.get(query_type) if search_task is None: diff --git a/cognee/modules/search/operations/__init__.py b/cognee/modules/search/operations/__init__.py index 41d2a4e4a..b2f9567fb 100644 --- a/cognee/modules/search/operations/__init__.py +++ b/cognee/modules/search/operations/__init__.py @@ -1,3 +1,4 @@ from .log_query import log_query from .log_result import log_result from .get_history import get_history +from .select_search_type import select_search_type diff --git a/cognee/modules/search/operations/select_search_type.py b/cognee/modules/search/operations/select_search_type.py new file mode 100644 index 000000000..d08074d0d --- /dev/null +++ b/cognee/modules/search/operations/select_search_type.py @@ -0,0 +1,43 @@ +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm.prompts import read_query_prompt +from cognee.modules.search.types import SearchType +from cognee.shared.logging_utils import get_logger + +logger = get_logger("SearchTypeSelector") + + +async def select_search_type( + query: str, + system_prompt_path: str = "search_type_selector_prompt.txt", +) -> SearchType: + """ + Analyzes the query and Selects the best search type. + + Args: + query: The query to analyze. + system_prompt_path: The path to the system prompt. + + Returns: + The best search type given by the LLM. + """ + default_search_type = SearchType.RAG_COMPLETION + system_prompt = read_query_prompt(system_prompt_path) + llm_client = get_llm_client() + + try: + response = await llm_client.acreate_structured_output( + text_input=query, + system_prompt=system_prompt, + response_model=str, + ) + + if response.upper() in SearchType.__members__: + logger.info(f"Selected lucky search type: {response.upper()}") + return SearchType(response.upper()) + + # If the response is not a valid search type, return the default search type + logger.info(f"LLM gives an invalid search type: {response.upper()}") + return default_search_type + except Exception as e: + logger.error(f"Failed to select search type intelligently from LLM: {str(e)}") + return default_search_type diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py index 1c672f0f0..8248117e7 100644 --- a/cognee/modules/search/types/SearchType.py +++ b/cognee/modules/search/types/SearchType.py @@ -13,3 +13,4 @@ class SearchType(Enum): NATURAL_LANGUAGE = "NATURAL_LANGUAGE" GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT" GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION" + FEELING_LUCKY = "FEELING_LUCKY" diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py index bec362144..14712f6d2 100644 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -155,6 +155,61 @@ async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever assert results[0]["content"] == "Chunk result" +@pytest.mark.asyncio +@pytest.mark.parametrize( + "selected_type, retriever_name, expected_content, top_k", + [ + (SearchType.RAG_COMPLETION, "CompletionRetriever", "RAG result from lucky search", 10), + (SearchType.CHUNKS, "ChunksRetriever", "Chunk result from lucky search", 5), + (SearchType.SUMMARIES, "SummariesRetriever", "Summary from lucky search", 15), + (SearchType.INSIGHTS, "InsightsRetriever", "Insight result from lucky search", 20), + ], +) +@patch.object(search_module, "select_search_type") +@patch.object(search_module, "send_telemetry") +async def test_specific_search_feeling_lucky( + mock_send_telemetry, + mock_select_search_type, + selected_type, + retriever_name, + expected_content, + top_k, + mock_user, +): + with patch.object(search_module, retriever_name) as mock_retriever_class: + # Setup + query = f"test query for {retriever_name}" + query_type = SearchType.FEELING_LUCKY + + # Mock the intelligent search type selection + mock_select_search_type.return_value = selected_type + + # Mock the retriever + mock_retriever_instance = MagicMock() + mock_retriever_instance.get_completion = AsyncMock( + return_value=[{"content": expected_content}] + ) + mock_retriever_class.return_value = mock_retriever_instance + + # Execute + results = await specific_search(query_type, query, mock_user, top_k=top_k) + + # Verify + mock_select_search_type.assert_called_once_with(query) + + if retriever_name == "CompletionRetriever": + mock_retriever_class.assert_called_once_with( + system_prompt_path="answer_simple_question.txt", top_k=top_k + ) + else: + mock_retriever_class.assert_called_once_with(top_k=top_k) + + mock_retriever_instance.get_completion.assert_called_once_with(query) + mock_send_telemetry.assert_called() + assert len(results) == 1 + assert results[0]["content"] == expected_content + + @pytest.mark.asyncio async def test_specific_search_invalid_type(mock_user): # Setup