feat: Add only_context and system prompt flags for search

This commit is contained in:
Igor Ilic 2025-08-28 13:43:37 +02:00
parent ac87e62adb
commit 2915698d60
10 changed files with 140 additions and 73 deletions

View file

@ -20,7 +20,9 @@ class SearchPayloadDTO(InDTO):
datasets: Optional[list[str]] = Field(default=None) datasets: Optional[list[str]] = Field(default=None)
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]]) dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
query: str = Field(default="What is in the document?") query: str = Field(default="What is in the document?")
system_prompt: Optional[str] = Field(default=None)
top_k: Optional[int] = Field(default=10) top_k: Optional[int] = Field(default=10)
only_context: bool = Field(default=False)
def get_search_router() -> APIRouter: def get_search_router() -> APIRouter:
@ -102,7 +104,9 @@ def get_search_router() -> APIRouter:
"datasets": payload.datasets, "datasets": payload.datasets,
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []], "dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
"query": payload.query, "query": payload.query,
"system_prompt": payload.system_prompt,
"top_k": payload.top_k, "top_k": payload.top_k,
"only_context": payload.only_context,
}, },
) )
@ -115,7 +119,9 @@ def get_search_router() -> APIRouter:
user=user, user=user,
datasets=payload.datasets, datasets=payload.datasets,
dataset_ids=payload.dataset_ids, dataset_ids=payload.dataset_ids,
system_prompt=payload.system_prompt,
top_k=payload.top_k, top_k=payload.top_k,
only_context=payload.only_context,
) )
return results return results

View file

@ -16,11 +16,13 @@ async def search(
datasets: Optional[Union[list[str], str]] = None, datasets: Optional[Union[list[str], str]] = None,
dataset_ids: Optional[Union[list[UUID], UUID]] = None, dataset_ids: Optional[Union[list[UUID], UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = False,
) -> list: ) -> list:
""" """
Search and query the knowledge graph for insights, information, and connections. Search and query the knowledge graph for insights, information, and connections.
@ -183,11 +185,13 @@ async def search(
dataset_ids=dataset_ids if dataset_ids else datasets, dataset_ids=dataset_ids if dataset_ids else datasets,
user=user, user=user,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k, top_k=top_k,
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context,
) )
return filtered_search_results return filtered_search_results

View file

@ -23,12 +23,16 @@ class CompletionRetriever(BaseRetriever):
self, self,
user_prompt_path: str = "context_for_question.txt", user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: str = None,
top_k: Optional[int] = 1, top_k: Optional[int] = 1,
only_context: bool = False,
): ):
"""Initialize retriever with optional custom prompt paths.""" """Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1 self.top_k = top_k if top_k is not None else 1
self.system_prompt = system_prompt
self.only_context = only_context
async def get_context(self, query: str) -> str: async def get_context(self, query: str) -> str:
""" """
@ -65,14 +69,7 @@ class CompletionRetriever(BaseRetriever):
logger.error("DocumentChunk_text collection not found") logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion( async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
) -> Any:
""" """
Generates an LLM completion using the context. Generates an LLM completion using the context.
@ -99,8 +96,7 @@ class CompletionRetriever(BaseRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt, system_prompt=self.system_prompt,
system_prompt=system_prompt, only_context=self.only_context,
only_context=only_context,
) )
return [completion] return [completion]

View file

@ -26,10 +26,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
self, self,
user_prompt_path: str = "graph_context_for_question.txt", user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: bool = False,
only_context: bool = False,
): ):
super().__init__( super().__init__(
user_prompt_path=user_prompt_path, user_prompt_path=user_prompt_path,
@ -38,15 +40,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
) )
async def get_completion( async def get_completion(
self, self,
query: str, query: str,
context: Optional[Any] = None, context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
context_extension_rounds=4, context_extension_rounds=4,
) -> List[str]: ) -> List[str]:
""" """
@ -92,8 +93,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt, system_prompt=self.system_prompt,
system_prompt=system_prompt,
) )
triplets += await self.get_triplets(completion) triplets += await self.get_triplets(completion)
@ -120,9 +120,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt, system_prompt=self.system_prompt,
system_prompt=system_prompt, only_context=self.only_context,
only_context=only_context,
) )
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:
@ -130,4 +129,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context, triplets=triplets question=query, answer=completion, context=context, triplets=triplets
) )
return [completion] if self.only_context:
return [context]
else:
return [completion]

View file

@ -32,14 +32,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
validation_system_prompt_path: str = "cot_validation_system_prompt.txt", validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
followup_system_prompt_path: str = "cot_followup_system_prompt.txt", followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
followup_user_prompt_path: str = "cot_followup_user_prompt.txt", followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
system_prompt: str = None,
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: bool = False,
only_context: bool = False,
): ):
super().__init__( super().__init__(
user_prompt_path=user_prompt_path, user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
only_context=only_context,
top_k=top_k, top_k=top_k,
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
@ -54,9 +58,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self, self,
query: str, query: str,
context: Optional[Any] = None, context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
max_iter=4, max_iter=4,
) -> List[str]: ) -> List[str]:
""" """
@ -98,8 +99,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt, system_prompt=self.system_prompt,
system_prompt=system_prompt,
) )
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter: if round_idx < max_iter:
@ -136,7 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context, triplets=triplets question=query, answer=completion, context=context, triplets=triplets
) )
if only_context: if self.only_context:
return [context] return [context]
else: else:
return [completion] return [completion]

View file

@ -36,15 +36,19 @@ class GraphCompletionRetriever(BaseRetriever):
self, self,
user_prompt_path: str = "graph_context_for_question.txt", user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: str = None,
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: bool = False,
only_context: bool = False,
): ):
"""Initialize retriever with prompt paths and search parameters.""" """Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction self.save_interaction = save_interaction
self.user_prompt_path = user_prompt_path self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.only_context = only_context
self.top_k = top_k if top_k is not None else 5 self.top_k = top_k if top_k is not None else 5
self.node_type = node_type self.node_type = node_type
self.node_name = node_name self.node_name = node_name
@ -155,9 +159,6 @@ class GraphCompletionRetriever(BaseRetriever):
self, self,
query: str, query: str,
context: Optional[Any] = None, context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
) -> Any: ) -> Any:
""" """
Generates a completion using graph connections context based on a query. Generates a completion using graph connections context based on a query.
@ -184,9 +185,8 @@ class GraphCompletionRetriever(BaseRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt, system_prompt=self.system_prompt,
system_prompt=system_prompt, only_context=self.only_context,
only_context=only_context,
) )
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:

View file

@ -21,6 +21,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
user_prompt_path: str = "graph_context_for_question.txt", user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
summarize_prompt_path: str = "summarize_search_results.txt", summarize_prompt_path: str = "summarize_search_results.txt",
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
@ -34,6 +35,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
) )
self.summarize_prompt_path = summarize_prompt_path self.summarize_prompt_path = summarize_prompt_path
@ -57,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
- str: A summary string representing the content of the retrieved edges. - str: A summary string representing the content of the retrieved edges.
""" """
direct_text = await super().resolve_edges_to_text(retrieved_edges) direct_text = await super().resolve_edges_to_text(retrieved_edges)
return await summarize_text(direct_text, self.summarize_prompt_path) return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt)

View file

@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever):
logger.info(f"Returning {len(summary_payloads)} summary payloads") logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
""" """
Generates a completion using summaries context. Generates a completion using summaries context.

View file

@ -1,3 +1,4 @@
from typing import Optional
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
@ -6,15 +7,15 @@ async def generate_completion(
context: str, context: str,
user_prompt_path: str, user_prompt_path: str,
system_prompt_path: str, system_prompt_path: str,
user_prompt: str = None, user_prompt: Optional[str] = None,
system_prompt: str = None, system_prompt: Optional[str] = None,
only_context: bool = False, only_context: bool = False,
) -> str: ) -> str:
"""Generates a completion using LLM with given context and prompts.""" """Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context} args = {"question": query, "context": context}
user_prompt = LLMGateway.render_prompt(user_prompt if user_prompt else user_prompt_path, args) user_prompt = user_prompt if user_prompt else LLMGateway.render_prompt(user_prompt_path, args)
system_prompt = LLMGateway.read_query_prompt( system_prompt = (
system_prompt if system_prompt else system_prompt_path system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
) )
if only_context: if only_context:
@ -29,10 +30,13 @@ async def generate_completion(
async def summarize_text( async def summarize_text(
text: str, text: str,
prompt_path: str = "summarize_search_results.txt", system_prompt_path: str = "summarize_search_results.txt",
system_prompt: str = None,
) -> str: ) -> str:
"""Summarizes text using LLM with the specified prompt.""" """Summarizes text using LLM with the specified prompt."""
system_prompt = LLMGateway.read_query_prompt(prompt_path) system_prompt = (
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
return await LLMGateway.acreate_structured_output( return await LLMGateway.acreate_structured_output(
text_input=text, text_input=text,

View file

@ -37,11 +37,13 @@ async def search(
dataset_ids: Union[list[UUID], None], dataset_ids: Union[list[UUID], None],
user: User, user: User,
system_prompt_path="answer_simple_question.txt", system_prompt_path="answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False, save_interaction: Optional[bool] = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = False,
): ):
""" """
@ -61,28 +63,34 @@ async def search(
# Use search function filtered by permissions if access control is enabled # Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return await authorized_search( return await authorized_search(
query_text=query_text,
query_type=query_type, query_type=query_type,
query_text=query_text,
user=user, user=user,
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context,
) )
query = await log_query(query_text, query_type.value, user.id) query = await log_query(query_text, query_type.value, user.id)
search_results = await specific_search( search_results = await specific_search(
query_type, query_type=query_type,
query_text, query_text=query_text,
user, user=user,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k, top_k=top_k,
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context,
) )
await log_result( await log_result(
@ -98,11 +106,10 @@ async def search(
async def specific_search( async def specific_search(
query_type: SearchType, query_type: SearchType,
query: str, query_text: str,
user: User, user: User,
system_prompt_path="answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
user_prompt: str = None, system_prompt: Optional[str] = None,
system_prompt: str = None,
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
@ -115,7 +122,10 @@ async def specific_search(
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion, SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion, SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
SearchType.RAG_COMPLETION: CompletionRetriever( SearchType.RAG_COMPLETION: CompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k system_prompt_path=system_prompt_path,
top_k=top_k,
system_prompt=system_prompt,
only_context=only_context,
).get_completion, ).get_completion,
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever( SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
@ -123,6 +133,8 @@ async def specific_search(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion, ).get_completion,
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
@ -130,6 +142,8 @@ async def specific_search(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion, ).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
@ -137,6 +151,8 @@ async def specific_search(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion, ).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
@ -144,6 +160,7 @@ async def specific_search(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
).get_completion, ).get_completion,
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion,
@ -153,7 +170,7 @@ async def specific_search(
# If the query type is FEELING_LUCKY, select the search type intelligently # If the query type is FEELING_LUCKY, select the search type intelligently
if query_type is SearchType.FEELING_LUCKY: if query_type is SearchType.FEELING_LUCKY:
query_type = await select_search_type(query) query_type = await select_search_type(query_text)
search_task = search_tasks.get(query_type) search_task = search_tasks.get(query_type)
@ -162,9 +179,7 @@ async def specific_search(
send_telemetry("cognee.search EXECUTION STARTED", user.id) send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task( results = await search_task(query=query_text)
query=query, system_prompt=system_prompt, user_prompt=user_prompt, only_context=only_context
)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id) send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
@ -172,14 +187,18 @@ async def specific_search(
async def authorized_search( async def authorized_search(
query_text: str,
query_type: SearchType, query_type: SearchType,
user: User = None, query_text: str,
user: User,
dataset_ids: Optional[list[UUID]] = None, dataset_ids: Optional[list[UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10, top_k: int = 10,
save_interaction: bool = False, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = None,
) -> list: ) -> list:
""" """
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
@ -193,14 +212,18 @@ async def authorized_search(
# Searches all provided datasets and handles setting up of appropriate database context based on permissions # Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await specific_search_by_context( search_results = await specific_search_by_context(
search_datasets, search_datasets=search_datasets,
query_text, query_type=query_type,
query_type, query_text=query_text,
user, user=user,
system_prompt_path, system_prompt_path=system_prompt_path,
top_k, system_prompt=system_prompt,
save_interaction, top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context,
) )
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
@ -210,13 +233,17 @@ async def authorized_search(
async def specific_search_by_context( async def specific_search_by_context(
search_datasets: list[Dataset], search_datasets: list[Dataset],
query_text: str,
query_type: SearchType, query_type: SearchType,
query_text: str,
user: User, user: User,
system_prompt_path: str, system_prompt_path: str = "answer_simple_question.txt",
top_k: int, system_prompt: Optional[str] = None,
save_interaction: bool = False, top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = None,
): ):
""" """
Searches all provided datasets and handles setting up of appropriate database context based on permissions. Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@ -224,18 +251,33 @@ async def specific_search_by_context(
""" """
async def _search_by_context( async def _search_by_context(
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k dataset: Dataset,
query_type: SearchType,
query_text: str,
user: User,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
): ):
# Set database configuration in async context for each dataset user has access for # Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id) await set_database_global_context_variables(dataset.id, dataset.owner_id)
search_results = await specific_search( search_results = await specific_search(
query_type, query_type=query_type,
query_text, query_text=query_text,
user, user=user,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context,
) )
return { return {
"search_result": search_results, "search_result": search_results,
@ -248,7 +290,18 @@ async def specific_search_by_context(
for dataset in search_datasets: for dataset in search_datasets:
tasks.append( tasks.append(
_search_by_context( _search_by_context(
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k dataset=dataset,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
) )
) )