fix: remove duplicate nodes and edges before saving; Fix FalkorDB vector index;
This commit is contained in:
parent
925346986e
commit
11acabdb6a
14 changed files with 118 additions and 120 deletions
|
|
@ -67,8 +67,9 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
node_properties = await self.stringify_properties({
|
||||
**data_point.model_dump(),
|
||||
**({
|
||||
property_names[index]: (vectorized_values[index] if index in vectorized_values else None) \
|
||||
for index in range(len(property_names)) \
|
||||
property_names[index]: (vectorized_values[index] \
|
||||
if index < len(vectorized_values) else getattr(data_point, property_name, None)) \
|
||||
for index, property_name in enumerate(property_names)
|
||||
}),
|
||||
})
|
||||
|
||||
|
|
@ -111,8 +112,8 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
property_value = getattr(data_point, property_name, None)
|
||||
|
||||
if property_value is not None:
|
||||
vector_map[key][property_name] = len(embeddable_values)
|
||||
embeddable_values.append(property_value)
|
||||
vector_map[key][property_name] = len(embeddable_values) - 1
|
||||
else:
|
||||
vector_map[key][property_name] = None
|
||||
|
||||
|
|
@ -123,7 +124,9 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
data_point,
|
||||
[
|
||||
vectorized_values[vector_map[str(data_point.id)][property_name]] \
|
||||
for property_name in DataPoint.get_embeddable_property_names(data_point)
|
||||
if vector_map[str(data_point.id)][property_name] is not None \
|
||||
else None \
|
||||
for property_name in DataPoint.get_embeddable_property_names(data_point)
|
||||
],
|
||||
) for data_point in data_points
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ from .get_graph_from_model import get_graph_from_model
|
|||
from .get_model_instance_from_graph import get_model_instance_from_graph
|
||||
from .retrieve_existing_edges import retrieve_existing_edges
|
||||
from .convert_node_to_data_point import convert_node_to_data_point
|
||||
from .deduplicate_nodes_and_edges import deduplicate_nodes_and_edges
|
||||
|
|
|
|||
19
cognee/modules/graph/utils/deduplicate_nodes_and_edges.py
Normal file
19
cognee/modules/graph/utils/deduplicate_nodes_and_edges.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
def deduplicate_nodes_and_edges(nodes: list[DataPoint], edges: list[dict]):
|
||||
added_entities = {}
|
||||
final_nodes = []
|
||||
final_edges = []
|
||||
|
||||
for node in nodes:
|
||||
if str(node.id) not in added_entities:
|
||||
final_nodes.append(node)
|
||||
added_entities[str(node.id)] = True
|
||||
|
||||
for edge in edges:
|
||||
edge_key = str(edge[0]) + str(edge[2]) + str(edge[1])
|
||||
if edge_key not in added_entities:
|
||||
final_edges.append(edge)
|
||||
added_entities[edge_key] = True
|
||||
|
||||
return final_nodes, final_edges
|
||||
|
|
@ -8,18 +8,20 @@ async def get_graph_from_model(
|
|||
include_root = True,
|
||||
added_nodes = None,
|
||||
added_edges = None,
|
||||
visited_properties = None,
|
||||
):
|
||||
if data_point.id in added_nodes:
|
||||
return [], []
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
added_nodes = added_nodes or {}
|
||||
added_edges = added_edges or {}
|
||||
visited_properties = visited_properties or {}
|
||||
|
||||
data_point_properties = {}
|
||||
excluded_properties = set()
|
||||
|
||||
if include_root:
|
||||
added_nodes[str(data_point.id)] = True
|
||||
|
||||
for field_name, field_value in data_point:
|
||||
if field_name == "_metadata":
|
||||
continue
|
||||
|
|
@ -30,7 +32,15 @@ async def get_graph_from_model(
|
|||
|
||||
if isinstance(field_value, DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
|
||||
|
||||
property_key = f"{str(data_point.id)}{field_name}{str(field_value.id)}"
|
||||
|
||||
if property_key in visited_properties:
|
||||
continue
|
||||
|
||||
visited_properties[property_key] = True
|
||||
|
||||
nodes, edges = await add_nodes_and_edges(
|
||||
data_point,
|
||||
field_name,
|
||||
field_value,
|
||||
|
|
@ -38,77 +48,33 @@ async def get_graph_from_model(
|
|||
edges,
|
||||
added_nodes,
|
||||
added_edges,
|
||||
visited_properties,
|
||||
)
|
||||
|
||||
property_nodes, property_edges = await get_graph_from_model(
|
||||
field_value,
|
||||
True,
|
||||
added_nodes,
|
||||
added_edges,
|
||||
)
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
nodes.append(node)
|
||||
added_nodes[str(node.id)] = True
|
||||
|
||||
for edge in property_edges:
|
||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(edge)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append((data_point.id, property_node.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}))
|
||||
added_edges[str(edge_key)] = True
|
||||
continue
|
||||
|
||||
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
for item in field_value:
|
||||
property_nodes, property_edges = await get_graph_from_model(
|
||||
item,
|
||||
True,
|
||||
for field_value_item in field_value:
|
||||
property_key = f"{str(data_point.id)}{field_name}{str(field_value_item.id)}"
|
||||
|
||||
if property_key in visited_properties:
|
||||
continue
|
||||
|
||||
visited_properties[property_key] = True
|
||||
|
||||
nodes, edges = await add_nodes_and_edges(
|
||||
data_point,
|
||||
field_name,
|
||||
field_value_item,
|
||||
nodes,
|
||||
edges,
|
||||
added_nodes,
|
||||
added_edges,
|
||||
visited_properties,
|
||||
)
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
nodes.append(node)
|
||||
added_nodes[str(node.id)] = True
|
||||
|
||||
for edge in property_edges:
|
||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(edge)
|
||||
added_edges[edge_key] = True
|
||||
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append((data_point.id, property_node.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"metadata": {
|
||||
"type": "list"
|
||||
},
|
||||
}))
|
||||
added_edges[edge_key] = True
|
||||
continue
|
||||
|
||||
data_point_properties[field_name] = field_value
|
||||
|
|
@ -128,12 +94,22 @@ async def get_graph_from_model(
|
|||
return nodes, edges
|
||||
|
||||
|
||||
def add_nodes_and_edges(
|
||||
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
|
||||
async def add_nodes_and_edges(
|
||||
data_point,
|
||||
field_name,
|
||||
field_value,
|
||||
nodes,
|
||||
edges,
|
||||
added_nodes,
|
||||
added_edges,
|
||||
visited_properties,
|
||||
):
|
||||
|
||||
property_nodes, property_edges = get_graph_from_model(
|
||||
field_value, dict(added_nodes), dict(added_edges)
|
||||
property_nodes, property_edges = await get_graph_from_model(
|
||||
field_value,
|
||||
True,
|
||||
added_nodes,
|
||||
added_edges,
|
||||
visited_properties,
|
||||
)
|
||||
|
||||
for node in property_nodes:
|
||||
|
|
@ -169,7 +145,7 @@ def add_nodes_and_edges(
|
|||
)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
return (nodes, edges, added_nodes, added_edges)
|
||||
return (nodes, edges)
|
||||
|
||||
|
||||
def get_own_properties(property_nodes, property_edges):
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ from typing import List, Optional
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
class Repository(DataPoint):
|
||||
__tablename__ = "Repository"
|
||||
path: str
|
||||
type: Optional[str] = "Repository"
|
||||
|
||||
class CodeFile(DataPoint):
|
||||
__tablename__ = "CodeFile"
|
||||
extracted_id: str # actually file path
|
||||
type: Optional[str] = "CodeFile"
|
||||
source_code: Optional[str] = None
|
||||
|
|
@ -19,6 +21,7 @@ class CodeFile(DataPoint):
|
|||
}
|
||||
|
||||
class CodePart(DataPoint):
|
||||
__tablename__ = "CodePart"
|
||||
# part_of: Optional[CodeFile]
|
||||
source_code: str
|
||||
type: Optional[str] = "CodePart"
|
||||
|
|
|
|||
|
|
@ -57,13 +57,13 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
|
|||
py_files_dict = await get_py_files_dict(repo_path)
|
||||
|
||||
repo = Repository(
|
||||
id=uuid5(NAMESPACE_OID, repo_path),
|
||||
path=repo_path,
|
||||
id = uuid5(NAMESPACE_OID, repo_path),
|
||||
path = repo_path,
|
||||
)
|
||||
|
||||
yield repo
|
||||
|
||||
with ProcessPoolExecutor(max_workers=12) as executor:
|
||||
with ProcessPoolExecutor(max_workers = 12) as executor:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
tasks = [
|
||||
|
|
@ -84,15 +84,16 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
|
|||
source_code = metadata.get("source_code")
|
||||
|
||||
yield CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, file_path),
|
||||
source_code=source_code,
|
||||
extracted_id=file_path,
|
||||
part_of=repo,
|
||||
depends_on=[
|
||||
id = uuid5(NAMESPACE_OID, file_path),
|
||||
source_code = source_code,
|
||||
extracted_id = file_path,
|
||||
part_of = repo,
|
||||
depends_on = [
|
||||
CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, dependency),
|
||||
extracted_id=dependency,
|
||||
part_of=repo,
|
||||
id = uuid5(NAMESPACE_OID, dependency),
|
||||
extracted_id = dependency,
|
||||
part_of = repo,
|
||||
source_code = py_files_dict.get(dependency, {}).get("source_code"),
|
||||
) for dependency in dependencies
|
||||
] if dependencies else None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
||||
from .index_data_points import index_data_points
|
||||
|
||||
|
||||
|
|
@ -17,9 +17,11 @@ async def add_data_points(data_points: list[DataPoint]):
|
|||
nodes.extend(result_nodes)
|
||||
edges.extend(result_edges)
|
||||
|
||||
nodes, edges = deduplicate_nodes_and_edges(nodes, edges)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
await index_data_points(data_points)
|
||||
await index_data_points(nodes)
|
||||
|
||||
await graph_engine.add_nodes(nodes)
|
||||
await graph_engine.add_edges(edges)
|
||||
|
|
|
|||
|
|
@ -8,16 +8,7 @@ async def index_data_points(data_points: list[DataPoint]):
|
|||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
flat_data_points: list[DataPoint] = []
|
||||
|
||||
results = await asyncio.gather(*[
|
||||
get_data_points_from_model(data_point) for data_point in data_points
|
||||
])
|
||||
|
||||
for result in results:
|
||||
flat_data_points.extend(result)
|
||||
|
||||
for data_point in flat_data_points:
|
||||
for data_point in data_points:
|
||||
data_point_type = type(data_point)
|
||||
|
||||
for field_name in data_point._metadata["index_fields"]:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from cognee.tests.unit.interfaces.graph.util import (
|
|||
|
||||
|
||||
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
|
||||
def test_society_nodes_and_edges(recursive_depth):
|
||||
async def test_society_nodes_and_edges(recursive_depth):
|
||||
import sys
|
||||
|
||||
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
|
||||
|
|
@ -22,7 +22,7 @@ def test_society_nodes_and_edges(recursive_depth):
|
|||
n_organizations, n_persons = count_society(society)
|
||||
society_counts_total = n_organizations + n_persons
|
||||
|
||||
nodes, edges = get_graph_from_model(society)
|
||||
nodes, edges = await get_graph_from_model(society)
|
||||
|
||||
assert (
|
||||
len(nodes) == society_counts_total
|
||||
|
|
|
|||
|
|
@ -48,29 +48,29 @@ PERSON_GROUND_TRUTH = {
|
|||
}
|
||||
|
||||
|
||||
def test_extracted_car_type(boris):
|
||||
nodes, _ = get_graph_from_model(boris)
|
||||
async def test_extracted_car_type(boris):
|
||||
nodes, _ = await get_graph_from_model(boris)
|
||||
assert len(nodes) == 3
|
||||
car_type = nodes[0]
|
||||
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
|
||||
|
||||
|
||||
def test_extracted_car(boris):
|
||||
nodes, _ = get_graph_from_model(boris)
|
||||
async def test_extracted_car(boris):
|
||||
nodes, _ = await get_graph_from_model(boris)
|
||||
assert len(nodes) == 3
|
||||
car = nodes[1]
|
||||
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
|
||||
|
||||
|
||||
def test_extracted_person(boris):
|
||||
nodes, _ = get_graph_from_model(boris)
|
||||
async def test_extracted_person(boris):
|
||||
nodes, _ = await get_graph_from_model(boris)
|
||||
assert len(nodes) == 3
|
||||
person = nodes[2]
|
||||
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
|
||||
|
||||
|
||||
def test_extracted_car_sedan_edge(boris):
|
||||
_, edges = get_graph_from_model(boris)
|
||||
async def test_extracted_car_sedan_edge(boris):
|
||||
_, edges = await get_graph_from_model(boris)
|
||||
edge = edges[0]
|
||||
|
||||
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
|
||||
|
|
@ -78,8 +78,8 @@ def test_extracted_car_sedan_edge(boris):
|
|||
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
|
||||
|
||||
|
||||
def test_extracted_boris_car_edge(boris):
|
||||
_, edges = get_graph_from_model(boris)
|
||||
async def test_extracted_boris_car_edge(boris):
|
||||
_, edges = await get_graph_from_model(boris)
|
||||
edge = edges[1]
|
||||
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -14,14 +14,14 @@ from cognee.tests.unit.interfaces.graph.util import (
|
|||
|
||||
|
||||
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
|
||||
def test_society_nodes_and_edges(recursive_depth):
|
||||
async def test_society_nodes_and_edges(recursive_depth):
|
||||
import sys
|
||||
|
||||
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
|
||||
society = create_organization_recursive(
|
||||
"society", "Society", PERSON_NAMES, recursive_depth
|
||||
)
|
||||
nodes, edges = get_graph_from_model(society)
|
||||
nodes, edges = await get_graph_from_model(society)
|
||||
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
|
||||
|
||||
assert str(society) == (str(parsed_society)), show_first_difference(
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ CAR_GROUND_TRUTH = {
|
|||
}
|
||||
|
||||
|
||||
def test_parsed_person(boris):
|
||||
nodes, edges = get_graph_from_model(boris)
|
||||
async def test_parsed_person(boris):
|
||||
nodes, edges = await get_graph_from_model(boris)
|
||||
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
|
||||
|
||||
run_test_against_ground_truth(
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ def check_install_package(package_name):
|
|||
|
||||
|
||||
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system()
|
||||
|
||||
|
|
@ -57,11 +56,11 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
|
|||
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
|
||||
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
|
||||
Task(add_data_points, task_config = { "batch_size": 50 }),
|
||||
Task(summarize_code, summarization_model = SummarizedContent),
|
||||
# Task(summarize_code, summarization_model = SummarizedContent),
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline")
|
||||
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import argparse
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from benchmark_function import benchmark_function
|
||||
from .benchmark_function import benchmark_function
|
||||
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
from cognee.tests.unit.interfaces.graph.util import (
|
||||
|
|
@ -28,9 +28,12 @@ if __name__ == "__main__":
|
|||
society = create_organization_recursive(
|
||||
"society", "Society", PERSON_NAMES, args.recursive_depth
|
||||
)
|
||||
nodes, edges = get_graph_from_model(society)
|
||||
nodes, edges = asyncio.run(get_graph_from_model(society))
|
||||
|
||||
results = benchmark_function(get_graph_from_model, society, num_runs=args.runs)
|
||||
def get_graph_from_model_sync(model):
|
||||
return asyncio.run(get_graph_from_model(model))
|
||||
|
||||
results = benchmark_function(get_graph_from_model_sync, society, num_runs=args.runs)
|
||||
print("\nBenchmark Results:")
|
||||
print(
|
||||
f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue