diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py index 6086977ce..14996f902 100644 --- a/cognee/modules/retrieval/EntityCompletionRetriever.py +++ b/cognee/modules/retrieval/EntityCompletionRetriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Optional, List +from typing import Any, Optional, List, Type from cognee.shared.logging_utils import get_logger from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor @@ -85,8 +85,12 @@ class EntityCompletionRetriever(BaseRetriever): return None async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None - ) -> List[str]: + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """ Generate completion using provided context or fetch new context. @@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever): fetched if not provided. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + response_model=response_model, ) if session_save: diff --git a/cognee/modules/retrieval/base_graph_retriever.py b/cognee/modules/retrieval/base_graph_retriever.py index b0abc2991..b203309ba 100644 --- a/cognee/modules/retrieval/base_graph_retriever.py +++ b/cognee/modules/retrieval/base_graph_retriever.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Optional, Type from abc import ABC, abstractmethod from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge @@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC): @abstractmethod async def get_completion( - self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None - ) -> str: + self, + query: str, + context: Optional[List[Edge]] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """Generates a response using the query and optional context (triplets).""" pass diff --git a/cognee/modules/retrieval/base_retriever.py b/cognee/modules/retrieval/base_retriever.py index 1533dd44f..b88c741b8 100644 --- a/cognee/modules/retrieval/base_retriever.py +++ b/cognee/modules/retrieval/base_retriever.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Optional, Type, List class BaseRetriever(ABC): @@ -12,7 +12,11 @@ class BaseRetriever(ABC): @abstractmethod async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None - ) -> Any: + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """Generates a response using the query and optional context.""" pass diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index bb568924d..126ebcab8 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Optional +from typing import Any, Optional, Type, List from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine @@ -75,8 +75,12 @@ class CompletionRetriever(BaseRetriever): raise NoDataError("No data found in the system, please add data first.") from error async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None - ) -> str: + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """ Generates an LLM completion using the context. @@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever): completion; if None, it retrieves the context for the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if session_save: @@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever): session_id=session_id, ) - return completion + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 58b6b586f..b07d11fd2 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional, List, Type +from typing import Optional, List, Type, Any from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -56,7 +56,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context: Optional[List[Edge]] = None, session_id: Optional[str] = None, context_extension_rounds=4, - ) -> List[str]: + response_model: Type = str, + ) -> List[Any]: """ Extends the context for a given query by retrieving related triplets and generating new completions based on them. @@ -76,6 +77,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): defaults to 'default_session'. (default None) - context_extension_rounds: The maximum number of rounds to extend the context with new triplets before halting. (default 4) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -143,6 +145,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -152,6 +155,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if self.save_interaction and context_text and triplets and completion: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 299db6855..eb8f502cb 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import ( - generate_structured_completion, + generate_completion, summarize_text, ) from cognee.modules.retrieval.utils.session_cache import ( @@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): questions based on reasoning. The public methods are: - get_completion - - get_structured_completion Instance variables include: - validation_system_prompt_path @@ -121,7 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): triplets += await self.get_context(followup_question) context_text = await self.resolve_edges_to_text(list(set(triplets))) - completion = await generate_structured_completion( + completion = await generate_completion( query=query, context=context_text, user_prompt_path=self.user_prompt_path, @@ -165,24 +164,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): return completion, context_text, triplets - async def get_structured_completion( + async def get_completion( self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, - max_iter: int = 4, + max_iter=4, response_model: Type = str, - ) -> Any: + ) -> List[Any]: """ - Generate structured completion responses based on a user query and contextual information. + Generate completion responses based on a user query and contextual information. - This method applies the same chain-of-thought logic as get_completion but returns + This method interacts with a language model client to retrieve a structured response, + using a series of iterations to refine the answers and generate follow-up questions + based on reasoning derived from previous outputs. It raises exceptions if the context + retrieval fails or if the model encounters issues in generating outputs. It returns structured output using the provided response model. Parameters: ----------- + - query (str): The user's query to be processed and answered. - - context (Optional[List[Edge]]): Optional context that may assist in answering the query. + - context (Optional[Any]): Optional context that may assist in answering the query. If not provided, it will be fetched based on the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) @@ -192,7 +195,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): Returns: -------- - - Any: The generated structured completion based on the response model. + + - List[str]: A list containing the generated answer to the user's query. """ # Check if session saving is enabled cache_config = CacheConfig() @@ -228,45 +232,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) - return completion - - async def get_completion( - self, - query: str, - context: Optional[List[Edge]] = None, - session_id: Optional[str] = None, - max_iter=4, - ) -> List[str]: - """ - Generate completion responses based on a user query and contextual information. - - This method interacts with a language model client to retrieve a structured response, - using a series of iterations to refine the answers and generate follow-up questions - based on reasoning derived from previous outputs. It raises exceptions if the context - retrieval fails or if the model encounters issues in generating outputs. - - Parameters: - ----------- - - - query (str): The user's query to be processed and answered. - - context (Optional[Any]): Optional context that may assist in answering the query. - If not provided, it will be fetched based on the query. (default None) - - session_id (Optional[str]): Optional session identifier for caching. If None, - defaults to 'default_session'. (default None) - - max_iter: The maximum number of iterations to refine the answer and generate - follow-up questions. (default 4) - - Returns: - -------- - - - List[str]: A list containing the generated answer to the user's query. - """ - completion = await self.get_structured_completion( - query=query, - context=context, - session_id=session_id, - max_iter=max_iter, - response_model=str, - ) - return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index b7ab4edae..df77a11ac 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -146,7 +146,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, - ) -> List[str]: + response_model: Type = str, + ) -> List[Any]: """ Generates a completion using graph connections context based on a query. @@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index ec68d37bb..f3da02c15 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -146,8 +146,12 @@ class TemporalRetriever(GraphCompletionRetriever): return self.descriptions_to_string(top_k_events) async def get_completion( - self, query: str, context: Optional[str] = None, session_id: Optional[str] = None - ) -> List[str]: + self, + query: str, + context: Optional[str] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """ Generates a response using the query and optional context. @@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever): retrieved based on the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + response_model=response_model, ) if session_save: diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index db7a10252..c90ce77f4 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt -async def generate_structured_completion( +async def generate_completion( query: str, context: str, user_prompt_path: str, @@ -12,7 +12,7 @@ async def generate_structured_completion( conversation_history: Optional[str] = None, response_model: Type = str, ) -> Any: - """Generates a structured completion using LLM with given context and prompts.""" + """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} user_prompt = render_prompt(user_prompt_path, args) system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path) @@ -28,26 +28,6 @@ async def generate_structured_completion( ) -async def generate_completion( - query: str, - context: str, - user_prompt_path: str, - system_prompt_path: str, - system_prompt: Optional[str] = None, - conversation_history: Optional[str] = None, -) -> str: - """Generates a completion using LLM with given context and prompts.""" - return await generate_structured_completion( - query=query, - context=context, - user_prompt_path=user_prompt_path, - system_prompt_path=system_prompt_path, - system_prompt=system_prompt, - conversation_history=conversation_history, - response_model=str, - ) - - async def summarize_text( text: str, system_prompt_path: str = "summarize_search_results.txt", diff --git a/cognee/tasks/feedback/generate_improved_answers.py b/cognee/tasks/feedback/generate_improved_answers.py index e439cf9e5..d2b143d29 100644 --- a/cognee/tasks/feedback/generate_improved_answers.py +++ b/cognee/tasks/feedback/generate_improved_answers.py @@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction( ) retrieved_context = await retriever.get_context(query_text) - completion = await retriever.get_structured_completion( + completion = await retriever.get_completion( query=query_text, context=retrieved_context, response_model=ImprovedAnswerResponse, @@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction( new_context_text = await retriever.resolve_edges_to_text(retrieved_context) if completion: - enrichment.improved_answer = completion.answer + enrichment.improved_answer = completion[0].answer enrichment.new_context = new_context_text - enrichment.explanation = completion.explanation + enrichment.explanation = completion[0].explanation return enrichment else: logger.warning( diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 7fcfe0d6b..206cfaf84 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -2,7 +2,6 @@ import os import pytest import pathlib from typing import Optional, Union -from pydantic import BaseModel import cognee from cognee.low_level import setup, DataPoint @@ -11,11 +10,6 @@ from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever -class TestAnswer(BaseModel): - answer: str - explanation: str - - class TestGraphCompletionCoTRetriever: @pytest.mark.asyncio async def test_graph_completion_cot_context_simple(self): @@ -174,48 +168,3 @@ class TestGraphCompletionCoTRetriever: assert all(isinstance(item, str) and item.strip() for item in answer), ( "Answer must contain only non-empty strings" ) - - @pytest.mark.asyncio - async def test_get_structured_completion(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1) - - entities = [company1, person1] - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_structured_completion("Who works at Figma?") - assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}" - assert string_answer.strip(), "Answer should not be empty" - - # Test with structured response model - structured_answer = await retriever.get_structured_completion( - "Who works at Figma?", response_model=TestAnswer - ) - assert isinstance(structured_answer, TestAnswer), ( - f"Expected TestAnswer, got {type(structured_answer).__name__}" - ) - assert structured_answer.answer.strip(), "Answer field should not be empty" - assert structured_answer.explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index 252af8352..9bfed68f3 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -3,6 +3,7 @@ from typing import List import pytest import pathlib import cognee + from cognee.low_level import setup from cognee.tasks.storage import add_data_points from cognee.infrastructure.databases.vector import get_vector_engine diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/unit/modules/retrieval/structured_output_test.py new file mode 100644 index 000000000..4ad3019ff --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/structured_output_test.py @@ -0,0 +1,204 @@ +import asyncio + +import pytest +import cognee +import pathlib +import os + +from pydantic import BaseModel +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor +from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) +from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever + + +class TestAnswer(BaseModel): + answer: str + explanation: str + + +def _assert_string_answer(answer: list[str]): + assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings" + assert all(item.strip() for item in answer), "Items should not be empty" + + +def _assert_structured_answer(answer: list[TestAnswer]): + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer" + assert all(x.answer.strip() for x in answer), "Answer text should not be empty" + assert all(x.explanation.strip() for x in answer), "Explanation should not be empty" + + +async def _test_get_structured_graph_completion_cot(): + retriever = GraphCompletionCotRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion(): + retriever = GraphCompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_temporal(): + retriever = TemporalRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("When did Steve start working at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "When did Steve start working at Figma??", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_rag(): + retriever = CompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Where does Steve work?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Where does Steve work?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_context_extension(): + retriever = GraphCompletionContextExtensionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_entity_completion(): + retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider()) + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who is Albert Einstein?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who is Albert Einstein?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +class TestStructuredOutputCompletion: + @pytest.mark.asyncio + async def test_get_structured_completion(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + works_since: int + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) + + entities = [company1, person1] + await add_data_points(entities) + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + await add_data_points(entities) + + entity_type = EntityType(name="Person", description="A human individual") + entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") + + entities = [entity] + await add_data_points(entities) + + await _test_get_structured_graph_completion_cot() + await _test_get_structured_graph_completion() + await _test_get_structured_graph_completion_temporal() + await _test_get_structured_graph_completion_rag() + await _test_get_structured_graph_completion_context_extension() + await _test_get_structured_entity_completion() diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py index fc96081bf..5f4b93425 100644 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.summaries_retriever import SummariesRetriever -class TextSummariesRetriever: +class TestSummariesRetriever: @pytest.mark.asyncio async def test_chunk_context(self): system_directory_path = os.path.join( diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index a322cb237..c3c6a47f6 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -1,4 +1,3 @@ -import asyncio from types import SimpleNamespace import pytest