feat: Save search flag progress
This commit is contained in:
parent
66673af56d
commit
ac87e62adb
6 changed files with 72 additions and 14 deletions
|
|
@ -65,7 +65,14 @@ class CompletionRetriever(BaseRetriever):
|
||||||
logger.error("DocumentChunk_text collection not found")
|
logger.error("DocumentChunk_text collection not found")
|
||||||
raise NoDataError("No data found in the system, please add data first.") from error
|
raise NoDataError("No data found in the system, please add data first.") from error
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
user_prompt: str = None,
|
||||||
|
system_prompt: str = None,
|
||||||
|
only_context: bool = False,
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Generates an LLM completion using the context.
|
Generates an LLM completion using the context.
|
||||||
|
|
||||||
|
|
@ -88,6 +95,12 @@ class CompletionRetriever(BaseRetriever):
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
|
|
||||||
completion = await generate_completion(
|
completion = await generate_completion(
|
||||||
query, context, self.user_prompt_path, self.system_prompt_path
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
user_prompt: str = None,
|
||||||
|
system_prompt: str = None,
|
||||||
|
only_context: bool = False,
|
||||||
|
context_extension_rounds=4,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Extends the context for a given query by retrieving related triplets and generating new
|
Extends the context for a given query by retrieving related triplets and generating new
|
||||||
|
|
@ -86,6 +92,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
triplets += await self.get_triplets(completion)
|
triplets += await self.get_triplets(completion)
|
||||||
|
|
@ -112,6 +120,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
self.followup_user_prompt_path = followup_user_prompt_path
|
self.followup_user_prompt_path = followup_user_prompt_path
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self, query: str, context: Optional[Any] = None, max_iter=4
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
user_prompt: str = None,
|
||||||
|
system_prompt: str = None,
|
||||||
|
only_context: bool = False,
|
||||||
|
max_iter=4,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Generate completion responses based on a user query and contextual information.
|
Generate completion responses based on a user query and contextual information.
|
||||||
|
|
@ -92,6 +98,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||||
if round_idx < max_iter:
|
if round_idx < max_iter:
|
||||||
|
|
@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
question=query, answer=completion, context=context, triplets=triplets
|
question=query, answer=completion, context=context, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
return [completion]
|
if only_context:
|
||||||
|
return [context]
|
||||||
|
else:
|
||||||
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,14 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
|
|
||||||
return context, triplets
|
return context, triplets
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
user_prompt: str = None,
|
||||||
|
system_prompt: str = None,
|
||||||
|
only_context: bool = False,
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Generates a completion using graph connections context based on a query.
|
Generates a completion using graph connections context based on a query.
|
||||||
|
|
||||||
|
|
@ -177,6 +184,9 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
|
|
|
||||||
|
|
@ -6,18 +6,26 @@ async def generate_completion(
|
||||||
context: str,
|
context: str,
|
||||||
user_prompt_path: str,
|
user_prompt_path: str,
|
||||||
system_prompt_path: str,
|
system_prompt_path: str,
|
||||||
|
user_prompt: str = None,
|
||||||
|
system_prompt: str = None,
|
||||||
|
only_context: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a completion using LLM with given context and prompts."""
|
"""Generates a completion using LLM with given context and prompts."""
|
||||||
args = {"question": query, "context": context}
|
args = {"question": query, "context": context}
|
||||||
user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
|
user_prompt = LLMGateway.render_prompt(user_prompt if user_prompt else user_prompt_path, args)
|
||||||
system_prompt = LLMGateway.read_query_prompt(system_prompt_path)
|
system_prompt = LLMGateway.read_query_prompt(
|
||||||
|
system_prompt if system_prompt else system_prompt_path
|
||||||
return await LLMGateway.acreate_structured_output(
|
|
||||||
text_input=user_prompt,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
response_model=str,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if only_context:
|
||||||
|
return context
|
||||||
|
else:
|
||||||
|
return await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=user_prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def summarize_text(
|
async def summarize_text(
|
||||||
text: str,
|
text: str,
|
||||||
|
|
|
||||||
|
|
@ -101,11 +101,14 @@ async def specific_search(
|
||||||
query: str,
|
query: str,
|
||||||
user: User,
|
user: User,
|
||||||
system_prompt_path="answer_simple_question.txt",
|
system_prompt_path="answer_simple_question.txt",
|
||||||
|
user_prompt: str = None,
|
||||||
|
system_prompt: str = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: Optional[bool] = False,
|
save_interaction: Optional[bool] = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
|
only_context: bool = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
search_tasks: dict[SearchType, Callable] = {
|
search_tasks: dict[SearchType, Callable] = {
|
||||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
||||||
|
|
@ -159,7 +162,9 @@ async def specific_search(
|
||||||
|
|
||||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||||
|
|
||||||
results = await search_task(query)
|
results = await search_task(
|
||||||
|
query=query, system_prompt=system_prompt, user_prompt=user_prompt, only_context=only_context
|
||||||
|
)
|
||||||
|
|
||||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue