Merge remote-tracking branch 'origin/main'

This commit is contained in:
Boris Arzentar 2024-11-23 14:07:44 +01:00
commit a2fa25fb60
15 changed files with 474 additions and 116 deletions

View file

@ -4,7 +4,7 @@
[![GitHub commits](https://badgen.net/github/commits/topoteretes/cognee)](https://GitHub.com/topoteretes/cognee/commit/) [![GitHub commits](https://badgen.net/github/commits/topoteretes/cognee)](https://GitHub.com/topoteretes/cognee/commit/)
[![Github tag](https://badgen.net/github/tag/topoteretes/cognee)](https://github.com/topoteretes/cognee/tags/) [![Github tag](https://badgen.net/github/tag/topoteretes/cognee)](https://github.com/topoteretes/cognee/tags/)
[![Downloads](https://static.pepy.tech/badge/cognee)](https://pepy.tech/project/cognee) [![Downloads](https://static.pepy.tech/badge/cognee)](https://pepy.tech/project/cognee)
[![GitHub license](https://badgen.net/github/license/topoteretes/cognee)](https://github.com/topoteretes/cognee/blob/master/LICENSE)
We build for developers who need a reliable, production-ready data layer for AI applications We build for developers who need a reliable, production-ready data layer for AI applications

View file

@ -1,8 +1,16 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model from cognee.modules.storage.utils import copy_model
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None):
if not added_nodes:
added_nodes = {}
if not added_edges:
added_edges = {}
nodes = [] nodes = []
edges = [] edges = []
@ -12,87 +20,94 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
for field_name, field_value in data_point: for field_name, field_value in data_point:
if field_name == "_metadata": if field_name == "_metadata":
continue continue
elif isinstance(field_value, DataPoint):
if isinstance(field_value, DataPoint):
excluded_properties.add(field_name) excluded_properties.add(field_name)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
)
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) elif (
isinstance(field_value, list)
for node in property_nodes: and len(field_value) > 0
if str(node.id) not in added_nodes: and isinstance(field_value[0], DataPoint)
nodes.append(node) ):
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name) excluded_properties.add(field_name)
for item in field_value: for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges) n_edges_before = len(edges)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
for node in property_nodes: data_point, field_name, item, nodes, edges, added_nodes, added_edges
if str(node.id) not in added_nodes: )
nodes.append(node) edges = edges[:n_edges_before] + [
added_nodes[str(node.id)] = True (*edge[:3], {**edge[3], "metadata": {"type": "list"}})
for edge in edges[n_edges_before:]
for edge in property_edges: ]
edge_key = str(edge[0]) + str(edge[1]) + edge[2] else:
data_point_properties[field_name] = field_value
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"metadata": {
"type": "list"
},
}))
added_edges[edge_key] = True
continue
data_point_properties[field_name] = field_value
SimpleDataPointModel = copy_model( SimpleDataPointModel = copy_model(
type(data_point), type(data_point),
include_fields = { include_fields={
"_metadata": (dict, data_point._metadata), "_metadata": (dict, data_point._metadata),
}, },
exclude_fields = excluded_properties, exclude_fields=excluded_properties,
) )
if include_root: nodes.append(SimpleDataPointModel(**data_point_properties))
nodes.append(SimpleDataPointModel(**data_point_properties))
return nodes, edges return nodes, edges
def add_nodes_and_edges(
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
):
property_nodes, property_edges = get_graph_from_model(
field_value, dict(added_nodes), dict(added_edges)
)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append(
(
data_point.id,
property_node.id,
field_name,
{
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
},
)
)
added_edges[str(edge_key)] = True
return (nodes, edges, added_nodes, added_edges)
def get_own_properties(property_nodes, property_edges): def get_own_properties(property_nodes, property_edges):
own_properties = [] own_properties = []

View file

@ -1,29 +1,41 @@
from typing import Callable
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model from cognee.modules.storage.utils import copy_model
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str): def get_model_instance_from_graph(
node_map = {} nodes: list[DataPoint],
edges: list[tuple[str, str, str, dict[str, str]]],
entity_id: str,
):
node_map = {node.id: node for node in nodes}
for node in nodes: for source_node_id, target_node_id, edge_label, edge_properties in edges:
node_map[node.id] = node source_node = node_map[source_node_id]
target_node = node_map[target_node_id]
for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
edge_metadata = edge_properties.get("metadata", {}) edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type") edge_type = edge_metadata.get("type", "default")
if edge_type == "list": if edge_type == "list":
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) }) NewModel = copy_model(
type(source_node),
{edge_label: (list[type(target_node)], PydanticUndefined)},
)
source_node_dict = source_node.model_dump()
source_node_edge_label_values = source_node_dict.get(edge_label, [])
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] }) node_map[source_node_id] = NewModel(**source_node_dict)
else: else:
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) }) NewModel = copy_model(
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
)
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node }) node_map[target_node_id] = NewModel(
**source_node.model_dump(), **{edge_label: target_node}
)
return node_map[entity_id] return node_map[entity_id]

View file

@ -29,7 +29,9 @@ def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list
**include_fields **include_fields
} }
return create_model(model.__name__, **final_fields) model = create_model(model.__name__, **final_fields)
model.model_rebuild()
return model
def get_own_properties(data_point: DataPoint): def get_own_properties(data_point: DataPoint):
properties = {} properties = {}

View file

@ -26,7 +26,7 @@ async def main():
await render_graph(None, include_nodes = True, include_labels = True) await render_graph(None, include_nodes = True, include_labels = True)
search_results = await cognee.search(SearchType.CHUNKS, query = "Student") search_results = await cognee.search(SearchType.CHUNKS, query_text = "Student")
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n") print("\n\nExtracted chunks are:\n")
for result in search_results: for result in search_results:

View file

@ -1,14 +1,9 @@
from datetime import datetime, timezone
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
import pytest import pytest
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
class CarTypeName(Enum): class CarTypeName(Enum):
@ -47,8 +42,8 @@ class Person(DataPoint):
_metadata: dict = dict(index_fields=["name"]) _metadata: dict = dict(index_fields=["name"])
@pytest.fixture(scope="session") @pytest.fixture(scope="function")
def graph_outputs(): def boris():
boris = Person( boris = Person(
id="boris", id="boris",
name="Boris", name="Boris",
@ -70,11 +65,4 @@ def graph_outputs():
"expires_on": "2025-11-06", "expires_on": "2025-11-06",
}, },
) )
nodes, edges = get_graph_from_model(boris) return boris
car, person = nodes[0], nodes[1]
edge = edges[0]
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
return (car, person, edge, parsed_person)

View file

@ -0,0 +1,37 @@
import warnings
import pytest
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
count_society,
create_organization_recursive,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons
nodes, edges = get_graph_from_model(society)
assert (
len(nodes) == society_counts_total
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"
assert len(edges) == (
len(nodes) - 1
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)

View file

@ -1,6 +1,19 @@
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
EDGE_GROUND_TRUTH = ( CAR_SEDAN_EDGE = (
"car1",
"sedan",
"is_type",
{
"source_node_id": "car1",
"target_node_id": "sedan",
"relationship_name": "is_type",
},
)
BORIS_CAR_EDGE_GROUND_TRUTH = (
"boris", "boris",
"car1", "car1",
"owns_car", "owns_car",
@ -12,6 +25,8 @@ EDGE_GROUND_TRUTH = (
}, },
) )
CAR_TYPE_GROUND_TRUTH = {"id": "sedan"}
CAR_GROUND_TRUTH = { CAR_GROUND_TRUTH = {
"id": "car1", "id": "car1",
"brand": "Toyota", "brand": "Toyota",
@ -33,22 +48,42 @@ PERSON_GROUND_TRUTH = {
} }
def test_extracted_person(graph_outputs): def test_extracted_car_type(boris):
(_, person, _, _) = graph_outputs nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH) car_type = nodes[0]
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
def test_extracted_car(graph_outputs): def test_extracted_car(boris):
(car, _, _, _) = graph_outputs nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
car = nodes[1]
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH) run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
def test_extracted_edge(graph_outputs): def test_extracted_person(boris):
(_, _, edge, _) = graph_outputs nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
person = nodes[2]
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
def test_extracted_car_sedan_edge(boris):
_, edges = get_graph_from_model(boris)
edge = edges[0]
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
for key, ground_truth in CAR_SEDAN_EDGE[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
def test_extracted_boris_car_edge(boris):
_, edges = get_graph_from_model(boris)
edge = edges[1]
assert ( assert (
EDGE_GROUND_TRUTH[:3] == edge[:3] BORIS_CAR_EDGE_GROUND_TRUTH[:3] == edge[:3]
), f"{EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }" ), f"{BORIS_CAR_EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
for key, ground_truth in EDGE_GROUND_TRUTH[3].items(): for key, ground_truth in BORIS_CAR_EDGE_GROUND_TRUTH[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }" assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"

View file

@ -0,0 +1,33 @@
import warnings
import pytest
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
create_organization_recursive,
show_first_difference,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
nodes, edges = get_graph_from_model(society)
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
assert str(society) == (str(parsed_society)), show_first_difference(
str(society), str(parsed_society), "society", "parsed_society"
)
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)

View file

@ -1,3 +1,7 @@
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
PARSED_PERSON_GROUND_TRUTH = { PARSED_PERSON_GROUND_TRUTH = {
@ -21,8 +25,10 @@ CAR_GROUND_TRUTH = {
} }
def test_parsed_person(graph_outputs): def test_parsed_person(boris):
(_, _, _, parsed_person) = graph_outputs nodes, edges = get_graph_from_model(boris)
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
run_test_against_ground_truth( run_test_against_ground_truth(
"parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH "parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH
) )

View file

@ -1,5 +1,9 @@
import random
import string
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Dict from typing import Any, Dict, Optional
from cognee.infrastructure.engine import DataPoint
def run_test_against_ground_truth( def run_test_against_ground_truth(
@ -21,6 +25,8 @@ def run_test_against_ground_truth(
assert ( assert (
ground_truth2 == getattr(test_target_item, key)[key2] ground_truth2 == getattr(test_target_item, key)[key2]
), f"{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }" ), f"{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }"
elif isinstance(ground_truth, list):
raise NotImplementedError("Currently not implemented for 'list'")
else: else:
assert ground_truth == getattr( assert ground_truth == getattr(
test_target_item, key test_target_item, key
@ -28,3 +34,117 @@ def run_test_against_ground_truth(
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at") time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }" assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }"
class Organization(DataPoint):
id: str
name: str
members: Optional[list["SocietyPerson"]]
class SocietyPerson(DataPoint):
id: str
name: str
memberships: Optional[list[Organization]]
SocietyPerson.model_rebuild()
Organization.model_rebuild()
ORGANIZATION_NAMES = [
"ChessClub",
"RowingClub",
"TheatreTroupe",
"PoliticalParty",
"Charity",
"FanClub",
"FilmClub",
"NeighborhoodGroup",
"LocalCouncil",
"Band",
]
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
memberships = [
create_organization_recursive(
f"{org_name}-{depth}-{id_suffix}",
org_name.lower(),
PERSON_NAMES,
max_depth,
depth + 1,
)
for org_name in organization_names
]
else:
memberships = None
return SocietyPerson(id=id, name=f"{name}{depth}", memberships=memberships)
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
members = [
create_society_person_recursive(
f"{member_name}-{depth}-{id_suffix}",
member_name.lower(),
ORGANIZATION_NAMES,
max_depth,
depth + 1,
)
for member_name in member_names
]
else:
members = None
return Organization(id=id, name=f"{name}{depth}", members=members)
def count_society(obj):
if isinstance(obj, SocietyPerson):
if obj.memberships is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.memberships]
)
organization_count = sum(organization_counts)
society_person_count = sum(society_person_counts) + 1
return (organization_count, society_person_count)
else:
return (0, 1)
if isinstance(obj, Organization):
if obj.members is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.members]
)
organization_count = sum(organization_counts) + 1
society_person_count = sum(society_person_counts)
return (organization_count, society_person_count)
else:
return (1, 0)
else:
raise Exception("Not allowed")
def show_first_difference(str1, str2, str1_name, str2_name, context=30):
for i, (c1, c2) in enumerate(zip(str1, str2)):
if c1 != c2:
start = max(0, i - context)
end1 = min(len(str1), i + context + 1)
end2 = min(len(str2), i + context + 1)
if i > 0:
return f"identical: '{str1[start:i-1]}' | {str1_name}: '{str1[i-1:end1]}'... != {str2_name}: '{str2[i-1:end2]}'..."
else:
return f"{str1_name} and {str2_name} have no overlap in characters"
if len(str1) > len(str2):
return f"{str2_name} is identical up to the {i}th character, missing afterwards '{str1[i:i+context]}'..."
if len(str2) > len(str1):
return f"{str1_name} is identical up to the {i}th character, missing afterwards '{str2[i:i+context]}'..."
else:
return f"{str1_name} and {str2_name} are identical."

View file

@ -59,7 +59,7 @@ await cognee.add(text) # Add a new piece of information
await cognee.cognify() # Use LLMs and cognee to create knowledge await cognee.cognify() # Use LLMs and cognee to create knowledge
search_results = await cognee.search("INSIGHTS", {'query': 'Tell me about NLP'}) # Query cognee for the knowledge search_results = await cognee.search(SearchType.INSIGHTS, query_text='Tell me about NLP') # Query cognee for the knowledge
for result_text in search_results: for result_text in search_results:
print(result_text) print(result_text)

View file

@ -209,7 +209,7 @@ async def main(enable_steps):
if enable_steps.get("search_insights"): if enable_steps.get("search_insights"):
search_results = await cognee.search( search_results = await cognee.search(
SearchType.INSIGHTS, SearchType.INSIGHTS,
{'query': 'Which applicant has the most relevant experience in data science?'} query_text='Which applicant has the most relevant experience in data science?'
) )
print("Search results:") print("Search results:")
for result_text in search_results: for result_text in search_results:

View file

@ -0,0 +1,67 @@
import statistics
import time
import tracemalloc
from typing import Any, Callable, Dict
import psutil
def benchmark_function(func: Callable, *args, num_runs: int = 5) -> Dict[str, Any]:
"""
Benchmark a function for memory usage and computational performance.
Args:
func: Function to benchmark
*args: Arguments to pass to the function
num_runs: Number of times to run the benchmark
Returns:
Dictionary containing benchmark metrics
"""
execution_times = []
peak_memory_usages = []
cpu_percentages = []
process = psutil.Process()
for _ in range(num_runs):
# Start memory tracking
tracemalloc.start()
initial_memory = process.memory_info().rss
# Measure execution time and CPU usage
start_time = time.perf_counter()
start_cpu_time = process.cpu_times()
result = func(*args)
end_cpu_time = process.cpu_times()
end_time = time.perf_counter()
# Calculate metrics
execution_time = end_time - start_time
cpu_time = (end_cpu_time.user + end_cpu_time.system) - (
start_cpu_time.user + start_cpu_time.system
)
current, peak = tracemalloc.get_traced_memory()
final_memory = process.memory_info().rss
memory_used = final_memory - initial_memory
# Store results
execution_times.append(execution_time)
peak_memory_usages.append(peak / 1024 / 1024) # Convert to MB
cpu_percentages.append((cpu_time / execution_time) * 100)
tracemalloc.stop()
analysis = {
"mean_execution_time": statistics.mean(execution_times),
"mean_peak_memory_mb": statistics.mean(peak_memory_usages),
"mean_cpu_percent": statistics.mean(cpu_percentages),
"num_runs": num_runs,
}
if num_runs > 1:
analysis["std_execution_time"] = statistics.stdev(execution_times)
return analysis

View file

@ -0,0 +1,43 @@
import argparse
import time
from benchmark_function import benchmark_function
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
create_organization_recursive,
)
# Example usage:
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark graph model with configurable recursive depth"
)
parser.add_argument(
"--recursive-depth",
type=int,
default=3,
help="Recursive depth for graph generation (default: 3)",
)
parser.add_argument(
"--runs", type=int, default=5, help="Number of benchmark runs (default: 5)"
)
args = parser.parse_args()
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, args.recursive_depth
)
nodes, edges = get_graph_from_model(society)
results = benchmark_function(get_graph_from_model, society, num_runs=args.runs)
print("\nBenchmark Results:")
print(
f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}"
)
print(f"Mean Peak Memory: {results['mean_peak_memory_mb']:.2f} MB")
print(f"Mean CPU Usage: {results['mean_cpu_percent']:.2f}%")
print(f"Mean Execution Time: {results['mean_execution_time']:.4f} seconds")
if "std_execution_time" in results:
print(f"Execution Time Std: {results['std_execution_time']:.4f} seconds")