diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index 65394f1ec..0f14683f9 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -41,6 +41,7 @@ async def add( extraction_rules: Optional[Dict[str, Any]] = None, tavily_config: Optional[BaseModel] = None, soup_crawler_config: Optional[BaseModel] = None, + data_per_batch: Optional[int] = 20, ): """ Add data to Cognee for knowledge graph processing. @@ -235,6 +236,7 @@ async def add( vector_db_config=vector_db_config, graph_db_config=graph_db_config, incremental_loading=incremental_loading, + data_per_batch=data_per_batch, ): pipeline_run_info = run_info diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index c3045f00a..1eb266765 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -51,6 +51,7 @@ async def cognify( incremental_loading: bool = True, custom_prompt: Optional[str] = None, temporal_cognify: bool = False, + data_per_batch: int = 20, ): """ Transform ingested data into a structured knowledge graph. @@ -228,6 +229,7 @@ async def cognify( graph_db_config=graph_db_config, incremental_loading=incremental_loading, pipeline_name="cognify_pipeline", + data_per_batch=data_per_batch, ) diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 0a9e76e96..9f158e9d0 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -1,6 +1,7 @@ from uuid import UUID from typing import Union, Optional, List, Type +from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.users.models import User from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult @@ -8,6 +9,9 @@ from cognee.modules.users.methods import get_default_user from cognee.modules.search.methods import search as search_function from cognee.modules.data.methods import get_authorized_existing_datasets from cognee.modules.data.exceptions import DatasetNotFoundError +from cognee.shared.logging_utils import get_logger + +logger = get_logger() async def search( @@ -175,6 +179,13 @@ async def search( if not datasets: raise DatasetNotFoundError(message="No datasets found.") + graph_engine = await get_graph_engine() + is_empty = await graph_engine.is_empty() + + if is_empty: + logger.warning("Search attempt on an empty knowledge graph") + return [] + filtered_search_results = await search_function( query_text=query_text, query_type=query_type, diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 15f5e3df3..29ecc22d1 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -39,6 +39,11 @@ class GraphDBInterface(ABC): - get_connections """ + @abstractmethod + async def is_empty(self) -> bool: + logger.warning("is_empty() is not implemented") + return True + @abstractmethod async def query(self, query: str, params: dict) -> List[Any]: """ diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 9b154fd47..d8ffe4fdd 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -197,6 +197,15 @@ class KuzuAdapter(GraphDBInterface): except FileNotFoundError: logger.warning(f"Kuzu S3 storage file not found: {self.db_path}") + async def is_empty(self) -> bool: + query = """ + MATCH (n) + RETURN true + LIMIT 1; + """ + query_result = await self.query(query) + return len(query_result) == 0 + async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]: """ Execute a Kuzu query asynchronously with automatic reconnection. diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index e1097451b..82c4311f2 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -86,6 +86,15 @@ class Neo4jAdapter(GraphDBInterface): async with self.driver.session(database=self.graph_database_name) as session: yield session + async def is_empty(self) -> bool: + query = """ + RETURN EXISTS { + MATCH (n) + } AS node_exists; + """ + query_result = await self.query(query) + return not query_result[0]["node_exists"] + @deadlock_retry() async def query( self, diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index eec5874ad..40b52ed2d 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -36,6 +36,7 @@ async def run_pipeline( graph_db_config: Optional[dict] = None, incremental_loading: bool = False, context: Optional[Dict] = None, + data_per_batch: int = 20, ): validate_pipeline_tasks(tasks) await setup_and_check_environment(vector_db_config, graph_db_config) @@ -51,6 +52,7 @@ async def run_pipeline( pipeline_name=pipeline_name, incremental_loading=incremental_loading, context=context, + data_per_batch=data_per_batch, ): yield run_info @@ -63,6 +65,7 @@ async def run_pipeline_per_dataset( pipeline_name: str = "custom_pipeline", incremental_loading=False, context: Optional[Dict] = None, + data_per_batch: int = 20, ): # Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -78,7 +81,7 @@ async def run_pipeline_per_dataset( return pipeline_run = run_tasks( - tasks, dataset, data, user, pipeline_name, context, incremental_loading + tasks, dataset, data, user, pipeline_name, context, incremental_loading, data_per_batch ) async for pipeline_run_info in pipeline_run: diff --git a/cognee/modules/pipelines/operations/run_tasks.py b/cognee/modules/pipelines/operations/run_tasks.py index 9a9cf3dcc..5a757aa5a 100644 --- a/cognee/modules/pipelines/operations/run_tasks.py +++ b/cognee/modules/pipelines/operations/run_tasks.py @@ -59,6 +59,7 @@ async def run_tasks( pipeline_name: str = "unknown_pipeline", context: Optional[Dict] = None, incremental_loading: bool = False, + data_per_batch: int = 20, ): if not user: user = await get_default_user() @@ -81,29 +82,34 @@ async def run_tasks( if incremental_loading: data = await resolve_data_directories(data) - # Create async tasks per data item that will run the pipeline for the data item - data_item_tasks = [ - asyncio.create_task( - run_tasks_data_item( - data_item, - dataset, - tasks, - pipeline_name, - pipeline_id, - pipeline_run_id, - { - **(context or {}), - "user": user, - "data": data_item, - "dataset": dataset, - }, - user, - incremental_loading, + # Create and gather batches of async tasks of data items that will run the pipeline for the data item + results = [] + for start in range(0, len(data), data_per_batch): + data_batch = data[start : start + data_per_batch] + + data_item_tasks = [ + asyncio.create_task( + run_tasks_data_item( + data_item, + dataset, + tasks, + pipeline_name, + pipeline_id, + pipeline_run_id, + { + **(context or {}), + "user": user, + "data": data_item, + "dataset": dataset, + }, + user, + incremental_loading, + ) ) - ) - for data_item in data - ] - results = await asyncio.gather(*data_item_tasks) + for data_item in data_batch + ] + + results.extend(await asyncio.gather(*data_item_tasks)) # Remove skipped data items from results results = [result for result in results if result] diff --git a/cognee/tests/test_kuzu.py b/cognee/tests/test_kuzu.py index 8749e42d0..fe9da6dcb 100644 --- a/cognee/tests/test_kuzu.py +++ b/cognee/tests/test_kuzu.py @@ -47,10 +47,26 @@ async def main(): pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt" ) + from cognee.infrastructure.databases.graph import get_graph_engine + + graph_engine = await get_graph_engine() + + is_empty = await graph_engine.is_empty() + + assert is_empty, "Kuzu graph database is not empty" + await cognee.add([explanation_file_path_quantum], dataset_name) + is_empty = await graph_engine.is_empty() + + assert is_empty, "Kuzu graph database should be empty before cognify" + await cognee.cognify([dataset_name]) + is_empty = await graph_engine.is_empty() + + assert not is_empty, "Kuzu graph database should not be empty" + from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() @@ -114,11 +130,10 @@ async def main(): assert not os.path.isdir(data_root_directory), "Local data files are not deleted" await cognee.prune.prune_system(metadata=True) - from cognee.infrastructure.databases.graph import get_graph_engine - graph_engine = await get_graph_engine() - nodes, edges = await graph_engine.get_graph_data() - assert len(nodes) == 0 and len(edges) == 0, "Kuzu graph database is not empty" + is_empty = await graph_engine.is_empty() + + assert is_empty, "Kuzu graph database is not empty" finally: # Ensure cleanup even if tests fail diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index c74b4ab65..925614e67 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -35,6 +35,14 @@ async def main(): explanation_file_path_nlp = os.path.join( pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt" ) + from cognee.infrastructure.databases.graph import get_graph_engine + + graph_engine = await get_graph_engine() + + is_empty = await graph_engine.is_empty() + + assert is_empty, "Graph has to be empty" + await cognee.add([explanation_file_path_nlp], dataset_name) explanation_file_path_quantum = os.path.join( @@ -42,9 +50,16 @@ async def main(): ) await cognee.add([explanation_file_path_quantum], dataset_name) + is_empty = await graph_engine.is_empty() + + assert is_empty, "Graph has to be empty before cognify" await cognee.cognify([dataset_name]) + is_empty = await graph_engine.is_empty() + + assert not is_empty, "Graph shouldn't be empty" + from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() @@ -117,11 +132,8 @@ async def main(): assert not os.path.isdir(data_root_directory), "Local data files are not deleted" await cognee.prune.prune_system(metadata=True) - from cognee.infrastructure.databases.graph import get_graph_engine - - graph_engine = await get_graph_engine() - nodes, edges = await graph_engine.get_graph_data() - assert len(nodes) == 0 and len(edges) == 0, "Neo4j graph database is not empty" + is_empty = await graph_engine.is_empty() + assert is_empty, "Neo4j graph database is not empty" if __name__ == "__main__": diff --git a/cognee/tests/unit/api/test_search.py b/cognee/tests/unit/api/test_search.py new file mode 100644 index 000000000..54a4cc35f --- /dev/null +++ b/cognee/tests/unit/api/test_search.py @@ -0,0 +1,21 @@ +import pytest +import cognee + + +@pytest.mark.asyncio +async def test_empty_search_raises_SearchOnEmptyGraphError_on_empty_graph(): + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await cognee.add("Sample input") + result = await cognee.search("Sample query") + assert result == [] + + +@pytest.mark.asyncio +async def test_empty_search_doesnt_raise_SearchOnEmptyGraphError(): + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await cognee.add("Sample input") + await cognee.cognify() + result = await cognee.search("Sample query") + assert result != []