chore: use cot retriever only

This commit is contained in:
lxobr 2025-10-21 01:39:35 +02:00
parent cccf523eea
commit 70c0a98055
5 changed files with 18 additions and 69 deletions

View file

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional from typing import List
from uuid import NAMESPACE_OID, uuid5 from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.llm import LLMGateway from cognee.infrastructure.llm import LLMGateway

View file

@ -1,21 +1,28 @@
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from uuid import UUID from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.infrastructure.llm import LLMGateway from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from uuid import uuid5, NAMESPACE_OID
from .utils import filter_negative_feedback
from .models import FeedbackEnrichment from .models import FeedbackEnrichment
logger = get_logger("extract_feedback_interactions") logger = get_logger("extract_feedback_interactions")
def _filter_negative_feedback(feedback_nodes):
"""Filter for negative sentiment feedback using precise sentiment classification."""
return [
(node_id, props)
for node_id, props in feedback_nodes
if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
]
def _get_normalized_id(node_id, props) -> str: def _get_normalized_id(node_id, props) -> str:
"""Return Cognee node id preference: props.id → props.node_id → raw node_id.""" """Return Cognee node id preference: props.id → props.node_id → raw node_id."""
return str(props.get("id") or props.get("node_id") or node_id) return str(props.get("id") or props.get("node_id") or node_id)
@ -179,7 +186,7 @@ async def extract_feedback_interactions(
return [] return []
feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes) feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes)
negative_feedback_nodes = filter_negative_feedback(feedback_nodes) negative_feedback_nodes = _filter_negative_feedback(feedback_nodes)
if not negative_feedback_nodes: if not negative_feedback_nodes:
logger.info("No negative feedback found; returning empty list") logger.info("No negative feedback found; returning empty list")
return [] return []

View file

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional, Tuple from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.llm import LLMGateway from cognee.infrastructure.llm import LLMGateway
@ -8,7 +8,7 @@ from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_promp
from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from .utils import create_retriever from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from .models import FeedbackEnrichment from .models import FeedbackEnrichment
@ -91,11 +91,10 @@ async def _generate_improved_answer_for_single_interaction(
async def generate_improved_answers( async def generate_improved_answers(
enrichments: List[FeedbackEnrichment], enrichments: List[FeedbackEnrichment],
retriever_name: str = "graph_completion_cot",
top_k: int = 20, top_k: int = 20,
reaction_prompt_location: str = "feedback_reaction_prompt.txt", reaction_prompt_location: str = "feedback_reaction_prompt.txt",
) -> List[FeedbackEnrichment]: ) -> List[FeedbackEnrichment]:
"""Generate improved answers using configurable retriever and LLM.""" """Generate improved answers using CoT retriever and LLM."""
if not enrichments: if not enrichments:
logger.info("No enrichments provided; returning empty list") logger.info("No enrichments provided; returning empty list")
return [] return []
@ -104,9 +103,9 @@ async def generate_improved_answers(
logger.error("Input data validation failed; missing required fields") logger.error("Input data validation failed; missing required fields")
return [] return []
retriever = create_retriever( retriever = GraphCompletionCotRetriever(
retriever_name=retriever_name,
top_k=top_k, top_k=top_k,
save_interaction=False,
user_prompt_path="graph_context_for_question.txt", user_prompt_path="graph_context_for_question.txt",
system_prompt_path="answer_simple_question.txt", system_prompt_path="answer_simple_question.txt",
) )

View file

@ -1,57 +0,0 @@
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.shared.logging_utils import get_logger
logger = get_logger("feedback_utils")
def create_retriever(
retriever_name: str = "graph_completion_cot",
top_k: int = 20,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
):
"""Factory for retriever instances with configurable top_k and prompt paths."""
if retriever_name == "graph_completion":
return GraphCompletionRetriever(
top_k=top_k,
save_interaction=False,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
)
if retriever_name == "graph_completion_cot":
return GraphCompletionCotRetriever(
top_k=top_k,
save_interaction=False,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
)
if retriever_name == "graph_completion_context_extension":
return GraphCompletionContextExtensionRetriever(
top_k=top_k,
save_interaction=False,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
)
logger.warning(
"Unknown retriever, defaulting to graph_completion_cot", retriever=retriever_name
)
return GraphCompletionCotRetriever(
top_k=top_k,
save_interaction=False,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
)
def filter_negative_feedback(feedback_nodes):
"""Filter for negative sentiment feedback using precise sentiment classification."""
return [
(node_id, props)
for node_id, props in feedback_nodes
if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
]

View file

@ -57,7 +57,7 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
# Instantiate tasks with their own kwargs # Instantiate tasks with their own kwargs
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)] extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
enrichment_tasks = [ enrichment_tasks = [
Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20), Task(generate_improved_answers, top_k=20),
Task(create_enrichments), Task(create_enrichments),
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}), Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}),
Task(add_data_points, task_config={"batch_size": 10}), Task(add_data_points, task_config={"batch_size": 10}),