diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py index f763cafd6..44786f79d 100644 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -1,7 +1,7 @@ import os import pytest import pathlib - +from typing import List import cognee from cognee.low_level import setup from cognee.tasks.storage import add_data_points @@ -10,6 +10,20 @@ from cognee.modules.chunking.models import DocumentChunk from cognee.modules.data.processing.document_types import TextDocument from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.processing.document_types import Document +from cognee.modules.engine.models import Entity + + +class DocumentChunkWithEntities(DataPoint): + text: str + chunk_size: int + chunk_index: int + cut_type: str + is_part_of: Document + contains: List[Entity] = None + + metadata: dict = {"index_fields": ["text"]} class TestChunksRetriever: @@ -179,7 +193,9 @@ class TestChunksRetriever: await retriever.get_context("Christina Mayer") vector_engine = get_vector_engine() - await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk) + await vector_engine.create_collection( + "DocumentChunk_text", payload_schema=DocumentChunkWithEntities + ) context = await retriever.get_context("Christina Mayer") assert len(context) == 0, "Found chunks when none should exist" diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index 356aed4d3..252af8352 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -1,7 +1,7 @@ import os +from typing import List import pytest import pathlib - import cognee from cognee.low_level import setup from cognee.tasks.storage import add_data_points @@ -10,6 +10,20 @@ from cognee.modules.chunking.models import DocumentChunk from cognee.modules.data.processing.document_types import TextDocument from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.processing.document_types import Document +from cognee.modules.engine.models import Entity + + +class DocumentChunkWithEntities(DataPoint): + text: str + chunk_size: int + chunk_index: int + cut_type: str + is_part_of: Document + contains: List[Entity] = None + + metadata: dict = {"index_fields": ["text"]} class TestRAGCompletionRetriever: @@ -182,7 +196,9 @@ class TestRAGCompletionRetriever: await retriever.get_context("Christina Mayer") vector_engine = get_vector_engine() - await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk) + await vector_engine.create_collection( + "DocumentChunk_text", payload_schema=DocumentChunkWithEntities + ) context = await retriever.get_context("Christina Mayer") assert context == "", "Returned context should be empty on an empty graph"