Merge branch 'dev' into feature/cog-2746-time-graph-to-cognify

This commit is contained in:
hajdul88 2025-08-29 18:21:45 +02:00 committed by GitHub
commit 4e9c0810c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 274 additions and 115 deletions

View file

@ -25,6 +25,7 @@ def get_add_router() -> APIRouter:
data: List[UploadFile] = File(default=None), data: List[UploadFile] = File(default=None),
datasetName: Optional[str] = Form(default=None), datasetName: Optional[str] = Form(default=None),
datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]), datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]),
node_set: Optional[List[str]] = Form(default=[""], example=[""]),
user: User = Depends(get_authenticated_user), user: User = Depends(get_authenticated_user),
): ):
""" """
@ -41,6 +42,8 @@ def get_add_router() -> APIRouter:
- Regular file uploads - Regular file uploads
- **datasetName** (Optional[str]): Name of the dataset to add data to - **datasetName** (Optional[str]): Name of the dataset to add data to
- **datasetId** (Optional[UUID]): UUID of an already existing dataset - **datasetId** (Optional[UUID]): UUID of an already existing dataset
- **node_set** Optional[list[str]]: List of node identifiers for graph organization and access control.
Used for grouping related data points in the knowledge graph.
Either datasetName or datasetId must be provided. Either datasetName or datasetId must be provided.
@ -65,9 +68,7 @@ def get_add_router() -> APIRouter:
send_telemetry( send_telemetry(
"Add API Endpoint Invoked", "Add API Endpoint Invoked",
user.id, user.id,
additional_properties={ additional_properties={"endpoint": "POST /v1/add", "node_set": node_set},
"endpoint": "POST /v1/add",
},
) )
from cognee.api.v1.add import add as cognee_add from cognee.api.v1.add import add as cognee_add
@ -76,34 +77,13 @@ def get_add_router() -> APIRouter:
raise ValueError("Either datasetId or datasetName must be provided.") raise ValueError("Either datasetId or datasetName must be provided.")
try: try:
if ( add_run = await cognee_add(
isinstance(data, str) data, datasetName, user=user, dataset_id=datasetId, node_set=node_set
and data.startswith("http") )
and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true")
):
if "github" in data:
# Perform git clone if the URL is from GitHub
repo_name = data.split("/")[-1].replace(".git", "")
subprocess.run(["git", "clone", data, f".data/{repo_name}"], check=True)
# TODO: Update add call with dataset info
await cognee_add(
"data://.data/",
f"{repo_name}",
)
else:
# Fetch and store the data from other types of URL using curl
response = requests.get(data)
response.raise_for_status()
file_data = await response.content() if isinstance(add_run, PipelineRunErrored):
# TODO: Update add call with dataset info return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
return await cognee_add(file_data) return add_run.model_dump()
else:
add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
if isinstance(add_run, PipelineRunErrored):
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
return add_run.model_dump()
except Exception as error: except Exception as error:
return JSONResponse(status_code=409, content={"error": str(error)}) return JSONResponse(status_code=409, content={"error": str(error)})

View file

@ -1,6 +1,7 @@
import os import os
import pathlib import pathlib
import asyncio import asyncio
from typing import Optional
from cognee.shared.logging_utils import get_logger, setup_logging from cognee.shared.logging_utils import get_logger, setup_logging
from cognee.modules.observability.get_observe import get_observe from cognee.modules.observability.get_observe import get_observe
@ -28,7 +29,12 @@ logger = get_logger("code_graph_pipeline")
@observe @observe
async def run_code_graph_pipeline(repo_path, include_docs=False): async def run_code_graph_pipeline(
repo_path,
include_docs=False,
excluded_paths: Optional[list[str]] = None,
supported_languages: Optional[list[str]] = None,
):
import cognee import cognee
from cognee.low_level import setup from cognee.low_level import setup
@ -40,13 +46,12 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
user = await get_default_user() user = await get_default_user()
detailed_extraction = True detailed_extraction = True
# Multi-language support: allow passing supported_languages
supported_languages = None # defer to task defaults
tasks = [ tasks = [
Task( Task(
get_repo_file_dependencies, get_repo_file_dependencies,
detailed_extraction=detailed_extraction, detailed_extraction=detailed_extraction,
supported_languages=supported_languages, supported_languages=supported_languages,
excluded_paths=excluded_paths,
), ),
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete # Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
Task(add_data_points, task_config={"batch_size": 30}), Task(add_data_points, task_config={"batch_size": 30}),
@ -95,7 +100,7 @@ if __name__ == "__main__":
async def main(): async def main():
async for run_status in run_code_graph_pipeline("REPO_PATH"): async for run_status in run_code_graph_pipeline("REPO_PATH"):
print(f"{run_status.pipeline_name}: {run_status.status}") print(f"{run_status.pipeline_run_id}: {run_status.status}")
file_path = os.path.join( file_path = os.path.join(
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html" pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"

View file

@ -38,7 +38,7 @@ class CognifyPayloadDTO(InDTO):
dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]]) dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]])
run_in_background: Optional[bool] = Field(default=False) run_in_background: Optional[bool] = Field(default=False)
custom_prompt: Optional[str] = Field( custom_prompt: Optional[str] = Field(
default=None, description="Custom prompt for entity extraction and graph generation" default="", description="Custom prompt for entity extraction and graph generation"
) )

View file

@ -1,9 +1,11 @@
from uuid import UUID from uuid import UUID
import pathlib
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
from pydantic import Field from pydantic import Field
from fastapi import Depends, APIRouter from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
from cognee.api.DTO import InDTO, OutDTO from cognee.api.DTO import InDTO, OutDTO
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError from cognee.modules.users.exceptions.exceptions import PermissionDeniedError
@ -20,7 +22,12 @@ class SearchPayloadDTO(InDTO):
datasets: Optional[list[str]] = Field(default=None) datasets: Optional[list[str]] = Field(default=None)
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]]) dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
query: str = Field(default="What is in the document?") query: str = Field(default="What is in the document?")
system_prompt: Optional[str] = Field(
default="Answer the question using the provided context. Be as brief as possible."
)
node_name: Optional[list[str]] = Field(default=None, example=[])
top_k: Optional[int] = Field(default=10) top_k: Optional[int] = Field(default=10)
only_context: bool = Field(default=False)
def get_search_router() -> APIRouter: def get_search_router() -> APIRouter:
@ -79,7 +86,10 @@ def get_search_router() -> APIRouter:
- **datasets** (Optional[List[str]]): List of dataset names to search within - **datasets** (Optional[List[str]]): List of dataset names to search within
- **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within - **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within
- **query** (str): The search query string - **query** (str): The search query string
- **system_prompt** Optional[str]: System prompt to be used for Completion type searches in Cognee
- **node_name** Optional[list[str]]: Filter results to specific node_sets defined in the add pipeline (for targeted search).
- **top_k** (Optional[int]): Maximum number of results to return (default: 10) - **top_k** (Optional[int]): Maximum number of results to return (default: 10)
- **only_context** bool: Set to true to only return context Cognee will be sending to LLM in Completion type searches. This will be returned instead of LLM calls for completion type searches.
## Response ## Response
Returns a list of search results containing relevant nodes from the graph. Returns a list of search results containing relevant nodes from the graph.
@ -102,7 +112,10 @@ def get_search_router() -> APIRouter:
"datasets": payload.datasets, "datasets": payload.datasets,
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []], "dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
"query": payload.query, "query": payload.query,
"system_prompt": payload.system_prompt,
"node_name": payload.node_name,
"top_k": payload.top_k, "top_k": payload.top_k,
"only_context": payload.only_context,
}, },
) )
@ -115,7 +128,10 @@ def get_search_router() -> APIRouter:
user=user, user=user,
datasets=payload.datasets, datasets=payload.datasets,
dataset_ids=payload.dataset_ids, dataset_ids=payload.dataset_ids,
system_prompt=payload.system_prompt,
node_name=payload.node_name,
top_k=payload.top_k, top_k=payload.top_k,
only_context=payload.only_context,
) )
return results return results

View file

@ -1,6 +1,7 @@
from uuid import UUID from uuid import UUID
from typing import Union, Optional, List, Type from typing import Union, Optional, List, Type
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
@ -16,11 +17,13 @@ async def search(
datasets: Optional[Union[list[str], str]] = None, datasets: Optional[Union[list[str], str]] = None,
dataset_ids: Optional[Union[list[UUID], UUID]] = None, dataset_ids: Optional[Union[list[UUID], UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = False,
) -> list: ) -> list:
""" """
Search and query the knowledge graph for insights, information, and connections. Search and query the knowledge graph for insights, information, and connections.
@ -183,11 +186,13 @@ async def search(
dataset_ids=dataset_ids if dataset_ids else datasets, dataset_ids=dataset_ids if dataset_ids else datasets,
user=user, user=user,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k, top_k=top_k,
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context,
) )
return filtered_search_results return filtered_search_results

View file

@ -76,7 +76,7 @@ class CogneeGraph(CogneeAbstractGraph):
start_time = time.time() start_time = time.time()
# Determine projection strategy # Determine projection strategy
if node_type is not None and node_name is not None: if node_type is not None and node_name not in [None, []]:
nodes_data, edges_data = await adapter.get_nodeset_subgraph( nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name node_type=node_type, node_name=node_name
) )

View file

@ -94,7 +94,15 @@ class CodeRetriever(BaseRetriever):
{"id": res.id, "score": res.score, "payload": res.payload} {"id": res.id, "score": res.score, "payload": res.payload}
) )
existing_collection = []
for collection in self.classes_and_functions_collections: for collection in self.classes_and_functions_collections:
if await vector_engine.has_collection(collection):
existing_collection.append(collection)
if not existing_collection:
raise RuntimeError("No collection found for code retriever")
for collection in existing_collection:
logger.debug(f"Searching {collection} collection with general query") logger.debug(f"Searching {collection} collection with general query")
search_results_code = await vector_engine.search( search_results_code = await vector_engine.search(
collection, query, limit=self.top_k collection, query, limit=self.top_k

View file

@ -23,12 +23,16 @@ class CompletionRetriever(BaseRetriever):
self, self,
user_prompt_path: str = "context_for_question.txt", user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: str = None,
top_k: Optional[int] = 1, top_k: Optional[int] = 1,
only_context: bool = False,
): ):
"""Initialize retriever with optional custom prompt paths.""" """Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1 self.top_k = top_k if top_k is not None else 1
self.system_prompt = system_prompt
self.only_context = only_context
async def get_context(self, query: str) -> str: async def get_context(self, query: str) -> str:
""" """
@ -88,6 +92,11 @@ class CompletionRetriever(BaseRetriever):
context = await self.get_context(query) context = await self.get_context(query)
completion = await generate_completion( completion = await generate_completion(
query, context, self.user_prompt_path, self.system_prompt_path query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
only_context=self.only_context,
) )
return [completion] return [completion]

View file

@ -26,10 +26,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
self, self,
user_prompt_path: str = "graph_context_for_question.txt", user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: bool = False,
only_context: bool = False,
): ):
super().__init__( super().__init__(
user_prompt_path=user_prompt_path, user_prompt_path=user_prompt_path,
@ -38,10 +40,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
) )
async def get_completion( async def get_completion(
self, query: str, context: Optional[Any] = None, context_extension_rounds=4 self,
query: str,
context: Optional[Any] = None,
context_extension_rounds=4,
) -> List[str]: ) -> List[str]:
""" """
Extends the context for a given query by retrieving related triplets and generating new Extends the context for a given query by retrieving related triplets and generating new
@ -86,6 +93,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
) )
triplets += await self.get_triplets(completion) triplets += await self.get_triplets(completion)
@ -112,6 +120,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
only_context=self.only_context,
) )
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:
@ -119,4 +129,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context, triplets=triplets question=query, answer=completion, context=context, triplets=triplets
) )
return [completion] if self.only_context:
return [context]
else:
return [completion]

View file

@ -32,14 +32,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
validation_system_prompt_path: str = "cot_validation_system_prompt.txt", validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
followup_system_prompt_path: str = "cot_followup_system_prompt.txt", followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
followup_user_prompt_path: str = "cot_followup_user_prompt.txt", followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
system_prompt: str = None,
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: bool = False,
only_context: bool = False,
): ):
super().__init__( super().__init__(
user_prompt_path=user_prompt_path, user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
only_context=only_context,
top_k=top_k, top_k=top_k,
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
@ -51,7 +55,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self.followup_user_prompt_path = followup_user_prompt_path self.followup_user_prompt_path = followup_user_prompt_path
async def get_completion( async def get_completion(
self, query: str, context: Optional[Any] = None, max_iter=4 self,
query: str,
context: Optional[Any] = None,
max_iter=4,
) -> List[str]: ) -> List[str]:
""" """
Generate completion responses based on a user query and contextual information. Generate completion responses based on a user query and contextual information.
@ -92,6 +99,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
) )
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter: if round_idx < max_iter:
@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context, triplets=triplets question=query, answer=completion, context=context, triplets=triplets
) )
return [completion] if self.only_context:
return [context]
else:
return [completion]

View file

@ -36,15 +36,19 @@ class GraphCompletionRetriever(BaseRetriever):
self, self,
user_prompt_path: str = "graph_context_for_question.txt", user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: str = None,
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: bool = False,
only_context: bool = False,
): ):
"""Initialize retriever with prompt paths and search parameters.""" """Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction self.save_interaction = save_interaction
self.user_prompt_path = user_prompt_path self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.only_context = only_context
self.top_k = top_k if top_k is not None else 5 self.top_k = top_k if top_k is not None else 5
self.node_type = node_type self.node_type = node_type
self.node_name = node_name self.node_name = node_name
@ -151,7 +155,11 @@ class GraphCompletionRetriever(BaseRetriever):
return context, triplets return context, triplets
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(
self,
query: str,
context: Optional[Any] = None,
) -> Any:
""" """
Generates a completion using graph connections context based on a query. Generates a completion using graph connections context based on a query.
@ -177,6 +185,8 @@ class GraphCompletionRetriever(BaseRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
only_context=self.only_context,
) )
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:

View file

@ -21,6 +21,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
user_prompt_path: str = "graph_context_for_question.txt", user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
summarize_prompt_path: str = "summarize_search_results.txt", summarize_prompt_path: str = "summarize_search_results.txt",
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
@ -34,6 +35,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
) )
self.summarize_prompt_path = summarize_prompt_path self.summarize_prompt_path = summarize_prompt_path
@ -57,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
- str: A summary string representing the content of the retrieved edges. - str: A summary string representing the content of the retrieved edges.
""" """
direct_text = await super().resolve_edges_to_text(retrieved_edges) direct_text = await super().resolve_edges_to_text(retrieved_edges)
return await summarize_text(direct_text, self.summarize_prompt_path) return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt)

View file

@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever):
logger.info(f"Returning {len(summary_payloads)} summary payloads") logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
""" """
Generates a completion using summaries context. Generates a completion using summaries context.

View file

@ -1,3 +1,4 @@
from typing import Optional
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
@ -6,25 +7,35 @@ async def generate_completion(
context: str, context: str,
user_prompt_path: str, user_prompt_path: str,
system_prompt_path: str, system_prompt_path: str,
system_prompt: Optional[str] = None,
only_context: bool = False,
) -> str: ) -> str:
"""Generates a completion using LLM with given context and prompts.""" """Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context} args = {"question": query, "context": context}
user_prompt = LLMGateway.render_prompt(user_prompt_path, args) user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
system_prompt = LLMGateway.read_query_prompt(system_prompt_path) system_prompt = (
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
) )
if only_context:
return context
else:
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
async def summarize_text( async def summarize_text(
text: str, text: str,
prompt_path: str = "summarize_search_results.txt", system_prompt_path: str = "summarize_search_results.txt",
system_prompt: str = None,
) -> str: ) -> str:
"""Summarizes text using LLM with the specified prompt.""" """Summarizes text using LLM with the specified prompt."""
system_prompt = LLMGateway.read_query_prompt(prompt_path) system_prompt = (
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
return await LLMGateway.acreate_structured_output( return await LLMGateway.acreate_structured_output(
text_input=text, text_input=text,

View file

@ -4,6 +4,7 @@ import asyncio
from uuid import UUID from uuid import UUID
from typing import Callable, List, Optional, Type, Union from typing import Callable, List, Optional, Type, Union
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.modules.search.exceptions import UnsupportedSearchTypeError
from cognee.context_global_variables import set_database_global_context_variables from cognee.context_global_variables import set_database_global_context_variables
@ -38,11 +39,13 @@ async def search(
dataset_ids: Union[list[UUID], None], dataset_ids: Union[list[UUID], None],
user: User, user: User,
system_prompt_path="answer_simple_question.txt", system_prompt_path="answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False, save_interaction: Optional[bool] = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = False,
): ):
""" """
@ -62,28 +65,34 @@ async def search(
# Use search function filtered by permissions if access control is enabled # Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return await authorized_search( return await authorized_search(
query_text=query_text,
query_type=query_type, query_type=query_type,
query_text=query_text,
user=user, user=user,
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context,
) )
query = await log_query(query_text, query_type.value, user.id) query = await log_query(query_text, query_type.value, user.id)
search_results = await specific_search( search_results = await specific_search(
query_type, query_type=query_type,
query_text, query_text=query_text,
user, user=user,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k, top_k=top_k,
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context,
) )
await log_result( await log_result(
@ -99,21 +108,26 @@ async def search(
async def specific_search( async def specific_search(
query_type: SearchType, query_type: SearchType,
query: str, query_text: str,
user: User, user: User,
system_prompt_path="answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False, save_interaction: Optional[bool] = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = None,
) -> list: ) -> list:
search_tasks: dict[SearchType, Callable] = { search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion, SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion, SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
SearchType.RAG_COMPLETION: CompletionRetriever( SearchType.RAG_COMPLETION: CompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k system_prompt_path=system_prompt_path,
top_k=top_k,
system_prompt=system_prompt,
only_context=only_context,
).get_completion, ).get_completion,
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever( SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
@ -121,6 +135,8 @@ async def specific_search(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion, ).get_completion,
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
@ -128,6 +144,8 @@ async def specific_search(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion, ).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
@ -135,6 +153,8 @@ async def specific_search(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion, ).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
@ -142,6 +162,7 @@ async def specific_search(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
system_prompt=system_prompt,
).get_completion, ).get_completion,
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion,
@ -152,7 +173,7 @@ async def specific_search(
# If the query type is FEELING_LUCKY, select the search type intelligently # If the query type is FEELING_LUCKY, select the search type intelligently
if query_type is SearchType.FEELING_LUCKY: if query_type is SearchType.FEELING_LUCKY:
query_type = await select_search_type(query) query_type = await select_search_type(query_text)
search_task = search_tasks.get(query_type) search_task = search_tasks.get(query_type)
@ -161,7 +182,7 @@ async def specific_search(
send_telemetry("cognee.search EXECUTION STARTED", user.id) send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(query) results = await search_task(query_text)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id) send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
@ -169,14 +190,18 @@ async def specific_search(
async def authorized_search( async def authorized_search(
query_text: str,
query_type: SearchType, query_type: SearchType,
user: User = None, query_text: str,
user: User,
dataset_ids: Optional[list[UUID]] = None, dataset_ids: Optional[list[UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10, top_k: int = 10,
save_interaction: bool = False, node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = None,
) -> list: ) -> list:
""" """
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
@ -190,14 +215,18 @@ async def authorized_search(
# Searches all provided datasets and handles setting up of appropriate database context based on permissions # Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await specific_search_by_context( search_results = await specific_search_by_context(
search_datasets, search_datasets=search_datasets,
query_text, query_type=query_type,
query_type, query_text=query_text,
user, user=user,
system_prompt_path, system_prompt_path=system_prompt_path,
top_k, system_prompt=system_prompt,
save_interaction, top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context,
) )
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
@ -207,13 +236,17 @@ async def authorized_search(
async def specific_search_by_context( async def specific_search_by_context(
search_datasets: list[Dataset], search_datasets: list[Dataset],
query_text: str,
query_type: SearchType, query_type: SearchType,
query_text: str,
user: User, user: User,
system_prompt_path: str, system_prompt_path: str = "answer_simple_question.txt",
top_k: int, system_prompt: Optional[str] = None,
save_interaction: bool = False, top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = None,
): ):
""" """
Searches all provided datasets and handles setting up of appropriate database context based on permissions. Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@ -221,18 +254,33 @@ async def specific_search_by_context(
""" """
async def _search_by_context( async def _search_by_context(
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k dataset: Dataset,
query_type: SearchType,
query_text: str,
user: User,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
): ):
# Set database configuration in async context for each dataset user has access for # Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id) await set_database_global_context_variables(dataset.id, dataset.owner_id)
search_results = await specific_search( search_results = await specific_search(
query_type, query_type=query_type,
query_text, query_text=query_text,
user, user=user,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context,
) )
return { return {
"search_result": search_results, "search_result": search_results,
@ -245,7 +293,18 @@ async def specific_search_by_context(
for dataset in search_datasets: for dataset in search_datasets:
tasks.append( tasks.append(
_search_by_context( _search_by_context(
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k dataset=dataset,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
) )
) )

View file

@ -1,24 +1,48 @@
import asyncio import asyncio
import math import math
import os import os
from pathlib import Path
# from concurrent.futures import ProcessPoolExecutor from typing import Set
from typing import AsyncGenerator from typing import AsyncGenerator, Optional, List
from uuid import NAMESPACE_OID, uuid5 from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, Repository from cognee.shared.CodeGraphEntities import CodeFile, Repository
# constant, declared only once
EXCLUDED_DIRS: Set[str] = {
".venv",
"venv",
"env",
".env",
"site-packages",
"node_modules",
"dist",
"build",
".git",
"tests",
"test",
}
async def get_source_code_files(repo_path, language_config: dict[str, list[str]] | None = None):
async def get_source_code_files(
repo_path,
language_config: dict[str, list[str]] | None = None,
excluded_paths: Optional[List[str]] = None,
):
""" """
Retrieve source code files from the specified repository path for multiple languages. Retrieve Python source code files from the specified repository path.
This function scans the given repository path for files that have the .py extension
while excluding test files and files within a virtual environment. It returns a list of
absolute paths to the source code files that are not empty.
Parameters: Parameters:
----------- -----------
- repo_path: The file path to the repository to search for source files. - repo_path: Root path of the repository to search
- language_config: dict mapping language names to file extensions, e.g., - language_config: dict mapping language names to file extensions, e.g.,
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...} {'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
- excluded_paths: Optional list of path fragments or glob patterns to exclude
Returns: Returns:
-------- --------
@ -54,28 +78,23 @@ async def get_source_code_files(repo_path, language_config: dict[str, list[str]]
lang = _get_language_from_extension(file, language_config) lang = _get_language_from_extension(file, language_config)
if lang is None: if lang is None:
continue continue
# Exclude tests and common build/venv directories # Exclude tests, common build/venv directories and files provided in exclude_paths
excluded_dirs = { excluded_dirs = EXCLUDED_DIRS
".venv", excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths
"venv",
"env", root_path = Path(root).resolve()
".env", root_parts = set(root_path.parts) # same as before
"site-packages",
"node_modules",
"dist",
"build",
".git",
"tests",
"test",
}
root_parts = set(os.path.normpath(root).split(os.sep))
base_name, _ext = os.path.splitext(file) base_name, _ext = os.path.splitext(file)
if ( if (
base_name.startswith("test_") base_name.startswith("test_")
or base_name.endswith("_test") # catches Go's *_test.go and similar or base_name.endswith("_test")
or ".test." in file or ".test." in file
or ".spec." in file or ".spec." in file
or (excluded_dirs & root_parts) or (excluded_dirs & root_parts) # name match
or any(
root_path.is_relative_to(p) # full-path match
for p in excluded_paths
)
): ):
continue continue
file_path = os.path.abspath(os.path.join(root, file)) file_path = os.path.abspath(os.path.join(root, file))
@ -115,7 +134,10 @@ def run_coroutine(coroutine_func, *args, **kwargs):
async def get_repo_file_dependencies( async def get_repo_file_dependencies(
repo_path: str, detailed_extraction: bool = False, supported_languages: list = None repo_path: str,
detailed_extraction: bool = False,
supported_languages: list = None,
excluded_paths: Optional[List[str]] = None,
) -> AsyncGenerator[DataPoint, None]: ) -> AsyncGenerator[DataPoint, None]:
""" """
Generate a dependency graph for source files (multi-language) in the given repository path. Generate a dependency graph for source files (multi-language) in the given repository path.
@ -150,6 +172,7 @@ async def get_repo_file_dependencies(
"go": [".go"], "go": [".go"],
"rust": [".rs"], "rust": [".rs"],
"cpp": [".cpp", ".c", ".h", ".hpp"], "cpp": [".cpp", ".c", ".h", ".hpp"],
"c": [".c", ".h"],
} }
if supported_languages is not None: if supported_languages is not None:
language_config = { language_config = {
@ -158,7 +181,9 @@ async def get_repo_file_dependencies(
else: else:
language_config = default_language_config language_config = default_language_config
source_code_files = await get_source_code_files(repo_path, language_config=language_config) source_code_files = await get_source_code_files(
repo_path, language_config=language_config, excluded_paths=excluded_paths
)
repo = Repository( repo = Repository(
id=uuid5(NAMESPACE_OID, repo_path), id=uuid5(NAMESPACE_OID, repo_path),

View file

@ -3,8 +3,8 @@ import uuid
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from pylint.checkers.utils import node_type
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.modules.search.exceptions import UnsupportedSearchTypeError
from cognee.modules.search.methods.search import search, specific_search from cognee.modules.search.methods.search import search, specific_search
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
@ -58,15 +58,17 @@ async def test_search(
# Verify # Verify
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id) mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
mock_specific_search.assert_called_once_with( mock_specific_search.assert_called_once_with(
query_type, query_type=query_type,
query_text, query_text=query_text,
mock_user, user=mock_user,
system_prompt_path="answer_simple_question.txt", system_prompt_path="answer_simple_question.txt",
system_prompt=None,
top_k=10, top_k=10,
node_type=None, node_type=NodeSet,
node_name=None, node_name=None,
save_interaction=False, save_interaction=False,
last_k=None, last_k=None,
only_context=False,
) )
# Verify result logging # Verify result logging
@ -201,7 +203,10 @@ async def test_specific_search_feeling_lucky(
if retriever_name == "CompletionRetriever": if retriever_name == "CompletionRetriever":
mock_retriever_class.assert_called_once_with( mock_retriever_class.assert_called_once_with(
system_prompt_path="answer_simple_question.txt", top_k=top_k system_prompt_path="answer_simple_question.txt",
top_k=top_k,
system_prompt=None,
only_context=None,
) )
else: else:
mock_retriever_class.assert_called_once_with(top_k=top_k) mock_retriever_class.assert_called_once_with(top_k=top_k)