feat: Add only_context and system prompt flags for search

This commit is contained in:
Igor Ilic 2025-08-28 13:43:37 +02:00
parent ac87e62adb
commit 2915698d60
10 changed files with 140 additions and 73 deletions

View file

@ -20,7 +20,9 @@ class SearchPayloadDTO(InDTO):
datasets: Optional[list[str]] = Field(default=None)
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
query: str = Field(default="What is in the document?")
system_prompt: Optional[str] = Field(default=None)
top_k: Optional[int] = Field(default=10)
only_context: bool = Field(default=False)
def get_search_router() -> APIRouter:
@ -102,7 +104,9 @@ def get_search_router() -> APIRouter:
"datasets": payload.datasets,
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
"query": payload.query,
"system_prompt": payload.system_prompt,
"top_k": payload.top_k,
"only_context": payload.only_context,
},
)
@ -115,7 +119,9 @@ def get_search_router() -> APIRouter:
user=user,
datasets=payload.datasets,
dataset_ids=payload.dataset_ids,
system_prompt=payload.system_prompt,
top_k=payload.top_k,
only_context=payload.only_context,
)
return results

View file

@ -16,11 +16,13 @@ async def search(
datasets: Optional[Union[list[str], str]] = None,
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
) -> list:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -183,11 +185,13 @@ async def search(
dataset_ids=dataset_ids if dataset_ids else datasets,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
return filtered_search_results

View file

@ -23,12 +23,16 @@ class CompletionRetriever(BaseRetriever):
self,
user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: str = None,
top_k: Optional[int] = 1,
only_context: bool = False,
):
"""Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1
self.system_prompt = system_prompt
self.only_context = only_context
async def get_context(self, query: str) -> str:
"""
@ -65,14 +69,7 @@ class CompletionRetriever(BaseRetriever):
logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
) -> Any:
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""
Generates an LLM completion using the context.
@ -99,8 +96,7 @@ class CompletionRetriever(BaseRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
only_context=only_context,
system_prompt=self.system_prompt,
only_context=self.only_context,
)
return [completion]

View file

@ -26,10 +26,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
only_context: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -38,15 +40,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
)
async def get_completion(
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
context_extension_rounds=4,
) -> List[str]:
"""
@ -92,8 +93,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
system_prompt=self.system_prompt,
)
triplets += await self.get_triplets(completion)
@ -120,9 +120,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
only_context=only_context,
system_prompt=self.system_prompt,
only_context=self.only_context,
)
if self.save_interaction and context and triplets and completion:
@ -130,4 +129,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]
if self.only_context:
return [context]
else:
return [completion]

View file

@ -32,14 +32,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
system_prompt: str = None,
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
only_context: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
only_context=only_context,
top_k=top_k,
node_type=node_type,
node_name=node_name,
@ -54,9 +58,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
max_iter=4,
) -> List[str]:
"""
@ -98,8 +99,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
system_prompt=self.system_prompt,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
@ -136,7 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context, triplets=triplets
)
if only_context:
if self.only_context:
return [context]
else:
return [completion]

View file

@ -36,15 +36,19 @@ class GraphCompletionRetriever(BaseRetriever):
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: str = None,
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
only_context: bool = False,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.only_context = only_context
self.top_k = top_k if top_k is not None else 5
self.node_type = node_type
self.node_name = node_name
@ -155,9 +159,6 @@ class GraphCompletionRetriever(BaseRetriever):
self,
query: str,
context: Optional[Any] = None,
user_prompt: str = None,
system_prompt: str = None,
only_context: bool = False,
) -> Any:
"""
Generates a completion using graph connections context based on a query.
@ -184,9 +185,8 @@ class GraphCompletionRetriever(BaseRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
user_prompt=user_prompt,
system_prompt=system_prompt,
only_context=only_context,
system_prompt=self.system_prompt,
only_context=self.only_context,
)
if self.save_interaction and context and triplets and completion:

View file

@ -21,6 +21,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
summarize_prompt_path: str = "summarize_search_results.txt",
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
@ -34,6 +35,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
)
self.summarize_prompt_path = summarize_prompt_path
@ -57,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
- str: A summary string representing the content of the retrieved edges.
"""
direct_text = await super().resolve_edges_to_text(retrieved_edges)
return await summarize_text(direct_text, self.summarize_prompt_path)
return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt)

View file

@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever):
logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
"""
Generates a completion using summaries context.

View file

@ -1,3 +1,4 @@
from typing import Optional
from cognee.infrastructure.llm.LLMGateway import LLMGateway
@ -6,15 +7,15 @@ async def generate_completion(
context: str,
user_prompt_path: str,
system_prompt_path: str,
user_prompt: str = None,
system_prompt: str = None,
user_prompt: Optional[str] = None,
system_prompt: Optional[str] = None,
only_context: bool = False,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = LLMGateway.render_prompt(user_prompt if user_prompt else user_prompt_path, args)
system_prompt = LLMGateway.read_query_prompt(
system_prompt if system_prompt else system_prompt_path
user_prompt = user_prompt if user_prompt else LLMGateway.render_prompt(user_prompt_path, args)
system_prompt = (
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
if only_context:
@ -29,10 +30,13 @@ async def generate_completion(
async def summarize_text(
text: str,
prompt_path: str = "summarize_search_results.txt",
system_prompt_path: str = "summarize_search_results.txt",
system_prompt: str = None,
) -> str:
"""Summarizes text using LLM with the specified prompt."""
system_prompt = LLMGateway.read_query_prompt(prompt_path)
system_prompt = (
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
return await LLMGateway.acreate_structured_output(
text_input=text,

View file

@ -37,11 +37,13 @@ async def search(
dataset_ids: Union[list[UUID], None],
user: User,
system_prompt_path="answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = False,
):
"""
@ -61,28 +63,34 @@ async def search(
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return await authorized_search(
query_text=query_text,
query_type=query_type,
query_text=query_text,
user=user,
dataset_ids=dataset_ids,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
query = await log_query(query_text, query_type.value, user.id)
search_results = await specific_search(
query_type,
query_text,
user,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
await log_result(
@ -98,11 +106,10 @@ async def search(
async def specific_search(
query_type: SearchType,
query: str,
query_text: str,
user: User,
system_prompt_path="answer_simple_question.txt",
user_prompt: str = None,
system_prompt: str = None,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
@ -115,7 +122,10 @@ async def specific_search(
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
SearchType.RAG_COMPLETION: CompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k
system_prompt_path=system_prompt_path,
top_k=top_k,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
@ -123,6 +133,8 @@ async def specific_search(
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
@ -130,6 +142,8 @@ async def specific_search(
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
@ -137,6 +151,8 @@ async def specific_search(
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
@ -144,6 +160,7 @@ async def specific_search(
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_completion,
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion,
@ -153,7 +170,7 @@ async def specific_search(
# If the query type is FEELING_LUCKY, select the search type intelligently
if query_type is SearchType.FEELING_LUCKY:
query_type = await select_search_type(query)
query_type = await select_search_type(query_text)
search_task = search_tasks.get(query_type)
@ -162,9 +179,7 @@ async def specific_search(
send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(
query=query, system_prompt=system_prompt, user_prompt=user_prompt, only_context=only_context
)
results = await search_task(query=query_text)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
@ -172,14 +187,18 @@ async def specific_search(
async def authorized_search(
query_text: str,
query_type: SearchType,
user: User = None,
query_text: str,
user: User,
dataset_ids: Optional[list[UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
save_interaction: bool = False,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
) -> list:
"""
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
@ -193,14 +212,18 @@ async def authorized_search(
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await specific_search_by_context(
search_datasets,
query_text,
query_type,
user,
system_prompt_path,
top_k,
save_interaction,
search_datasets=search_datasets,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
@ -210,13 +233,17 @@ async def authorized_search(
async def specific_search_by_context(
search_datasets: list[Dataset],
query_text: str,
query_type: SearchType,
query_text: str,
user: User,
system_prompt_path: str,
top_k: int,
save_interaction: bool = False,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
):
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@ -224,18 +251,33 @@ async def specific_search_by_context(
"""
async def _search_by_context(
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
dataset: Dataset,
query_type: SearchType,
query_text: str,
user: User,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
):
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
search_results = await specific_search(
query_type,
query_text,
user,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
return {
"search_result": search_results,
@ -248,7 +290,18 @@ async def specific_search_by_context(
for dataset in search_datasets:
tasks.append(
_search_by_context(
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
dataset=dataset,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
)