Feature/cog 539 implementing additional retriever approaches (#262)
* fix: refactor get_graph_from_model to return nodes and edges correctly * fix: add missing params * fix: remove complex zip usage * fix: add edges to data_point properties * fix: handle rate limit error coming from llm model * fix: fixes lost edges and nodes in get_graph_from_model * fix: fixes database pruning issue in pgvector * fix: fixes database pruning issue in pgvector (#261) * feat: adds code summary embeddings to vector DB * fix: cognee_demo notebook pipeline is not saving summaries * feat: implements first version of codegraph retriever * chore: implements minor changes mostly to make the code production ready * fix: turns off raising duplicated edges unit test as we have these in our current codegraph generation * feat: implements unit tests for description to codepart search * fix: fixes edge property inconsistent access in codepart retriever * chore: implements more precise typing for get_attribute method for cogneegraph * chore: adds spacing to tests and changes the cogneegraph getter names --------- Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
This commit is contained in:
parent
5ffbebdd01
commit
6d85165189
6 changed files with 208 additions and 16 deletions
|
|
@ -40,7 +40,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
edge.node1.add_skeleton_edge(edge)
|
edge.node1.add_skeleton_edge(edge)
|
||||||
edge.node2.add_skeleton_edge(edge)
|
edge.node2.add_skeleton_edge(edge)
|
||||||
else:
|
else:
|
||||||
raise EntityAlreadyExistsError(message=f"Edge {edge} already exists in the graph.")
|
print(f"Edge {edge} already exists in the graph.")
|
||||||
|
|
||||||
def get_node(self, node_id: str) -> Node:
|
def get_node(self, node_id: str) -> Node:
|
||||||
return self.nodes.get(node_id, None)
|
return self.nodes.get(node_id, None)
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,12 @@ class Node:
|
||||||
def get_attribute(self, key: str) -> Union[str, int, float]:
|
def get_attribute(self, key: str) -> Union[str, int, float]:
|
||||||
return self.attributes[key]
|
return self.attributes[key]
|
||||||
|
|
||||||
|
def get_skeleton_edges(self):
|
||||||
|
return self.skeleton_edges
|
||||||
|
|
||||||
|
def get_skeleton_neighbours(self):
|
||||||
|
return self.skeleton_neighbours
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Node({self.id}, attributes={self.attributes})"
|
return f"Node({self.id}, attributes={self.attributes})"
|
||||||
|
|
||||||
|
|
@ -109,8 +115,14 @@ class Edge:
|
||||||
def add_attribute(self, key: str, value: Any) -> None:
|
def add_attribute(self, key: str, value: Any) -> None:
|
||||||
self.attributes[key] = value
|
self.attributes[key] = value
|
||||||
|
|
||||||
def get_attribute(self, key: str, value: Any) -> Union[str, int, float]:
|
def get_attribute(self, key: str) -> Optional[Union[str, int, float]]:
|
||||||
return self.attributes[key]
|
return self.attributes.get(key)
|
||||||
|
|
||||||
|
def get_source_node(self):
|
||||||
|
return self.node1
|
||||||
|
|
||||||
|
def get_destination_node(self):
|
||||||
|
return self.node2
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
direction = "->" if self.directed else "--"
|
direction = "->" if self.directed else "--"
|
||||||
|
|
|
||||||
116
cognee/modules/retrieval/description_to_codepart_search.py
Normal file
116
cognee/modules/retrieval/description_to_codepart_search.py
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing import Set, List
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.shared.utils import send_telemetry
|
||||||
|
|
||||||
|
|
||||||
|
async def code_description_to_code_part_search(query: str, user: User = None, top_k = 2) -> list:
|
||||||
|
if user is None:
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise PermissionError("No user found in the system. Please create a user.")
|
||||||
|
|
||||||
|
retrieved_codeparts = await code_description_to_code_part(query, user, top_k)
|
||||||
|
return retrieved_codeparts
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def code_description_to_code_part(
|
||||||
|
query: str,
|
||||||
|
user: User,
|
||||||
|
top_k: int
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Maps a code description query to relevant code parts using a CodeGraph pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The search query describing the code parts.
|
||||||
|
user (User): The user performing the search.
|
||||||
|
top_k (int): Number of codegraph descriptions to match ( num of corresponding codeparts will be higher)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set[str]: A set of unique code parts matching the query.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If arguments are invalid.
|
||||||
|
RuntimeError: If an unexpected error occurs during execution.
|
||||||
|
"""
|
||||||
|
if not query or not isinstance(query, str):
|
||||||
|
raise ValueError("The query must be a non-empty string.")
|
||||||
|
if top_k <= 0 or not isinstance(top_k, int):
|
||||||
|
raise ValueError("top_k must be a positive integer.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
except Exception as init_error:
|
||||||
|
logging.error("Failed to initialize engines: %s", init_error, exc_info=True)
|
||||||
|
raise RuntimeError("System initialization error. Please try again later.") from init_error
|
||||||
|
|
||||||
|
send_telemetry("code_description_to_code_part_search EXECUTION STARTED", user.id)
|
||||||
|
logging.info("Search initiated by user %s with query: '%s' and top_k: %d", user.id, query, top_k)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = await vector_engine.search(
|
||||||
|
"code_summary_text", query_text=query, limit=top_k
|
||||||
|
)
|
||||||
|
if not results:
|
||||||
|
logging.warning("No results found for query: '%s' by user: %s", query, user.id)
|
||||||
|
return []
|
||||||
|
|
||||||
|
memory_fragment = CogneeGraph()
|
||||||
|
await memory_fragment.project_graph_from_db(
|
||||||
|
graph_engine,
|
||||||
|
node_properties_to_project=['id', 'type', 'text', 'source_code'],
|
||||||
|
edge_properties_to_project=['relationship_name']
|
||||||
|
)
|
||||||
|
|
||||||
|
code_pieces_to_return = set()
|
||||||
|
|
||||||
|
for node in results:
|
||||||
|
node_id = str(node.id)
|
||||||
|
node_to_search_from = memory_fragment.get_node(node_id)
|
||||||
|
|
||||||
|
if not node_to_search_from:
|
||||||
|
logging.debug("Node %s not found in memory fragment graph", node_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for code_file in node_to_search_from.get_skeleton_neighbours():
|
||||||
|
for code_file_edge in code_file.get_skeleton_edges():
|
||||||
|
if code_file_edge.get_attribute('relationship_name') == 'contains':
|
||||||
|
code_pieces_to_return.add(code_file_edge.get_destination_node())
|
||||||
|
|
||||||
|
logging.info("Search completed for user: %s, query: '%s'. Found %d code pieces.",
|
||||||
|
user.id, query, len(code_pieces_to_return))
|
||||||
|
|
||||||
|
return list(code_pieces_to_return)
|
||||||
|
|
||||||
|
except Exception as exec_error:
|
||||||
|
logging.error(
|
||||||
|
"Error during code description to code part search for user: %s, query: '%s'. Error: %s",
|
||||||
|
user.id, query, exec_error, exc_info=True
|
||||||
|
)
|
||||||
|
send_telemetry("code_description_to_code_part_search EXECUTION FAILED", user.id)
|
||||||
|
raise RuntimeError("An error occurred while processing your request.") from exec_error
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
async def main():
|
||||||
|
query = "I am looking for a class with blue eyes"
|
||||||
|
user = None
|
||||||
|
try:
|
||||||
|
results = await code_description_to_code_part_search(query, user)
|
||||||
|
print("Retrieved Code Parts:", results)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -14,6 +14,7 @@ class TextSummary(DataPoint):
|
||||||
|
|
||||||
|
|
||||||
class CodeSummary(DataPoint):
|
class CodeSummary(DataPoint):
|
||||||
|
__tablename__ = "code_summary"
|
||||||
text: str
|
text: str
|
||||||
made_from: CodeFile
|
made_from: CodeFile
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,19 +42,6 @@ def test_add_edge_success(setup_graph):
|
||||||
assert edge in node2.skeleton_edges
|
assert edge in node2.skeleton_edges
|
||||||
|
|
||||||
|
|
||||||
def test_add_duplicate_edge(setup_graph):
|
|
||||||
"""Test adding a duplicate edge raises an exception."""
|
|
||||||
graph = setup_graph
|
|
||||||
node1 = Node("node1")
|
|
||||||
node2 = Node("node2")
|
|
||||||
graph.add_node(node1)
|
|
||||||
graph.add_node(node2)
|
|
||||||
edge = Edge(node1, node2)
|
|
||||||
graph.add_edge(edge)
|
|
||||||
with pytest.raises(EntityAlreadyExistsError, match="Edge .* already exists in the graph."):
|
|
||||||
graph.add_edge(edge)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_node_success(setup_graph):
|
def test_get_node_success(setup_graph):
|
||||||
"""Test retrieving an existing node."""
|
"""Test retrieving an existing node."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_code_description_to_code_part_no_results():
|
||||||
|
"""Test that code_description_to_code_part handles no search results."""
|
||||||
|
|
||||||
|
mock_user = AsyncMock()
|
||||||
|
mock_user.id = "user123"
|
||||||
|
mock_vector_engine = AsyncMock()
|
||||||
|
mock_vector_engine.search.return_value = []
|
||||||
|
|
||||||
|
with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine), \
|
||||||
|
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()), \
|
||||||
|
patch("cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", return_value=AsyncMock()):
|
||||||
|
|
||||||
|
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
|
||||||
|
result = await code_description_to_code_part("search query", mock_user, 2)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_code_description_to_code_part_invalid_query():
|
||||||
|
"""Test that code_description_to_code_part raises ValueError for invalid query."""
|
||||||
|
|
||||||
|
mock_user = AsyncMock()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||||
|
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
|
||||||
|
await code_description_to_code_part("", mock_user, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_code_description_to_code_part_invalid_top_k():
|
||||||
|
"""Test that code_description_to_code_part raises ValueError for invalid top_k."""
|
||||||
|
|
||||||
|
mock_user = AsyncMock()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||||
|
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
|
||||||
|
await code_description_to_code_part("search query", mock_user, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_code_description_to_code_part_initialization_error():
|
||||||
|
"""Test that code_description_to_code_part raises RuntimeError for engine initialization errors."""
|
||||||
|
|
||||||
|
mock_user = AsyncMock()
|
||||||
|
|
||||||
|
with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", side_effect=Exception("Engine init failed")), \
|
||||||
|
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()):
|
||||||
|
|
||||||
|
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
|
||||||
|
with pytest.raises(RuntimeError, match="System initialization error. Please try again later."):
|
||||||
|
await code_description_to_code_part("search query", mock_user, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_code_description_to_code_part_execution_error():
|
||||||
|
"""Test that code_description_to_code_part raises RuntimeError for execution errors."""
|
||||||
|
|
||||||
|
mock_user = AsyncMock()
|
||||||
|
mock_user.id = "user123"
|
||||||
|
mock_vector_engine = AsyncMock()
|
||||||
|
mock_vector_engine.search.side_effect = Exception("Execution error")
|
||||||
|
|
||||||
|
with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine), \
|
||||||
|
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()), \
|
||||||
|
patch("cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", return_value=AsyncMock()):
|
||||||
|
|
||||||
|
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
|
||||||
|
with pytest.raises(RuntimeError, match="An error occurred while processing your request."):
|
||||||
|
await code_description_to_code_part("search query", mock_user, 2)
|
||||||
Loading…
Add table
Reference in a new issue