From 2408fd7a01765b9e466748d0bd2a83b185469c4a Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 28 Nov 2024 09:12:37 +0100 Subject: [PATCH] fix: falkordb adapter errors --- .../hybrid/falkordb/FalkorDBAdapter.py | 76 ++++++++++--------- .../embeddings/LiteLLMEmbeddingEngine.py | 21 +++-- .../infrastructure/engine/models/DataPoint.py | 4 + .../graph/utils/get_graph_from_model.py | 18 ++--- cognee/shared/CodeGraphEntities.py | 1 - .../repo_processor/enrich_dependency_graph.py | 25 +++--- .../repo_processor/expand_dependency_graph.py | 6 +- .../get_repo_file_dependencies.py | 13 ++-- cognee/tests/test_falkordb.py | 4 +- evals/eval_swe_bench.py | 11 ++- 10 files changed, 101 insertions(+), 78 deletions(-) diff --git a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py index bd6a2bc2d..32a9853c2 100644 --- a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py +++ b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py @@ -1,5 +1,6 @@ import asyncio # from datetime import datetime +import json from uuid import UUID from textwrap import dedent from falkordb import FalkorDB @@ -53,28 +54,28 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): return f"'vecf32({value})'" # if type(value) is datetime: # return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f%z") + if type(value) is dict: + return f"'{json.dumps(value)}'" return f"'{value}'" return ",".join([f"{key}:{parse_value(value)}" for key, value in properties.items()]) - async def create_data_point_query(self, data_point: DataPoint, vectorized_values: list = None): + async def create_data_point_query(self, data_point: DataPoint, vectorized_values: dict): node_label = type(data_point).__tablename__ - embeddable_fields = data_point._metadata.get("index_fields", []) + property_names = DataPoint.get_embeddable_property_names(data_point) node_properties = await self.stringify_properties({ **data_point.model_dump(), **({ - embeddable_fields[index]: vectorized_values[index] \ - for index in range(len(embeddable_fields)) \ - } if vectorized_values is not None else {}), + property_names[index]: (vectorized_values[index] if index in vectorized_values else None) \ + for index in range(len(property_names)) \ + }), }) return dedent(f""" MERGE (node:{node_label} {{id: '{str(data_point.id)}'}}) - ON CREATE SET node += ({{{node_properties}}}) - ON CREATE SET node.updated_at = timestamp() - ON MATCH SET node += ({{{node_properties}}}) - ON MATCH SET node.updated_at = timestamp() + ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp() + ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp() """).strip() async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str: @@ -98,31 +99,33 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): return collection_name in collections async def create_data_points(self, data_points: list[DataPoint]): - embeddable_values = [DataPoint.get_embeddable_properties(data_point) for data_point in data_points] + embeddable_values = [] + vector_map = {} - vectorized_values = await self.embed_data( - sum(embeddable_values, []) - ) + for data_point in data_points: + property_names = DataPoint.get_embeddable_property_names(data_point) + key = str(data_point.id) + vector_map[key] = {} - index = 0 - positioned_vectorized_values = [] + for property_name in property_names: + property_value = getattr(data_point, property_name, None) - for values in embeddable_values: - if len(values) > 0: - values_list = [] - for i in range(len(values)): - values_list.append(vectorized_values[index + i]) + if property_value is not None: + embeddable_values.append(property_value) + vector_map[key][property_name] = len(embeddable_values) - 1 + else: + vector_map[key][property_name] = None - positioned_vectorized_values.append(values_list) - index += len(values) - else: - positioned_vectorized_values.append(None) + vectorized_values = await self.embed_data(embeddable_values) queries = [ await self.create_data_point_query( data_point, - positioned_vectorized_values[index], - ) for index, data_point in enumerate(data_points) + [ + vectorized_values[vector_map[str(data_point.id)][property_name]] \ + for property_name in DataPoint.get_embeddable_property_names(data_point) + ], + ) for data_point in data_points ] for query in queries: @@ -182,18 +185,21 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): return [result["edge_exists"] for result in results] - async def retrieve(self, data_point_ids: list[str]): - return self.query( + async def retrieve(self, data_point_ids: list[UUID]): + result = self.query( f"MATCH (node) WHERE node.id IN $node_ids RETURN node", { - "node_ids": data_point_ids, + "node_ids": [str(data_point) for data_point in data_point_ids], }, ) + return result.result_set - async def extract_node(self, data_point_id: str): - return await self.retrieve([data_point_id]) + async def extract_node(self, data_point_id: UUID): + result = await self.retrieve([data_point_id]) + result = result[0][0] if len(result[0]) > 0 else None + return result.properties if result else None - async def extract_nodes(self, data_point_ids: list[str]): + async def extract_nodes(self, data_point_ids: list[UUID]): return await self.retrieve(data_point_ids) async def get_connections(self, node_id: UUID) -> list: @@ -296,11 +302,11 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): return (nodes, edges) - async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]): return self.query( f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node", { - "node_ids": data_point_ids, + "node_ids": [str(data_point) for data_point in data_point_ids], }, ) @@ -324,4 +330,4 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): print(f"Error deleting graph: {e}") async def prune(self): - self.delete_graph() + await self.delete_graph() diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index ab1274fb8..de30640e5 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -1,9 +1,10 @@ -import asyncio +import logging from typing import List, Optional import litellm from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine litellm.set_verbose = False +logger = logging.getLogger("LiteLLMEmbeddingEngine") class LiteLLMEmbeddingEngine(EmbeddingEngine): api_key: str @@ -28,13 +29,17 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): async def embed_text(self, text: List[str]) -> List[List[float]]: async def get_embedding(text_): - response = await litellm.aembedding( - self.model, - input = text_, - api_key = self.api_key, - api_base = self.endpoint, - api_version = self.api_version - ) + try: + response = await litellm.aembedding( + self.model, + input = text_, + api_key = self.api_key, + api_base = self.endpoint, + api_version = self.api_version + ) + except litellm.exceptions.BadRequestError as error: + logger.error("Error embedding text: %s", str(error)) + raise error return [data["embedding"] for data in response.data] diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 5d6b1d513..b76971f34 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -35,3 +35,7 @@ class DataPoint(BaseModel): return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]] return [] + + @classmethod + def get_embeddable_property_names(self, data_point): + return data_point._metadata["index_fields"] or [] diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index ec1da85e3..7bd300df1 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -118,17 +118,17 @@ async def get_graph_from_model( data_point_properties[field_name] = field_value - SimpleDataPointModel = copy_model( - type(data_point), - include_fields = { - "_metadata": (dict, data_point._metadata), - "__tablename__": data_point.__tablename__, - }, - exclude_fields = excluded_properties, - ) - if include_root: + SimpleDataPointModel = copy_model( + type(data_point), + include_fields = { + "_metadata": (dict, data_point._metadata), + "__tablename__": data_point.__tablename__, + }, + exclude_fields = excluded_properties, + ) nodes.append(SimpleDataPointModel(**data_point_properties)) + added_nodes[str(data_point.id)] = True return nodes, edges diff --git a/cognee/shared/CodeGraphEntities.py b/cognee/shared/CodeGraphEntities.py index 4811106e5..d709b8d3a 100644 --- a/cognee/shared/CodeGraphEntities.py +++ b/cognee/shared/CodeGraphEntities.py @@ -19,7 +19,6 @@ class CodeFile(DataPoint): } class CodePart(DataPoint): - type: str # part_of: Optional[CodeFile] source_code: str type: Optional[str] = "CodePart" diff --git a/cognee/tasks/repo_processor/enrich_dependency_graph.py b/cognee/tasks/repo_processor/enrich_dependency_graph.py index ba222ef3f..03db7b0bb 100644 --- a/cognee/tasks/repo_processor/enrich_dependency_graph.py +++ b/cognee/tasks/repo_processor/enrich_dependency_graph.py @@ -1,6 +1,5 @@ -import asyncio import networkx as nx -from typing import Dict, List +from typing import AsyncGenerator, Dict, List from tqdm.asyncio import tqdm from cognee.infrastructure.engine import DataPoint @@ -66,20 +65,25 @@ async def node_enrich_and_connect( if desc_id not in topological_order[:topological_rank + 1]: continue + desc = None if desc_id in data_points_map: desc = data_points_map[desc_id] else: node_data = await graph_engine.extract_node(desc_id) - desc = convert_node_to_data_point(node_data) + try: + desc = convert_node_to_data_point(node_data) + except Exception: + pass - new_connections.append(desc) + if desc is not None: + new_connections.append(desc) node.depends_directly_on = node.depends_directly_on or [] node.depends_directly_on.extend(new_connections) -async def enrich_dependency_graph(data_points: list[DataPoint]) -> list[DataPoint]: +async def enrich_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerator[list[DataPoint], None]: """Enriches the graph with topological ranks and 'depends_on' edges.""" nodes = [] edges = [] @@ -108,17 +112,18 @@ async def enrich_dependency_graph(data_points: list[DataPoint]) -> list[DataPoin # data_points.append(node_enrich_and_connect(graph, topological_order, node)) data_points_map = {data_point.id: data_point for data_point in data_points} - data_points_futures = [] + # data_points_futures = [] for data_point in tqdm(data_points, desc = "Enriching dependency graph", unit = "data_point"): if data_point.id not in node_rank_map: continue if isinstance(data_point, CodeFile): - data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map)) + # data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map)) + await node_enrich_and_connect(graph, topological_order, data_point, data_points_map) - # yield data_point + yield data_point - await asyncio.gather(*data_points_futures) + # await asyncio.gather(*data_points_futures) - return data_points + # return data_points diff --git a/cognee/tasks/repo_processor/expand_dependency_graph.py b/cognee/tasks/repo_processor/expand_dependency_graph.py index 722bfa5c6..43a451bd6 100644 --- a/cognee/tasks/repo_processor/expand_dependency_graph.py +++ b/cognee/tasks/repo_processor/expand_dependency_graph.py @@ -1,3 +1,4 @@ +from typing import AsyncGenerator from uuid import NAMESPACE_OID, uuid5 # from tqdm import tqdm from cognee.infrastructure.engine import DataPoint @@ -53,11 +54,12 @@ def _process_single_node(code_file: CodeFile) -> None: _add_code_parts_nodes_and_edges(code_file, part_type, code_parts) -async def expand_dependency_graph(data_points: list[DataPoint]) -> list[DataPoint]: +async def expand_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerator[list[DataPoint], None]: """Process Python file nodes, adding code part nodes and edges.""" # for data_point in tqdm(data_points, desc = "Expand dependency graph", unit = "data_point"): for data_point in data_points: if isinstance(data_point, CodeFile): _process_single_node(data_point) + yield data_point - return data_points + # return data_points diff --git a/cognee/tasks/repo_processor/get_repo_file_dependencies.py b/cognee/tasks/repo_processor/get_repo_file_dependencies.py index 58f3857a9..9ac4e9f2e 100644 --- a/cognee/tasks/repo_processor/get_repo_file_dependencies.py +++ b/cognee/tasks/repo_processor/get_repo_file_dependencies.py @@ -1,4 +1,5 @@ import os +from typing import AsyncGenerator from uuid import NAMESPACE_OID, uuid5 import aiofiles from tqdm.asyncio import tqdm @@ -44,7 +45,7 @@ def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bo return (file_path, dependency, {"relation": "depends_directly_on"}) -async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]: +async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list[DataPoint], None]: """Generate a dependency graph for Python files in the given repository path.""" py_files_dict = await get_py_files_dict(repo_path) @@ -53,7 +54,8 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]: path = repo_path, ) - data_points = [repo] + # data_points = [repo] + yield repo # dependency_graph = nx.DiGraph() @@ -66,7 +68,8 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]: dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path) - data_points.append(CodeFile( + # data_points.append() + yield CodeFile( id = uuid5(NAMESPACE_OID, file_path), source_code = source_code, extracted_id = file_path, @@ -78,10 +81,10 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]: part_of = repo, ) for dependency in dependencies ] if len(dependencies) else None, - )) + ) # dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies] # dependency_graph.add_edges_from(dependency_edges) - return data_points + # return data_points # return dependency_graph diff --git a/cognee/tests/test_falkordb.py b/cognee/tests/test_falkordb.py index 36c029cf7..25fe81a75 100755 --- a/cognee/tests/test_falkordb.py +++ b/cognee/tests/test_falkordb.py @@ -8,9 +8,9 @@ from cognee.shared.utils import render_graph logging.basicConfig(level = logging.DEBUG) async def main(): - data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_library")).resolve()) + data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_falkordb")).resolve()) cognee.config.data_root_directory(data_directory_path) - cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_library")).resolve()) + cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_falkordb")).resolve()) cognee.config.system_root_directory(cognee_directory_path) await cognee.prune.prune_data() diff --git a/evals/eval_swe_bench.py b/evals/eval_swe_bench.py index 0a4806e3f..1dd0e58ab 100644 --- a/evals/eval_swe_bench.py +++ b/evals/eval_swe_bench.py @@ -30,7 +30,6 @@ from evals.eval_utils import download_github_repo from evals.eval_utils import delete_repo async def generate_patch_with_cognee(instance): - await cognee.prune.prune_data() await cognee.prune.prune_system() @@ -44,10 +43,10 @@ async def generate_patch_with_cognee(instance): tasks = [ Task(get_repo_file_dependencies), - Task(add_data_points), - Task(enrich_dependency_graph), - Task(expand_dependency_graph), - Task(add_data_points), + Task(add_data_points, task_config = { "batch_size": 50 }), + Task(enrich_dependency_graph, task_config = { "batch_size": 50 }), + Task(expand_dependency_graph, task_config = { "batch_size": 50 }), + Task(add_data_points, task_config = { "batch_size": 50 }), # Task(summarize_code, summarization_model = SummarizedContent), ] @@ -58,7 +57,7 @@ async def generate_patch_with_cognee(instance): print('Here we have the repo under the repo_path') - await render_graph() + await render_graph(None, include_labels = True, include_nodes = True) problem_statement = instance['problem_statement'] instructions = read_query_prompt("patch_gen_instructions.txt")