feat: adds basic retriever for swe bench

This commit is contained in:
hajdul88 2025-01-09 19:54:58 +01:00
parent 56cc223302
commit 9604d95ba5
3 changed files with 49 additions and 39 deletions

View file

@ -10,7 +10,7 @@ from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
async def code_description_to_code_part_search(query: str, user: User = None, top_k=2) -> list: async def code_description_to_code_part_search(query: str, user: User = None, top_k=5) -> list:
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()
@ -55,21 +55,23 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
) )
try: try:
results = await vector_engine.search("code_summary_text", query_text=query, limit=top_k) code_summaries = await vector_engine.search(
if not results: "code_summary_text", query_text=query, limit=top_k
)
if not code_summaries:
logging.warning("No results found for query: '%s' by user: %s", query, user.id) logging.warning("No results found for query: '%s' by user: %s", query, user.id)
return [] return []
memory_fragment = CogneeGraph() memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db( await memory_fragment.project_graph_from_db(
graph_engine, graph_engine,
node_properties_to_project=["id", "type", "text", "source_code"], node_properties_to_project=["id", "type", "text", "source_code", "pydantic_type"],
edge_properties_to_project=["relationship_name"], edge_properties_to_project=["relationship_name"],
) )
code_pieces_to_return = set() code_pieces_to_return = set()
for node in results: for node in code_summaries:
node_id = str(node.id) node_id = str(node.id)
node_to_search_from = memory_fragment.get_node(node_id) node_to_search_from = memory_fragment.get_node(node_id)
@ -78,9 +80,16 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
continue continue
for code_file in node_to_search_from.get_skeleton_neighbours(): for code_file in node_to_search_from.get_skeleton_neighbours():
for code_file_edge in code_file.get_skeleton_edges(): if code_file.get_attribute("pydantic_type") == "SourceCodeChunk":
if code_file_edge.get_attribute("relationship_name") == "contains": for code_file_edge in code_file.get_skeleton_edges():
code_pieces_to_return.add(code_file_edge.get_destination_node()) if code_file_edge.get_attribute("relationship_name") == "code_chunk_of":
code_pieces_to_return.add(code_file_edge.get_destination_node())
elif code_file.get_attribute("pydantic_type") == "CodePart":
code_pieces_to_return.add(code_file)
elif code_file.get_attribute("pydantic_type") == "CodeFile":
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute("relationship_name") == "contains":
code_pieces_to_return.add(code_file_edge.get_destination_node())
logging.info( logging.info(
"Search completed for user: %s, query: '%s'. Found %d code pieces.", "Search completed for user: %s, query: '%s'. Found %d code pieces.",
@ -89,7 +98,11 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
len(code_pieces_to_return), len(code_pieces_to_return),
) )
return list(code_pieces_to_return) context = ""
for code_piece in code_pieces_to_return:
context = context + code_piece.get_attribute("source_code")
return context
except Exception as exec_error: except Exception as exec_error:
logging.error( logging.error(

View file

@ -231,6 +231,7 @@ class SummarizedContent(BaseModel):
summary: str summary: str
description: str description: str
pydantic_type: str = "SummarizedContent"
class SummarizedFunction(BaseModel): class SummarizedFunction(BaseModel):
@ -239,6 +240,7 @@ class SummarizedFunction(BaseModel):
inputs: Optional[List[str]] = None inputs: Optional[List[str]] = None
outputs: Optional[List[str]] = None outputs: Optional[List[str]] = None
decorators: Optional[List[str]] = None decorators: Optional[List[str]] = None
pydantic_type: str = "SummarizedFunction"
class SummarizedClass(BaseModel): class SummarizedClass(BaseModel):
@ -246,6 +248,7 @@ class SummarizedClass(BaseModel):
description: str description: str
methods: Optional[List[SummarizedFunction]] = None methods: Optional[List[SummarizedFunction]] = None
decorators: Optional[List[str]] = None decorators: Optional[List[str]] = None
pydantic_type: str = "SummarizedClass"
class SummarizedCode(BaseModel): class SummarizedCode(BaseModel):
@ -256,6 +259,7 @@ class SummarizedCode(BaseModel):
classes: List[SummarizedClass] = [] classes: List[SummarizedClass] = []
functions: List[SummarizedFunction] = [] functions: List[SummarizedFunction] = []
workflow_description: Optional[str] = None workflow_description: Optional[str] = None
pydantic_type: str = "SummarizedCode"
class GraphDBType(Enum): class GraphDBType(Enum):

View file

@ -11,7 +11,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.description_to_codepart_search import (
code_description_to_code_part_search,
)
from cognee.shared.utils import render_graph from cognee.shared.utils import render_graph
from evals.eval_utils import download_github_repo, retrieved_edges_to_string from evals.eval_utils import download_github_repo, retrieved_edges_to_string
@ -32,26 +34,16 @@ def check_install_package(package_name):
return False return False
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS): async def generate_patch_with_cognee(instance):
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS") """repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")"""
async for result in run_code_graph_pipeline(repo_path, include_docs=True):
print(result)
print("Here we have the repo under the repo_path")
await render_graph(None, include_labels=True, include_nodes=True)
problem_statement = instance["problem_statement"] problem_statement = instance["problem_statement"]
instructions = read_query_prompt("patch_gen_kg_instructions.txt") instructions = read_query_prompt("patch_gen_kg_instructions.txt")
retrieved_edges = await brute_force_triplet_search( repo_path = "/Users/laszlohajdu/Documents/GitHub/test/"
problem_statement, async for result in run_code_graph_pipeline(repo_path, include_docs=False):
top_k=3, print(result)
collections=["code_summary_text"],
)
retrieved_edges_str = retrieved_edges_to_string(retrieved_edges) retrieved_codeparts = await code_description_to_code_part_search(problem_statement)
prompt = "\n".join( prompt = "\n".join(
[ [
@ -60,7 +52,7 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
PATCH_EXAMPLE, PATCH_EXAMPLE,
"</patch>", "</patch>",
"These are the retrieved edges:", "These are the retrieved edges:",
retrieved_edges_str, retrieved_codeparts,
] ]
) )
@ -86,8 +78,6 @@ async def generate_patch_without_cognee(instance, llm_client):
async def get_preds(dataset, with_cognee=True): async def get_preds(dataset, with_cognee=True):
llm_client = get_llm_client()
if with_cognee: if with_cognee:
model_name = "with_cognee" model_name = "with_cognee"
pred_func = generate_patch_with_cognee pred_func = generate_patch_with_cognee
@ -95,17 +85,18 @@ async def get_preds(dataset, with_cognee=True):
model_name = "without_cognee" model_name = "without_cognee"
pred_func = generate_patch_without_cognee pred_func = generate_patch_without_cognee
futures = [(instance["instance_id"], pred_func(instance, llm_client)) for instance in dataset] preds = []
model_patches = await asyncio.gather(*[x[1] for x in futures])
preds = [ for instance in dataset:
{ instance_id = instance["instance_id"]
"instance_id": instance_id, model_patch = await pred_func(instance) # Sequentially await the async function
"model_patch": model_patch, preds.append(
"model_name_or_path": model_name, {
} "instance_id": instance_id,
for (instance_id, _), model_patch in zip(futures, model_patches) "model_patch": model_patch,
] "model_name_or_path": model_name,
}
)
return preds return preds
@ -135,6 +126,7 @@ async def main():
with open(predictions_path, "w") as file: with open(predictions_path, "w") as file:
json.dump(preds, file) json.dump(preds, file)
""" This part is for the evaluation
subprocess.run( subprocess.run(
[ [
"python", "python",
@ -152,6 +144,7 @@ async def main():
"test_run", "test_run",
] ]
) )
"""
if __name__ == "__main__": if __name__ == "__main__":