diff --git a/cognee/modules/retrieval/description_to_codepart_search.py b/cognee/modules/retrieval/description_to_codepart_search.py index ecd187907..fec17fb16 100644 --- a/cognee/modules/retrieval/description_to_codepart_search.py +++ b/cognee/modules/retrieval/description_to_codepart_search.py @@ -10,7 +10,7 @@ from cognee.modules.users.models import User from cognee.shared.utils import send_telemetry -async def code_description_to_code_part_search(query: str, user: User = None, top_k=2) -> list: +async def code_description_to_code_part_search(query: str, user: User = None, top_k=5) -> list: if user is None: user = await get_default_user() @@ -55,21 +55,23 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L ) try: - results = await vector_engine.search("code_summary_text", query_text=query, limit=top_k) - if not results: + code_summaries = await vector_engine.search( + "code_summary_text", query_text=query, limit=top_k + ) + if not code_summaries: logging.warning("No results found for query: '%s' by user: %s", query, user.id) return [] memory_fragment = CogneeGraph() await memory_fragment.project_graph_from_db( graph_engine, - node_properties_to_project=["id", "type", "text", "source_code"], + node_properties_to_project=["id", "type", "text", "source_code", "pydantic_type"], edge_properties_to_project=["relationship_name"], ) code_pieces_to_return = set() - for node in results: + for node in code_summaries: node_id = str(node.id) node_to_search_from = memory_fragment.get_node(node_id) @@ -78,9 +80,16 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L continue for code_file in node_to_search_from.get_skeleton_neighbours(): - for code_file_edge in code_file.get_skeleton_edges(): - if code_file_edge.get_attribute("relationship_name") == "contains": - code_pieces_to_return.add(code_file_edge.get_destination_node()) + if code_file.get_attribute("pydantic_type") == "SourceCodeChunk": + for code_file_edge in code_file.get_skeleton_edges(): + if code_file_edge.get_attribute("relationship_name") == "code_chunk_of": + code_pieces_to_return.add(code_file_edge.get_destination_node()) + elif code_file.get_attribute("pydantic_type") == "CodePart": + code_pieces_to_return.add(code_file) + elif code_file.get_attribute("pydantic_type") == "CodeFile": + for code_file_edge in code_file.get_skeleton_edges(): + if code_file_edge.get_attribute("relationship_name") == "contains": + code_pieces_to_return.add(code_file_edge.get_destination_node()) logging.info( "Search completed for user: %s, query: '%s'. Found %d code pieces.", @@ -89,7 +98,11 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L len(code_pieces_to_return), ) - return list(code_pieces_to_return) + context = "" + for code_piece in code_pieces_to_return: + context = context + code_piece.get_attribute("source_code") + + return context except Exception as exec_error: logging.error( diff --git a/cognee/shared/data_models.py b/cognee/shared/data_models.py index d23d2841c..a36a09010 100644 --- a/cognee/shared/data_models.py +++ b/cognee/shared/data_models.py @@ -231,6 +231,7 @@ class SummarizedContent(BaseModel): summary: str description: str + pydantic_type: str = "SummarizedContent" class SummarizedFunction(BaseModel): @@ -239,6 +240,7 @@ class SummarizedFunction(BaseModel): inputs: Optional[List[str]] = None outputs: Optional[List[str]] = None decorators: Optional[List[str]] = None + pydantic_type: str = "SummarizedFunction" class SummarizedClass(BaseModel): @@ -246,6 +248,7 @@ class SummarizedClass(BaseModel): description: str methods: Optional[List[SummarizedFunction]] = None decorators: Optional[List[str]] = None + pydantic_type: str = "SummarizedClass" class SummarizedCode(BaseModel): @@ -256,6 +259,7 @@ class SummarizedCode(BaseModel): classes: List[SummarizedClass] = [] functions: List[SummarizedFunction] = [] workflow_description: Optional[str] = None + pydantic_type: str = "SummarizedCode" class GraphDBType(Enum): diff --git a/evals/eval_swe_bench.py b/evals/eval_swe_bench.py index 20e005751..b5fcc616b 100644 --- a/evals/eval_swe_bench.py +++ b/evals/eval_swe_bench.py @@ -11,7 +11,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.api.v1.search import SearchType from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt -from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.description_to_codepart_search import ( + code_description_to_code_part_search, +) from cognee.shared.utils import render_graph from evals.eval_utils import download_github_repo, retrieved_edges_to_string @@ -32,26 +34,16 @@ def check_install_package(package_name): return False -async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS): - repo_path = download_github_repo(instance, "../RAW_GIT_REPOS") - - async for result in run_code_graph_pipeline(repo_path, include_docs=True): - print(result) - - print("Here we have the repo under the repo_path") - - await render_graph(None, include_labels=True, include_nodes=True) - +async def generate_patch_with_cognee(instance): + """repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")""" problem_statement = instance["problem_statement"] instructions = read_query_prompt("patch_gen_kg_instructions.txt") - retrieved_edges = await brute_force_triplet_search( - problem_statement, - top_k=3, - collections=["code_summary_text"], - ) + repo_path = "/Users/laszlohajdu/Documents/GitHub/test/" + async for result in run_code_graph_pipeline(repo_path, include_docs=False): + print(result) - retrieved_edges_str = retrieved_edges_to_string(retrieved_edges) + retrieved_codeparts = await code_description_to_code_part_search(problem_statement) prompt = "\n".join( [ @@ -60,7 +52,7 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp PATCH_EXAMPLE, "", "These are the retrieved edges:", - retrieved_edges_str, + retrieved_codeparts, ] ) @@ -86,8 +78,6 @@ async def generate_patch_without_cognee(instance, llm_client): async def get_preds(dataset, with_cognee=True): - llm_client = get_llm_client() - if with_cognee: model_name = "with_cognee" pred_func = generate_patch_with_cognee @@ -95,17 +85,18 @@ async def get_preds(dataset, with_cognee=True): model_name = "without_cognee" pred_func = generate_patch_without_cognee - futures = [(instance["instance_id"], pred_func(instance, llm_client)) for instance in dataset] - model_patches = await asyncio.gather(*[x[1] for x in futures]) + preds = [] - preds = [ - { - "instance_id": instance_id, - "model_patch": model_patch, - "model_name_or_path": model_name, - } - for (instance_id, _), model_patch in zip(futures, model_patches) - ] + for instance in dataset: + instance_id = instance["instance_id"] + model_patch = await pred_func(instance) # Sequentially await the async function + preds.append( + { + "instance_id": instance_id, + "model_patch": model_patch, + "model_name_or_path": model_name, + } + ) return preds @@ -135,6 +126,7 @@ async def main(): with open(predictions_path, "w") as file: json.dump(preds, file) + """ This part is for the evaluation subprocess.run( [ "python", @@ -152,6 +144,7 @@ async def main(): "test_run", ] ) + """ if __name__ == "__main__":