diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 4c0e265b0..dbfbc7bb7 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -40,7 +40,7 @@ class CogneeGraph(CogneeAbstractGraph): edge.node1.add_skeleton_edge(edge) edge.node2.add_skeleton_edge(edge) else: - raise EntityAlreadyExistsError(message=f"Edge {edge} already exists in the graph.") + print(f"Edge {edge} already exists in the graph.") def get_node(self, node_id: str) -> Node: return self.nodes.get(node_id, None) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index 09d1e84cf..bab0f3bb6 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -65,6 +65,12 @@ class Node: def get_attribute(self, key: str) -> Union[str, int, float]: return self.attributes[key] + def get_skeleton_edges(self): + return self.skeleton_edges + + def get_skeleton_neighbours(self): + return self.skeleton_neighbours + def __repr__(self) -> str: return f"Node({self.id}, attributes={self.attributes})" @@ -109,8 +115,14 @@ class Edge: def add_attribute(self, key: str, value: Any) -> None: self.attributes[key] = value - def get_attribute(self, key: str, value: Any) -> Union[str, int, float]: - return self.attributes[key] + def get_attribute(self, key: str) -> Optional[Union[str, int, float]]: + return self.attributes.get(key) + + def get_source_node(self): + return self.node1 + + def get_destination_node(self): + return self.node2 def __repr__(self) -> str: direction = "->" if self.directed else "--" diff --git a/cognee/modules/retrieval/description_to_codepart_search.py b/cognee/modules/retrieval/description_to_codepart_search.py new file mode 100644 index 000000000..e1da9a43f --- /dev/null +++ b/cognee/modules/retrieval/description_to_codepart_search.py @@ -0,0 +1,116 @@ +import asyncio +import logging + +from typing import Set, List +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.users.methods import get_default_user +from cognee.modules.users.models import User +from cognee.shared.utils import send_telemetry + + +async def code_description_to_code_part_search(query: str, user: User = None, top_k = 2) -> list: + if user is None: + user = await get_default_user() + + if user is None: + raise PermissionError("No user found in the system. Please create a user.") + + retrieved_codeparts = await code_description_to_code_part(query, user, top_k) + return retrieved_codeparts + + + +async def code_description_to_code_part( + query: str, + user: User, + top_k: int +) -> List[str]: + """ + Maps a code description query to relevant code parts using a CodeGraph pipeline. + + Args: + query (str): The search query describing the code parts. + user (User): The user performing the search. + top_k (int): Number of codegraph descriptions to match ( num of corresponding codeparts will be higher) + + Returns: + Set[str]: A set of unique code parts matching the query. + + Raises: + ValueError: If arguments are invalid. + RuntimeError: If an unexpected error occurs during execution. + """ + if not query or not isinstance(query, str): + raise ValueError("The query must be a non-empty string.") + if top_k <= 0 or not isinstance(top_k, int): + raise ValueError("top_k must be a positive integer.") + + try: + vector_engine = get_vector_engine() + graph_engine = await get_graph_engine() + except Exception as init_error: + logging.error("Failed to initialize engines: %s", init_error, exc_info=True) + raise RuntimeError("System initialization error. Please try again later.") from init_error + + send_telemetry("code_description_to_code_part_search EXECUTION STARTED", user.id) + logging.info("Search initiated by user %s with query: '%s' and top_k: %d", user.id, query, top_k) + + try: + results = await vector_engine.search( + "code_summary_text", query_text=query, limit=top_k + ) + if not results: + logging.warning("No results found for query: '%s' by user: %s", query, user.id) + return [] + + memory_fragment = CogneeGraph() + await memory_fragment.project_graph_from_db( + graph_engine, + node_properties_to_project=['id', 'type', 'text', 'source_code'], + edge_properties_to_project=['relationship_name'] + ) + + code_pieces_to_return = set() + + for node in results: + node_id = str(node.id) + node_to_search_from = memory_fragment.get_node(node_id) + + if not node_to_search_from: + logging.debug("Node %s not found in memory fragment graph", node_id) + continue + + for code_file in node_to_search_from.get_skeleton_neighbours(): + for code_file_edge in code_file.get_skeleton_edges(): + if code_file_edge.get_attribute('relationship_name') == 'contains': + code_pieces_to_return.add(code_file_edge.get_destination_node()) + + logging.info("Search completed for user: %s, query: '%s'. Found %d code pieces.", + user.id, query, len(code_pieces_to_return)) + + return list(code_pieces_to_return) + + except Exception as exec_error: + logging.error( + "Error during code description to code part search for user: %s, query: '%s'. Error: %s", + user.id, query, exec_error, exc_info=True + ) + send_telemetry("code_description_to_code_part_search EXECUTION FAILED", user.id) + raise RuntimeError("An error occurred while processing your request.") from exec_error + + +if __name__ == "__main__": + async def main(): + query = "I am looking for a class with blue eyes" + user = None + try: + results = await code_description_to_code_part_search(query, user) + print("Retrieved Code Parts:", results) + except Exception as e: + print(f"An error occurred: {e}") + + asyncio.run(main()) + + diff --git a/cognee/tasks/summarization/models.py b/cognee/tasks/summarization/models.py index af468fb9d..6fef4fb02 100644 --- a/cognee/tasks/summarization/models.py +++ b/cognee/tasks/summarization/models.py @@ -14,6 +14,7 @@ class TextSummary(DataPoint): class CodeSummary(DataPoint): + __tablename__ = "code_summary" text: str made_from: CodeFile diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 6f6165202..6888648c3 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -42,19 +42,6 @@ def test_add_edge_success(setup_graph): assert edge in node2.skeleton_edges -def test_add_duplicate_edge(setup_graph): - """Test adding a duplicate edge raises an exception.""" - graph = setup_graph - node1 = Node("node1") - node2 = Node("node2") - graph.add_node(node1) - graph.add_node(node2) - edge = Edge(node1, node2) - graph.add_edge(edge) - with pytest.raises(EntityAlreadyExistsError, match="Edge .* already exists in the graph."): - graph.add_edge(edge) - - def test_get_node_success(setup_graph): """Test retrieving an existing node.""" graph = setup_graph diff --git a/cognee/tests/unit/modules/retriever/test_description_to_codepart_search.py b/cognee/tests/unit/modules/retriever/test_description_to_codepart_search.py new file mode 100644 index 000000000..4c39883a9 --- /dev/null +++ b/cognee/tests/unit/modules/retriever/test_description_to_codepart_search.py @@ -0,0 +1,76 @@ +import pytest +from unittest.mock import AsyncMock, patch + + + +@pytest.mark.asyncio +async def test_code_description_to_code_part_no_results(): + """Test that code_description_to_code_part handles no search results.""" + + mock_user = AsyncMock() + mock_user.id = "user123" + mock_vector_engine = AsyncMock() + mock_vector_engine.search.return_value = [] + + with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine), \ + patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()), \ + patch("cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", return_value=AsyncMock()): + + from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part + result = await code_description_to_code_part("search query", mock_user, 2) + + assert result == [] + + +@pytest.mark.asyncio +async def test_code_description_to_code_part_invalid_query(): + """Test that code_description_to_code_part raises ValueError for invalid query.""" + + mock_user = AsyncMock() + + with pytest.raises(ValueError, match="The query must be a non-empty string."): + from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part + await code_description_to_code_part("", mock_user, 2) + + +@pytest.mark.asyncio +async def test_code_description_to_code_part_invalid_top_k(): + """Test that code_description_to_code_part raises ValueError for invalid top_k.""" + + mock_user = AsyncMock() + + with pytest.raises(ValueError, match="top_k must be a positive integer."): + from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part + await code_description_to_code_part("search query", mock_user, 0) + + +@pytest.mark.asyncio +async def test_code_description_to_code_part_initialization_error(): + """Test that code_description_to_code_part raises RuntimeError for engine initialization errors.""" + + mock_user = AsyncMock() + + with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", side_effect=Exception("Engine init failed")), \ + patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()): + + from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part + with pytest.raises(RuntimeError, match="System initialization error. Please try again later."): + await code_description_to_code_part("search query", mock_user, 2) + + +@pytest.mark.asyncio +async def test_code_description_to_code_part_execution_error(): + """Test that code_description_to_code_part raises RuntimeError for execution errors.""" + + mock_user = AsyncMock() + mock_user.id = "user123" + mock_vector_engine = AsyncMock() + mock_vector_engine.search.side_effect = Exception("Execution error") + + with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine), \ + patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()), \ + patch("cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", return_value=AsyncMock()): + + from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part + with pytest.raises(RuntimeError, match="An error occurred while processing your request."): + await code_description_to_code_part("search query", mock_user, 2)