feat: Save search flag progress
This commit is contained in:
parent
66673af56d
commit
ac87e62adb
6 changed files with 72 additions and 14 deletions
|
|
@ -65,7 +65,14 @@ class CompletionRetriever(BaseRetriever):
|
|||
logger.error("DocumentChunk_text collection not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
user_prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
only_context: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
Generates an LLM completion using the context.
|
||||
|
||||
|
|
@ -88,6 +95,12 @@ class CompletionRetriever(BaseRetriever):
|
|||
context = await self.get_context(query)
|
||||
|
||||
completion = await generate_completion(
|
||||
query, context, self.user_prompt_path, self.system_prompt_path
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
user_prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
)
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -41,7 +41,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
)
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
user_prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
only_context: bool = False,
|
||||
context_extension_rounds=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Extends the context for a given query by retrieving related triplets and generating new
|
||||
|
|
@ -86,6 +92,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
user_prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
triplets += await self.get_triplets(completion)
|
||||
|
|
@ -112,6 +120,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
user_prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
|
|
|
|||
|
|
@ -51,7 +51,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
self.followup_user_prompt_path = followup_user_prompt_path
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, max_iter=4
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
user_prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
only_context: bool = False,
|
||||
max_iter=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
|
@ -92,6 +98,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
user_prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||
if round_idx < max_iter:
|
||||
|
|
@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
question=query, answer=completion, context=context, triplets=triplets
|
||||
)
|
||||
|
||||
return [completion]
|
||||
if only_context:
|
||||
return [context]
|
||||
else:
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -151,7 +151,14 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
|
||||
return context, triplets
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
user_prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
only_context: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
Generates a completion using graph connections context based on a query.
|
||||
|
||||
|
|
@ -177,6 +184,9 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
user_prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
|
|
|
|||
|
|
@ -6,18 +6,26 @@ async def generate_completion(
|
|||
context: str,
|
||||
user_prompt_path: str,
|
||||
system_prompt_path: str,
|
||||
user_prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
only_context: bool = False,
|
||||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
|
||||
system_prompt = LLMGateway.read_query_prompt(system_prompt_path)
|
||||
|
||||
return await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
user_prompt = LLMGateway.render_prompt(user_prompt if user_prompt else user_prompt_path, args)
|
||||
system_prompt = LLMGateway.read_query_prompt(
|
||||
system_prompt if system_prompt else system_prompt_path
|
||||
)
|
||||
|
||||
if only_context:
|
||||
return context
|
||||
else:
|
||||
return await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
|
||||
async def summarize_text(
|
||||
text: str,
|
||||
|
|
|
|||
|
|
@ -101,11 +101,14 @@ async def specific_search(
|
|||
query: str,
|
||||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
user_prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = None,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
||||
|
|
@ -159,7 +162,9 @@ async def specific_search(
|
|||
|
||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||
|
||||
results = await search_task(query)
|
||||
results = await search_task(
|
||||
query=query, system_prompt=system_prompt, user_prompt=user_prompt, only_context=only_context
|
||||
)
|
||||
|
||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue