Merge remote-tracking branch 'upstream/dev' into delete-last-acessed

This commit is contained in:
chinu0609 2025-10-29 20:22:34 +05:30
commit 7d4804ff7b
31 changed files with 1211 additions and 101 deletions

View file

@ -358,6 +358,34 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/tasks/entity_extraction/entity_extraction_test.py
test-feedback-enrichment:
name: Test Feedback Enrichment
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run Feedback Enrichment Test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_feedback_enrichment.py
run_conversation_sessions_test:
name: Conversation sessions test
runs-on: ubuntu-latest

View file

@ -10,6 +10,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee import __version__ as cognee_version
logger = get_logger()
@ -63,7 +64,11 @@ def get_add_router() -> APIRouter:
send_telemetry(
"Add API Endpoint Invoked",
user.id,
additional_properties={"endpoint": "POST /v1/add", "node_set": node_set},
additional_properties={
"endpoint": "POST /v1/add",
"node_set": node_set,
"cognee_version": cognee_version,
},
)
from cognee.api.v1.add import add as cognee_add

View file

@ -29,7 +29,7 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
)
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger("api.cognify")
@ -98,6 +98,7 @@ def get_cognify_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "POST /v1/cognify",
"cognee_version": cognee_version,
},
)

View file

@ -24,6 +24,7 @@ from cognee.modules.users.permissions.methods import (
from cognee.modules.graph.methods import get_formatted_graph_data
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger()
@ -100,6 +101,7 @@ def get_datasets_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "GET /v1/datasets",
"cognee_version": cognee_version,
},
)
@ -147,6 +149,7 @@ def get_datasets_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "POST /v1/datasets",
"cognee_version": cognee_version,
},
)
@ -201,6 +204,7 @@ def get_datasets_router() -> APIRouter:
additional_properties={
"endpoint": f"DELETE /v1/datasets/{str(dataset_id)}",
"dataset_id": str(dataset_id),
"cognee_version": cognee_version,
},
)
@ -246,6 +250,7 @@ def get_datasets_router() -> APIRouter:
"endpoint": f"DELETE /v1/datasets/{str(dataset_id)}/data/{str(data_id)}",
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"cognee_version": cognee_version,
},
)
@ -327,6 +332,7 @@ def get_datasets_router() -> APIRouter:
additional_properties={
"endpoint": f"GET /v1/datasets/{str(dataset_id)}/data",
"dataset_id": str(dataset_id),
"cognee_version": cognee_version,
},
)
@ -387,6 +393,7 @@ def get_datasets_router() -> APIRouter:
additional_properties={
"endpoint": "GET /v1/datasets/status",
"datasets": [str(dataset_id) for dataset_id in datasets],
"cognee_version": cognee_version,
},
)
@ -433,6 +440,7 @@ def get_datasets_router() -> APIRouter:
"endpoint": f"GET /v1/datasets/{str(dataset_id)}/data/{str(data_id)}/raw",
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"cognee_version": cognee_version,
},
)

View file

@ -6,6 +6,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger()
@ -39,6 +40,7 @@ def get_delete_router() -> APIRouter:
"endpoint": "DELETE /v1/delete",
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"cognee_version": cognee_version,
},
)

View file

@ -12,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee import __version__ as cognee_version
logger = get_logger()
@ -73,7 +74,7 @@ def get_memify_router() -> APIRouter:
send_telemetry(
"Memify API Endpoint Invoked",
user.id,
additional_properties={"endpoint": "POST /v1/memify"},
additional_properties={"endpoint": "POST /v1/memify", "cognee_version": cognee_version},
)
if not payload.dataset_id and not payload.dataset_name:

View file

@ -7,6 +7,7 @@ from fastapi.responses import JSONResponse
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
def get_permissions_router() -> APIRouter:
@ -48,6 +49,7 @@ def get_permissions_router() -> APIRouter:
"endpoint": f"POST /v1/permissions/datasets/{str(principal_id)}",
"dataset_ids": str(dataset_ids),
"principal_id": str(principal_id),
"cognee_version": cognee_version,
},
)
@ -89,6 +91,7 @@ def get_permissions_router() -> APIRouter:
additional_properties={
"endpoint": "POST /v1/permissions/roles",
"role_name": role_name,
"cognee_version": cognee_version,
},
)
@ -133,6 +136,7 @@ def get_permissions_router() -> APIRouter:
"endpoint": f"POST /v1/permissions/users/{str(user_id)}/roles",
"user_id": str(user_id),
"role_id": str(role_id),
"cognee_version": cognee_version,
},
)
@ -175,6 +179,7 @@ def get_permissions_router() -> APIRouter:
"endpoint": f"POST /v1/permissions/users/{str(user_id)}/tenants",
"user_id": str(user_id),
"tenant_id": str(tenant_id),
"cognee_version": cognee_version,
},
)
@ -209,6 +214,7 @@ def get_permissions_router() -> APIRouter:
additional_properties={
"endpoint": "POST /v1/permissions/tenants",
"tenant_name": tenant_name,
"cognee_version": cognee_version,
},
)

View file

@ -13,6 +13,7 @@ from cognee.modules.users.models import User
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
# Note: Datasets sent by name will only map to datasets owned by the request sender
@ -61,9 +62,7 @@ def get_search_router() -> APIRouter:
send_telemetry(
"Search API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": "GET /v1/search",
},
additional_properties={"endpoint": "GET /v1/search", "cognee_version": cognee_version},
)
try:
@ -118,6 +117,7 @@ def get_search_router() -> APIRouter:
"top_k": payload.top_k,
"only_context": payload.only_context,
"use_combined_context": payload.use_combined_context,
"cognee_version": cognee_version,
},
)

View file

@ -12,6 +12,7 @@ from cognee.modules.sync.methods import get_running_sync_operations_for_user, ge
from cognee.shared.utils import send_telemetry
from cognee.shared.logging_utils import get_logger
from cognee.api.v1.sync import SyncResponse
from cognee import __version__ as cognee_version
from cognee.context_global_variables import set_database_global_context_variables
logger = get_logger()
@ -99,6 +100,7 @@ def get_sync_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "POST /v1/sync",
"cognee_version": cognee_version,
"dataset_ids": [str(id) for id in request.dataset_ids]
if request.dataset_ids
else "*",
@ -205,6 +207,7 @@ def get_sync_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "GET /v1/sync/status",
"cognee_version": cognee_version,
},
)

View file

@ -503,7 +503,7 @@ def start_ui(
if start_mcp:
logger.info("Starting Cognee MCP server with Docker...")
try:
image = "cognee/cognee-mcp:feature-standalone-mcp" # TODO: change to "cognee/cognee-mcp:main" right before merging into main
image = "cognee/cognee-mcp:main"
subprocess.run(["docker", "pull", image], check=True)
import uuid
@ -538,9 +538,7 @@ def start_ui(
env_file = os.path.join(cwd, ".env")
docker_cmd.extend(["--env-file", env_file])
docker_cmd.append(
image
) # TODO: change to "cognee/cognee-mcp:main" right before merging into main
docker_cmd.append(image)
mcp_process = subprocess.Popen(
docker_cmd,

View file

@ -9,6 +9,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunErrored,
)
@ -64,6 +65,7 @@ def get_update_router() -> APIRouter:
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"node_set": str(node_set),
"cognee_version": cognee_version,
},
)

View file

@ -8,6 +8,7 @@ from cognee.modules.users.models import User
from cognee.context_global_variables import set_database_global_context_variables
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger()
@ -46,6 +47,7 @@ def get_visualize_router() -> APIRouter:
additional_properties={
"endpoint": "GET /v1/visualize",
"dataset_id": str(dataset_id),
"cognee_version": cognee_version,
},
)

View file

@ -1366,9 +1366,15 @@ class KuzuAdapter(GraphDBInterface):
params[param_name] = values
where_clause = " AND ".join(where_clauses)
nodes_query = (
f"MATCH (n:Node) WHERE {where_clause} RETURN n.id, {{properties: n.properties}}"
)
nodes_query = f"""
MATCH (n:Node)
WHERE {where_clause}
RETURN n.id, {{
name: n.name,
type: n.type,
properties: n.properties
}}
"""
edges_query = f"""
MATCH (n1:Node)-[r:EDGE]->(n2:Node)
WHERE {where_clause.replace("n.", "n1.")} AND {where_clause.replace("n.", "n2.")}

View file

@ -0,0 +1,14 @@
A question was previously answered, but the answer received negative feedback.
Please reconsider and improve the response.
Question: {question}
Context originally used: {context}
Previous answer: {wrong_answer}
Feedback on that answer: {negative_feedback}
Task: Provide a better response. The new answer should be short and direct.
Then explain briefly why this answer is better.
Format your reply as:
Answer: <improved answer>
Explanation: <short explanation>

View file

@ -0,0 +1,13 @@
Write a concise, stand-alone paragraph that explains the correct answer to the question below.
The paragraph should read naturally on its own, providing all necessary context and reasoning
so the answer is clear and well-supported.
Question: {question}
Correct answer: {improved_answer}
Supporting context: {new_context}
Your paragraph should:
- First sentence clearly states the correct answer as a full sentence
- Remainder flows from first sentence and provides explanation based on context
- Use simple, direct language that is easy to follow
- Use shorter sentences, no long-winded explanations

View file

@ -0,0 +1,5 @@
Question: {question}
Context: {context}
Provide a one paragraph human readable summary of this interaction context,
listing all the relevant facts and information in a simple and direct way.

View file

@ -2,6 +2,7 @@ import inspect
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from ..tasks.task import Task
@ -25,6 +26,7 @@ async def handle_task(
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
"cognee_version": cognee_version,
},
)
@ -46,6 +48,7 @@ async def handle_task(
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
"cognee_version": cognee_version,
},
)
except Exception as error:
@ -58,6 +61,7 @@ async def handle_task(
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
"cognee_version": cognee_version,
},
)
raise error

View file

@ -4,6 +4,7 @@ from cognee.modules.settings import get_current_settings
from cognee.modules.users.models import User
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from .run_tasks_base import run_tasks_base
from ..tasks.task import Task
@ -26,6 +27,7 @@ async def run_tasks_with_telemetry(
user.id,
additional_properties={
"pipeline_name": str(pipeline_name),
"cognee_version": cognee_version,
}
| config,
)
@ -39,7 +41,9 @@ async def run_tasks_with_telemetry(
user.id,
additional_properties={
"pipeline_name": str(pipeline_name),
},
"cognee_version": cognee_version,
}
| config,
)
except Exception as error:
logger.error(
@ -53,6 +57,7 @@ async def run_tasks_with_telemetry(
user.id,
additional_properties={
"pipeline_name": str(pipeline_name),
"cognee_version": cognee_version,
}
| config,
)

View file

@ -1,10 +1,15 @@
import asyncio
import json
from typing import Optional, List, Type, Any
from pydantic import BaseModel
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.completion import (
generate_structured_completion,
summarize_text,
)
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
@ -17,6 +22,20 @@ from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger()
def _as_answer_text(completion: Any) -> str:
"""Convert completion to human-readable text for validation and follow-up prompts."""
if isinstance(completion, str):
return completion
if isinstance(completion, BaseModel):
# Add notice that this is a structured response
json_str = completion.model_dump_json(indent=2)
return f"[Structured Response]\n{json_str}"
try:
return json.dumps(completion, indent=2)
except TypeError:
return str(completion)
class GraphCompletionCotRetriever(GraphCompletionRetriever):
"""
Handles graph completion by generating responses based on a series of interactions with
@ -25,6 +44,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
questions based on reasoning. The public methods are:
- get_completion
- get_structured_completion
Instance variables include:
- validation_system_prompt_path
@ -61,6 +81,155 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self.followup_system_prompt_path = followup_system_prompt_path
self.followup_user_prompt_path = followup_user_prompt_path
async def _run_cot_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
conversation_history: str = "",
max_iter: int = 4,
response_model: Type = str,
) -> tuple[Any, str, List[Edge]]:
"""
Run chain-of-thought completion with optional structured output.
Parameters:
-----------
- query: User query
- context: Optional pre-fetched context edges
- conversation_history: Optional conversation history string
- max_iter: Maximum CoT iterations
- response_model: Type for structured output (str for plain text)
Returns:
--------
- completion_result: The generated completion (string or structured model)
- context_text: The resolved context text
- triplets: The list of triplets used
"""
followup_question = ""
triplets = []
completion = ""
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_structured_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if conversation_history else None,
response_model=response_model,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
answer_text = _as_answer_text(completion)
valid_args = {"query": query, "answer": answer_text, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
return completion, context_text, triplets
async def get_structured_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter: int = 4,
response_model: Type = str,
) -> Any:
"""
Generate structured completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns
structured output using the provided response model.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
- Any: The generated structured completion based on the response model.
"""
# Check if session saving is enabled
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
# Load conversation history if enabled
conversation_history = ""
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
conversation_history=conversation_history,
max_iter=max_iter,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=str(completion), context=context_text, triplets=triplets
)
# Save to session cache if enabled
if session_save:
context_summary = await summarize_text(context_text)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=str(completion),
session_id=session_id,
)
return completion
async def get_completion(
self,
query: str,
@ -92,82 +261,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer to the user's query.
"""
followup_question = ""
triplets = []
completion = ""
# Retrieve conversation history if session saving is enabled
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
conversation_history = ""
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if session_save else None,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": completion, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
)
# Save to session cache
if session_save:
context_summary = await summarize_text(context_text)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
completion = await self.get_structured_completion(
query=query,
context=context,
session_id=session_id,
max_iter=max_iter,
response_model=str,
)
return [completion]

View file

@ -1,17 +1,18 @@
from typing import Optional
from typing import Optional, Type, Any
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_completion(
async def generate_structured_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
response_model: Type = str,
) -> Any:
"""Generates a structured completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
@ -23,6 +24,26 @@ async def generate_completion(
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=response_model,
)
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
return await generate_structured_completion(
query=query,
context=context,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
conversation_history=conversation_history,
response_model=str,
)

View file

@ -24,7 +24,7 @@ from cognee.modules.data.models import Dataset
from cognee.modules.data.methods.get_authorized_existing_datasets import (
get_authorized_existing_datasets,
)
from cognee import __version__ as cognee_version
from .get_search_type_tools import get_search_type_tools
from .no_access_control_search import no_access_control_search
from ..utils.prepare_search_result import prepare_search_result
@ -64,7 +64,11 @@ async def search(
Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
"""
query = await log_query(query_text, query_type.value, user.id)
send_telemetry("cognee.search EXECUTION STARTED", user.id)
send_telemetry(
"cognee.search EXECUTION STARTED",
user.id,
additional_properties={"cognee_version": cognee_version},
)
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
@ -101,7 +105,11 @@ async def search(
)
]
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
send_telemetry(
"cognee.search EXECUTION COMPLETED",
user.id,
additional_properties={"cognee_version": cognee_version},
)
await log_result(
query.id,

View file

@ -8,7 +8,7 @@ import http.server
import socketserver
from threading import Thread
import pathlib
from uuid import uuid4
from uuid import uuid4, uuid5, NAMESPACE_OID
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.graph import get_graph_engine
@ -51,6 +51,26 @@ def get_anonymous_id():
return anonymous_id
def _sanitize_nested_properties(obj, property_names: list[str]):
"""
Recursively replaces any property whose key matches one of `property_names`
(e.g., ['url', 'path']) in a nested dict or list with a uuid5 hash
of its string value. Returns a new sanitized copy.
"""
if isinstance(obj, dict):
new_obj = {}
for k, v in obj.items():
if k in property_names and isinstance(v, str):
new_obj[k] = str(uuid5(NAMESPACE_OID, v))
else:
new_obj[k] = _sanitize_nested_properties(v, property_names)
return new_obj
elif isinstance(obj, list):
return [_sanitize_nested_properties(item, property_names) for item in obj]
else:
return obj
def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
if os.getenv("TELEMETRY_DISABLED"):
return
@ -58,7 +78,9 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
env = os.getenv("ENV")
if env in ["test", "dev"]:
return
additional_properties = _sanitize_nested_properties(
obj=additional_properties, property_names=["url"]
)
current_time = datetime.now(timezone.utc)
payload = {
"anonymous_id": str(get_anonymous_id()),

View file

@ -0,0 +1,13 @@
from .extract_feedback_interactions import extract_feedback_interactions
from .generate_improved_answers import generate_improved_answers
from .create_enrichments import create_enrichments
from .link_enrichments_to_feedback import link_enrichments_to_feedback
from .models import FeedbackEnrichment
__all__ = [
"extract_feedback_interactions",
"generate_improved_answers",
"create_enrichments",
"link_enrichments_to_feedback",
"FeedbackEnrichment",
]

View file

@ -0,0 +1,84 @@
from __future__ import annotations
from typing import List
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.models import NodeSet
from .models import FeedbackEnrichment
logger = get_logger("create_enrichments")
def _validate_enrichments(enrichments: List[FeedbackEnrichment]) -> bool:
"""Validate that all enrichments contain required fields for completion."""
return all(
enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.improved_answer is not None
and enrichment.new_context is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
for enrichment in enrichments
)
async def _generate_enrichment_report(
question: str, improved_answer: str, new_context: str, report_prompt_location: str
) -> str:
"""Generate educational report using feedback report prompt."""
try:
prompt_template = read_query_prompt(report_prompt_location)
rendered_prompt = prompt_template.format(
question=question,
improved_answer=improved_answer,
new_context=new_context,
)
return await LLMGateway.acreate_structured_output(
text_input=rendered_prompt,
system_prompt="You are a helpful assistant that creates educational content.",
response_model=str,
)
except Exception as exc:
logger.warning("Failed to generate enrichment report", error=str(exc), question=question)
return f"Educational content for: {question} - {improved_answer}"
async def create_enrichments(
enrichments: List[FeedbackEnrichment],
report_prompt_location: str = "feedback_report_prompt.txt",
) -> List[FeedbackEnrichment]:
"""Fill text and belongs_to_set fields of existing FeedbackEnrichment DataPoints."""
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
if not _validate_enrichments(enrichments):
logger.error("Input validation failed; missing required fields")
return []
logger.info("Completing enrichments", count=len(enrichments))
nodeset = NodeSet(id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment")
completed_enrichments: List[FeedbackEnrichment] = []
for enrichment in enrichments:
report_text = await _generate_enrichment_report(
enrichment.question,
enrichment.improved_answer,
enrichment.new_context,
report_prompt_location,
)
enrichment.text = report_text
enrichment.belongs_to_set = [nodeset]
completed_enrichments.append(enrichment)
logger.info("Completed enrichments", successful=len(completed_enrichments))
return completed_enrichments

View file

@ -0,0 +1,230 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from .models import FeedbackEnrichment
logger = get_logger("extract_feedback_interactions")
def _filter_negative_feedback(feedback_nodes):
"""Filter for negative sentiment feedback using precise sentiment classification."""
return [
(node_id, props)
for node_id, props in feedback_nodes
if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
]
def _get_normalized_id(node_id, props) -> str:
"""Return Cognee node id preference: props.id → props.node_id → raw node_id."""
return str(props.get("id") or props.get("node_id") or node_id)
async def _fetch_feedback_and_interaction_graph_data() -> Tuple[List, List]:
"""Fetch feedback and interaction nodes with edges from graph engine."""
try:
graph_engine = await get_graph_engine()
attribute_filters = [{"type": ["CogneeUserFeedback", "CogneeUserInteraction"]}]
return await graph_engine.get_filtered_graph_data(attribute_filters)
except Exception as exc: # noqa: BLE001
logger.error("Failed to fetch filtered graph data", error=str(exc))
return [], []
def _separate_feedback_and_interaction_nodes(graph_nodes: List) -> Tuple[List, List]:
"""Split nodes into feedback and interaction groups by type field."""
feedback_nodes = [
(_get_normalized_id(node_id, props), props)
for node_id, props in graph_nodes
if props.get("type") == "CogneeUserFeedback"
]
interaction_nodes = [
(_get_normalized_id(node_id, props), props)
for node_id, props in graph_nodes
if props.get("type") == "CogneeUserInteraction"
]
return feedback_nodes, interaction_nodes
def _match_feedback_nodes_to_interactions_by_edges(
feedback_nodes: List, interaction_nodes: List, graph_edges: List
) -> List[Tuple[Tuple, Tuple]]:
"""Match feedback to interactions using gives_feedback_to edges."""
interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes}
feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes}
feedback_edges = [
(source_id, target_id)
for source_id, target_id, rel, _ in graph_edges
if rel == "gives_feedback_to"
]
feedback_interaction_pairs: List[Tuple[Tuple, Tuple]] = []
for source_id, target_id in feedback_edges:
source_id_str, target_id_str = str(source_id), str(target_id)
feedback_node = feedback_by_id.get(source_id_str)
interaction_node = interaction_by_id.get(target_id_str)
if feedback_node and interaction_node:
feedback_interaction_pairs.append((feedback_node, interaction_node))
return feedback_interaction_pairs
def _sort_pairs_by_recency_and_limit(
feedback_interaction_pairs: List[Tuple[Tuple, Tuple]], last_n_limit: Optional[int]
) -> List[Tuple[Tuple, Tuple]]:
"""Sort by interaction created_at desc with updated_at fallback, then limit."""
def _recency_key(pair):
_, (_, interaction_props) = pair
created_at = interaction_props.get("created_at") or ""
updated_at = interaction_props.get("updated_at") or ""
return (created_at, updated_at)
sorted_pairs = sorted(feedback_interaction_pairs, key=_recency_key, reverse=True)
return sorted_pairs[: last_n_limit or len(sorted_pairs)]
async def _generate_human_readable_context_summary(
question_text: str, raw_context_text: str
) -> str:
"""Generate a concise human-readable summary for given context."""
try:
prompt = read_query_prompt("feedback_user_context_prompt.txt")
rendered = prompt.format(question=question_text, context=raw_context_text)
return await LLMGateway.acreate_structured_output(
text_input=rendered, system_prompt="", response_model=str
)
except Exception as exc: # noqa: BLE001
logger.warning("Failed to summarize context", error=str(exc))
return raw_context_text or ""
def _has_required_feedback_fields(enrichment: FeedbackEnrichment) -> bool:
"""Validate required fields exist in the FeedbackEnrichment DataPoint."""
return (
enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.context is not None
and enrichment.feedback_text is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
)
async def _build_feedback_interaction_record(
feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict
) -> Optional[FeedbackEnrichment]:
"""Build a single FeedbackEnrichment DataPoint with context summary."""
try:
question_text = interaction_props.get("question")
original_answer_text = interaction_props.get("answer")
raw_context_text = interaction_props.get("context", "")
feedback_text = feedback_props.get("feedback") or feedback_props.get("text") or ""
context_summary_text = await _generate_human_readable_context_summary(
question_text or "", raw_context_text
)
enrichment = FeedbackEnrichment(
id=str(uuid5(NAMESPACE_OID, f"{question_text}_{interaction_node_id}")),
text="",
question=question_text,
original_answer=original_answer_text,
improved_answer="",
feedback_id=UUID(str(feedback_node_id)),
interaction_id=UUID(str(interaction_node_id)),
belongs_to_set=None,
context=context_summary_text,
feedback_text=feedback_text,
new_context="",
explanation="",
)
if _has_required_feedback_fields(enrichment):
return enrichment
else:
logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id))
return None
except Exception as exc: # noqa: BLE001
logger.error("Failed to process feedback pair", error=str(exc))
return None
async def _build_feedback_interaction_records(
matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]],
) -> List[FeedbackEnrichment]:
"""Build all FeedbackEnrichment DataPoints from matched pairs."""
feedback_interaction_records: List[FeedbackEnrichment] = []
for (feedback_node_id, feedback_props), (
interaction_node_id,
interaction_props,
) in matched_feedback_interaction_pairs:
record = await _build_feedback_interaction_record(
feedback_node_id, feedback_props, interaction_node_id, interaction_props
)
if record:
feedback_interaction_records.append(record)
return feedback_interaction_records
async def extract_feedback_interactions(
data: Any, last_n: Optional[int] = None
) -> List[FeedbackEnrichment]:
"""Extract negative feedback-interaction pairs and create FeedbackEnrichment DataPoints."""
if not data or data == [{}]:
logger.info(
"No data passed to the extraction task (extraction task fetches data from graph directly)",
data=data,
)
graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data()
if not graph_nodes:
logger.warning("No graph nodes retrieved from database")
return []
feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes)
logger.info(
"Retrieved nodes from graph",
total_nodes=len(graph_nodes),
feedback_nodes=len(feedback_nodes),
interaction_nodes=len(interaction_nodes),
)
negative_feedback_nodes = _filter_negative_feedback(feedback_nodes)
logger.info(
"Filtered feedback nodes",
total_feedback=len(feedback_nodes),
negative_feedback=len(negative_feedback_nodes),
)
if not negative_feedback_nodes:
logger.info("No negative feedback found; returning empty list")
return []
matched_feedback_interaction_pairs = _match_feedback_nodes_to_interactions_by_edges(
negative_feedback_nodes, interaction_nodes, graph_edges
)
if not matched_feedback_interaction_pairs:
logger.info("No feedback-to-interaction matches found; returning empty list")
return []
matched_feedback_interaction_pairs = _sort_pairs_by_recency_and_limit(
matched_feedback_interaction_pairs, last_n
)
feedback_interaction_records = await _build_feedback_interaction_records(
matched_feedback_interaction_pairs
)
logger.info("Extracted feedback pairs", count=len(feedback_interaction_records))
return feedback_interaction_records

View file

@ -0,0 +1,130 @@
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel
from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from .models import FeedbackEnrichment
class ImprovedAnswerResponse(BaseModel):
"""Response model for improved answer generation containing answer and explanation."""
answer: str
explanation: str
logger = get_logger("generate_improved_answers")
def _validate_input_data(enrichments: List[FeedbackEnrichment]) -> bool:
"""Validate that input contains required fields for all enrichments."""
return all(
enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.context is not None
and enrichment.feedback_text is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
for enrichment in enrichments
)
def _render_reaction_prompt(
question: str, context: str, wrong_answer: str, negative_feedback: str
) -> str:
"""Render the feedback reaction prompt with provided variables."""
prompt_template = read_query_prompt("feedback_reaction_prompt.txt")
return prompt_template.format(
question=question,
context=context,
wrong_answer=wrong_answer,
negative_feedback=negative_feedback,
)
async def _generate_improved_answer_for_single_interaction(
enrichment: FeedbackEnrichment, retriever, reaction_prompt_location: str
) -> Optional[FeedbackEnrichment]:
"""Generate improved answer for a single enrichment using structured retriever completion."""
try:
query_text = _render_reaction_prompt(
enrichment.question,
enrichment.context,
enrichment.original_answer,
enrichment.feedback_text,
)
retrieved_context = await retriever.get_context(query_text)
completion = await retriever.get_structured_completion(
query=query_text,
context=retrieved_context,
response_model=ImprovedAnswerResponse,
max_iter=4,
)
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
if completion:
enrichment.improved_answer = completion.answer
enrichment.new_context = new_context_text
enrichment.explanation = completion.explanation
return enrichment
else:
logger.warning(
"Failed to get structured completion from retriever", question=enrichment.question
)
return None
except Exception as exc: # noqa: BLE001
logger.error(
"Failed to generate improved answer",
error=str(exc),
question=enrichment.question,
)
return None
async def generate_improved_answers(
enrichments: List[FeedbackEnrichment],
top_k: int = 20,
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
) -> List[FeedbackEnrichment]:
"""Generate improved answers using CoT retriever and LLM."""
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
if not _validate_input_data(enrichments):
logger.error("Input data validation failed; missing required fields")
return []
retriever = GraphCompletionCotRetriever(
top_k=top_k,
save_interaction=False,
user_prompt_path="graph_context_for_question.txt",
system_prompt_path="answer_simple_question.txt",
)
improved_answers: List[FeedbackEnrichment] = []
for enrichment in enrichments:
result = await _generate_improved_answer_for_single_interaction(
enrichment, retriever, reaction_prompt_location
)
if result:
improved_answers.append(result)
else:
logger.warning(
"Failed to generate improved answer",
question=enrichment.question,
interaction_id=enrichment.interaction_id,
)
logger.info("Generated improved answers", count=len(improved_answers))
return improved_answers

View file

@ -0,0 +1,67 @@
from __future__ import annotations
from typing import List, Tuple
from uuid import UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.tasks.storage import index_graph_edges
from cognee.shared.logging_utils import get_logger
from .models import FeedbackEnrichment
logger = get_logger("link_enrichments_to_feedback")
def _create_edge_tuple(
source_id: UUID, target_id: UUID, relationship_name: str
) -> Tuple[UUID, UUID, str, dict]:
"""Create an edge tuple with proper properties structure."""
return (
source_id,
target_id,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id,
"ontology_valid": False,
},
)
async def link_enrichments_to_feedback(
enrichments: List[FeedbackEnrichment],
) -> List[FeedbackEnrichment]:
"""Manually create edges from enrichments to original feedback/interaction nodes."""
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
relationships = []
for enrichment in enrichments:
enrichment_id = enrichment.id
feedback_id = enrichment.feedback_id
interaction_id = enrichment.interaction_id
if enrichment_id and feedback_id:
enriches_feedback_edge = _create_edge_tuple(
enrichment_id, feedback_id, "enriches_feedback"
)
relationships.append(enriches_feedback_edge)
if enrichment_id and interaction_id:
improves_interaction_edge = _create_edge_tuple(
enrichment_id, interaction_id, "improves_interaction"
)
relationships.append(improves_interaction_edge)
if relationships:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
await index_graph_edges(relationships)
logger.info("Linking enrichments to feedback", edge_count=len(relationships))
logger.info("Linked enrichments", enrichment_count=len(enrichments))
return enrichments

View file

@ -0,0 +1,26 @@
from typing import List, Optional, Union
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models import Entity, NodeSet
from cognee.tasks.temporal_graph.models import Event
class FeedbackEnrichment(DataPoint):
"""Minimal DataPoint for feedback enrichment that works with extract_graph_from_data."""
text: str
contains: Optional[List[Union[Entity, Event]]] = None
metadata: dict = {"index_fields": ["text"]}
question: str
original_answer: str
improved_answer: str
feedback_id: UUID
interaction_id: UUID
belongs_to_set: Optional[List[NodeSet]] = None
context: str = ""
feedback_text: str = ""
new_context: str = ""
explanation: str = ""

View file

@ -0,0 +1,174 @@
"""
End-to-end integration test for feedback enrichment feature.
Tests the complete feedback enrichment pipeline:
1. Add data and cognify
2. Run search with save_interaction=True to create CogneeUserInteraction nodes
3. Submit feedback to create CogneeUserFeedback nodes
4. Run memify with feedback enrichment tasks to create FeedbackEnrichment nodes
5. Verify all nodes and edges are properly created and linked in the graph
"""
import os
import pathlib
from collections import Counter
import cognee
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.search.types import SearchType
from cognee.shared.data_models import KnowledgeGraph
from cognee.shared.logging_utils import get_logger
from cognee.tasks.feedback.create_enrichments import create_enrichments
from cognee.tasks.feedback.extract_feedback_interactions import (
extract_feedback_interactions,
)
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
from cognee.tasks.feedback.link_enrichments_to_feedback import (
link_enrichments_to_feedback,
)
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
logger = get_logger()
async def main():
data_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_feedback_enrichment",
)
).resolve()
)
cognee_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_feedback_enrichment",
)
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "feedback_enrichment_test"
await cognee.add("Cognee turns documents into AI memory.", dataset_name)
await cognee.cognify([dataset_name])
question_text = "Say something."
result = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=question_text,
save_interaction=True,
)
assert len(result) > 0, "Search should return non-empty results"
feedback_text = "This answer was completely useless, my feedback is definitely negative."
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text=feedback_text,
last_k=1,
)
graph_engine = await get_graph_engine()
nodes_before, edges_before = await graph_engine.get_graph_data()
interaction_nodes_before = [
(node_id, props)
for node_id, props in nodes_before
if props.get("type") == "CogneeUserInteraction"
]
feedback_nodes_before = [
(node_id, props)
for node_id, props in nodes_before
if props.get("type") == "CogneeUserFeedback"
]
edge_types_before = Counter(edge[2] for edge in edges_before)
assert len(interaction_nodes_before) >= 1, (
f"Expected at least 1 CogneeUserInteraction node, found {len(interaction_nodes_before)}"
)
assert len(feedback_nodes_before) >= 1, (
f"Expected at least 1 CogneeUserFeedback node, found {len(feedback_nodes_before)}"
)
for node_id, props in feedback_nodes_before:
sentiment = props.get("sentiment", "")
score = props.get("score", 0)
feedback_text = props.get("feedback", "")
logger.info(
"Feedback node created",
feedback=feedback_text,
sentiment=sentiment,
score=score,
)
assert edge_types_before.get("gives_feedback_to", 0) >= 1, (
f"Expected at least 1 'gives_feedback_to' edge, found {edge_types_before.get('gives_feedback_to', 0)}"
)
extraction_tasks = [Task(extract_feedback_interactions, last_n=5)]
enrichment_tasks = [
Task(generate_improved_answers, top_k=20),
Task(create_enrichments),
Task(
extract_graph_from_data,
graph_model=KnowledgeGraph,
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
Task(link_enrichments_to_feedback),
]
await cognee.memify(
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
data=[{}],
dataset="feedback_enrichment_test_memify",
)
nodes_after, edges_after = await graph_engine.get_graph_data()
enrichment_nodes = [
(node_id, props)
for node_id, props in nodes_after
if props.get("type") == "FeedbackEnrichment"
]
assert len(enrichment_nodes) >= 1, (
f"Expected at least 1 FeedbackEnrichment node, found {len(enrichment_nodes)}"
)
for node_id, props in enrichment_nodes:
assert "text" in props, f"FeedbackEnrichment node {node_id} missing 'text' property"
enrichment_node_ids = {node_id for node_id, _ in enrichment_nodes}
edges_with_enrichments = [
edge
for edge in edges_after
if edge[0] in enrichment_node_ids or edge[1] in enrichment_node_ids
]
assert len(edges_with_enrichments) >= 1, (
f"Expected enrichment nodes to have at least 1 edge, found {len(edges_with_enrichments)}"
)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
logger.info("All feedback enrichment tests passed successfully")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -2,6 +2,7 @@ import os
import pytest
import pathlib
from typing import Optional, Union
from pydantic import BaseModel
import cognee
from cognee.low_level import setup, DataPoint
@ -10,6 +11,11 @@ from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
class TestAnswer(BaseModel):
answer: str
explanation: str
class TestGraphCompletionCoTRetriever:
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(self):
@ -168,3 +174,48 @@ class TestGraphCompletionCoTRetriever:
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_structured_completion(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1)
entities = [company1, person1]
await add_data_points(entities)
retriever = GraphCompletionCotRetriever()
# Test with string response model (default)
string_answer = await retriever.get_structured_completion("Who works at Figma?")
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
assert string_answer.strip(), "Answer should not be empty"
# Test with structured response model
structured_answer = await retriever.get_structured_completion(
"Who works at Figma?", response_model=TestAnswer
)
assert isinstance(structured_answer, TestAnswer), (
f"Expected TestAnswer, got {type(structured_answer).__name__}"
)
assert structured_answer.answer.strip(), "Answer field should not be empty"
assert structured_answer.explanation.strip(), "Explanation field should not be empty"

View file

@ -0,0 +1,82 @@
import asyncio
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.pipelines.tasks.task import Task
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.shared.data_models import KnowledgeGraph
from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
from cognee.tasks.feedback.create_enrichments import create_enrichments
from cognee.tasks.feedback.link_enrichments_to_feedback import link_enrichments_to_feedback
CONVERSATION = [
"Alice: Hey, Bob. Did you talk to Mallory?",
"Bob: Yeah, I just saw her before coming here.",
"Alice: Then she told you to bring my documents, right?",
"Bob: Uh… not exactly. She said you wanted me to bring you donuts. Which sounded kind of odd…",
"Alice: Ugh, shes so annoying. Thanks for the donuts anyway!",
]
async def initialize_conversation_and_graph(conversation):
"""Prune data/system, add conversation, cognify."""
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(conversation)
await cognee.cognify()
async def run_question_and_submit_feedback(question_text: str) -> bool:
"""Ask question, submit feedback based on correctness, and return correctness flag."""
result = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=question_text,
save_interaction=True,
)
answer_text = str(result).lower()
mentions_mallory = "mallory" in answer_text
feedback_text = (
"Great answers, very helpful!"
if mentions_mallory
else "The answer about Bob and donuts was wrong."
)
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text=feedback_text,
last_k=1,
)
return mentions_mallory
async def run_feedback_enrichment_memify(last_n: int = 5):
"""Execute memify with extraction, answer improvement, enrichment creation, and graph processing tasks."""
# Instantiate tasks with their own kwargs
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
enrichment_tasks = [
Task(generate_improved_answers, top_k=20),
Task(create_enrichments),
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}),
Task(add_data_points, task_config={"batch_size": 10}),
Task(link_enrichments_to_feedback),
]
await cognee.memify(
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
data=[{}], # A placeholder to prevent fetching the entire graph
dataset="feedback_enrichment_minimal",
)
async def main():
await initialize_conversation_and_graph(CONVERSATION)
is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
if not is_correct:
await run_feedback_enrichment_memify(last_n=5)
if __name__ == "__main__":
asyncio.run(main())