changes based on PR comments

This commit is contained in:
Andrej Milicevic 2025-11-07 10:00:17 +01:00
parent 72ba8d0dcb
commit 4ab53c9d64
4 changed files with 21 additions and 15 deletions

View file

@ -1,4 +1,4 @@
from typing import List, Optional from typing import Any, List, Optional, Type
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC):
@abstractmethod @abstractmethod
async def get_completion( async def get_completion(
self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None self,
) -> str: query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""Generates a response using the query and optional context (triplets).""" """Generates a response using the query and optional context (triplets)."""
pass pass

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional from typing import Any, Optional, Type, List
class BaseRetriever(ABC): class BaseRetriever(ABC):
@ -12,7 +12,11 @@ class BaseRetriever(ABC):
@abstractmethod @abstractmethod
async def get_completion( async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None self,
) -> Any: query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""Generates a response using the query and optional context.""" """Generates a response using the query and optional context."""
pass pass

View file

@ -196,11 +196,9 @@ class TestStructuredOutputCompletion:
entities = [entity] entities = [entity]
await add_data_points(entities) await add_data_points(entities)
await asyncio.gather( await _test_get_structured_graph_completion_cot()
_test_get_structured_graph_completion_cot(), await _test_get_structured_graph_completion()
_test_get_structured_graph_completion(), await _test_get_structured_graph_completion_temporal()
_test_get_structured_graph_completion_temporal(), await _test_get_structured_graph_completion_rag()
_test_get_structured_graph_completion_rag(), await _test_get_structured_graph_completion_context_extension()
_test_get_structured_graph_completion_context_extension(), await _test_get_structured_entity_completion()
_test_get_structured_entity_completion(),
)

View file

@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
class TextSummariesRetriever: class TestSummariesRetriever:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chunk_context(self): async def test_chunk_context(self):
system_directory_path = os.path.join( system_directory_path = os.path.join(