chore: use cot retriever only
This commit is contained in:
parent
cccf523eea
commit
70c0a98055
5 changed files with 18 additions and 69 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
|
|
|
|||
|
|
@ -1,21 +1,28 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
|
||||
from .utils import filter_negative_feedback
|
||||
from .models import FeedbackEnrichment
|
||||
|
||||
|
||||
logger = get_logger("extract_feedback_interactions")
|
||||
|
||||
|
||||
def _filter_negative_feedback(feedback_nodes):
|
||||
"""Filter for negative sentiment feedback using precise sentiment classification."""
|
||||
return [
|
||||
(node_id, props)
|
||||
for node_id, props in feedback_nodes
|
||||
if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
|
||||
]
|
||||
|
||||
|
||||
def _get_normalized_id(node_id, props) -> str:
|
||||
"""Return Cognee node id preference: props.id → props.node_id → raw node_id."""
|
||||
return str(props.get("id") or props.get("node_id") or node_id)
|
||||
|
|
@ -179,7 +186,7 @@ async def extract_feedback_interactions(
|
|||
return []
|
||||
|
||||
feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes)
|
||||
negative_feedback_nodes = filter_negative_feedback(feedback_nodes)
|
||||
negative_feedback_nodes = _filter_negative_feedback(feedback_nodes)
|
||||
if not negative_feedback_nodes:
|
||||
logger.info("No negative feedback found; returning empty list")
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_promp
|
|||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .utils import create_retriever
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
from .models import FeedbackEnrichment
|
||||
|
||||
|
||||
|
|
@ -91,11 +91,10 @@ async def _generate_improved_answer_for_single_interaction(
|
|||
|
||||
async def generate_improved_answers(
|
||||
enrichments: List[FeedbackEnrichment],
|
||||
retriever_name: str = "graph_completion_cot",
|
||||
top_k: int = 20,
|
||||
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
|
||||
) -> List[FeedbackEnrichment]:
|
||||
"""Generate improved answers using configurable retriever and LLM."""
|
||||
"""Generate improved answers using CoT retriever and LLM."""
|
||||
if not enrichments:
|
||||
logger.info("No enrichments provided; returning empty list")
|
||||
return []
|
||||
|
|
@ -104,9 +103,9 @@ async def generate_improved_answers(
|
|||
logger.error("Input data validation failed; missing required fields")
|
||||
return []
|
||||
|
||||
retriever = create_retriever(
|
||||
retriever_name=retriever_name,
|
||||
retriever = GraphCompletionCotRetriever(
|
||||
top_k=top_k,
|
||||
save_interaction=False,
|
||||
user_prompt_path="graph_context_for_question.txt",
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,57 +0,0 @@
|
|||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
logger = get_logger("feedback_utils")
|
||||
|
||||
|
||||
def create_retriever(
|
||||
retriever_name: str = "graph_completion_cot",
|
||||
top_k: int = 20,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
):
|
||||
"""Factory for retriever instances with configurable top_k and prompt paths."""
|
||||
if retriever_name == "graph_completion":
|
||||
return GraphCompletionRetriever(
|
||||
top_k=top_k,
|
||||
save_interaction=False,
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
)
|
||||
if retriever_name == "graph_completion_cot":
|
||||
return GraphCompletionCotRetriever(
|
||||
top_k=top_k,
|
||||
save_interaction=False,
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
)
|
||||
if retriever_name == "graph_completion_context_extension":
|
||||
return GraphCompletionContextExtensionRetriever(
|
||||
top_k=top_k,
|
||||
save_interaction=False,
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
)
|
||||
logger.warning(
|
||||
"Unknown retriever, defaulting to graph_completion_cot", retriever=retriever_name
|
||||
)
|
||||
return GraphCompletionCotRetriever(
|
||||
top_k=top_k,
|
||||
save_interaction=False,
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
)
|
||||
|
||||
|
||||
def filter_negative_feedback(feedback_nodes):
|
||||
"""Filter for negative sentiment feedback using precise sentiment classification."""
|
||||
return [
|
||||
(node_id, props)
|
||||
for node_id, props in feedback_nodes
|
||||
if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
|
||||
]
|
||||
|
|
@ -57,7 +57,7 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
|
|||
# Instantiate tasks with their own kwargs
|
||||
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
|
||||
enrichment_tasks = [
|
||||
Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20),
|
||||
Task(generate_improved_answers, top_k=20),
|
||||
Task(create_enrichments),
|
||||
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue