fix: remove duplicate nodes and edges before saving; Fix FalkorDB vector index;

This commit is contained in:
Boris Arzentar 2024-12-02 10:10:18 +01:00
parent 925346986e
commit 11acabdb6a
14 changed files with 118 additions and 120 deletions

View file

@ -67,8 +67,9 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
node_properties = await self.stringify_properties({
**data_point.model_dump(),
**({
property_names[index]: (vectorized_values[index] if index in vectorized_values else None) \
for index in range(len(property_names)) \
property_names[index]: (vectorized_values[index] \
if index < len(vectorized_values) else getattr(data_point, property_name, None)) \
for index, property_name in enumerate(property_names)
}),
})
@ -111,8 +112,8 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
property_value = getattr(data_point, property_name, None)
if property_value is not None:
vector_map[key][property_name] = len(embeddable_values)
embeddable_values.append(property_value)
vector_map[key][property_name] = len(embeddable_values) - 1
else:
vector_map[key][property_name] = None
@ -123,7 +124,9 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
data_point,
[
vectorized_values[vector_map[str(data_point.id)][property_name]] \
for property_name in DataPoint.get_embeddable_property_names(data_point)
if vector_map[str(data_point.id)][property_name] is not None \
else None \
for property_name in DataPoint.get_embeddable_property_names(data_point)
],
) for data_point in data_points
]

View file

@ -3,3 +3,4 @@ from .get_graph_from_model import get_graph_from_model
from .get_model_instance_from_graph import get_model_instance_from_graph
from .retrieve_existing_edges import retrieve_existing_edges
from .convert_node_to_data_point import convert_node_to_data_point
from .deduplicate_nodes_and_edges import deduplicate_nodes_and_edges

View file

@ -0,0 +1,19 @@
from cognee.infrastructure.engine import DataPoint
def deduplicate_nodes_and_edges(nodes: list[DataPoint], edges: list[dict]):
added_entities = {}
final_nodes = []
final_edges = []
for node in nodes:
if str(node.id) not in added_entities:
final_nodes.append(node)
added_entities[str(node.id)] = True
for edge in edges:
edge_key = str(edge[0]) + str(edge[2]) + str(edge[1])
if edge_key not in added_entities:
final_edges.append(edge)
added_entities[edge_key] = True
return final_nodes, final_edges

View file

@ -8,18 +8,20 @@ async def get_graph_from_model(
include_root = True,
added_nodes = None,
added_edges = None,
visited_properties = None,
):
if data_point.id in added_nodes:
return [], []
nodes = []
edges = []
added_nodes = added_nodes or {}
added_edges = added_edges or {}
visited_properties = visited_properties or {}
data_point_properties = {}
excluded_properties = set()
if include_root:
added_nodes[str(data_point.id)] = True
for field_name, field_value in data_point:
if field_name == "_metadata":
continue
@ -30,7 +32,15 @@ async def get_graph_from_model(
if isinstance(field_value, DataPoint):
excluded_properties.add(field_name)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
property_key = f"{str(data_point.id)}{field_name}{str(field_value.id)}"
if property_key in visited_properties:
continue
visited_properties[property_key] = True
nodes, edges = await add_nodes_and_edges(
data_point,
field_name,
field_value,
@ -38,77 +48,33 @@ async def get_graph_from_model(
edges,
added_nodes,
added_edges,
visited_properties,
)
property_nodes, property_edges = await get_graph_from_model(
field_value,
True,
added_nodes,
added_edges,
)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)
for item in field_value:
property_nodes, property_edges = await get_graph_from_model(
item,
True,
for field_value_item in field_value:
property_key = f"{str(data_point.id)}{field_name}{str(field_value_item.id)}"
if property_key in visited_properties:
continue
visited_properties[property_key] = True
nodes, edges = await add_nodes_and_edges(
data_point,
field_name,
field_value_item,
nodes,
edges,
added_nodes,
added_edges,
visited_properties,
)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"metadata": {
"type": "list"
},
}))
added_edges[edge_key] = True
continue
data_point_properties[field_name] = field_value
@ -128,12 +94,22 @@ async def get_graph_from_model(
return nodes, edges
def add_nodes_and_edges(
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
async def add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
visited_properties,
):
property_nodes, property_edges = get_graph_from_model(
field_value, dict(added_nodes), dict(added_edges)
property_nodes, property_edges = await get_graph_from_model(
field_value,
True,
added_nodes,
added_edges,
visited_properties,
)
for node in property_nodes:
@ -169,7 +145,7 @@ def add_nodes_and_edges(
)
added_edges[str(edge_key)] = True
return (nodes, edges, added_nodes, added_edges)
return (nodes, edges)
def get_own_properties(property_nodes, property_edges):

View file

@ -2,10 +2,12 @@ from typing import List, Optional
from cognee.infrastructure.engine import DataPoint
class Repository(DataPoint):
__tablename__ = "Repository"
path: str
type: Optional[str] = "Repository"
class CodeFile(DataPoint):
__tablename__ = "CodeFile"
extracted_id: str # actually file path
type: Optional[str] = "CodeFile"
source_code: Optional[str] = None
@ -19,6 +21,7 @@ class CodeFile(DataPoint):
}
class CodePart(DataPoint):
__tablename__ = "CodePart"
# part_of: Optional[CodeFile]
source_code: str
type: Optional[str] = "CodePart"

View file

@ -57,13 +57,13 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
py_files_dict = await get_py_files_dict(repo_path)
repo = Repository(
id=uuid5(NAMESPACE_OID, repo_path),
path=repo_path,
id = uuid5(NAMESPACE_OID, repo_path),
path = repo_path,
)
yield repo
with ProcessPoolExecutor(max_workers=12) as executor:
with ProcessPoolExecutor(max_workers = 12) as executor:
loop = asyncio.get_event_loop()
tasks = [
@ -84,15 +84,16 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
source_code = metadata.get("source_code")
yield CodeFile(
id=uuid5(NAMESPACE_OID, file_path),
source_code=source_code,
extracted_id=file_path,
part_of=repo,
depends_on=[
id = uuid5(NAMESPACE_OID, file_path),
source_code = source_code,
extracted_id = file_path,
part_of = repo,
depends_on = [
CodeFile(
id=uuid5(NAMESPACE_OID, dependency),
extracted_id=dependency,
part_of=repo,
id = uuid5(NAMESPACE_OID, dependency),
extracted_id = dependency,
part_of = repo,
source_code = py_files_dict.get(dependency, {}).get("source_code"),
) for dependency in dependencies
] if dependencies else None,
)

View file

@ -1,7 +1,7 @@
import asyncio
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.utils import get_graph_from_model
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
from .index_data_points import index_data_points
@ -17,9 +17,11 @@ async def add_data_points(data_points: list[DataPoint]):
nodes.extend(result_nodes)
edges.extend(result_edges)
nodes, edges = deduplicate_nodes_and_edges(nodes, edges)
graph_engine = await get_graph_engine()
await index_data_points(data_points)
await index_data_points(nodes)
await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges)

View file

@ -8,16 +8,7 @@ async def index_data_points(data_points: list[DataPoint]):
vector_engine = get_vector_engine()
flat_data_points: list[DataPoint] = []
results = await asyncio.gather(*[
get_data_points_from_model(data_point) for data_point in data_points
])
for result in results:
flat_data_points.extend(result)
for data_point in flat_data_points:
for data_point in data_points:
data_point_type = type(data_point)
for field_name in data_point._metadata["index_fields"]:

View file

@ -11,7 +11,7 @@ from cognee.tests.unit.interfaces.graph.util import (
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
async def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
@ -22,7 +22,7 @@ def test_society_nodes_and_edges(recursive_depth):
n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons
nodes, edges = get_graph_from_model(society)
nodes, edges = await get_graph_from_model(society)
assert (
len(nodes) == society_counts_total

View file

@ -48,29 +48,29 @@ PERSON_GROUND_TRUTH = {
}
def test_extracted_car_type(boris):
nodes, _ = get_graph_from_model(boris)
async def test_extracted_car_type(boris):
nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3
car_type = nodes[0]
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
def test_extracted_car(boris):
nodes, _ = get_graph_from_model(boris)
async def test_extracted_car(boris):
nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3
car = nodes[1]
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
def test_extracted_person(boris):
nodes, _ = get_graph_from_model(boris)
async def test_extracted_person(boris):
nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3
person = nodes[2]
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
def test_extracted_car_sedan_edge(boris):
_, edges = get_graph_from_model(boris)
async def test_extracted_car_sedan_edge(boris):
_, edges = await get_graph_from_model(boris)
edge = edges[0]
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
@ -78,8 +78,8 @@ def test_extracted_car_sedan_edge(boris):
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
def test_extracted_boris_car_edge(boris):
_, edges = get_graph_from_model(boris)
async def test_extracted_boris_car_edge(boris):
_, edges = await get_graph_from_model(boris)
edge = edges[1]
assert (

View file

@ -14,14 +14,14 @@ from cognee.tests.unit.interfaces.graph.util import (
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
async def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
nodes, edges = get_graph_from_model(society)
nodes, edges = await get_graph_from_model(society)
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
assert str(society) == (str(parsed_society)), show_first_difference(

View file

@ -25,8 +25,8 @@ CAR_GROUND_TRUTH = {
}
def test_parsed_person(boris):
nodes, edges = get_graph_from_model(boris)
async def test_parsed_person(boris):
nodes, edges = await get_graph_from_model(boris)
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
run_test_against_ground_truth(

View file

@ -43,7 +43,6 @@ def check_install_package(package_name):
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
await cognee.prune.prune_data()
await cognee.prune.prune_system()
@ -57,11 +56,11 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
Task(add_data_points, task_config = { "batch_size": 50 }),
Task(summarize_code, summarization_model = SummarizedContent),
# Task(summarize_code, summarization_model = SummarizedContent),
]
pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline")
async for result in pipeline:
print(result)

View file

@ -1,7 +1,7 @@
import argparse
import time
import asyncio
from benchmark_function import benchmark_function
from .benchmark_function import benchmark_function
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
@ -28,9 +28,12 @@ if __name__ == "__main__":
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, args.recursive_depth
)
nodes, edges = get_graph_from_model(society)
nodes, edges = asyncio.run(get_graph_from_model(society))
results = benchmark_function(get_graph_from_model, society, num_runs=args.runs)
def get_graph_from_model_sync(model):
return asyncio.run(get_graph_from_model(model))
results = benchmark_function(get_graph_from_model_sync, society, num_runs=args.runs)
print("\nBenchmark Results:")
print(
f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}"