diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index d2baa3b1c..b6215b81b 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -1,8 +1,11 @@ +import os +import pathlib import asyncio from cognee.shared.logging_utils import get_logger from uuid import NAMESPACE_OID, uuid5 from cognee.api.v1.search import SearchType, search +from cognee.api.v1.visualize.visualize import visualize_graph from cognee.base_config import get_base_config from cognee.modules.cognify.config import get_cognify_config from cognee.modules.pipelines import run_tasks @@ -78,10 +81,13 @@ async def run_code_graph_pipeline(repo_path, include_docs=False): if __name__ == "__main__": async def main(): - async for data_points in run_code_graph_pipeline("YOUR_REPO_PATH"): - print(data_points) + async for run_status in run_code_graph_pipeline("REPO_PATH"): + print(f"{run_status.pipeline_name}: {run_status.status}") - await render_graph() + file_path = os.path.join( + pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html" + ) + await visualize_graph(file_path) search_results = await search( query_type=SearchType.CODE, @@ -89,6 +95,6 @@ if __name__ == "__main__": ) for file in search_results: - print(file.filename) + print(file["name"]) asyncio.run(main()) diff --git a/cognee/modules/retrieval/code_retriever.py b/cognee/modules/retrieval/code_retriever.py index 27c601b60..1d0cedb63 100644 --- a/cognee/modules/retrieval/code_retriever.py +++ b/cognee/modules/retrieval/code_retriever.py @@ -1,12 +1,9 @@ -from typing import Any, Optional, List, Dict +from typing import Any, Optional, List import asyncio import aiofiles from pydantic import BaseModel -from cognee.low_level import DataPoint -from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.retrieval.base_retriever import BaseRetriever -from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.llm.get_llm_client import get_llm_client @@ -98,11 +95,11 @@ class CodeRetriever(BaseRetriever): {"id": res.id, "score": res.score, "payload": res.payload} ) - file_ids = [str(item["id"]) for item in similar_filenames] - code_ids = [str(item["id"]) for item in similar_codepieces] - relevant_triplets = await asyncio.gather( - *[graph_engine.get_connections(node_id) for node_id in code_ids + file_ids] + *[ + graph_engine.get_connections(similar_piece["id"]) + for similar_piece in similar_filenames + similar_codepieces + ] ) paths = set()