chore: removes integration tests that pretended to be unit tests
This commit is contained in:
parent
49f7c5188c
commit
7961e96710
7 changed files with 0 additions and 1339 deletions
|
|
@ -1,201 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from typing import List
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
|
||||
|
||||
class DocumentChunkWithEntities(DataPoint):
|
||||
text: str
|
||||
chunk_size: int
|
||||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Entity] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class TestChunksRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = ChunksRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection(
|
||||
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
|
||||
)
|
||||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert len(context) == 0, "Found chunks when none should exist"
|
||||
|
|
@ -1,177 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphCompletionWithContextExtensionRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_graph_completion_extension_context_simple",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_graph_completion_extension_context_simple",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_graph_completion_extension_context_complex",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_graph_completion_extension_context_complex",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
|
||||
|
||||
context = await resolve_edges_to_text(
|
||||
await retriever.get_context("Who works at Figma and drives Tesla?")
|
||||
)
|
||||
|
||||
print(context)
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_graph_completion_extension_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_graph_completion_extension_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
await setup()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
|
||||
|
||||
class TestGraphCompletionCoTRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_graph_completion_cot_context_complex",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionCotRetriever(top_k=20)
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||
|
||||
print(context)
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_graph_completion_cot_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_graph_completion_cot_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
await setup()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
|
@ -1,223 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
||||
|
||||
class TestGraphCompletionRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma", description="Figma is a company")
|
||||
company2 = Company(name="Canva", description="Canvas is a company")
|
||||
person1 = Person(
|
||||
name="Steve Rodger",
|
||||
description="This is description about Steve Rodger",
|
||||
works_for=company1,
|
||||
)
|
||||
person2 = Person(
|
||||
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
|
||||
)
|
||||
person3 = Person(
|
||||
name="Jason Statham",
|
||||
description="This is description about Jason Statham",
|
||||
works_for=company1,
|
||||
)
|
||||
person4 = Person(
|
||||
name="Mike Broski",
|
||||
description="This is description about Mike Broski",
|
||||
works_for=company2,
|
||||
)
|
||||
person5 = Person(
|
||||
name="Christina Mayer",
|
||||
description="This is description about Christina Mayer",
|
||||
works_for=company2,
|
||||
)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
# Ensure the top-level sections are present
|
||||
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
|
||||
assert "Connections:" in context, "Missing 'Connections:' section in context"
|
||||
|
||||
# --- Nodes headers ---
|
||||
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
|
||||
assert "Node: Figma" in context, "Missing node header for Figma"
|
||||
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
|
||||
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
|
||||
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
|
||||
assert "Node: Canva" in context, "Missing node header for Canva"
|
||||
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
|
||||
|
||||
# --- Node contents ---
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Steve Rodger altered"
|
||||
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
|
||||
"Description block for Figma altered"
|
||||
)
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Ike Loma altered"
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Jason Statham altered"
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Mike Broski altered"
|
||||
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
|
||||
"Description block for Canva altered"
|
||||
)
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Christina Mayer altered"
|
||||
|
||||
# --- Connections ---
|
||||
assert "Steve Rodger --[works_for]--> Figma" in context, (
|
||||
"Connection Steve Rodger→Figma missing or changed"
|
||||
)
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, (
|
||||
"Connection Ike Loma→Figma missing or changed"
|
||||
)
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, (
|
||||
"Connection Jason Statham→Figma missing or changed"
|
||||
)
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, (
|
||||
"Connection Mike Broski→Canva missing or changed"
|
||||
)
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, (
|
||||
"Connection Christina Mayer→Canva missing or changed"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionRetriever(top_k=20)
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||
|
||||
print(context)
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_graph_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_graph_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
await setup()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
|
@ -1,205 +0,0 @@
|
|||
import os
|
||||
from typing import List
|
||||
import pytest
|
||||
import pathlib
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
|
||||
|
||||
class DocumentChunkWithEntities(DataPoint):
|
||||
text: str
|
||||
chunk_size: int
|
||||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Entity] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class TestRAGCompletionRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert context == "Mike Broski", "Failed to get Mike Broski"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
# TODO: top_k doesn't affect the output, it should be fixed.
|
||||
retriever = CompletionRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rag_completion_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_rag_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_rag_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection(
|
||||
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
|
||||
)
|
||||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == "", "Returned context should be empty on an empty graph"
|
||||
|
|
@ -1,204 +0,0 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import cognee
|
||||
import pathlib
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor
|
||||
from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
|
||||
|
||||
class TestAnswer(BaseModel):
|
||||
answer: str
|
||||
explanation: str
|
||||
|
||||
|
||||
def _assert_string_answer(answer: list[str]):
|
||||
assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings"
|
||||
assert all(item.strip() for item in answer), "Items should not be empty"
|
||||
|
||||
|
||||
def _assert_structured_answer(answer: list[TestAnswer]):
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer"
|
||||
assert all(x.answer.strip() for x in answer), "Answer text should not be empty"
|
||||
assert all(x.explanation.strip() for x in answer), "Explanation should not be empty"
|
||||
|
||||
|
||||
async def _test_get_structured_graph_completion_cot():
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("Who works at Figma?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Who works at Figma?", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
async def _test_get_structured_graph_completion():
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("Who works at Figma?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Who works at Figma?", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
async def _test_get_structured_graph_completion_temporal():
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("When did Steve start working at Figma?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"When did Steve start working at Figma??", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
async def _test_get_structured_graph_completion_rag():
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("Where does Steve work?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Where does Steve work?", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
async def _test_get_structured_graph_completion_context_extension():
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("Who works at Figma?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Who works at Figma?", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
async def _test_get_structured_entity_completion():
|
||||
retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider())
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("Who is Albert Einstein?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Who is Albert Einstein?", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
class TestStructuredOutputCompletion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_structured_completion(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
works_since: int
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
|
||||
|
||||
entities = [company1, person1]
|
||||
await add_data_points(entities)
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
await add_data_points(entities)
|
||||
|
||||
entity_type = EntityType(name="Person", description="A human individual")
|
||||
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
||||
|
||||
entities = [entity]
|
||||
await add_data_points(entities)
|
||||
|
||||
await _test_get_structured_graph_completion_cot()
|
||||
await _test_get_structured_graph_completion()
|
||||
await _test_get_structured_graph_completion_temporal()
|
||||
await _test_get_structured_graph_completion_rag()
|
||||
await _test_get_structured_graph_completion_context_extension()
|
||||
await _test_get_structured_entity_completion()
|
||||
|
|
@ -1,159 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.tasks.summarization.models import TextSummary
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
|
||||
|
||||
class TestSummariesRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk1_summary = TextSummary(
|
||||
text="S.R.",
|
||||
made_from=chunk1,
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2_summary = TextSummary(
|
||||
text="M.B.",
|
||||
made_from=chunk2,
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3_summary = TextSummary(
|
||||
text="C.M.",
|
||||
made_from=chunk3,
|
||||
)
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk4_summary = TextSummary(
|
||||
text="R.R.",
|
||||
made_from=chunk4,
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5_summary = TextSummary(
|
||||
text="H.Y.",
|
||||
made_from=chunk5,
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6_summary = TextSummary(
|
||||
text="C.H.",
|
||||
made_from=chunk6,
|
||||
)
|
||||
|
||||
entities = [
|
||||
chunk1_summary,
|
||||
chunk2_summary,
|
||||
chunk3_summary,
|
||||
chunk4_summary,
|
||||
chunk5_summary,
|
||||
chunk6_summary,
|
||||
]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = SummariesRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
|
||||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == [], "Returned context should be empty on an empty graph"
|
||||
Loading…
Add table
Reference in a new issue