chore: use cot retriever only
This commit is contained in:
parent
cccf523eea
commit
70c0a98055
5 changed files with 18 additions and 69 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import List
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
from cognee.infrastructure.llm import LLMGateway
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,28 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
from uuid import UUID
|
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||||
|
|
||||||
from cognee.infrastructure.llm import LLMGateway
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from uuid import uuid5, NAMESPACE_OID
|
|
||||||
|
|
||||||
from .utils import filter_negative_feedback
|
|
||||||
from .models import FeedbackEnrichment
|
from .models import FeedbackEnrichment
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("extract_feedback_interactions")
|
logger = get_logger("extract_feedback_interactions")
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_negative_feedback(feedback_nodes):
|
||||||
|
"""Filter for negative sentiment feedback using precise sentiment classification."""
|
||||||
|
return [
|
||||||
|
(node_id, props)
|
||||||
|
for node_id, props in feedback_nodes
|
||||||
|
if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _get_normalized_id(node_id, props) -> str:
|
def _get_normalized_id(node_id, props) -> str:
|
||||||
"""Return Cognee node id preference: props.id → props.node_id → raw node_id."""
|
"""Return Cognee node id preference: props.id → props.node_id → raw node_id."""
|
||||||
return str(props.get("id") or props.get("node_id") or node_id)
|
return str(props.get("id") or props.get("node_id") or node_id)
|
||||||
|
|
@ -179,7 +186,7 @@ async def extract_feedback_interactions(
|
||||||
return []
|
return []
|
||||||
|
|
||||||
feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes)
|
feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes)
|
||||||
negative_feedback_nodes = filter_negative_feedback(feedback_nodes)
|
negative_feedback_nodes = _filter_negative_feedback(feedback_nodes)
|
||||||
if not negative_feedback_nodes:
|
if not negative_feedback_nodes:
|
||||||
logger.info("No negative feedback found; returning empty list")
|
logger.info("No negative feedback found; returning empty list")
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.infrastructure.llm import LLMGateway
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_promp
|
||||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from .utils import create_retriever
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
from .models import FeedbackEnrichment
|
from .models import FeedbackEnrichment
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -91,11 +91,10 @@ async def _generate_improved_answer_for_single_interaction(
|
||||||
|
|
||||||
async def generate_improved_answers(
|
async def generate_improved_answers(
|
||||||
enrichments: List[FeedbackEnrichment],
|
enrichments: List[FeedbackEnrichment],
|
||||||
retriever_name: str = "graph_completion_cot",
|
|
||||||
top_k: int = 20,
|
top_k: int = 20,
|
||||||
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
|
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
|
||||||
) -> List[FeedbackEnrichment]:
|
) -> List[FeedbackEnrichment]:
|
||||||
"""Generate improved answers using configurable retriever and LLM."""
|
"""Generate improved answers using CoT retriever and LLM."""
|
||||||
if not enrichments:
|
if not enrichments:
|
||||||
logger.info("No enrichments provided; returning empty list")
|
logger.info("No enrichments provided; returning empty list")
|
||||||
return []
|
return []
|
||||||
|
|
@ -104,9 +103,9 @@ async def generate_improved_answers(
|
||||||
logger.error("Input data validation failed; missing required fields")
|
logger.error("Input data validation failed; missing required fields")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
retriever = create_retriever(
|
retriever = GraphCompletionCotRetriever(
|
||||||
retriever_name=retriever_name,
|
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
save_interaction=False,
|
||||||
user_prompt_path="graph_context_for_question.txt",
|
user_prompt_path="graph_context_for_question.txt",
|
||||||
system_prompt_path="answer_simple_question.txt",
|
system_prompt_path="answer_simple_question.txt",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,57 +0,0 @@
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
||||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
|
||||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
|
||||||
GraphCompletionContextExtensionRetriever,
|
|
||||||
)
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("feedback_utils")
|
|
||||||
|
|
||||||
|
|
||||||
def create_retriever(
|
|
||||||
retriever_name: str = "graph_completion_cot",
|
|
||||||
top_k: int = 20,
|
|
||||||
user_prompt_path: str = "graph_context_for_question.txt",
|
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
|
||||||
):
|
|
||||||
"""Factory for retriever instances with configurable top_k and prompt paths."""
|
|
||||||
if retriever_name == "graph_completion":
|
|
||||||
return GraphCompletionRetriever(
|
|
||||||
top_k=top_k,
|
|
||||||
save_interaction=False,
|
|
||||||
user_prompt_path=user_prompt_path,
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
)
|
|
||||||
if retriever_name == "graph_completion_cot":
|
|
||||||
return GraphCompletionCotRetriever(
|
|
||||||
top_k=top_k,
|
|
||||||
save_interaction=False,
|
|
||||||
user_prompt_path=user_prompt_path,
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
)
|
|
||||||
if retriever_name == "graph_completion_context_extension":
|
|
||||||
return GraphCompletionContextExtensionRetriever(
|
|
||||||
top_k=top_k,
|
|
||||||
save_interaction=False,
|
|
||||||
user_prompt_path=user_prompt_path,
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
"Unknown retriever, defaulting to graph_completion_cot", retriever=retriever_name
|
|
||||||
)
|
|
||||||
return GraphCompletionCotRetriever(
|
|
||||||
top_k=top_k,
|
|
||||||
save_interaction=False,
|
|
||||||
user_prompt_path=user_prompt_path,
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_negative_feedback(feedback_nodes):
|
|
||||||
"""Filter for negative sentiment feedback using precise sentiment classification."""
|
|
||||||
return [
|
|
||||||
(node_id, props)
|
|
||||||
for node_id, props in feedback_nodes
|
|
||||||
if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
|
|
||||||
]
|
|
||||||
|
|
@ -57,7 +57,7 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
|
||||||
# Instantiate tasks with their own kwargs
|
# Instantiate tasks with their own kwargs
|
||||||
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
|
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
|
||||||
enrichment_tasks = [
|
enrichment_tasks = [
|
||||||
Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20),
|
Task(generate_improved_answers, top_k=20),
|
||||||
Task(create_enrichments),
|
Task(create_enrichments),
|
||||||
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}),
|
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}),
|
||||||
Task(add_data_points, task_config={"batch_size": 10}),
|
Task(add_data_points, task_config={"batch_size": 10}),
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue