fix: falkordb adapter errors
This commit is contained in:
parent
6403d15a76
commit
2408fd7a01
10 changed files with 101 additions and 78 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
# from datetime import datetime
|
||||
import json
|
||||
from uuid import UUID
|
||||
from textwrap import dedent
|
||||
from falkordb import FalkorDB
|
||||
|
|
@ -53,28 +54,28 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
return f"'vecf32({value})'"
|
||||
# if type(value) is datetime:
|
||||
# return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
if type(value) is dict:
|
||||
return f"'{json.dumps(value)}'"
|
||||
return f"'{value}'"
|
||||
|
||||
return ",".join([f"{key}:{parse_value(value)}" for key, value in properties.items()])
|
||||
|
||||
async def create_data_point_query(self, data_point: DataPoint, vectorized_values: list = None):
|
||||
async def create_data_point_query(self, data_point: DataPoint, vectorized_values: dict):
|
||||
node_label = type(data_point).__tablename__
|
||||
embeddable_fields = data_point._metadata.get("index_fields", [])
|
||||
property_names = DataPoint.get_embeddable_property_names(data_point)
|
||||
|
||||
node_properties = await self.stringify_properties({
|
||||
**data_point.model_dump(),
|
||||
**({
|
||||
embeddable_fields[index]: vectorized_values[index] \
|
||||
for index in range(len(embeddable_fields)) \
|
||||
} if vectorized_values is not None else {}),
|
||||
property_names[index]: (vectorized_values[index] if index in vectorized_values else None) \
|
||||
for index in range(len(property_names)) \
|
||||
}),
|
||||
})
|
||||
|
||||
return dedent(f"""
|
||||
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
|
||||
ON CREATE SET node += ({{{node_properties}}})
|
||||
ON CREATE SET node.updated_at = timestamp()
|
||||
ON MATCH SET node += ({{{node_properties}}})
|
||||
ON MATCH SET node.updated_at = timestamp()
|
||||
ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
|
||||
ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
|
||||
""").strip()
|
||||
|
||||
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
|
||||
|
|
@ -98,31 +99,33 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
return collection_name in collections
|
||||
|
||||
async def create_data_points(self, data_points: list[DataPoint]):
|
||||
embeddable_values = [DataPoint.get_embeddable_properties(data_point) for data_point in data_points]
|
||||
embeddable_values = []
|
||||
vector_map = {}
|
||||
|
||||
vectorized_values = await self.embed_data(
|
||||
sum(embeddable_values, [])
|
||||
)
|
||||
for data_point in data_points:
|
||||
property_names = DataPoint.get_embeddable_property_names(data_point)
|
||||
key = str(data_point.id)
|
||||
vector_map[key] = {}
|
||||
|
||||
index = 0
|
||||
positioned_vectorized_values = []
|
||||
for property_name in property_names:
|
||||
property_value = getattr(data_point, property_name, None)
|
||||
|
||||
for values in embeddable_values:
|
||||
if len(values) > 0:
|
||||
values_list = []
|
||||
for i in range(len(values)):
|
||||
values_list.append(vectorized_values[index + i])
|
||||
if property_value is not None:
|
||||
embeddable_values.append(property_value)
|
||||
vector_map[key][property_name] = len(embeddable_values) - 1
|
||||
else:
|
||||
vector_map[key][property_name] = None
|
||||
|
||||
positioned_vectorized_values.append(values_list)
|
||||
index += len(values)
|
||||
else:
|
||||
positioned_vectorized_values.append(None)
|
||||
vectorized_values = await self.embed_data(embeddable_values)
|
||||
|
||||
queries = [
|
||||
await self.create_data_point_query(
|
||||
data_point,
|
||||
positioned_vectorized_values[index],
|
||||
) for index, data_point in enumerate(data_points)
|
||||
[
|
||||
vectorized_values[vector_map[str(data_point.id)][property_name]] \
|
||||
for property_name in DataPoint.get_embeddable_property_names(data_point)
|
||||
],
|
||||
) for data_point in data_points
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
|
|
@ -182,18 +185,21 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
|
||||
return [result["edge_exists"] for result in results]
|
||||
|
||||
async def retrieve(self, data_point_ids: list[str]):
|
||||
return self.query(
|
||||
async def retrieve(self, data_point_ids: list[UUID]):
|
||||
result = self.query(
|
||||
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
|
||||
{
|
||||
"node_ids": data_point_ids,
|
||||
"node_ids": [str(data_point) for data_point in data_point_ids],
|
||||
},
|
||||
)
|
||||
return result.result_set
|
||||
|
||||
async def extract_node(self, data_point_id: str):
|
||||
return await self.retrieve([data_point_id])
|
||||
async def extract_node(self, data_point_id: UUID):
|
||||
result = await self.retrieve([data_point_id])
|
||||
result = result[0][0] if len(result[0]) > 0 else None
|
||||
return result.properties if result else None
|
||||
|
||||
async def extract_nodes(self, data_point_ids: list[str]):
|
||||
async def extract_nodes(self, data_point_ids: list[UUID]):
|
||||
return await self.retrieve(data_point_ids)
|
||||
|
||||
async def get_connections(self, node_id: UUID) -> list:
|
||||
|
|
@ -296,11 +302,11 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
|
||||
return (nodes, edges)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
|
||||
return self.query(
|
||||
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
|
||||
{
|
||||
"node_ids": data_point_ids,
|
||||
"node_ids": [str(data_point) for data_point in data_point_ids],
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -324,4 +330,4 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
print(f"Error deleting graph: {e}")
|
||||
|
||||
async def prune(self):
|
||||
self.delete_graph()
|
||||
await self.delete_graph()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
import litellm
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
litellm.set_verbose = False
|
||||
logger = logging.getLogger("LiteLLMEmbeddingEngine")
|
||||
|
||||
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||
api_key: str
|
||||
|
|
@ -28,13 +29,17 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
async def get_embedding(text_):
|
||||
response = await litellm.aembedding(
|
||||
self.model,
|
||||
input = text_,
|
||||
api_key = self.api_key,
|
||||
api_base = self.endpoint,
|
||||
api_version = self.api_version
|
||||
)
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
self.model,
|
||||
input = text_,
|
||||
api_key = self.api_key,
|
||||
api_base = self.endpoint,
|
||||
api_version = self.api_version
|
||||
)
|
||||
except litellm.exceptions.BadRequestError as error:
|
||||
logger.error("Error embedding text: %s", str(error))
|
||||
raise error
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
|
|
|
|||
|
|
@ -35,3 +35,7 @@ class DataPoint(BaseModel):
|
|||
return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]]
|
||||
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def get_embeddable_property_names(self, data_point):
|
||||
return data_point._metadata["index_fields"] or []
|
||||
|
|
|
|||
|
|
@ -118,17 +118,17 @@ async def get_graph_from_model(
|
|||
|
||||
data_point_properties[field_name] = field_value
|
||||
|
||||
SimpleDataPointModel = copy_model(
|
||||
type(data_point),
|
||||
include_fields = {
|
||||
"_metadata": (dict, data_point._metadata),
|
||||
"__tablename__": data_point.__tablename__,
|
||||
},
|
||||
exclude_fields = excluded_properties,
|
||||
)
|
||||
|
||||
if include_root:
|
||||
SimpleDataPointModel = copy_model(
|
||||
type(data_point),
|
||||
include_fields = {
|
||||
"_metadata": (dict, data_point._metadata),
|
||||
"__tablename__": data_point.__tablename__,
|
||||
},
|
||||
exclude_fields = excluded_properties,
|
||||
)
|
||||
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||
added_nodes[str(data_point.id)] = True
|
||||
|
||||
return nodes, edges
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ class CodeFile(DataPoint):
|
|||
}
|
||||
|
||||
class CodePart(DataPoint):
|
||||
type: str
|
||||
# part_of: Optional[CodeFile]
|
||||
source_code: str
|
||||
type: Optional[str] = "CodePart"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import networkx as nx
|
||||
from typing import Dict, List
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -66,20 +65,25 @@ async def node_enrich_and_connect(
|
|||
if desc_id not in topological_order[:topological_rank + 1]:
|
||||
continue
|
||||
|
||||
desc = None
|
||||
|
||||
if desc_id in data_points_map:
|
||||
desc = data_points_map[desc_id]
|
||||
else:
|
||||
node_data = await graph_engine.extract_node(desc_id)
|
||||
desc = convert_node_to_data_point(node_data)
|
||||
try:
|
||||
desc = convert_node_to_data_point(node_data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
new_connections.append(desc)
|
||||
if desc is not None:
|
||||
new_connections.append(desc)
|
||||
|
||||
node.depends_directly_on = node.depends_directly_on or []
|
||||
node.depends_directly_on.extend(new_connections)
|
||||
|
||||
|
||||
async def enrich_dependency_graph(data_points: list[DataPoint]) -> list[DataPoint]:
|
||||
async def enrich_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerator[list[DataPoint], None]:
|
||||
"""Enriches the graph with topological ranks and 'depends_on' edges."""
|
||||
nodes = []
|
||||
edges = []
|
||||
|
|
@ -108,17 +112,18 @@ async def enrich_dependency_graph(data_points: list[DataPoint]) -> list[DataPoin
|
|||
# data_points.append(node_enrich_and_connect(graph, topological_order, node))
|
||||
|
||||
data_points_map = {data_point.id: data_point for data_point in data_points}
|
||||
data_points_futures = []
|
||||
# data_points_futures = []
|
||||
|
||||
for data_point in tqdm(data_points, desc = "Enriching dependency graph", unit = "data_point"):
|
||||
if data_point.id not in node_rank_map:
|
||||
continue
|
||||
|
||||
if isinstance(data_point, CodeFile):
|
||||
data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map))
|
||||
# data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map))
|
||||
await node_enrich_and_connect(graph, topological_order, data_point, data_points_map)
|
||||
|
||||
# yield data_point
|
||||
yield data_point
|
||||
|
||||
await asyncio.gather(*data_points_futures)
|
||||
# await asyncio.gather(*data_points_futures)
|
||||
|
||||
return data_points
|
||||
# return data_points
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from typing import AsyncGenerator
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
# from tqdm import tqdm
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -53,11 +54,12 @@ def _process_single_node(code_file: CodeFile) -> None:
|
|||
_add_code_parts_nodes_and_edges(code_file, part_type, code_parts)
|
||||
|
||||
|
||||
async def expand_dependency_graph(data_points: list[DataPoint]) -> list[DataPoint]:
|
||||
async def expand_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerator[list[DataPoint], None]:
|
||||
"""Process Python file nodes, adding code part nodes and edges."""
|
||||
# for data_point in tqdm(data_points, desc = "Expand dependency graph", unit = "data_point"):
|
||||
for data_point in data_points:
|
||||
if isinstance(data_point, CodeFile):
|
||||
_process_single_node(data_point)
|
||||
yield data_point
|
||||
|
||||
return data_points
|
||||
# return data_points
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from typing import AsyncGenerator
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
import aiofiles
|
||||
from tqdm.asyncio import tqdm
|
||||
|
|
@ -44,7 +45,7 @@ def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bo
|
|||
return (file_path, dependency, {"relation": "depends_directly_on"})
|
||||
|
||||
|
||||
async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
|
||||
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list[DataPoint], None]:
|
||||
"""Generate a dependency graph for Python files in the given repository path."""
|
||||
py_files_dict = await get_py_files_dict(repo_path)
|
||||
|
||||
|
|
@ -53,7 +54,8 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
|
|||
path = repo_path,
|
||||
)
|
||||
|
||||
data_points = [repo]
|
||||
# data_points = [repo]
|
||||
yield repo
|
||||
|
||||
# dependency_graph = nx.DiGraph()
|
||||
|
||||
|
|
@ -66,7 +68,8 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
|
|||
|
||||
dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path)
|
||||
|
||||
data_points.append(CodeFile(
|
||||
# data_points.append()
|
||||
yield CodeFile(
|
||||
id = uuid5(NAMESPACE_OID, file_path),
|
||||
source_code = source_code,
|
||||
extracted_id = file_path,
|
||||
|
|
@ -78,10 +81,10 @@ async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
|
|||
part_of = repo,
|
||||
) for dependency in dependencies
|
||||
] if len(dependencies) else None,
|
||||
))
|
||||
)
|
||||
# dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
|
||||
|
||||
# dependency_graph.add_edges_from(dependency_edges)
|
||||
|
||||
return data_points
|
||||
# return data_points
|
||||
# return dependency_graph
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ from cognee.shared.utils import render_graph
|
|||
logging.basicConfig(level = logging.DEBUG)
|
||||
|
||||
async def main():
|
||||
data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_library")).resolve())
|
||||
data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_falkordb")).resolve())
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_library")).resolve())
|
||||
cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_falkordb")).resolve())
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ from evals.eval_utils import download_github_repo
|
|||
from evals.eval_utils import delete_repo
|
||||
|
||||
async def generate_patch_with_cognee(instance):
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system()
|
||||
|
||||
|
|
@ -44,10 +43,10 @@ async def generate_patch_with_cognee(instance):
|
|||
|
||||
tasks = [
|
||||
Task(get_repo_file_dependencies),
|
||||
Task(add_data_points),
|
||||
Task(enrich_dependency_graph),
|
||||
Task(expand_dependency_graph),
|
||||
Task(add_data_points),
|
||||
Task(add_data_points, task_config = { "batch_size": 50 }),
|
||||
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
|
||||
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
|
||||
Task(add_data_points, task_config = { "batch_size": 50 }),
|
||||
# Task(summarize_code, summarization_model = SummarizedContent),
|
||||
]
|
||||
|
||||
|
|
@ -58,7 +57,7 @@ async def generate_patch_with_cognee(instance):
|
|||
|
||||
print('Here we have the repo under the repo_path')
|
||||
|
||||
await render_graph()
|
||||
await render_graph(None, include_labels = True, include_nodes = True)
|
||||
|
||||
problem_statement = instance['problem_statement']
|
||||
instructions = read_query_prompt("patch_gen_instructions.txt")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue