fix: falkordb adapter errors

This commit is contained in:
Boris Arzentar 2024-11-28 09:12:37 +01:00
parent 6403d15a76
commit 2408fd7a01
10 changed files with 101 additions and 78 deletions

View file

@ -1,5 +1,6 @@
import asyncio import asyncio
# from datetime import datetime # from datetime import datetime
import json
from uuid import UUID from uuid import UUID
from textwrap import dedent from textwrap import dedent
from falkordb import FalkorDB from falkordb import FalkorDB
@ -53,28 +54,28 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return f"'vecf32({value})'" return f"'vecf32({value})'"
# if type(value) is datetime: # if type(value) is datetime:
# return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f%z") # return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f%z")
if type(value) is dict:
return f"'{json.dumps(value)}'"
return f"'{value}'" return f"'{value}'"
return ",".join([f"{key}:{parse_value(value)}" for key, value in properties.items()]) return ",".join([f"{key}:{parse_value(value)}" for key, value in properties.items()])
async def create_data_point_query(self, data_point: DataPoint, vectorized_values: list = None): async def create_data_point_query(self, data_point: DataPoint, vectorized_values: dict):
node_label = type(data_point).__tablename__ node_label = type(data_point).__tablename__
embeddable_fields = data_point._metadata.get("index_fields", []) property_names = DataPoint.get_embeddable_property_names(data_point)
node_properties = await self.stringify_properties({ node_properties = await self.stringify_properties({
**data_point.model_dump(), **data_point.model_dump(),
**({ **({
embeddable_fields[index]: vectorized_values[index] \ property_names[index]: (vectorized_values[index] if index in vectorized_values else None) \
for index in range(len(embeddable_fields)) \ for index in range(len(property_names)) \
} if vectorized_values is not None else {}), }),
}) })
return dedent(f""" return dedent(f"""
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}}) MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
ON CREATE SET node += ({{{node_properties}}}) ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
ON CREATE SET node.updated_at = timestamp() ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
ON MATCH SET node += ({{{node_properties}}})
ON MATCH SET node.updated_at = timestamp()
""").strip() """).strip()
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str: async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
@ -98,31 +99,33 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return collection_name in collections return collection_name in collections
async def create_data_points(self, data_points: list[DataPoint]): async def create_data_points(self, data_points: list[DataPoint]):
embeddable_values = [DataPoint.get_embeddable_properties(data_point) for data_point in data_points] embeddable_values = []
vector_map = {}
vectorized_values = await self.embed_data( for data_point in data_points:
sum(embeddable_values, []) property_names = DataPoint.get_embeddable_property_names(data_point)
) key = str(data_point.id)
vector_map[key] = {}
index = 0 for property_name in property_names:
positioned_vectorized_values = [] property_value = getattr(data_point, property_name, None)
for values in embeddable_values: if property_value is not None:
if len(values) > 0: embeddable_values.append(property_value)
values_list = [] vector_map[key][property_name] = len(embeddable_values) - 1
for i in range(len(values)): else:
values_list.append(vectorized_values[index + i]) vector_map[key][property_name] = None
positioned_vectorized_values.append(values_list) vectorized_values = await self.embed_data(embeddable_values)
index += len(values)
else:
positioned_vectorized_values.append(None)
queries = [ queries = [
await self.create_data_point_query( await self.create_data_point_query(
data_point, data_point,
positioned_vectorized_values[index], [
) for index, data_point in enumerate(data_points) vectorized_values[vector_map[str(data_point.id)][property_name]] \
for property_name in DataPoint.get_embeddable_property_names(data_point)
],
) for data_point in data_points
] ]
for query in queries: for query in queries:
@ -182,18 +185,21 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return [result["edge_exists"] for result in results] return [result["edge_exists"] for result in results]
async def retrieve(self, data_point_ids: list[str]): async def retrieve(self, data_point_ids: list[UUID]):
return self.query( result = self.query(
f"MATCH (node) WHERE node.id IN $node_ids RETURN node", f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
{ {
"node_ids": data_point_ids, "node_ids": [str(data_point) for data_point in data_point_ids],
}, },
) )
return result.result_set
async def extract_node(self, data_point_id: str): async def extract_node(self, data_point_id: UUID):
return await self.retrieve([data_point_id]) result = await self.retrieve([data_point_id])
result = result[0][0] if len(result[0]) > 0 else None
return result.properties if result else None
async def extract_nodes(self, data_point_ids: list[str]): async def extract_nodes(self, data_point_ids: list[UUID]):
return await self.retrieve(data_point_ids) return await self.retrieve(data_point_ids)
async def get_connections(self, node_id: UUID) -> list: async def get_connections(self, node_id: UUID) -> list:
@ -296,11 +302,11 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return (nodes, edges) return (nodes, edges)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
return self.query( return self.query(
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node", f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
{ {
"node_ids": data_point_ids, "node_ids": [str(data_point) for data_point in data_point_ids],
}, },
) )
@ -324,4 +330,4 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
print(f"Error deleting graph: {e}") print(f"Error deleting graph: {e}")
async def prune(self): async def prune(self):
self.delete_graph() await self.delete_graph()

View file

@ -1,9 +1,10 @@
import asyncio import logging
from typing import List, Optional from typing import List, Optional
import litellm import litellm
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False litellm.set_verbose = False
logger = logging.getLogger("LiteLLMEmbeddingEngine")
class LiteLLMEmbeddingEngine(EmbeddingEngine): class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str api_key: str
@ -28,13 +29,17 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
async def embed_text(self, text: List[str]) -> List[List[float]]: async def embed_text(self, text: List[str]) -> List[List[float]]:
async def get_embedding(text_): async def get_embedding(text_):
response = await litellm.aembedding( try:
self.model, response = await litellm.aembedding(
input = text_, self.model,
api_key = self.api_key, input = text_,
api_base = self.endpoint, api_key = self.api_key,
api_version = self.api_version api_base = self.endpoint,
) api_version = self.api_version
)
except litellm.exceptions.BadRequestError as error:
logger.error("Error embedding text: %s", str(error))
raise error
return [data["embedding"] for data in response.data] return [data["embedding"] for data in response.data]

View file

@ -35,3 +35,7 @@ class DataPoint(BaseModel):
return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]] return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]]
return [] return []
@classmethod
def get_embeddable_property_names(self, data_point):
return data_point._metadata["index_fields"] or []

View file

@ -118,17 +118,17 @@ async def get_graph_from_model(
data_point_properties[field_name] = field_value data_point_properties[field_name] = field_value
SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
"_metadata": (dict, data_point._metadata),
"__tablename__": data_point.__tablename__,
},
exclude_fields = excluded_properties,
)
if include_root: if include_root:
SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
"_metadata": (dict, data_point._metadata),
"__tablename__": data_point.__tablename__,
},
exclude_fields = excluded_properties,
)
nodes.append(SimpleDataPointModel(**data_point_properties)) nodes.append(SimpleDataPointModel(**data_point_properties))
added_nodes[str(data_point.id)] = True
return nodes, edges return nodes, edges

View file

@ -19,7 +19,6 @@ class CodeFile(DataPoint):
} }
class CodePart(DataPoint): class CodePart(DataPoint):
type: str
# part_of: Optional[CodeFile] # part_of: Optional[CodeFile]
source_code: str source_code: str
type: Optional[str] = "CodePart" type: Optional[str] = "CodePart"

View file

@ -1,6 +1,5 @@
import asyncio
import networkx as nx import networkx as nx
from typing import Dict, List from typing import AsyncGenerator, Dict, List
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -66,20 +65,25 @@ async def node_enrich_and_connect(
if desc_id not in topological_order[:topological_rank + 1]: if desc_id not in topological_order[:topological_rank + 1]:
continue continue
desc = None
if desc_id in data_points_map: if desc_id in data_points_map:
desc = data_points_map[desc_id] desc = data_points_map[desc_id]
else: else:
node_data = await graph_engine.extract_node(desc_id) node_data = await graph_engine.extract_node(desc_id)
desc = convert_node_to_data_point(node_data) try:
desc = convert_node_to_data_point(node_data)
except Exception:
pass
new_connections.append(desc) if desc is not None:
new_connections.append(desc)
node.depends_directly_on = node.depends_directly_on or [] node.depends_directly_on = node.depends_directly_on or []
node.depends_directly_on.extend(new_connections) node.depends_directly_on.extend(new_connections)
async def enrich_dependency_graph(data_points: list[DataPoint]) -> list[DataPoint]: async def enrich_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerator[list[DataPoint], None]:
"""Enriches the graph with topological ranks and 'depends_on' edges.""" """Enriches the graph with topological ranks and 'depends_on' edges."""
nodes = [] nodes = []
edges = [] edges = []
@ -108,17 +112,18 @@ async def enrich_dependency_graph(data_points: list[DataPoint]) -> list[DataPoin
# data_points.append(node_enrich_and_connect(graph, topological_order, node)) # data_points.append(node_enrich_and_connect(graph, topological_order, node))
data_points_map = {data_point.id: data_point for data_point in data_points} data_points_map = {data_point.id: data_point for data_point in data_points}
data_points_futures = [] # data_points_futures = []
for data_point in tqdm(data_points, desc = "Enriching dependency graph", unit = "data_point"): for data_point in tqdm(data_points, desc = "Enriching dependency graph", unit = "data_point"):
if data_point.id not in node_rank_map: if data_point.id not in node_rank_map:
continue continue
if isinstance(data_point, CodeFile): if isinstance(data_point, CodeFile):
data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map)) # data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map))
await node_enrich_and_connect(graph, topological_order, data_point, data_points_map)
# yield data_point yield data_point
await asyncio.gather(*data_points_futures) # await asyncio.gather(*data_points_futures)
return data_points # return data_points

View file

@ -1,3 +1,4 @@
from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5 from uuid import NAMESPACE_OID, uuid5
# from tqdm import tqdm # from tqdm import tqdm
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -53,11 +54,12 @@ def _process_single_node(code_file: CodeFile) -> None:
_add_code_parts_nodes_and_edges(code_file, part_type, code_parts) _add_code_parts_nodes_and_edges(code_file, part_type, code_parts)
async def expand_dependency_graph(data_points: list[DataPoint]) -> list[DataPoint]: async def expand_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerator[list[DataPoint], None]:
"""Process Python file nodes, adding code part nodes and edges.""" """Process Python file nodes, adding code part nodes and edges."""
# for data_point in tqdm(data_points, desc = "Expand dependency graph", unit = "data_point"): # for data_point in tqdm(data_points, desc = "Expand dependency graph", unit = "data_point"):
for data_point in data_points: for data_point in data_points:
if isinstance(data_point, CodeFile): if isinstance(data_point, CodeFile):
_process_single_node(data_point) _process_single_node(data_point)
yield data_point
return data_points # return data_points

View file

@ -1,4 +1,5 @@
import os import os
from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5 from uuid import NAMESPACE_OID, uuid5
import aiofiles import aiofiles
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
@ -44,7 +45,7 @@ def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bo
return (file_path, dependency, {"relation": "depends_directly_on"}) return (file_path, dependency, {"relation": "depends_directly_on"})
async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]: async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list[DataPoint], None]:
"""Generate a dependency graph for Python files in the given repository path.""" """Generate a dependency graph for Python files in the given repository path."""
py_files_dict = await get_py_files_dict(repo_path) py_files_dict = await get_py_files_dict(repo_path)
@ -53,7 +54,8 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
path = repo_path, path = repo_path,
) )
data_points = [repo] # data_points = [repo]
yield repo
# dependency_graph = nx.DiGraph() # dependency_graph = nx.DiGraph()
@ -66,7 +68,8 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path) dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path)
data_points.append(CodeFile( # data_points.append()
yield CodeFile(
id = uuid5(NAMESPACE_OID, file_path), id = uuid5(NAMESPACE_OID, file_path),
source_code = source_code, source_code = source_code,
extracted_id = file_path, extracted_id = file_path,
@ -78,10 +81,10 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
part_of = repo, part_of = repo,
) for dependency in dependencies ) for dependency in dependencies
] if len(dependencies) else None, ] if len(dependencies) else None,
)) )
# dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies] # dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
# dependency_graph.add_edges_from(dependency_edges) # dependency_graph.add_edges_from(dependency_edges)
return data_points # return data_points
# return dependency_graph # return dependency_graph

View file

@ -8,9 +8,9 @@ from cognee.shared.utils import render_graph
logging.basicConfig(level = logging.DEBUG) logging.basicConfig(level = logging.DEBUG)
async def main(): async def main():
data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_library")).resolve()) data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_falkordb")).resolve())
cognee.config.data_root_directory(data_directory_path) cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_library")).resolve()) cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_falkordb")).resolve())
cognee.config.system_root_directory(cognee_directory_path) cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data() await cognee.prune.prune_data()

View file

@ -30,7 +30,6 @@ from evals.eval_utils import download_github_repo
from evals.eval_utils import delete_repo from evals.eval_utils import delete_repo
async def generate_patch_with_cognee(instance): async def generate_patch_with_cognee(instance):
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system() await cognee.prune.prune_system()
@ -44,10 +43,10 @@ async def generate_patch_with_cognee(instance):
tasks = [ tasks = [
Task(get_repo_file_dependencies), Task(get_repo_file_dependencies),
Task(add_data_points), Task(add_data_points, task_config = { "batch_size": 50 }),
Task(enrich_dependency_graph), Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
Task(expand_dependency_graph), Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
Task(add_data_points), Task(add_data_points, task_config = { "batch_size": 50 }),
# Task(summarize_code, summarization_model = SummarizedContent), # Task(summarize_code, summarization_model = SummarizedContent),
] ]
@ -58,7 +57,7 @@ async def generate_patch_with_cognee(instance):
print('Here we have the repo under the repo_path') print('Here we have the repo under the repo_path')
await render_graph() await render_graph(None, include_labels = True, include_nodes = True)
problem_statement = instance['problem_statement'] problem_statement = instance['problem_statement']
instructions = read_query_prompt("patch_gen_instructions.txt") instructions = read_query_prompt("patch_gen_instructions.txt")