feat: Save search flag progress

This commit is contained in:
Igor Ilic 2025-08-28 10:52:08 +02:00
parent 66673af56d
commit ac87e62adb
6 changed files with 72 additions and 14 deletions

View file

@ -65,7 +65,14 @@ class CompletionRetriever(BaseRetriever):
logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
) -> Any:
"""
Generates an LLM completion using the context.
@ -88,6 +95,12 @@ class CompletionRetriever(BaseRetriever):
context = await self.get_context(query)
completion = await generate_completion(
query, context, self.user_prompt_path, self.system_prompt_path
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
only_context=only_context,
)
return [completion]

View file

@ -41,7 +41,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
)
async def get_completion(
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
context_extension_rounds=4,
) -> List[str]:
"""
Extends the context for a given query by retrieving related triplets and generating new
@ -86,6 +92,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
)
triplets += await self.get_triplets(completion)
@ -112,6 +120,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
only_context=only_context,
)
if self.save_interaction and context and triplets and completion:

View file

@ -51,7 +51,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self.followup_user_prompt_path = followup_user_prompt_path
async def get_completion(
self, query: str, context: Optional[Any] = None, max_iter=4
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
max_iter=4,
) -> List[str]:
"""
Generate completion responses based on a user query and contextual information.
@ -92,6 +98,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]
if only_context:
return [context]
else:
return [completion]

View file

@ -151,7 +151,14 @@ class GraphCompletionRetriever(BaseRetriever):
return context, triplets
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
) -> Any:
"""
Generates a completion using graph connections context based on a query.
@ -177,6 +184,9 @@ class GraphCompletionRetriever(BaseRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
only_context=only_context,
)
if self.save_interaction and context and triplets and completion:

View file

@ -6,18 +6,26 @@ async def generate_completion(
context: str,
user_prompt_path: str,
system_prompt_path: str,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
system_prompt = LLMGateway.read_query_prompt(system_prompt_path)
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
user_prompt = LLMGateway.render_prompt(user_prompt if user_prompt else user_prompt_path, args)
system_prompt = LLMGateway.read_query_prompt(
system_prompt if system_prompt else system_prompt_path
)
if only_context:
return context
else:
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
async def summarize_text(
text: str,

View file

@ -101,11 +101,14 @@ async def specific_search(
query: str,
user: User,
system_prompt_path="answer_simple_question.txt",
user_prompt: str = None,
system_prompt: str = None,
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
@ -159,7 +162,9 @@ async def specific_search(
send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(query)
results = await search_task(
query=query, system_prompt=system_prompt, user_prompt=user_prompt, only_context=only_context
)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)