Merge remote-tracking branch 'origin/dev' into feature/cog-3014-refactor-delete-feature
This commit is contained in:
commit
b2c632cc8f
11 changed files with 127 additions and 32 deletions
|
|
@ -41,6 +41,7 @@ async def add(
|
|||
extraction_rules: Optional[Dict[str, Any]] = None,
|
||||
tavily_config: Optional[BaseModel] = None,
|
||||
soup_crawler_config: Optional[BaseModel] = None,
|
||||
data_per_batch: Optional[int] = 20,
|
||||
):
|
||||
"""
|
||||
Add data to Cognee for knowledge graph processing.
|
||||
|
|
@ -235,6 +236,7 @@ async def add(
|
|||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
data_per_batch=data_per_batch,
|
||||
):
|
||||
pipeline_run_info = run_info
|
||||
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ async def cognify(
|
|||
incremental_loading: bool = True,
|
||||
custom_prompt: Optional[str] = None,
|
||||
temporal_cognify: bool = False,
|
||||
data_per_batch: int = 20,
|
||||
):
|
||||
"""
|
||||
Transform ingested data into a structured knowledge graph.
|
||||
|
|
@ -228,6 +229,7 @@ async def cognify(
|
|||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
pipeline_name="cognify_pipeline",
|
||||
data_per_batch=data_per_batch,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from uuid import UUID
|
||||
from typing import Union, Optional, List, Type
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
|
||||
|
|
@ -8,6 +9,9 @@ from cognee.modules.users.methods import get_default_user
|
|||
from cognee.modules.search.methods import search as search_function
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
from cognee.modules.data.exceptions import DatasetNotFoundError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def search(
|
||||
|
|
@ -175,6 +179,13 @@ async def search(
|
|||
if not datasets:
|
||||
raise DatasetNotFoundError(message="No datasets found.")
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
if is_empty:
|
||||
logger.warning("Search attempt on an empty knowledge graph")
|
||||
return []
|
||||
|
||||
filtered_search_results = await search_function(
|
||||
query_text=query_text,
|
||||
query_type=query_type,
|
||||
|
|
|
|||
|
|
@ -39,6 +39,11 @@ class GraphDBInterface(ABC):
|
|||
- get_connections
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def is_empty(self) -> bool:
|
||||
logger.warning("is_empty() is not implemented")
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, query: str, params: dict) -> List[Any]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -197,6 +197,15 @@ class KuzuAdapter(GraphDBInterface):
|
|||
except FileNotFoundError:
|
||||
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
||||
|
||||
async def is_empty(self) -> bool:
|
||||
query = """
|
||||
MATCH (n)
|
||||
RETURN true
|
||||
LIMIT 1;
|
||||
"""
|
||||
query_result = await self.query(query)
|
||||
return len(query_result) == 0
|
||||
|
||||
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
||||
"""
|
||||
Execute a Kuzu query asynchronously with automatic reconnection.
|
||||
|
|
|
|||
|
|
@ -86,6 +86,15 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async with self.driver.session(database=self.graph_database_name) as session:
|
||||
yield session
|
||||
|
||||
async def is_empty(self) -> bool:
|
||||
query = """
|
||||
RETURN EXISTS {
|
||||
MATCH (n)
|
||||
} AS node_exists;
|
||||
"""
|
||||
query_result = await self.query(query)
|
||||
return not query_result[0]["node_exists"]
|
||||
|
||||
@deadlock_retry()
|
||||
async def query(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ async def run_pipeline(
|
|||
graph_db_config: Optional[dict] = None,
|
||||
incremental_loading: bool = False,
|
||||
context: Optional[Dict] = None,
|
||||
data_per_batch: int = 20,
|
||||
):
|
||||
validate_pipeline_tasks(tasks)
|
||||
await setup_and_check_environment(vector_db_config, graph_db_config)
|
||||
|
|
@ -51,6 +52,7 @@ async def run_pipeline(
|
|||
pipeline_name=pipeline_name,
|
||||
incremental_loading=incremental_loading,
|
||||
context=context,
|
||||
data_per_batch=data_per_batch,
|
||||
):
|
||||
yield run_info
|
||||
|
||||
|
|
@ -63,6 +65,7 @@ async def run_pipeline_per_dataset(
|
|||
pipeline_name: str = "custom_pipeline",
|
||||
incremental_loading=False,
|
||||
context: Optional[Dict] = None,
|
||||
data_per_batch: int = 20,
|
||||
):
|
||||
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
|
@ -78,7 +81,7 @@ async def run_pipeline_per_dataset(
|
|||
return
|
||||
|
||||
pipeline_run = run_tasks(
|
||||
tasks, dataset, data, user, pipeline_name, context, incremental_loading
|
||||
tasks, dataset, data, user, pipeline_name, context, incremental_loading, data_per_batch
|
||||
)
|
||||
|
||||
async for pipeline_run_info in pipeline_run:
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ async def run_tasks(
|
|||
pipeline_name: str = "unknown_pipeline",
|
||||
context: Optional[Dict] = None,
|
||||
incremental_loading: bool = False,
|
||||
data_per_batch: int = 20,
|
||||
):
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
|
@ -81,29 +82,34 @@ async def run_tasks(
|
|||
if incremental_loading:
|
||||
data = await resolve_data_directories(data)
|
||||
|
||||
# Create async tasks per data item that will run the pipeline for the data item
|
||||
data_item_tasks = [
|
||||
asyncio.create_task(
|
||||
run_tasks_data_item(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_name,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
{
|
||||
**(context or {}),
|
||||
"user": user,
|
||||
"data": data_item,
|
||||
"dataset": dataset,
|
||||
},
|
||||
user,
|
||||
incremental_loading,
|
||||
# Create and gather batches of async tasks of data items that will run the pipeline for the data item
|
||||
results = []
|
||||
for start in range(0, len(data), data_per_batch):
|
||||
data_batch = data[start : start + data_per_batch]
|
||||
|
||||
data_item_tasks = [
|
||||
asyncio.create_task(
|
||||
run_tasks_data_item(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_name,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
{
|
||||
**(context or {}),
|
||||
"user": user,
|
||||
"data": data_item,
|
||||
"dataset": dataset,
|
||||
},
|
||||
user,
|
||||
incremental_loading,
|
||||
)
|
||||
)
|
||||
)
|
||||
for data_item in data
|
||||
]
|
||||
results = await asyncio.gather(*data_item_tasks)
|
||||
for data_item in data_batch
|
||||
]
|
||||
|
||||
results.extend(await asyncio.gather(*data_item_tasks))
|
||||
|
||||
# Remove skipped data items from results
|
||||
results = [result for result in results if result]
|
||||
|
|
|
|||
|
|
@ -47,10 +47,26 @@ async def main():
|
|||
pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
assert is_empty, "Kuzu graph database is not empty"
|
||||
|
||||
await cognee.add([explanation_file_path_quantum], dataset_name)
|
||||
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
assert is_empty, "Kuzu graph database should be empty before cognify"
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
assert not is_empty, "Kuzu graph database should not be empty"
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
|
@ -114,11 +130,10 @@ async def main():
|
|||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 0 and len(edges) == 0, "Kuzu graph database is not empty"
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
assert is_empty, "Kuzu graph database is not empty"
|
||||
|
||||
finally:
|
||||
# Ensure cleanup even if tests fail
|
||||
|
|
|
|||
|
|
@ -35,6 +35,14 @@ async def main():
|
|||
explanation_file_path_nlp = os.path.join(
|
||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
|
||||
)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
assert is_empty, "Graph has to be empty"
|
||||
|
||||
await cognee.add([explanation_file_path_nlp], dataset_name)
|
||||
|
||||
explanation_file_path_quantum = os.path.join(
|
||||
|
|
@ -42,9 +50,16 @@ async def main():
|
|||
)
|
||||
|
||||
await cognee.add([explanation_file_path_quantum], dataset_name)
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
assert is_empty, "Graph has to be empty before cognify"
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
assert not is_empty, "Graph shouldn't be empty"
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
|
@ -117,11 +132,8 @@ async def main():
|
|||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 0 and len(edges) == 0, "Neo4j graph database is not empty"
|
||||
is_empty = await graph_engine.is_empty()
|
||||
assert is_empty, "Neo4j graph database is not empty"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
21
cognee/tests/unit/api/test_search.py
Normal file
21
cognee/tests/unit/api/test_search.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
import pytest
|
||||
import cognee
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_search_raises_SearchOnEmptyGraphError_on_empty_graph():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await cognee.add("Sample input")
|
||||
result = await cognee.search("Sample query")
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_search_doesnt_raise_SearchOnEmptyGraphError():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await cognee.add("Sample input")
|
||||
await cognee.cognify()
|
||||
result = await cognee.search("Sample query")
|
||||
assert result != []
|
||||
Loading…
Add table
Reference in a new issue