fix: remove duplicate nodes and edges before saving; Fix FalkorDB vector index;
This commit is contained in:
parent
925346986e
commit
11acabdb6a
14 changed files with 118 additions and 120 deletions
|
|
@ -67,8 +67,9 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||||
node_properties = await self.stringify_properties({
|
node_properties = await self.stringify_properties({
|
||||||
**data_point.model_dump(),
|
**data_point.model_dump(),
|
||||||
**({
|
**({
|
||||||
property_names[index]: (vectorized_values[index] if index in vectorized_values else None) \
|
property_names[index]: (vectorized_values[index] \
|
||||||
for index in range(len(property_names)) \
|
if index < len(vectorized_values) else getattr(data_point, property_name, None)) \
|
||||||
|
for index, property_name in enumerate(property_names)
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -111,8 +112,8 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||||
property_value = getattr(data_point, property_name, None)
|
property_value = getattr(data_point, property_name, None)
|
||||||
|
|
||||||
if property_value is not None:
|
if property_value is not None:
|
||||||
|
vector_map[key][property_name] = len(embeddable_values)
|
||||||
embeddable_values.append(property_value)
|
embeddable_values.append(property_value)
|
||||||
vector_map[key][property_name] = len(embeddable_values) - 1
|
|
||||||
else:
|
else:
|
||||||
vector_map[key][property_name] = None
|
vector_map[key][property_name] = None
|
||||||
|
|
||||||
|
|
@ -123,7 +124,9 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||||
data_point,
|
data_point,
|
||||||
[
|
[
|
||||||
vectorized_values[vector_map[str(data_point.id)][property_name]] \
|
vectorized_values[vector_map[str(data_point.id)][property_name]] \
|
||||||
for property_name in DataPoint.get_embeddable_property_names(data_point)
|
if vector_map[str(data_point.id)][property_name] is not None \
|
||||||
|
else None \
|
||||||
|
for property_name in DataPoint.get_embeddable_property_names(data_point)
|
||||||
],
|
],
|
||||||
) for data_point in data_points
|
) for data_point in data_points
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,4 @@ from .get_graph_from_model import get_graph_from_model
|
||||||
from .get_model_instance_from_graph import get_model_instance_from_graph
|
from .get_model_instance_from_graph import get_model_instance_from_graph
|
||||||
from .retrieve_existing_edges import retrieve_existing_edges
|
from .retrieve_existing_edges import retrieve_existing_edges
|
||||||
from .convert_node_to_data_point import convert_node_to_data_point
|
from .convert_node_to_data_point import convert_node_to_data_point
|
||||||
|
from .deduplicate_nodes_and_edges import deduplicate_nodes_and_edges
|
||||||
|
|
|
||||||
19
cognee/modules/graph/utils/deduplicate_nodes_and_edges.py
Normal file
19
cognee/modules/graph/utils/deduplicate_nodes_and_edges.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
def deduplicate_nodes_and_edges(nodes: list[DataPoint], edges: list[dict]):
|
||||||
|
added_entities = {}
|
||||||
|
final_nodes = []
|
||||||
|
final_edges = []
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if str(node.id) not in added_entities:
|
||||||
|
final_nodes.append(node)
|
||||||
|
added_entities[str(node.id)] = True
|
||||||
|
|
||||||
|
for edge in edges:
|
||||||
|
edge_key = str(edge[0]) + str(edge[2]) + str(edge[1])
|
||||||
|
if edge_key not in added_entities:
|
||||||
|
final_edges.append(edge)
|
||||||
|
added_entities[edge_key] = True
|
||||||
|
|
||||||
|
return final_nodes, final_edges
|
||||||
|
|
@ -8,18 +8,20 @@ async def get_graph_from_model(
|
||||||
include_root = True,
|
include_root = True,
|
||||||
added_nodes = None,
|
added_nodes = None,
|
||||||
added_edges = None,
|
added_edges = None,
|
||||||
|
visited_properties = None,
|
||||||
):
|
):
|
||||||
if data_point.id in added_nodes:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
edges = []
|
edges = []
|
||||||
added_nodes = added_nodes or {}
|
added_nodes = added_nodes or {}
|
||||||
added_edges = added_edges or {}
|
added_edges = added_edges or {}
|
||||||
|
visited_properties = visited_properties or {}
|
||||||
|
|
||||||
data_point_properties = {}
|
data_point_properties = {}
|
||||||
excluded_properties = set()
|
excluded_properties = set()
|
||||||
|
|
||||||
|
if include_root:
|
||||||
|
added_nodes[str(data_point.id)] = True
|
||||||
|
|
||||||
for field_name, field_value in data_point:
|
for field_name, field_value in data_point:
|
||||||
if field_name == "_metadata":
|
if field_name == "_metadata":
|
||||||
continue
|
continue
|
||||||
|
|
@ -30,7 +32,15 @@ async def get_graph_from_model(
|
||||||
|
|
||||||
if isinstance(field_value, DataPoint):
|
if isinstance(field_value, DataPoint):
|
||||||
excluded_properties.add(field_name)
|
excluded_properties.add(field_name)
|
||||||
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
|
|
||||||
|
property_key = f"{str(data_point.id)}{field_name}{str(field_value.id)}"
|
||||||
|
|
||||||
|
if property_key in visited_properties:
|
||||||
|
continue
|
||||||
|
|
||||||
|
visited_properties[property_key] = True
|
||||||
|
|
||||||
|
nodes, edges = await add_nodes_and_edges(
|
||||||
data_point,
|
data_point,
|
||||||
field_name,
|
field_name,
|
||||||
field_value,
|
field_value,
|
||||||
|
|
@ -38,77 +48,33 @@ async def get_graph_from_model(
|
||||||
edges,
|
edges,
|
||||||
added_nodes,
|
added_nodes,
|
||||||
added_edges,
|
added_edges,
|
||||||
|
visited_properties,
|
||||||
)
|
)
|
||||||
|
|
||||||
property_nodes, property_edges = await get_graph_from_model(
|
|
||||||
field_value,
|
|
||||||
True,
|
|
||||||
added_nodes,
|
|
||||||
added_edges,
|
|
||||||
)
|
|
||||||
|
|
||||||
for node in property_nodes:
|
|
||||||
if str(node.id) not in added_nodes:
|
|
||||||
nodes.append(node)
|
|
||||||
added_nodes[str(node.id)] = True
|
|
||||||
|
|
||||||
for edge in property_edges:
|
|
||||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
|
||||||
|
|
||||||
if str(edge_key) not in added_edges:
|
|
||||||
edges.append(edge)
|
|
||||||
added_edges[str(edge_key)] = True
|
|
||||||
|
|
||||||
for property_node in get_own_properties(property_nodes, property_edges):
|
|
||||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
|
||||||
|
|
||||||
if str(edge_key) not in added_edges:
|
|
||||||
edges.append((data_point.id, property_node.id, field_name, {
|
|
||||||
"source_node_id": data_point.id,
|
|
||||||
"target_node_id": property_node.id,
|
|
||||||
"relationship_name": field_name,
|
|
||||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
}))
|
|
||||||
added_edges[str(edge_key)] = True
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||||
excluded_properties.add(field_name)
|
excluded_properties.add(field_name)
|
||||||
|
|
||||||
for item in field_value:
|
for field_value_item in field_value:
|
||||||
property_nodes, property_edges = await get_graph_from_model(
|
property_key = f"{str(data_point.id)}{field_name}{str(field_value_item.id)}"
|
||||||
item,
|
|
||||||
True,
|
if property_key in visited_properties:
|
||||||
|
continue
|
||||||
|
|
||||||
|
visited_properties[property_key] = True
|
||||||
|
|
||||||
|
nodes, edges = await add_nodes_and_edges(
|
||||||
|
data_point,
|
||||||
|
field_name,
|
||||||
|
field_value_item,
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
added_nodes,
|
added_nodes,
|
||||||
added_edges,
|
added_edges,
|
||||||
|
visited_properties,
|
||||||
)
|
)
|
||||||
|
|
||||||
for node in property_nodes:
|
|
||||||
if str(node.id) not in added_nodes:
|
|
||||||
nodes.append(node)
|
|
||||||
added_nodes[str(node.id)] = True
|
|
||||||
|
|
||||||
for edge in property_edges:
|
|
||||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
|
||||||
|
|
||||||
if str(edge_key) not in added_edges:
|
|
||||||
edges.append(edge)
|
|
||||||
added_edges[edge_key] = True
|
|
||||||
|
|
||||||
for property_node in get_own_properties(property_nodes, property_edges):
|
|
||||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
|
||||||
|
|
||||||
if str(edge_key) not in added_edges:
|
|
||||||
edges.append((data_point.id, property_node.id, field_name, {
|
|
||||||
"source_node_id": data_point.id,
|
|
||||||
"target_node_id": property_node.id,
|
|
||||||
"relationship_name": field_name,
|
|
||||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
"metadata": {
|
|
||||||
"type": "list"
|
|
||||||
},
|
|
||||||
}))
|
|
||||||
added_edges[edge_key] = True
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data_point_properties[field_name] = field_value
|
data_point_properties[field_name] = field_value
|
||||||
|
|
@ -128,12 +94,22 @@ async def get_graph_from_model(
|
||||||
return nodes, edges
|
return nodes, edges
|
||||||
|
|
||||||
|
|
||||||
def add_nodes_and_edges(
|
async def add_nodes_and_edges(
|
||||||
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
|
data_point,
|
||||||
|
field_name,
|
||||||
|
field_value,
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
|
added_nodes,
|
||||||
|
added_edges,
|
||||||
|
visited_properties,
|
||||||
):
|
):
|
||||||
|
property_nodes, property_edges = await get_graph_from_model(
|
||||||
property_nodes, property_edges = get_graph_from_model(
|
field_value,
|
||||||
field_value, dict(added_nodes), dict(added_edges)
|
True,
|
||||||
|
added_nodes,
|
||||||
|
added_edges,
|
||||||
|
visited_properties,
|
||||||
)
|
)
|
||||||
|
|
||||||
for node in property_nodes:
|
for node in property_nodes:
|
||||||
|
|
@ -169,7 +145,7 @@ def add_nodes_and_edges(
|
||||||
)
|
)
|
||||||
added_edges[str(edge_key)] = True
|
added_edges[str(edge_key)] = True
|
||||||
|
|
||||||
return (nodes, edges, added_nodes, added_edges)
|
return (nodes, edges)
|
||||||
|
|
||||||
|
|
||||||
def get_own_properties(property_nodes, property_edges):
|
def get_own_properties(property_nodes, property_edges):
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,12 @@ from typing import List, Optional
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
class Repository(DataPoint):
|
class Repository(DataPoint):
|
||||||
|
__tablename__ = "Repository"
|
||||||
path: str
|
path: str
|
||||||
type: Optional[str] = "Repository"
|
type: Optional[str] = "Repository"
|
||||||
|
|
||||||
class CodeFile(DataPoint):
|
class CodeFile(DataPoint):
|
||||||
|
__tablename__ = "CodeFile"
|
||||||
extracted_id: str # actually file path
|
extracted_id: str # actually file path
|
||||||
type: Optional[str] = "CodeFile"
|
type: Optional[str] = "CodeFile"
|
||||||
source_code: Optional[str] = None
|
source_code: Optional[str] = None
|
||||||
|
|
@ -19,6 +21,7 @@ class CodeFile(DataPoint):
|
||||||
}
|
}
|
||||||
|
|
||||||
class CodePart(DataPoint):
|
class CodePart(DataPoint):
|
||||||
|
__tablename__ = "CodePart"
|
||||||
# part_of: Optional[CodeFile]
|
# part_of: Optional[CodeFile]
|
||||||
source_code: str
|
source_code: str
|
||||||
type: Optional[str] = "CodePart"
|
type: Optional[str] = "CodePart"
|
||||||
|
|
|
||||||
|
|
@ -57,13 +57,13 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
|
||||||
py_files_dict = await get_py_files_dict(repo_path)
|
py_files_dict = await get_py_files_dict(repo_path)
|
||||||
|
|
||||||
repo = Repository(
|
repo = Repository(
|
||||||
id=uuid5(NAMESPACE_OID, repo_path),
|
id = uuid5(NAMESPACE_OID, repo_path),
|
||||||
path=repo_path,
|
path = repo_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield repo
|
yield repo
|
||||||
|
|
||||||
with ProcessPoolExecutor(max_workers=12) as executor:
|
with ProcessPoolExecutor(max_workers = 12) as executor:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
|
|
@ -84,15 +84,16 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
|
||||||
source_code = metadata.get("source_code")
|
source_code = metadata.get("source_code")
|
||||||
|
|
||||||
yield CodeFile(
|
yield CodeFile(
|
||||||
id=uuid5(NAMESPACE_OID, file_path),
|
id = uuid5(NAMESPACE_OID, file_path),
|
||||||
source_code=source_code,
|
source_code = source_code,
|
||||||
extracted_id=file_path,
|
extracted_id = file_path,
|
||||||
part_of=repo,
|
part_of = repo,
|
||||||
depends_on=[
|
depends_on = [
|
||||||
CodeFile(
|
CodeFile(
|
||||||
id=uuid5(NAMESPACE_OID, dependency),
|
id = uuid5(NAMESPACE_OID, dependency),
|
||||||
extracted_id=dependency,
|
extracted_id = dependency,
|
||||||
part_of=repo,
|
part_of = repo,
|
||||||
|
source_code = py_files_dict.get(dependency, {}).get("source_code"),
|
||||||
) for dependency in dependencies
|
) for dependency in dependencies
|
||||||
] if dependencies else None,
|
] if dependencies else None,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.graph.utils import get_graph_from_model
|
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
||||||
from .index_data_points import index_data_points
|
from .index_data_points import index_data_points
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,9 +17,11 @@ async def add_data_points(data_points: list[DataPoint]):
|
||||||
nodes.extend(result_nodes)
|
nodes.extend(result_nodes)
|
||||||
edges.extend(result_edges)
|
edges.extend(result_edges)
|
||||||
|
|
||||||
|
nodes, edges = deduplicate_nodes_and_edges(nodes, edges)
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
await index_data_points(data_points)
|
await index_data_points(nodes)
|
||||||
|
|
||||||
await graph_engine.add_nodes(nodes)
|
await graph_engine.add_nodes(nodes)
|
||||||
await graph_engine.add_edges(edges)
|
await graph_engine.add_edges(edges)
|
||||||
|
|
|
||||||
|
|
@ -8,16 +8,7 @@ async def index_data_points(data_points: list[DataPoint]):
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
flat_data_points: list[DataPoint] = []
|
for data_point in data_points:
|
||||||
|
|
||||||
results = await asyncio.gather(*[
|
|
||||||
get_data_points_from_model(data_point) for data_point in data_points
|
|
||||||
])
|
|
||||||
|
|
||||||
for result in results:
|
|
||||||
flat_data_points.extend(result)
|
|
||||||
|
|
||||||
for data_point in flat_data_points:
|
|
||||||
data_point_type = type(data_point)
|
data_point_type = type(data_point)
|
||||||
|
|
||||||
for field_name in data_point._metadata["index_fields"]:
|
for field_name in data_point._metadata["index_fields"]:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from cognee.tests.unit.interfaces.graph.util import (
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
|
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
|
||||||
def test_society_nodes_and_edges(recursive_depth):
|
async def test_society_nodes_and_edges(recursive_depth):
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
|
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
|
||||||
|
|
@ -22,7 +22,7 @@ def test_society_nodes_and_edges(recursive_depth):
|
||||||
n_organizations, n_persons = count_society(society)
|
n_organizations, n_persons = count_society(society)
|
||||||
society_counts_total = n_organizations + n_persons
|
society_counts_total = n_organizations + n_persons
|
||||||
|
|
||||||
nodes, edges = get_graph_from_model(society)
|
nodes, edges = await get_graph_from_model(society)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(nodes) == society_counts_total
|
len(nodes) == society_counts_total
|
||||||
|
|
|
||||||
|
|
@ -48,29 +48,29 @@ PERSON_GROUND_TRUTH = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_extracted_car_type(boris):
|
async def test_extracted_car_type(boris):
|
||||||
nodes, _ = get_graph_from_model(boris)
|
nodes, _ = await get_graph_from_model(boris)
|
||||||
assert len(nodes) == 3
|
assert len(nodes) == 3
|
||||||
car_type = nodes[0]
|
car_type = nodes[0]
|
||||||
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
|
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
|
||||||
|
|
||||||
|
|
||||||
def test_extracted_car(boris):
|
async def test_extracted_car(boris):
|
||||||
nodes, _ = get_graph_from_model(boris)
|
nodes, _ = await get_graph_from_model(boris)
|
||||||
assert len(nodes) == 3
|
assert len(nodes) == 3
|
||||||
car = nodes[1]
|
car = nodes[1]
|
||||||
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
|
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
|
||||||
|
|
||||||
|
|
||||||
def test_extracted_person(boris):
|
async def test_extracted_person(boris):
|
||||||
nodes, _ = get_graph_from_model(boris)
|
nodes, _ = await get_graph_from_model(boris)
|
||||||
assert len(nodes) == 3
|
assert len(nodes) == 3
|
||||||
person = nodes[2]
|
person = nodes[2]
|
||||||
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
|
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
|
||||||
|
|
||||||
|
|
||||||
def test_extracted_car_sedan_edge(boris):
|
async def test_extracted_car_sedan_edge(boris):
|
||||||
_, edges = get_graph_from_model(boris)
|
_, edges = await get_graph_from_model(boris)
|
||||||
edge = edges[0]
|
edge = edges[0]
|
||||||
|
|
||||||
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
|
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
|
||||||
|
|
@ -78,8 +78,8 @@ def test_extracted_car_sedan_edge(boris):
|
||||||
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
|
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
|
||||||
|
|
||||||
|
|
||||||
def test_extracted_boris_car_edge(boris):
|
async def test_extracted_boris_car_edge(boris):
|
||||||
_, edges = get_graph_from_model(boris)
|
_, edges = await get_graph_from_model(boris)
|
||||||
edge = edges[1]
|
edge = edges[1]
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
|
|
||||||
|
|
@ -14,14 +14,14 @@ from cognee.tests.unit.interfaces.graph.util import (
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
|
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
|
||||||
def test_society_nodes_and_edges(recursive_depth):
|
async def test_society_nodes_and_edges(recursive_depth):
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
|
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
|
||||||
society = create_organization_recursive(
|
society = create_organization_recursive(
|
||||||
"society", "Society", PERSON_NAMES, recursive_depth
|
"society", "Society", PERSON_NAMES, recursive_depth
|
||||||
)
|
)
|
||||||
nodes, edges = get_graph_from_model(society)
|
nodes, edges = await get_graph_from_model(society)
|
||||||
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
|
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
|
||||||
|
|
||||||
assert str(society) == (str(parsed_society)), show_first_difference(
|
assert str(society) == (str(parsed_society)), show_first_difference(
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,8 @@ CAR_GROUND_TRUTH = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_parsed_person(boris):
|
async def test_parsed_person(boris):
|
||||||
nodes, edges = get_graph_from_model(boris)
|
nodes, edges = await get_graph_from_model(boris)
|
||||||
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
|
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
|
||||||
|
|
||||||
run_test_against_ground_truth(
|
run_test_against_ground_truth(
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,6 @@ def check_install_package(package_name):
|
||||||
|
|
||||||
|
|
||||||
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
|
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system()
|
await cognee.prune.prune_system()
|
||||||
|
|
||||||
|
|
@ -57,7 +56,7 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
|
||||||
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
|
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
|
||||||
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
|
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
|
||||||
Task(add_data_points, task_config = { "batch_size": 50 }),
|
Task(add_data_points, task_config = { "batch_size": 50 }),
|
||||||
Task(summarize_code, summarization_model = SummarizedContent),
|
# Task(summarize_code, summarization_model = SummarizedContent),
|
||||||
]
|
]
|
||||||
|
|
||||||
pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline")
|
pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import asyncio
|
||||||
|
|
||||||
from benchmark_function import benchmark_function
|
from .benchmark_function import benchmark_function
|
||||||
|
|
||||||
from cognee.modules.graph.utils import get_graph_from_model
|
from cognee.modules.graph.utils import get_graph_from_model
|
||||||
from cognee.tests.unit.interfaces.graph.util import (
|
from cognee.tests.unit.interfaces.graph.util import (
|
||||||
|
|
@ -28,9 +28,12 @@ if __name__ == "__main__":
|
||||||
society = create_organization_recursive(
|
society = create_organization_recursive(
|
||||||
"society", "Society", PERSON_NAMES, args.recursive_depth
|
"society", "Society", PERSON_NAMES, args.recursive_depth
|
||||||
)
|
)
|
||||||
nodes, edges = get_graph_from_model(society)
|
nodes, edges = asyncio.run(get_graph_from_model(society))
|
||||||
|
|
||||||
results = benchmark_function(get_graph_from_model, society, num_runs=args.runs)
|
def get_graph_from_model_sync(model):
|
||||||
|
return asyncio.run(get_graph_from_model(model))
|
||||||
|
|
||||||
|
results = benchmark_function(get_graph_from_model_sync, society, num_runs=args.runs)
|
||||||
print("\nBenchmark Results:")
|
print("\nBenchmark Results:")
|
||||||
print(
|
print(
|
||||||
f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}"
|
f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue