changes based on PR comments

This commit is contained in:
Andrej Milicevic 2025-11-07 10:00:17 +01:00
parent 72ba8d0dcb
commit 4ab53c9d64
4 changed files with 21 additions and 15 deletions

View file

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Any, List, Optional, Type
from abc import ABC, abstractmethod
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC):
@abstractmethod
async def get_completion(
self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
) -> str:
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""Generates a response using the query and optional context (triplets)."""
pass

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, Optional, Type, List
class BaseRetriever(ABC):
@ -12,7 +12,11 @@ class BaseRetriever(ABC):
@abstractmethod
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""Generates a response using the query and optional context."""
pass

View file

@ -196,11 +196,9 @@ class TestStructuredOutputCompletion:
entities = [entity]
await add_data_points(entities)
await asyncio.gather(
_test_get_structured_graph_completion_cot(),
_test_get_structured_graph_completion(),
_test_get_structured_graph_completion_temporal(),
_test_get_structured_graph_completion_rag(),
_test_get_structured_graph_completion_context_extension(),
_test_get_structured_entity_completion(),
)
await _test_get_structured_graph_completion_cot()
await _test_get_structured_graph_completion()
await _test_get_structured_graph_completion_temporal()
await _test_get_structured_graph_completion_rag()
await _test_get_structured_graph_completion_context_extension()
await _test_get_structured_entity_completion()

View file

@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
class TextSummariesRetriever:
class TestSummariesRetriever:
@pytest.mark.asyncio
async def test_chunk_context(self):
system_directory_path = os.path.join(