Feature/cog 539 implementing additional retriever approaches (#262)
* fix: refactor get_graph_from_model to return nodes and edges correctly * fix: add missing params * fix: remove complex zip usage * fix: add edges to data_point properties * fix: handle rate limit error coming from llm model * fix: fixes lost edges and nodes in get_graph_from_model * fix: fixes database pruning issue in pgvector * fix: fixes database pruning issue in pgvector (#261) * feat: adds code summary embeddings to vector DB * fix: cognee_demo notebook pipeline is not saving summaries * feat: implements first version of codegraph retriever * chore: implements minor changes mostly to make the code production ready * fix: turns off raising duplicated edges unit test as we have these in our current codegraph generation * feat: implements unit tests for description to codepart search * fix: fixes edge property inconsistent access in codepart retriever * chore: implements more precise typing for get_attribute method for cogneegraph * chore: adds spacing to tests and changes the cogneegraph getter names --------- Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
This commit is contained in:
parent
5ffbebdd01
commit
6d85165189
6 changed files with 208 additions and 16 deletions
|
|
@ -40,7 +40,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
edge.node1.add_skeleton_edge(edge)
|
||||
edge.node2.add_skeleton_edge(edge)
|
||||
else:
|
||||
raise EntityAlreadyExistsError(message=f"Edge {edge} already exists in the graph.")
|
||||
print(f"Edge {edge} already exists in the graph.")
|
||||
|
||||
def get_node(self, node_id: str) -> Node:
|
||||
return self.nodes.get(node_id, None)
|
||||
|
|
|
|||
|
|
@ -65,6 +65,12 @@ class Node:
|
|||
def get_attribute(self, key: str) -> Union[str, int, float]:
|
||||
return self.attributes[key]
|
||||
|
||||
def get_skeleton_edges(self):
|
||||
return self.skeleton_edges
|
||||
|
||||
def get_skeleton_neighbours(self):
|
||||
return self.skeleton_neighbours
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Node({self.id}, attributes={self.attributes})"
|
||||
|
||||
|
|
@ -109,8 +115,14 @@ class Edge:
|
|||
def add_attribute(self, key: str, value: Any) -> None:
|
||||
self.attributes[key] = value
|
||||
|
||||
def get_attribute(self, key: str, value: Any) -> Union[str, int, float]:
|
||||
return self.attributes[key]
|
||||
def get_attribute(self, key: str) -> Optional[Union[str, int, float]]:
|
||||
return self.attributes.get(key)
|
||||
|
||||
def get_source_node(self):
|
||||
return self.node1
|
||||
|
||||
def get_destination_node(self):
|
||||
return self.node2
|
||||
|
||||
def __repr__(self) -> str:
|
||||
direction = "->" if self.directed else "--"
|
||||
|
|
|
|||
116
cognee/modules/retrieval/description_to_codepart_search.py
Normal file
116
cognee/modules/retrieval/description_to_codepart_search.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import Set, List
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
|
||||
|
||||
async def code_description_to_code_part_search(query: str, user: User = None, top_k = 2) -> list:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
raise PermissionError("No user found in the system. Please create a user.")
|
||||
|
||||
retrieved_codeparts = await code_description_to_code_part(query, user, top_k)
|
||||
return retrieved_codeparts
|
||||
|
||||
|
||||
|
||||
async def code_description_to_code_part(
|
||||
query: str,
|
||||
user: User,
|
||||
top_k: int
|
||||
) -> List[str]:
|
||||
"""
|
||||
Maps a code description query to relevant code parts using a CodeGraph pipeline.
|
||||
|
||||
Args:
|
||||
query (str): The search query describing the code parts.
|
||||
user (User): The user performing the search.
|
||||
top_k (int): Number of codegraph descriptions to match ( num of corresponding codeparts will be higher)
|
||||
|
||||
Returns:
|
||||
Set[str]: A set of unique code parts matching the query.
|
||||
|
||||
Raises:
|
||||
ValueError: If arguments are invalid.
|
||||
RuntimeError: If an unexpected error occurs during execution.
|
||||
"""
|
||||
if not query or not isinstance(query, str):
|
||||
raise ValueError("The query must be a non-empty string.")
|
||||
if top_k <= 0 or not isinstance(top_k, int):
|
||||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
except Exception as init_error:
|
||||
logging.error("Failed to initialize engines: %s", init_error, exc_info=True)
|
||||
raise RuntimeError("System initialization error. Please try again later.") from init_error
|
||||
|
||||
send_telemetry("code_description_to_code_part_search EXECUTION STARTED", user.id)
|
||||
logging.info("Search initiated by user %s with query: '%s' and top_k: %d", user.id, query, top_k)
|
||||
|
||||
try:
|
||||
results = await vector_engine.search(
|
||||
"code_summary_text", query_text=query, limit=top_k
|
||||
)
|
||||
if not results:
|
||||
logging.warning("No results found for query: '%s' by user: %s", query, user.id)
|
||||
return []
|
||||
|
||||
memory_fragment = CogneeGraph()
|
||||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=['id', 'type', 'text', 'source_code'],
|
||||
edge_properties_to_project=['relationship_name']
|
||||
)
|
||||
|
||||
code_pieces_to_return = set()
|
||||
|
||||
for node in results:
|
||||
node_id = str(node.id)
|
||||
node_to_search_from = memory_fragment.get_node(node_id)
|
||||
|
||||
if not node_to_search_from:
|
||||
logging.debug("Node %s not found in memory fragment graph", node_id)
|
||||
continue
|
||||
|
||||
for code_file in node_to_search_from.get_skeleton_neighbours():
|
||||
for code_file_edge in code_file.get_skeleton_edges():
|
||||
if code_file_edge.get_attribute('relationship_name') == 'contains':
|
||||
code_pieces_to_return.add(code_file_edge.get_destination_node())
|
||||
|
||||
logging.info("Search completed for user: %s, query: '%s'. Found %d code pieces.",
|
||||
user.id, query, len(code_pieces_to_return))
|
||||
|
||||
return list(code_pieces_to_return)
|
||||
|
||||
except Exception as exec_error:
|
||||
logging.error(
|
||||
"Error during code description to code part search for user: %s, query: '%s'. Error: %s",
|
||||
user.id, query, exec_error, exc_info=True
|
||||
)
|
||||
send_telemetry("code_description_to_code_part_search EXECUTION FAILED", user.id)
|
||||
raise RuntimeError("An error occurred while processing your request.") from exec_error
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
query = "I am looking for a class with blue eyes"
|
||||
user = None
|
||||
try:
|
||||
results = await code_description_to_code_part_search(query, user)
|
||||
print("Retrieved Code Parts:", results)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
|
|
@ -14,6 +14,7 @@ class TextSummary(DataPoint):
|
|||
|
||||
|
||||
class CodeSummary(DataPoint):
|
||||
__tablename__ = "code_summary"
|
||||
text: str
|
||||
made_from: CodeFile
|
||||
|
||||
|
|
|
|||
|
|
@ -42,19 +42,6 @@ def test_add_edge_success(setup_graph):
|
|||
assert edge in node2.skeleton_edges
|
||||
|
||||
|
||||
def test_add_duplicate_edge(setup_graph):
|
||||
"""Test adding a duplicate edge raises an exception."""
|
||||
graph = setup_graph
|
||||
node1 = Node("node1")
|
||||
node2 = Node("node2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
edge = Edge(node1, node2)
|
||||
graph.add_edge(edge)
|
||||
with pytest.raises(EntityAlreadyExistsError, match="Edge .* already exists in the graph."):
|
||||
graph.add_edge(edge)
|
||||
|
||||
|
||||
def test_get_node_success(setup_graph):
|
||||
"""Test retrieving an existing node."""
|
||||
graph = setup_graph
|
||||
|
|
|
|||
|
|
@ -0,0 +1,76 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_description_to_code_part_no_results():
|
||||
"""Test that code_description_to_code_part handles no search results."""
|
||||
|
||||
mock_user = AsyncMock()
|
||||
mock_user.id = "user123"
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = []
|
||||
|
||||
with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine), \
|
||||
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()), \
|
||||
patch("cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", return_value=AsyncMock()):
|
||||
|
||||
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
|
||||
result = await code_description_to_code_part("search query", mock_user, 2)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_description_to_code_part_invalid_query():
|
||||
"""Test that code_description_to_code_part raises ValueError for invalid query."""
|
||||
|
||||
mock_user = AsyncMock()
|
||||
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
|
||||
await code_description_to_code_part("", mock_user, 2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_description_to_code_part_invalid_top_k():
|
||||
"""Test that code_description_to_code_part raises ValueError for invalid top_k."""
|
||||
|
||||
mock_user = AsyncMock()
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
|
||||
await code_description_to_code_part("search query", mock_user, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_description_to_code_part_initialization_error():
|
||||
"""Test that code_description_to_code_part raises RuntimeError for engine initialization errors."""
|
||||
|
||||
mock_user = AsyncMock()
|
||||
|
||||
with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", side_effect=Exception("Engine init failed")), \
|
||||
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()):
|
||||
|
||||
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
|
||||
with pytest.raises(RuntimeError, match="System initialization error. Please try again later."):
|
||||
await code_description_to_code_part("search query", mock_user, 2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_description_to_code_part_execution_error():
|
||||
"""Test that code_description_to_code_part raises RuntimeError for execution errors."""
|
||||
|
||||
mock_user = AsyncMock()
|
||||
mock_user.id = "user123"
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.side_effect = Exception("Execution error")
|
||||
|
||||
with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine), \
|
||||
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()), \
|
||||
patch("cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", return_value=AsyncMock()):
|
||||
|
||||
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
|
||||
with pytest.raises(RuntimeError, match="An error occurred while processing your request."):
|
||||
await code_description_to_code_part("search query", mock_user, 2)
|
||||
Loading…
Add table
Reference in a new issue