Compare commits

...
Sign in to create a new pull request.

2 commits

Author SHA1 Message Date
hajdul88
6825ea2789 adds context log 2025-05-07 15:26:29 +02:00
hajdul88
7630341c01 poc save DONT MERGE 2025-05-07 14:54:38 +02:00
3 changed files with 91 additions and 6 deletions

View file

@ -63,8 +63,7 @@ class OpenAIAdapter(LLMInterface):
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
"content": text_input,
},
{
"role": "system",
@ -91,8 +90,7 @@ class OpenAIAdapter(LLMInterface):
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
"content": text_input,
},
{
"role": "system",

View file

@ -1 +1,10 @@
Answer the question using the provided context. Be as brief as possible.
Please make sure that you answer the question with a correct type.
Rules:
• Minimize words Use as few as possible
• Yes/no questions → Answer with only "yes" or "no"
• What/Who/Where questions → Answer with a single word or phrase (no full sentences).
• When questions → Give only the relevant date, time, or period.
• How/Why questions → Answer with the shortest possible phrase.
• No punctuation, just the answer, preferably in dry, concise lowercase

View file

@ -1,13 +1,28 @@
from typing import Any, Optional
from typing import Any, Optional, List
from collections import Counter
import string
import logging
from pydantic import BaseModel, Field
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.modules.search.models.Query import Query
t_logger = logging.getLogger(__name__)
t_logger.setLevel(logging.INFO)
class AnswerValidation(BaseModel):
is_valid: bool = Field(
...,
description="Indicates whether the answer to the question is fully supported by the context",
)
reasoning: str = Field("", description="Detailed reasoning of what is missing from the context")
class GraphCompletionRetriever(BaseRetriever):
@ -84,6 +99,7 @@ class GraphCompletionRetriever(BaseRetriever):
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a completion using graph connections context."""
"""
if context is None:
context = await self.get_context(query)
@ -93,8 +109,70 @@ class GraphCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
"""
completion = await self.get_chain_of_thought(query=query)
return [completion]
async def get_chain_of_thought(self, query, max_iter=4):
llm_client = get_llm_client()
followup_question = ""
triplets = []
for round_idx in range(max_iter + 1):
if round_idx == 0:
triplets = await self.get_triplets(query)
context = await self.resolve_edges_to_text(triplets)
else:
triplets += await self.get_triplets(followup_question)
context = await self.resolve_edges_to_text(list(set(triplets)))
t_logger.info(f"Round {round_idx} - context: {context}")
# Generate answer
answer = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
t_logger.info(f"Round {round_idx} - answer: {answer}")
# Prepare validation prompt
valid_user_prompt = (
f"""\n\n--Question--\n{query}\n\n--ANSWER--\n{answer}\n\n--CONTEXT--\n{context}\n"""
)
valid_system_prompt = "You are a helpful agent who are allowed to use only the provided question answer and context. I want to you find reasoning what is missing from the context or why the answer is not answering the question or not correct strictly based on the context"
reasoning = await llm_client.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
# Ask follow-up question to fill gaps
followup_system = (
"You are a helpful assistant whose job is to ask exactly one clarifying follow-up question,"
" to collect the missing piece of information needed to fully answer the users original query."
" Respond with the question only (no extra text, no punctuation beyond whats needed)."
)
followup_prompt = (
"Based on the following, ask exactly one question that would directly resolve the gap identified in the validation reasoning and allow a valid answer."
"Think in a way that with the followup question you are exploring a knowledge graph which contains entities, entity types and document chunks\n\n"
f"Query: {query}\n"
f"Answer: {answer}\n"
f"Reasoning:\n{reasoning}\n"
)
followup_question = await llm_client.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
t_logger.info(f"Round {round_idx} - follow-up question: {followup_question}")
# Fallback if no iteration passed validation
return answer
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
"""Concatenates the top N frequent words in text."""
if stop_words is None: