cognee/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py

204 lines
6.4 KiB
Python

import os
from typing import List
import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
class DocumentChunkWithEntities(DataPoint):
text: str
chunk_size: int
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None
metadata: dict = {"index_fields": ["text"]}
class TestRAGCompletionRetriever:
@pytest.mark.asyncio
async def test_rag_completion_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
retriever = CompletionRetriever()
context = await retriever.get_context("Mike")
assert context == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio
async def test_rag_completion_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
await add_data_points(entities)
# TODO: top_k doesn't affect the output, it should be fixed.
retriever = CompletionRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_get_rag_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_get_rag_completion_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_get_rag_completion_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = CompletionRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection(
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
)
context = await retriever.get_context("Christina Mayer")
assert context == "", "Returned context should be empty on an empty graph"