Compare commits
2 commits
main
...
cot_poc_sa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6825ea2789 | ||
|
|
7630341c01 |
3 changed files with 91 additions and 6 deletions
|
|
@ -63,8 +63,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
"content": text_input,
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -91,8 +90,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
"content": text_input,
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
|
|
|
|||
|
|
@ -1 +1,10 @@
|
|||
Answer the question using the provided context. Be as brief as possible.
|
||||
Answer the question using the provided context. Be as brief as possible.
|
||||
Please make sure that you answer the question with a correct type.
|
||||
|
||||
Rules:
|
||||
• Minimize words – Use as few as possible
|
||||
• Yes/no questions → Answer with only "yes" or "no"
|
||||
• What/Who/Where questions → Answer with a single word or phrase (no full sentences).
|
||||
• When questions → Give only the relevant date, time, or period.
|
||||
• How/Why questions → Answer with the shortest possible phrase.
|
||||
• No punctuation, just the answer, preferably in dry, concise lowercase
|
||||
|
|
|
|||
|
|
@ -1,13 +1,28 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Any, Optional, List
|
||||
from collections import Counter
|
||||
import string
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
from cognee.modules.search.models.Query import Query
|
||||
|
||||
t_logger = logging.getLogger(__name__)
|
||||
t_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class AnswerValidation(BaseModel):
|
||||
is_valid: bool = Field(
|
||||
...,
|
||||
description="Indicates whether the answer to the question is fully supported by the context",
|
||||
)
|
||||
reasoning: str = Field("", description="Detailed reasoning of what is missing from the context")
|
||||
|
||||
|
||||
class GraphCompletionRetriever(BaseRetriever):
|
||||
|
|
@ -84,6 +99,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates a completion using graph connections context."""
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
|
|
@ -93,8 +109,70 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
completion = await self.get_chain_of_thought(query=query)
|
||||
return [completion]
|
||||
|
||||
async def get_chain_of_thought(self, query, max_iter=4):
|
||||
llm_client = get_llm_client()
|
||||
followup_question = ""
|
||||
triplets = []
|
||||
|
||||
for round_idx in range(max_iter + 1):
|
||||
if round_idx == 0:
|
||||
triplets = await self.get_triplets(query)
|
||||
context = await self.resolve_edges_to_text(triplets)
|
||||
else:
|
||||
triplets += await self.get_triplets(followup_question)
|
||||
context = await self.resolve_edges_to_text(list(set(triplets)))
|
||||
t_logger.info(f"Round {round_idx} - context: {context}")
|
||||
|
||||
# Generate answer
|
||||
answer = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
t_logger.info(f"Round {round_idx} - answer: {answer}")
|
||||
|
||||
# Prepare validation prompt
|
||||
valid_user_prompt = (
|
||||
f"""\n\n--Question--\n{query}\n\n--ANSWER--\n{answer}\n\n--CONTEXT--\n{context}\n"""
|
||||
)
|
||||
|
||||
valid_system_prompt = "You are a helpful agent who are allowed to use only the provided question answer and context. I want to you find reasoning what is missing from the context or why the answer is not answering the question or not correct strictly based on the context"
|
||||
|
||||
reasoning = await llm_client.acreate_structured_output(
|
||||
text_input=valid_user_prompt,
|
||||
system_prompt=valid_system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
# Ask follow-up question to fill gaps
|
||||
followup_system = (
|
||||
"You are a helpful assistant whose job is to ask exactly one clarifying follow-up question,"
|
||||
" to collect the missing piece of information needed to fully answer the user’s original query."
|
||||
" Respond with the question only (no extra text, no punctuation beyond what’s needed)."
|
||||
)
|
||||
|
||||
followup_prompt = (
|
||||
"Based on the following, ask exactly one question that would directly resolve the gap identified in the validation reasoning and allow a valid answer."
|
||||
"Think in a way that with the followup question you are exploring a knowledge graph which contains entities, entity types and document chunks\n\n"
|
||||
f"Query: {query}\n"
|
||||
f"Answer: {answer}\n"
|
||||
f"Reasoning:\n{reasoning}\n"
|
||||
)
|
||||
followup_question = await llm_client.acreate_structured_output(
|
||||
text_input=followup_prompt, system_prompt=followup_system, response_model=str
|
||||
)
|
||||
t_logger.info(f"Round {round_idx} - follow-up question: {followup_question}")
|
||||
|
||||
# Fallback if no iteration passed validation
|
||||
return answer
|
||||
|
||||
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
if stop_words is None:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue