fix: falkordb adapter errors

This commit is contained in:
Boris Arzentar 2024-11-28 09:12:37 +01:00
parent 6403d15a76
commit 2408fd7a01
10 changed files with 101 additions and 78 deletions

View file

@ -1,5 +1,6 @@
import asyncio
# from datetime import datetime
import json
from uuid import UUID
from textwrap import dedent
from falkordb import FalkorDB
@ -53,28 +54,28 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return f"'vecf32({value})'"
# if type(value) is datetime:
# return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f%z")
if type(value) is dict:
return f"'{json.dumps(value)}'"
return f"'{value}'"
return ",".join([f"{key}:{parse_value(value)}" for key, value in properties.items()])
async def create_data_point_query(self, data_point: DataPoint, vectorized_values: list = None):
async def create_data_point_query(self, data_point: DataPoint, vectorized_values: dict):
node_label = type(data_point).__tablename__
embeddable_fields = data_point._metadata.get("index_fields", [])
property_names = DataPoint.get_embeddable_property_names(data_point)
node_properties = await self.stringify_properties({
**data_point.model_dump(),
**({
embeddable_fields[index]: vectorized_values[index] \
for index in range(len(embeddable_fields)) \
} if vectorized_values is not None else {}),
property_names[index]: (vectorized_values[index] if index in vectorized_values else None) \
for index in range(len(property_names)) \
}),
})
return dedent(f"""
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
ON CREATE SET node += ({{{node_properties}}})
ON CREATE SET node.updated_at = timestamp()
ON MATCH SET node += ({{{node_properties}}})
ON MATCH SET node.updated_at = timestamp()
ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
""").strip()
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
@ -98,31 +99,33 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return collection_name in collections
async def create_data_points(self, data_points: list[DataPoint]):
embeddable_values = [DataPoint.get_embeddable_properties(data_point) for data_point in data_points]
embeddable_values = []
vector_map = {}
vectorized_values = await self.embed_data(
sum(embeddable_values, [])
)
for data_point in data_points:
property_names = DataPoint.get_embeddable_property_names(data_point)
key = str(data_point.id)
vector_map[key] = {}
index = 0
positioned_vectorized_values = []
for property_name in property_names:
property_value = getattr(data_point, property_name, None)
for values in embeddable_values:
if len(values) > 0:
values_list = []
for i in range(len(values)):
values_list.append(vectorized_values[index + i])
if property_value is not None:
embeddable_values.append(property_value)
vector_map[key][property_name] = len(embeddable_values) - 1
else:
vector_map[key][property_name] = None
positioned_vectorized_values.append(values_list)
index += len(values)
else:
positioned_vectorized_values.append(None)
vectorized_values = await self.embed_data(embeddable_values)
queries = [
await self.create_data_point_query(
data_point,
positioned_vectorized_values[index],
) for index, data_point in enumerate(data_points)
[
vectorized_values[vector_map[str(data_point.id)][property_name]] \
for property_name in DataPoint.get_embeddable_property_names(data_point)
],
) for data_point in data_points
]
for query in queries:
@ -182,18 +185,21 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return [result["edge_exists"] for result in results]
async def retrieve(self, data_point_ids: list[str]):
return self.query(
async def retrieve(self, data_point_ids: list[UUID]):
result = self.query(
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
{
"node_ids": data_point_ids,
"node_ids": [str(data_point) for data_point in data_point_ids],
},
)
return result.result_set
async def extract_node(self, data_point_id: str):
return await self.retrieve([data_point_id])
async def extract_node(self, data_point_id: UUID):
result = await self.retrieve([data_point_id])
result = result[0][0] if len(result[0]) > 0 else None
return result.properties if result else None
async def extract_nodes(self, data_point_ids: list[str]):
async def extract_nodes(self, data_point_ids: list[UUID]):
return await self.retrieve(data_point_ids)
async def get_connections(self, node_id: UUID) -> list:
@ -296,11 +302,11 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return (nodes, edges)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
return self.query(
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
{
"node_ids": data_point_ids,
"node_ids": [str(data_point) for data_point in data_point_ids],
},
)
@ -324,4 +330,4 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
print(f"Error deleting graph: {e}")
async def prune(self):
self.delete_graph()
await self.delete_graph()

View file

@ -1,9 +1,10 @@
import asyncio
import logging
from typing import List, Optional
import litellm
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False
logger = logging.getLogger("LiteLLMEmbeddingEngine")
class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str
@ -28,13 +29,17 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
async def embed_text(self, text: List[str]) -> List[List[float]]:
async def get_embedding(text_):
response = await litellm.aembedding(
self.model,
input = text_,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)
try:
response = await litellm.aembedding(
self.model,
input = text_,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)
except litellm.exceptions.BadRequestError as error:
logger.error("Error embedding text: %s", str(error))
raise error
return [data["embedding"] for data in response.data]

View file

@ -35,3 +35,7 @@ class DataPoint(BaseModel):
return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]]
return []
@classmethod
def get_embeddable_property_names(self, data_point):
return data_point._metadata["index_fields"] or []

View file

@ -118,17 +118,17 @@ async def get_graph_from_model(
data_point_properties[field_name] = field_value
SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
"_metadata": (dict, data_point._metadata),
"__tablename__": data_point.__tablename__,
},
exclude_fields = excluded_properties,
)
if include_root:
SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
"_metadata": (dict, data_point._metadata),
"__tablename__": data_point.__tablename__,
},
exclude_fields = excluded_properties,
)
nodes.append(SimpleDataPointModel(**data_point_properties))
added_nodes[str(data_point.id)] = True
return nodes, edges

View file

@ -19,7 +19,6 @@ class CodeFile(DataPoint):
}
class CodePart(DataPoint):
type: str
# part_of: Optional[CodeFile]
source_code: str
type: Optional[str] = "CodePart"

View file

@ -1,6 +1,5 @@
import asyncio
import networkx as nx
from typing import Dict, List
from typing import AsyncGenerator, Dict, List
from tqdm.asyncio import tqdm
from cognee.infrastructure.engine import DataPoint
@ -66,20 +65,25 @@ async def node_enrich_and_connect(
if desc_id not in topological_order[:topological_rank + 1]:
continue
desc = None
if desc_id in data_points_map:
desc = data_points_map[desc_id]
else:
node_data = await graph_engine.extract_node(desc_id)
desc = convert_node_to_data_point(node_data)
try:
desc = convert_node_to_data_point(node_data)
except Exception:
pass
new_connections.append(desc)
if desc is not None:
new_connections.append(desc)
node.depends_directly_on = node.depends_directly_on or []
node.depends_directly_on.extend(new_connections)
async def enrich_dependency_graph(data_points: list[DataPoint]) -> list[DataPoint]:
async def enrich_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerator[list[DataPoint], None]:
"""Enriches the graph with topological ranks and 'depends_on' edges."""
nodes = []
edges = []
@ -108,17 +112,18 @@ async def enrich_dependency_graph(data_points: list[DataPoint]) -> list[DataPoin
# data_points.append(node_enrich_and_connect(graph, topological_order, node))
data_points_map = {data_point.id: data_point for data_point in data_points}
data_points_futures = []
# data_points_futures = []
for data_point in tqdm(data_points, desc = "Enriching dependency graph", unit = "data_point"):
if data_point.id not in node_rank_map:
continue
if isinstance(data_point, CodeFile):
data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map))
# data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map))
await node_enrich_and_connect(graph, topological_order, data_point, data_points_map)
# yield data_point
yield data_point
await asyncio.gather(*data_points_futures)
# await asyncio.gather(*data_points_futures)
return data_points
# return data_points

View file

@ -1,3 +1,4 @@
from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5
# from tqdm import tqdm
from cognee.infrastructure.engine import DataPoint
@ -53,11 +54,12 @@ def _process_single_node(code_file: CodeFile) -> None:
_add_code_parts_nodes_and_edges(code_file, part_type, code_parts)
async def expand_dependency_graph(data_points: list[DataPoint]) -> list[DataPoint]:
async def expand_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerator[list[DataPoint], None]:
"""Process Python file nodes, adding code part nodes and edges."""
# for data_point in tqdm(data_points, desc = "Expand dependency graph", unit = "data_point"):
for data_point in data_points:
if isinstance(data_point, CodeFile):
_process_single_node(data_point)
yield data_point
return data_points
# return data_points

View file

@ -1,4 +1,5 @@
import os
from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5
import aiofiles
from tqdm.asyncio import tqdm
@ -44,7 +45,7 @@ def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bo
return (file_path, dependency, {"relation": "depends_directly_on"})
async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list[DataPoint], None]:
"""Generate a dependency graph for Python files in the given repository path."""
py_files_dict = await get_py_files_dict(repo_path)
@ -53,7 +54,8 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
path = repo_path,
)
data_points = [repo]
# data_points = [repo]
yield repo
# dependency_graph = nx.DiGraph()
@ -66,7 +68,8 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path)
data_points.append(CodeFile(
# data_points.append()
yield CodeFile(
id = uuid5(NAMESPACE_OID, file_path),
source_code = source_code,
extracted_id = file_path,
@ -78,10 +81,10 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
part_of = repo,
) for dependency in dependencies
] if len(dependencies) else None,
))
)
# dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
# dependency_graph.add_edges_from(dependency_edges)
return data_points
# return data_points
# return dependency_graph

View file

@ -8,9 +8,9 @@ from cognee.shared.utils import render_graph
logging.basicConfig(level = logging.DEBUG)
async def main():
data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_library")).resolve())
data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_falkordb")).resolve())
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_library")).resolve())
cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_falkordb")).resolve())
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()

View file

@ -30,7 +30,6 @@ from evals.eval_utils import download_github_repo
from evals.eval_utils import delete_repo
async def generate_patch_with_cognee(instance):
await cognee.prune.prune_data()
await cognee.prune.prune_system()
@ -44,10 +43,10 @@ async def generate_patch_with_cognee(instance):
tasks = [
Task(get_repo_file_dependencies),
Task(add_data_points),
Task(enrich_dependency_graph),
Task(expand_dependency_graph),
Task(add_data_points),
Task(add_data_points, task_config = { "batch_size": 50 }),
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
Task(add_data_points, task_config = { "batch_size": 50 }),
# Task(summarize_code, summarization_model = SummarizedContent),
]
@ -58,7 +57,7 @@ async def generate_patch_with_cognee(instance):
print('Here we have the repo under the repo_path')
await render_graph()
await render_graph(None, include_labels = True, include_nodes = True)
problem_statement = instance['problem_statement']
instructions = read_query_prompt("patch_gen_instructions.txt")