Merge branch 'dev' into configurable-path-exclusion-code-graph

This commit is contained in:
Vasilije 2025-08-29 17:32:31 +02:00 committed by GitHub
commit 4ee807579b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 167 additions and 58 deletions

View file

@ -38,7 +38,7 @@ class CognifyPayloadDTO(InDTO):
dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]])
run_in_background: Optional[bool] = Field(default=False)
custom_prompt: Optional[str] = Field(
default=None, description="Custom prompt for entity extraction and graph generation"
default="", description="Custom prompt for entity extraction and graph generation"
)

View file

@ -1,9 +1,11 @@
from uuid import UUID
import pathlib
from typing import Optional
from datetime import datetime
from pydantic import Field
from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse
from cognee.modules.search.types import SearchType
from cognee.api.DTO import InDTO, OutDTO
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError
@ -20,8 +22,12 @@ class SearchPayloadDTO(InDTO):
datasets: Optional[list[str]] = Field(default=None)
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
query: str = Field(default="What is in the document?")
system_prompt: Optional[str] = Field(
default="Answer the question using the provided context. Be as brief as possible."
)
node_name: Optional[list[str]] = Field(default=None, example=[])
top_k: Optional[int] = Field(default=10)
only_context: bool = Field(default=False)
def get_search_router() -> APIRouter:
@ -80,8 +86,10 @@ def get_search_router() -> APIRouter:
- **datasets** (Optional[List[str]]): List of dataset names to search within
- **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within
- **query** (str): The search query string
- **system_prompt** Optional[str]: System prompt to be used for Completion type searches in Cognee
- **node_name** Optional[list[str]]: Filter results to specific node_sets defined in the add pipeline (for targeted search).
- **top_k** (Optional[int]): Maximum number of results to return (default: 10)
- **only_context** bool: Set to true to only return context Cognee will be sending to LLM in Completion type searches. This will be returned instead of LLM calls for completion type searches.
## Response
Returns a list of search results containing relevant nodes from the graph.
@ -104,8 +112,10 @@ def get_search_router() -> APIRouter:
"datasets": payload.datasets,
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
"query": payload.query,
"system_prompt": payload.system_prompt,
"node_name": payload.node_name,
"top_k": payload.top_k,
"only_context": payload.only_context,
},
)
@ -118,8 +128,10 @@ def get_search_router() -> APIRouter:
user=user,
datasets=payload.datasets,
dataset_ids=payload.dataset_ids,
system_prompt=payload.system_prompt,
node_name=payload.node_name,
top_k=payload.top_k,
only_context=payload.only_context,
)
return results

View file

@ -17,11 +17,13 @@ async def search(
datasets: Optional[Union[list[str], str]] = None,
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
) -> list:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -184,11 +186,13 @@ async def search(
dataset_ids=dataset_ids if dataset_ids else datasets,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
return filtered_search_results

View file

@ -23,12 +23,16 @@ class CompletionRetriever(BaseRetriever):
self,
user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: str = None,
top_k: Optional[int] = 1,
only_context: bool = False,
):
"""Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1
self.system_prompt = system_prompt
self.only_context = only_context
async def get_context(self, query: str) -> str:
"""
@ -88,6 +92,11 @@ class CompletionRetriever(BaseRetriever):
context = await self.get_context(query)
completion = await generate_completion(
query, context, self.user_prompt_path, self.system_prompt_path
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
only_context=self.only_context,
)
return [completion]

View file

@ -26,10 +26,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
only_context: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -38,10 +40,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
)
async def get_completion(
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
self,
query: str,
context: Optional[Any] = None,
context_extension_rounds=4,
) -> List[str]:
"""
Extends the context for a given query by retrieving related triplets and generating new
@ -86,6 +93,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
triplets += await self.get_triplets(completion)
@ -112,6 +120,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
only_context=self.only_context,
)
if self.save_interaction and context and triplets and completion:
@ -119,4 +129,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]
if self.only_context:
return [context]
else:
return [completion]

View file

@ -32,14 +32,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
system_prompt: str = None,
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
only_context: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
only_context=only_context,
top_k=top_k,
node_type=node_type,
node_name=node_name,
@ -51,7 +55,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self.followup_user_prompt_path = followup_user_prompt_path
async def get_completion(
self, query: str, context: Optional[Any] = None, max_iter=4
self,
query: str,
context: Optional[Any] = None,
max_iter=4,
) -> List[str]:
"""
Generate completion responses based on a user query and contextual information.
@ -92,6 +99,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]
if self.only_context:
return [context]
else:
return [completion]

View file

@ -36,15 +36,19 @@ class GraphCompletionRetriever(BaseRetriever):
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: str = None,
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
only_context: bool = False,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.only_context = only_context
self.top_k = top_k if top_k is not None else 5
self.node_type = node_type
self.node_name = node_name
@ -151,7 +155,11 @@ class GraphCompletionRetriever(BaseRetriever):
return context, triplets
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(
self,
query: str,
context: Optional[Any] = None,
) -> Any:
"""
Generates a completion using graph connections context based on a query.
@ -177,6 +185,8 @@ class GraphCompletionRetriever(BaseRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
only_context=self.only_context,
)
if self.save_interaction and context and triplets and completion:

View file

@ -21,6 +21,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
summarize_prompt_path: str = "summarize_search_results.txt",
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
@ -34,6 +35,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
)
self.summarize_prompt_path = summarize_prompt_path
@ -57,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
- str: A summary string representing the content of the retrieved edges.
"""
direct_text = await super().resolve_edges_to_text(retrieved_edges)
return await summarize_text(direct_text, self.summarize_prompt_path)
return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt)

View file

@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever):
logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
"""
Generates a completion using summaries context.

View file

@ -1,3 +1,4 @@
from typing import Optional
from cognee.infrastructure.llm.LLMGateway import LLMGateway
@ -6,25 +7,35 @@ async def generate_completion(
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
only_context: bool = False,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
system_prompt = LLMGateway.read_query_prompt(system_prompt_path)
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
system_prompt = (
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
if only_context:
return context
else:
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
async def summarize_text(
text: str,
prompt_path: str = "summarize_search_results.txt",
system_prompt_path: str = "summarize_search_results.txt",
system_prompt: str = None,
) -> str:
"""Summarizes text using LLM with the specified prompt."""
system_prompt = LLMGateway.read_query_prompt(prompt_path)
system_prompt = (
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
return await LLMGateway.acreate_structured_output(
text_input=text,

View file

@ -38,11 +38,13 @@ async def search(
dataset_ids: Union[list[UUID], None],
user: User,
system_prompt_path="answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = False,
):
"""
@ -62,30 +64,34 @@ async def search(
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return await authorized_search(
query_text=query_text,
query_type=query_type,
query_text=query_text,
user=user,
dataset_ids=dataset_ids,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
query = await log_query(query_text, query_type.value, user.id)
search_results = await specific_search(
query_type,
query_text,
user,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
await log_result(
@ -101,21 +107,26 @@ async def search(
async def specific_search(
query_type: SearchType,
query: str,
query_text: str,
user: User,
system_prompt_path="answer_simple_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
SearchType.RAG_COMPLETION: CompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k
system_prompt_path=system_prompt_path,
top_k=top_k,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
@ -123,6 +134,8 @@ async def specific_search(
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
@ -130,6 +143,8 @@ async def specific_search(
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
@ -137,6 +152,8 @@ async def specific_search(
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
@ -144,6 +161,7 @@ async def specific_search(
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_completion,
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion,
@ -153,7 +171,7 @@ async def specific_search(
# If the query type is FEELING_LUCKY, select the search type intelligently
if query_type is SearchType.FEELING_LUCKY:
query_type = await select_search_type(query)
query_type = await select_search_type(query_text)
search_task = search_tasks.get(query_type)
@ -162,7 +180,7 @@ async def specific_search(
send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(query)
results = await search_task(query_text)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
@ -170,16 +188,18 @@ async def specific_search(
async def authorized_search(
query_text: str,
query_type: SearchType,
user: User = None,
query_text: str,
user: User,
dataset_ids: Optional[list[UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
) -> list:
"""
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
@ -193,16 +213,18 @@ async def authorized_search(
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await specific_search_by_context(
search_datasets,
query_text,
query_type,
user,
system_prompt_path,
top_k,
search_datasets=search_datasets,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
@ -212,15 +234,17 @@ async def authorized_search(
async def specific_search_by_context(
search_datasets: list[Dataset],
query_text: str,
query_type: SearchType,
query_text: str,
user: User,
system_prompt_path: str,
top_k: int,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
):
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@ -228,28 +252,33 @@ async def specific_search_by_context(
"""
async def _search_by_context(
dataset,
user,
query_type,
query_text,
system_prompt_path,
top_k,
dataset: Dataset,
query_type: SearchType,
query_text: str,
user: User,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
):
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
search_results = await specific_search(
query_type,
query_text,
user,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
return {
"search_result": search_results,
@ -262,15 +291,18 @@ async def specific_search_by_context(
for dataset in search_datasets:
tasks.append(
_search_by_context(
dataset,
user,
query_type,
query_text,
system_prompt_path,
top_k,
dataset=dataset,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
)

View file

@ -58,15 +58,17 @@ async def test_search(
# Verify
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
mock_specific_search.assert_called_once_with(
query_type,
query_text,
mock_user,
query_type=query_type,
query_text=query_text,
user=mock_user,
system_prompt_path="answer_simple_question.txt",
system_prompt=None,
top_k=10,
node_type=NodeSet,
node_name=None,
save_interaction=False,
last_k=None,
only_context=False,
)
# Verify result logging
@ -201,7 +203,10 @@ async def test_specific_search_feeling_lucky(
if retriever_name == "CompletionRetriever":
mock_retriever_class.assert_called_once_with(
system_prompt_path="answer_simple_question.txt", top_k=top_k
system_prompt_path="answer_simple_question.txt",
top_k=top_k,
system_prompt=None,
only_context=None,
)
else:
mock_retriever_class.assert_called_once_with(top_k=top_k)