diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py new file mode 100644 index 000000000..0d5413ff8 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py @@ -0,0 +1,215 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_extension_context_simple" + ) + data_directory_path = str( + base_dir / ".data_storage/test_graph_completion_extension_context_simple" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_extension_context_complex" + ) + data_directory_path = str( + base_dir / ".data_storage/test_graph_completion_extension_context_complex" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_extension_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_extension_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (simple).""" + retriever = GraphCompletionContextExtensionRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_extension_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (complex).""" + retriever = GraphCompletionContextExtensionRetriever(top_k=20) + + context = await resolve_edges_to_text( + await retriever.get_context("Who works at Figma and drives Tesla?") + ) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_get_graph_completion_extension_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionContextExtensionRetriever handles empty graph correctly.""" + retriever = GraphCompletionContextExtensionRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py new file mode 100644 index 000000000..f937df70d --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py @@ -0,0 +1,207 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.tasks.storage import add_data_points +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_cot_context_simple" + ) + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_cot_context_complex" + ) + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_cot_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_cot_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionCotRetriever can retrieve context (simple).""" + retriever = GraphCompletionCotRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_cot_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionCotRetriever can retrieve context (complex).""" + retriever = GraphCompletionCotRetriever(top_k=20) + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_get_graph_completion_cot_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionCotRetriever handles empty graph correctly.""" + retriever = GraphCompletionCotRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + )