diff --git a/cognee/api/v1/add/routers/get_add_router.py b/cognee/api/v1/add/routers/get_add_router.py index 66b165a38..1703d9931 100644 --- a/cognee/api/v1/add/routers/get_add_router.py +++ b/cognee/api/v1/add/routers/get_add_router.py @@ -25,6 +25,7 @@ def get_add_router() -> APIRouter: data: List[UploadFile] = File(default=None), datasetName: Optional[str] = Form(default=None), datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]), + node_set: Optional[List[str]] = Form(default=[""], example=[""]), user: User = Depends(get_authenticated_user), ): """ @@ -41,6 +42,8 @@ def get_add_router() -> APIRouter: - Regular file uploads - **datasetName** (Optional[str]): Name of the dataset to add data to - **datasetId** (Optional[UUID]): UUID of an already existing dataset + - **node_set** Optional[list[str]]: List of node identifiers for graph organization and access control. + Used for grouping related data points in the knowledge graph. Either datasetName or datasetId must be provided. @@ -65,9 +68,7 @@ def get_add_router() -> APIRouter: send_telemetry( "Add API Endpoint Invoked", user.id, - additional_properties={ - "endpoint": "POST /v1/add", - }, + additional_properties={"endpoint": "POST /v1/add", "node_set": node_set}, ) from cognee.api.v1.add import add as cognee_add @@ -76,34 +77,13 @@ def get_add_router() -> APIRouter: raise ValueError("Either datasetId or datasetName must be provided.") try: - if ( - isinstance(data, str) - and data.startswith("http") - and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true") - ): - if "github" in data: - # Perform git clone if the URL is from GitHub - repo_name = data.split("/")[-1].replace(".git", "") - subprocess.run(["git", "clone", data, f".data/{repo_name}"], check=True) - # TODO: Update add call with dataset info - await cognee_add( - "data://.data/", - f"{repo_name}", - ) - else: - # Fetch and store the data from other types of URL using curl - response = requests.get(data) - response.raise_for_status() + add_run = await cognee_add( + data, datasetName, user=user, dataset_id=datasetId, node_set=node_set + ) - file_data = await response.content() - # TODO: Update add call with dataset info - return await cognee_add(file_data) - else: - add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId) - - if isinstance(add_run, PipelineRunErrored): - return JSONResponse(status_code=420, content=add_run.model_dump(mode="json")) - return add_run.model_dump() + if isinstance(add_run, PipelineRunErrored): + return JSONResponse(status_code=420, content=add_run.model_dump(mode="json")) + return add_run.model_dump() except Exception as error: return JSONResponse(status_code=409, content={"error": str(error)}) diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 19d194b26..fb3612857 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -1,6 +1,7 @@ import os import pathlib import asyncio +from typing import Optional from cognee.shared.logging_utils import get_logger, setup_logging from cognee.modules.observability.get_observe import get_observe @@ -28,7 +29,12 @@ logger = get_logger("code_graph_pipeline") @observe -async def run_code_graph_pipeline(repo_path, include_docs=False): +async def run_code_graph_pipeline( + repo_path, + include_docs=False, + excluded_paths: Optional[list[str]] = None, + supported_languages: Optional[list[str]] = None, +): import cognee from cognee.low_level import setup @@ -40,13 +46,12 @@ async def run_code_graph_pipeline(repo_path, include_docs=False): user = await get_default_user() detailed_extraction = True - # Multi-language support: allow passing supported_languages - supported_languages = None # defer to task defaults tasks = [ Task( get_repo_file_dependencies, detailed_extraction=detailed_extraction, supported_languages=supported_languages, + excluded_paths=excluded_paths, ), # Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete Task(add_data_points, task_config={"batch_size": 30}), @@ -95,7 +100,7 @@ if __name__ == "__main__": async def main(): async for run_status in run_code_graph_pipeline("REPO_PATH"): - print(f"{run_status.pipeline_name}: {run_status.status}") + print(f"{run_status.pipeline_run_id}: {run_status.status}") file_path = os.path.join( pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html" diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 6809f089a..d40345f8e 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -38,7 +38,7 @@ class CognifyPayloadDTO(InDTO): dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]]) run_in_background: Optional[bool] = Field(default=False) custom_prompt: Optional[str] = Field( - default=None, description="Custom prompt for entity extraction and graph generation" + default="", description="Custom prompt for entity extraction and graph generation" ) diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index 0ceeb1abb..2cf087b25 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -1,9 +1,11 @@ from uuid import UUID +import pathlib from typing import Optional from datetime import datetime from pydantic import Field from fastapi import Depends, APIRouter from fastapi.responses import JSONResponse + from cognee.modules.search.types import SearchType from cognee.api.DTO import InDTO, OutDTO from cognee.modules.users.exceptions.exceptions import PermissionDeniedError @@ -20,7 +22,12 @@ class SearchPayloadDTO(InDTO): datasets: Optional[list[str]] = Field(default=None) dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]]) query: str = Field(default="What is in the document?") + system_prompt: Optional[str] = Field( + default="Answer the question using the provided context. Be as brief as possible." + ) + node_name: Optional[list[str]] = Field(default=None, example=[]) top_k: Optional[int] = Field(default=10) + only_context: bool = Field(default=False) def get_search_router() -> APIRouter: @@ -79,7 +86,10 @@ def get_search_router() -> APIRouter: - **datasets** (Optional[List[str]]): List of dataset names to search within - **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within - **query** (str): The search query string + - **system_prompt** Optional[str]: System prompt to be used for Completion type searches in Cognee + - **node_name** Optional[list[str]]: Filter results to specific node_sets defined in the add pipeline (for targeted search). - **top_k** (Optional[int]): Maximum number of results to return (default: 10) + - **only_context** bool: Set to true to only return context Cognee will be sending to LLM in Completion type searches. This will be returned instead of LLM calls for completion type searches. ## Response Returns a list of search results containing relevant nodes from the graph. @@ -102,7 +112,10 @@ def get_search_router() -> APIRouter: "datasets": payload.datasets, "dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []], "query": payload.query, + "system_prompt": payload.system_prompt, + "node_name": payload.node_name, "top_k": payload.top_k, + "only_context": payload.only_context, }, ) @@ -115,7 +128,10 @@ def get_search_router() -> APIRouter: user=user, datasets=payload.datasets, dataset_ids=payload.dataset_ids, + system_prompt=payload.system_prompt, + node_name=payload.node_name, top_k=payload.top_k, + only_context=payload.only_context, ) return results diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index f37f8ba6d..49f7aee51 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -1,6 +1,7 @@ from uuid import UUID from typing import Union, Optional, List, Type +from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.users.models import User from cognee.modules.search.types import SearchType from cognee.modules.users.methods import get_default_user @@ -16,11 +17,13 @@ async def search( datasets: Optional[Union[list[str], str]] = None, dataset_ids: Optional[Union[list[UUID], UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, - node_type: Optional[Type] = None, + node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, + only_context: bool = False, ) -> list: """ Search and query the knowledge graph for insights, information, and connections. @@ -183,11 +186,13 @@ async def search( dataset_ids=dataset_ids if dataset_ids else datasets, user=user, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, node_type=node_type, node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) return filtered_search_results diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index ed867ae24..924532ce0 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -76,7 +76,7 @@ class CogneeGraph(CogneeAbstractGraph): start_time = time.time() # Determine projection strategy - if node_type is not None and node_name is not None: + if node_type is not None and node_name not in [None, []]: nodes_data, edges_data = await adapter.get_nodeset_subgraph( node_type=node_type, node_name=node_name ) diff --git a/cognee/modules/retrieval/code_retriever.py b/cognee/modules/retrieval/code_retriever.py index 6e819d8a7..76b5e758c 100644 --- a/cognee/modules/retrieval/code_retriever.py +++ b/cognee/modules/retrieval/code_retriever.py @@ -94,7 +94,15 @@ class CodeRetriever(BaseRetriever): {"id": res.id, "score": res.score, "payload": res.payload} ) + existing_collection = [] for collection in self.classes_and_functions_collections: + if await vector_engine.has_collection(collection): + existing_collection.append(collection) + + if not existing_collection: + raise RuntimeError("No collection found for code retriever") + + for collection in existing_collection: logger.debug(f"Searching {collection} collection with general query") search_results_code = await vector_engine.search( collection, query, limit=self.top_k diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 655a9010d..4d34dfdbe 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -23,12 +23,16 @@ class CompletionRetriever(BaseRetriever): self, user_prompt_path: str = "context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + system_prompt: str = None, top_k: Optional[int] = 1, + only_context: bool = False, ): """Initialize retriever with optional custom prompt paths.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path self.top_k = top_k if top_k is not None else 1 + self.system_prompt = system_prompt + self.only_context = only_context async def get_context(self, query: str) -> str: """ @@ -88,6 +92,11 @@ class CompletionRetriever(BaseRetriever): context = await self.get_context(query) completion = await generate_completion( - query, context, self.user_prompt_path, self.system_prompt_path + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + only_context=self.only_context, ) return [completion] diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index d05e6b4fa..8bdf5f1a0 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -26,10 +26,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): self, user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + only_context: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, @@ -38,10 +40,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ) async def get_completion( - self, query: str, context: Optional[Any] = None, context_extension_rounds=4 + self, + query: str, + context: Optional[Any] = None, + context_extension_rounds=4, ) -> List[str]: """ Extends the context for a given query by retrieving related triplets and generating new @@ -86,6 +93,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, ) triplets += await self.get_triplets(completion) @@ -112,6 +120,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + only_context=self.only_context, ) if self.save_interaction and context and triplets and completion: @@ -119,4 +129,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context, triplets=triplets ) - return [completion] + if self.only_context: + return [context] + else: + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 032dccf9e..86ff8555b 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -32,14 +32,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): validation_system_prompt_path: str = "cot_validation_system_prompt.txt", followup_system_prompt_path: str = "cot_followup_system_prompt.txt", followup_user_prompt_path: str = "cot_followup_user_prompt.txt", + system_prompt: str = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + only_context: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + only_context=only_context, top_k=top_k, node_type=node_type, node_name=node_name, @@ -51,7 +55,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): self.followup_user_prompt_path = followup_user_prompt_path async def get_completion( - self, query: str, context: Optional[Any] = None, max_iter=4 + self, + query: str, + context: Optional[Any] = None, + max_iter=4, ) -> List[str]: """ Generate completion responses based on a user query and contextual information. @@ -92,6 +99,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, ) logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") if round_idx < max_iter: @@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context, triplets=triplets ) - return [completion] + if self.only_context: + return [context] + else: + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index fb3cf4885..6a5193c56 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -36,15 +36,19 @@ class GraphCompletionRetriever(BaseRetriever): self, user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + system_prompt: str = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + only_context: bool = False, ): """Initialize retriever with prompt paths and search parameters.""" self.save_interaction = save_interaction self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path + self.system_prompt = system_prompt + self.only_context = only_context self.top_k = top_k if top_k is not None else 5 self.node_type = node_type self.node_name = node_name @@ -151,7 +155,11 @@ class GraphCompletionRetriever(BaseRetriever): return context, triplets - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion( + self, + query: str, + context: Optional[Any] = None, + ) -> Any: """ Generates a completion using graph connections context based on a query. @@ -177,6 +185,8 @@ class GraphCompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + only_context=self.only_context, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index d344ebd26..051f39b22 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -21,6 +21,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", summarize_prompt_path: str = "summarize_search_results.txt", + system_prompt: Optional[str] = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, @@ -34,6 +35,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, ) self.summarize_prompt_path = summarize_prompt_path @@ -57,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): - str: A summary string representing the content of the retrieved edges. """ direct_text = await super().resolve_edges_to_text(retrieved_edges) - return await summarize_text(direct_text, self.summarize_prompt_path) + return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt) diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 56f414013..df35cdc51 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever): logger.info(f"Returning {len(summary_payloads)} summary payloads") return summary_payloads - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any: """ Generates a completion using summaries context. diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index ca0b30c18..81e636aad 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -1,3 +1,4 @@ +from typing import Optional from cognee.infrastructure.llm.LLMGateway import LLMGateway @@ -6,25 +7,35 @@ async def generate_completion( context: str, user_prompt_path: str, system_prompt_path: str, + system_prompt: Optional[str] = None, + only_context: bool = False, ) -> str: """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} user_prompt = LLMGateway.render_prompt(user_prompt_path, args) - system_prompt = LLMGateway.read_query_prompt(system_prompt_path) - - return await LLMGateway.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=str, + system_prompt = ( + system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path) ) + if only_context: + return context + else: + return await LLMGateway.acreate_structured_output( + text_input=user_prompt, + system_prompt=system_prompt, + response_model=str, + ) + async def summarize_text( text: str, - prompt_path: str = "summarize_search_results.txt", + system_prompt_path: str = "summarize_search_results.txt", + system_prompt: str = None, ) -> str: """Summarizes text using LLM with the specified prompt.""" - system_prompt = LLMGateway.read_query_prompt(prompt_path) + system_prompt = ( + system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path) + ) return await LLMGateway.acreate_structured_output( text_input=text, diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 6c0aa6a1d..fe5df0345 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -4,6 +4,7 @@ import asyncio from uuid import UUID from typing import Callable, List, Optional, Type, Union +from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.context_global_variables import set_database_global_context_variables @@ -38,11 +39,13 @@ async def search( dataset_ids: Union[list[UUID], None], user: User, system_prompt_path="answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, - node_type: Optional[Type] = None, + node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = False, ): """ @@ -62,28 +65,34 @@ async def search( # Use search function filtered by permissions if access control is enabled if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": return await authorized_search( - query_text=query_text, query_type=query_type, + query_text=query_text, user=user, dataset_ids=dataset_ids, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, + node_type=node_type, + node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) query = await log_query(query_text, query_type.value, user.id) search_results = await specific_search( - query_type, - query_text, - user, + query_type=query_type, + query_text=query_text, + user=user, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, node_type=node_type, node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) await log_result( @@ -99,21 +108,26 @@ async def search( async def specific_search( query_type: SearchType, - query: str, + query_text: str, user: User, - system_prompt_path="answer_simple_question.txt", + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, - node_type: Optional[Type] = None, + node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion, SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion, SearchType.RAG_COMPLETION: CompletionRetriever( - system_prompt_path=system_prompt_path, top_k=top_k + system_prompt_path=system_prompt_path, + top_k=top_k, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_COMPLETION: GraphCompletionRetriever( system_prompt_path=system_prompt_path, @@ -121,6 +135,8 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( system_prompt_path=system_prompt_path, @@ -128,6 +144,8 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( system_prompt_path=system_prompt_path, @@ -135,6 +153,8 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, @@ -142,6 +162,7 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, ).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion, @@ -152,7 +173,7 @@ async def specific_search( # If the query type is FEELING_LUCKY, select the search type intelligently if query_type is SearchType.FEELING_LUCKY: - query_type = await select_search_type(query) + query_type = await select_search_type(query_text) search_task = search_tasks.get(query_type) @@ -161,7 +182,7 @@ async def specific_search( send_telemetry("cognee.search EXECUTION STARTED", user.id) - results = await search_task(query) + results = await search_task(query_text) send_telemetry("cognee.search EXECUTION COMPLETED", user.id) @@ -169,14 +190,18 @@ async def specific_search( async def authorized_search( - query_text: str, query_type: SearchType, - user: User = None, + query_text: str, + user: User, dataset_ids: Optional[list[UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, - save_interaction: bool = False, + node_type: Optional[Type] = NodeSet, + node_name: Optional[List[str]] = None, + save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ) -> list: """ Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. @@ -190,14 +215,18 @@ async def authorized_search( # Searches all provided datasets and handles setting up of appropriate database context based on permissions search_results = await specific_search_by_context( - search_datasets, - query_text, - query_type, - user, - system_prompt_path, - top_k, - save_interaction, + search_datasets=search_datasets, + query_type=query_type, + query_text=query_text, + user=user, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) @@ -207,13 +236,17 @@ async def authorized_search( async def specific_search_by_context( search_datasets: list[Dataset], - query_text: str, query_type: SearchType, + query_text: str, user: User, - system_prompt_path: str, - top_k: int, - save_interaction: bool = False, + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, + top_k: int = 10, + node_type: Optional[Type] = NodeSet, + node_name: Optional[List[str]] = None, + save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ): """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -221,18 +254,33 @@ async def specific_search_by_context( """ async def _search_by_context( - dataset, user, query_type, query_text, system_prompt_path, top_k, last_k + dataset: Dataset, + query_type: SearchType, + query_text: str, + user: User, + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, + top_k: int = 10, + node_type: Optional[Type] = NodeSet, + node_name: Optional[List[str]] = None, + save_interaction: Optional[bool] = False, + last_k: Optional[int] = None, + only_context: bool = None, ): # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) search_results = await specific_search( - query_type, - query_text, - user, + query_type=query_type, + query_text=query_text, + user=user, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, + node_type=node_type, + node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) return { "search_result": search_results, @@ -245,7 +293,18 @@ async def specific_search_by_context( for dataset in search_datasets: tasks.append( _search_by_context( - dataset, user, query_type, query_text, system_prompt_path, top_k, last_k + dataset=dataset, + query_type=query_type, + query_text=query_text, + user=user, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + last_k=last_k, + only_context=only_context, ) ) diff --git a/cognee/tasks/repo_processor/get_repo_file_dependencies.py b/cognee/tasks/repo_processor/get_repo_file_dependencies.py index 4ff79523f..06cc3bddb 100644 --- a/cognee/tasks/repo_processor/get_repo_file_dependencies.py +++ b/cognee/tasks/repo_processor/get_repo_file_dependencies.py @@ -1,24 +1,48 @@ import asyncio import math import os - -# from concurrent.futures import ProcessPoolExecutor -from typing import AsyncGenerator +from pathlib import Path +from typing import Set +from typing import AsyncGenerator, Optional, List from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.shared.CodeGraphEntities import CodeFile, Repository +# constant, declared only once +EXCLUDED_DIRS: Set[str] = { + ".venv", + "venv", + "env", + ".env", + "site-packages", + "node_modules", + "dist", + "build", + ".git", + "tests", + "test", +} -async def get_source_code_files(repo_path, language_config: dict[str, list[str]] | None = None): + +async def get_source_code_files( + repo_path, + language_config: dict[str, list[str]] | None = None, + excluded_paths: Optional[List[str]] = None, +): """ - Retrieve source code files from the specified repository path for multiple languages. + Retrieve Python source code files from the specified repository path. + + This function scans the given repository path for files that have the .py extension + while excluding test files and files within a virtual environment. It returns a list of + absolute paths to the source code files that are not empty. Parameters: ----------- - - repo_path: The file path to the repository to search for source files. - - language_config: dict mapping language names to file extensions, e.g., + - repo_path: Root path of the repository to search + - language_config: dict mapping language names to file extensions, e.g., {'python': ['.py'], 'javascript': ['.js', '.jsx'], ...} + - excluded_paths: Optional list of path fragments or glob patterns to exclude Returns: -------- @@ -54,28 +78,23 @@ async def get_source_code_files(repo_path, language_config: dict[str, list[str]] lang = _get_language_from_extension(file, language_config) if lang is None: continue - # Exclude tests and common build/venv directories - excluded_dirs = { - ".venv", - "venv", - "env", - ".env", - "site-packages", - "node_modules", - "dist", - "build", - ".git", - "tests", - "test", - } - root_parts = set(os.path.normpath(root).split(os.sep)) + # Exclude tests, common build/venv directories and files provided in exclude_paths + excluded_dirs = EXCLUDED_DIRS + excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths + + root_path = Path(root).resolve() + root_parts = set(root_path.parts) # same as before base_name, _ext = os.path.splitext(file) if ( base_name.startswith("test_") - or base_name.endswith("_test") # catches Go's *_test.go and similar + or base_name.endswith("_test") or ".test." in file or ".spec." in file - or (excluded_dirs & root_parts) + or (excluded_dirs & root_parts) # name match + or any( + root_path.is_relative_to(p) # full-path match + for p in excluded_paths + ) ): continue file_path = os.path.abspath(os.path.join(root, file)) @@ -115,7 +134,10 @@ def run_coroutine(coroutine_func, *args, **kwargs): async def get_repo_file_dependencies( - repo_path: str, detailed_extraction: bool = False, supported_languages: list = None + repo_path: str, + detailed_extraction: bool = False, + supported_languages: list = None, + excluded_paths: Optional[List[str]] = None, ) -> AsyncGenerator[DataPoint, None]: """ Generate a dependency graph for source files (multi-language) in the given repository path. @@ -150,6 +172,7 @@ async def get_repo_file_dependencies( "go": [".go"], "rust": [".rs"], "cpp": [".cpp", ".c", ".h", ".hpp"], + "c": [".c", ".h"], } if supported_languages is not None: language_config = { @@ -158,7 +181,9 @@ async def get_repo_file_dependencies( else: language_config = default_language_config - source_code_files = await get_source_code_files(repo_path, language_config=language_config) + source_code_files = await get_source_code_files( + repo_path, language_config=language_config, excluded_paths=excluded_paths + ) repo = Repository( id=uuid5(NAMESPACE_OID, repo_path), diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py index 46995d087..3a6bdc51e 100644 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -3,8 +3,8 @@ import uuid from unittest.mock import AsyncMock, MagicMock, patch import pytest -from pylint.checkers.utils import node_type +from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.modules.search.methods.search import search, specific_search from cognee.modules.search.types import SearchType @@ -58,15 +58,17 @@ async def test_search( # Verify mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id) mock_specific_search.assert_called_once_with( - query_type, - query_text, - mock_user, + query_type=query_type, + query_text=query_text, + user=mock_user, system_prompt_path="answer_simple_question.txt", + system_prompt=None, top_k=10, - node_type=None, + node_type=NodeSet, node_name=None, save_interaction=False, last_k=None, + only_context=False, ) # Verify result logging @@ -201,7 +203,10 @@ async def test_specific_search_feeling_lucky( if retriever_name == "CompletionRetriever": mock_retriever_class.assert_called_once_with( - system_prompt_path="answer_simple_question.txt", top_k=top_k + system_prompt_path="answer_simple_question.txt", + top_k=top_k, + system_prompt=None, + only_context=None, ) else: mock_retriever_class.assert_called_once_with(top_k=top_k)