feat: Save search flag progress

This commit is contained in:
Igor Ilic 2025-08-28 10:52:08 +02:00
parent 66673af56d
commit ac87e62adb
6 changed files with 72 additions and 14 deletions

View file

@ -65,7 +65,14 @@ class CompletionRetriever(BaseRetriever):
logger.error("DocumentChunk_text collection not found") logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
) -> Any:
""" """
Generates an LLM completion using the context. Generates an LLM completion using the context.
@ -88,6 +95,12 @@ class CompletionRetriever(BaseRetriever):
context = await self.get_context(query) context = await self.get_context(query)
completion = await generate_completion( completion = await generate_completion(
query, context, self.user_prompt_path, self.system_prompt_path query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
only_context=only_context,
) )
return [completion] return [completion]

View file

@ -41,7 +41,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
) )
async def get_completion( async def get_completion(
self, query: str, context: Optional[Any] = None, context_extension_rounds=4 self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
context_extension_rounds=4,
) -> List[str]: ) -> List[str]:
""" """
Extends the context for a given query by retrieving related triplets and generating new Extends the context for a given query by retrieving related triplets and generating new
@ -86,6 +92,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
) )
triplets += await self.get_triplets(completion) triplets += await self.get_triplets(completion)
@ -112,6 +120,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
only_context=only_context,
) )
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:

View file

@ -51,7 +51,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self.followup_user_prompt_path = followup_user_prompt_path self.followup_user_prompt_path = followup_user_prompt_path
async def get_completion( async def get_completion(
self, query: str, context: Optional[Any] = None, max_iter=4 self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
max_iter=4,
) -> List[str]: ) -> List[str]:
""" """
Generate completion responses based on a user query and contextual information. Generate completion responses based on a user query and contextual information.
@ -92,6 +98,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
) )
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter: if round_idx < max_iter:
@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context, triplets=triplets question=query, answer=completion, context=context, triplets=triplets
) )
return [completion] if only_context:
return [context]
else:
return [completion]

View file

@ -151,7 +151,14 @@ class GraphCompletionRetriever(BaseRetriever):
return context, triplets return context, triplets
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
) -> Any:
""" """
Generates a completion using graph connections context based on a query. Generates a completion using graph connections context based on a query.
@ -177,6 +184,9 @@ class GraphCompletionRetriever(BaseRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
only_context=only_context,
) )
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:

View file

@ -6,18 +6,26 @@ async def generate_completion(
context: str, context: str,
user_prompt_path: str, user_prompt_path: str,
system_prompt_path: str, system_prompt_path: str,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
) -> str: ) -> str:
"""Generates a completion using LLM with given context and prompts.""" """Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context} args = {"question": query, "context": context}
user_prompt = LLMGateway.render_prompt(user_prompt_path, args) user_prompt = LLMGateway.render_prompt(user_prompt if user_prompt else user_prompt_path, args)
system_prompt = LLMGateway.read_query_prompt(system_prompt_path) system_prompt = LLMGateway.read_query_prompt(
system_prompt if system_prompt else system_prompt_path
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
) )
if only_context:
return context
else:
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
async def summarize_text( async def summarize_text(
text: str, text: str,

View file

@ -101,11 +101,14 @@ async def specific_search(
query: str, query: str,
user: User, user: User,
system_prompt_path="answer_simple_question.txt", system_prompt_path="answer_simple_question.txt",
user_prompt: str = None,
system_prompt: str = None,
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False, save_interaction: Optional[bool] = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = None,
) -> list: ) -> list:
search_tasks: dict[SearchType, Callable] = { search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
@ -159,7 +162,9 @@ async def specific_search(
send_telemetry("cognee.search EXECUTION STARTED", user.id) send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(query) results = await search_task(
query=query, system_prompt=system_prompt, user_prompt=user_prompt, only_context=only_context
)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id) send_telemetry("cognee.search EXECUTION COMPLETED", user.id)