Merge branch 'dev' into multi-tenancy

This commit is contained in:
Igor Ilic 2025-11-07 15:51:44 +01:00
commit 9fb7f2c4cf
15 changed files with 279 additions and 147 deletions

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Any, Optional, List
from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
@ -85,8 +85,12 @@ class EntityCompletionRetriever(BaseRetriever):
return None
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> List[str]:
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generate completion using provided context or fetch new context.
@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
fetched if not provided. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
response_model=response_model,
)
if session_save:

View file

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Any, List, Optional, Type
from abc import ABC, abstractmethod
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC):
@abstractmethod
async def get_completion(
self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
) -> str:
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""Generates a response using the query and optional context (triplets)."""
pass

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, Optional, Type, List
class BaseRetriever(ABC):
@ -12,7 +12,11 @@ class BaseRetriever(ABC):
@abstractmethod
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""Generates a response using the query and optional context."""
pass

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Any, Optional
from typing import Any, Optional, Type, List
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
@ -75,8 +75,12 @@ class CompletionRetriever(BaseRetriever):
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> str:
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generates an LLM completion using the context.
@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if session_save:
@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
session_id=session_id,
)
return completion
return [completion]

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Optional, List, Type
from typing import Optional, List, Type, Any
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -56,7 +56,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4,
) -> List[str]:
response_model: Type = str,
) -> List[Any]:
"""
Extends the context for a given query by retrieving related triplets and generating new
completions based on them.
@ -76,6 +77,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
defaults to 'default_session'. (default None)
- context_extension_rounds: The maximum number of rounds to extend the context with
new triplets before halting. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -143,6 +145,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -152,6 +155,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context_text and triplets and completion:

View file

@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import (
generate_structured_completion,
generate_completion,
summarize_text,
)
from cognee.modules.retrieval.utils.session_cache import (
@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
questions based on reasoning. The public methods are:
- get_completion
- get_structured_completion
Instance variables include:
- validation_system_prompt_path
@ -121,7 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_structured_completion(
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
@ -165,24 +164,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
return completion, context_text, triplets
async def get_structured_completion(
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter: int = 4,
max_iter=4,
response_model: Type = str,
) -> Any:
) -> List[Any]:
"""
Generate structured completion responses based on a user query and contextual information.
Generate completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs. It returns
structured output using the provided response model.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
@ -192,7 +195,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
Returns:
--------
- Any: The generated structured completion based on the response model.
- List[str]: A list containing the generated answer to the user's query.
"""
# Check if session saving is enabled
cache_config = CacheConfig()
@ -228,45 +232,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_id=session_id,
)
return completion
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter=4,
) -> List[str]:
"""
Generate completion responses based on a user query and contextual information.
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
Returns:
--------
- List[str]: A list containing the generated answer to the user's query.
"""
completion = await self.get_structured_completion(
query=query,
context=context,
session_id=session_id,
max_iter=max_iter,
response_model=str,
)
return [completion]

View file

@ -146,7 +146,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
) -> List[str]:
response_model: Type = str,
) -> List[Any]:
"""
Generates a completion using graph connections context based on a query.
@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:

View file

@ -146,8 +146,12 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events)
async def get_completion(
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
) -> List[str]:
self,
query: str,
context: Optional[str] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generates a response using the query and optional context.
@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever):
retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
response_model=response_model,
)
if session_save:

View file

@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_structured_completion(
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
@ -12,7 +12,7 @@ async def generate_structured_completion(
conversation_history: Optional[str] = None,
response_model: Type = str,
) -> Any:
"""Generates a structured completion using LLM with given context and prompts."""
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
@ -28,26 +28,6 @@ async def generate_structured_completion(
)
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
return await generate_structured_completion(
query=query,
context=context,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
conversation_history=conversation_history,
response_model=str,
)
async def summarize_text(
text: str,
system_prompt_path: str = "summarize_search_results.txt",

View file

@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction(
)
retrieved_context = await retriever.get_context(query_text)
completion = await retriever.get_structured_completion(
completion = await retriever.get_completion(
query=query_text,
context=retrieved_context,
response_model=ImprovedAnswerResponse,
@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction(
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
if completion:
enrichment.improved_answer = completion.answer
enrichment.improved_answer = completion[0].answer
enrichment.new_context = new_context_text
enrichment.explanation = completion.explanation
enrichment.explanation = completion[0].explanation
return enrichment
else:
logger.warning(

View file

@ -2,7 +2,6 @@ import os
import pytest
import pathlib
from typing import Optional, Union
from pydantic import BaseModel
import cognee
from cognee.low_level import setup, DataPoint
@ -11,11 +10,6 @@ from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
class TestAnswer(BaseModel):
answer: str
explanation: str
class TestGraphCompletionCoTRetriever:
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(self):
@ -174,48 +168,3 @@ class TestGraphCompletionCoTRetriever:
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_structured_completion(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1)
entities = [company1, person1]
await add_data_points(entities)
retriever = GraphCompletionCotRetriever()
# Test with string response model (default)
string_answer = await retriever.get_structured_completion("Who works at Figma?")
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
assert string_answer.strip(), "Answer should not be empty"
# Test with structured response model
structured_answer = await retriever.get_structured_completion(
"Who works at Figma?", response_model=TestAnswer
)
assert isinstance(structured_answer, TestAnswer), (
f"Expected TestAnswer, got {type(structured_answer).__name__}"
)
assert structured_answer.answer.strip(), "Answer field should not be empty"
assert structured_answer.explanation.strip(), "Explanation field should not be empty"

View file

@ -3,6 +3,7 @@ from typing import List
import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine

View file

@ -0,0 +1,204 @@
import asyncio
import pytest
import cognee
import pathlib
import os
from pydantic import BaseModel
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor
from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
class TestAnswer(BaseModel):
answer: str
explanation: str
def _assert_string_answer(answer: list[str]):
assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings"
assert all(item.strip() for item in answer), "Items should not be empty"
def _assert_structured_answer(answer: list[TestAnswer]):
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer"
assert all(x.answer.strip() for x in answer), "Answer text should not be empty"
assert all(x.explanation.strip() for x in answer), "Explanation should not be empty"
async def _test_get_structured_graph_completion_cot():
retriever = GraphCompletionCotRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Who works at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion():
retriever = GraphCompletionRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Who works at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion_temporal():
retriever = TemporalRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("When did Steve start working at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"When did Steve start working at Figma??", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion_rag():
retriever = CompletionRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Where does Steve work?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Where does Steve work?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion_context_extension():
retriever = GraphCompletionContextExtensionRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Who works at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_entity_completion():
retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider())
# Test with string response model (default)
string_answer = await retriever.get_completion("Who is Albert Einstein?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who is Albert Einstein?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
class TestStructuredOutputCompletion:
@pytest.mark.asyncio
async def test_get_structured_completion(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
works_since: int
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
entities = [company1, person1]
await add_data_points(entities)
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
entity_type = EntityType(name="Person", description="A human individual")
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
entities = [entity]
await add_data_points(entities)
await _test_get_structured_graph_completion_cot()
await _test_get_structured_graph_completion()
await _test_get_structured_graph_completion_temporal()
await _test_get_structured_graph_completion_rag()
await _test_get_structured_graph_completion_context_extension()
await _test_get_structured_entity_completion()

View file

@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
class TextSummariesRetriever:
class TestSummariesRetriever:
@pytest.mark.asyncio
async def test_chunk_context(self):
system_directory_path = os.path.join(

View file

@ -1,4 +1,3 @@
import asyncio
from types import SimpleNamespace
import pytest