Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
a2fa25fb60
15 changed files with 474 additions and 116 deletions
|
|
@ -4,7 +4,7 @@
|
|||
[](https://GitHub.com/topoteretes/cognee/commit/)
|
||||
[](https://github.com/topoteretes/cognee/tags/)
|
||||
[](https://pepy.tech/project/cognee)
|
||||
[](https://github.com/topoteretes/cognee/blob/master/LICENSE)
|
||||
|
||||
|
||||
|
||||
We build for developers who need a reliable, production-ready data layer for AI applications
|
||||
|
|
|
|||
|
|
@ -1,8 +1,16 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import copy_model
|
||||
|
||||
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
|
||||
|
||||
def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None):
|
||||
|
||||
if not added_nodes:
|
||||
added_nodes = {}
|
||||
if not added_edges:
|
||||
added_edges = {}
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
|
|
@ -12,87 +20,94 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
|
|||
for field_name, field_value in data_point:
|
||||
if field_name == "_metadata":
|
||||
continue
|
||||
|
||||
if isinstance(field_value, DataPoint):
|
||||
elif isinstance(field_value, DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
|
||||
data_point,
|
||||
field_name,
|
||||
field_value,
|
||||
nodes,
|
||||
edges,
|
||||
added_nodes,
|
||||
added_edges,
|
||||
)
|
||||
|
||||
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
nodes.append(node)
|
||||
added_nodes[str(node.id)] = True
|
||||
|
||||
for edge in property_edges:
|
||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(edge)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append((data_point.id, property_node.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}))
|
||||
added_edges[str(edge_key)] = True
|
||||
continue
|
||||
|
||||
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||
elif (
|
||||
isinstance(field_value, list)
|
||||
and len(field_value) > 0
|
||||
and isinstance(field_value[0], DataPoint)
|
||||
):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
for item in field_value:
|
||||
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
nodes.append(node)
|
||||
added_nodes[str(node.id)] = True
|
||||
|
||||
for edge in property_edges:
|
||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(edge)
|
||||
added_edges[edge_key] = True
|
||||
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append((data_point.id, property_node.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"metadata": {
|
||||
"type": "list"
|
||||
},
|
||||
}))
|
||||
added_edges[edge_key] = True
|
||||
continue
|
||||
|
||||
data_point_properties[field_name] = field_value
|
||||
n_edges_before = len(edges)
|
||||
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
|
||||
data_point, field_name, item, nodes, edges, added_nodes, added_edges
|
||||
)
|
||||
edges = edges[:n_edges_before] + [
|
||||
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
|
||||
for edge in edges[n_edges_before:]
|
||||
]
|
||||
else:
|
||||
data_point_properties[field_name] = field_value
|
||||
|
||||
SimpleDataPointModel = copy_model(
|
||||
type(data_point),
|
||||
include_fields = {
|
||||
include_fields={
|
||||
"_metadata": (dict, data_point._metadata),
|
||||
},
|
||||
exclude_fields = excluded_properties,
|
||||
exclude_fields=excluded_properties,
|
||||
)
|
||||
|
||||
if include_root:
|
||||
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||
|
||||
return nodes, edges
|
||||
|
||||
|
||||
def add_nodes_and_edges(
|
||||
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
|
||||
):
|
||||
|
||||
property_nodes, property_edges = get_graph_from_model(
|
||||
field_value, dict(added_nodes), dict(added_edges)
|
||||
)
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
nodes.append(node)
|
||||
added_nodes[str(node.id)] = True
|
||||
|
||||
for edge in property_edges:
|
||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(edge)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(
|
||||
(
|
||||
data_point.id,
|
||||
property_node.id,
|
||||
field_name,
|
||||
{
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
return (nodes, edges, added_nodes, added_edges)
|
||||
|
||||
|
||||
def get_own_properties(property_nodes, property_edges):
|
||||
own_properties = []
|
||||
|
||||
|
|
|
|||
|
|
@ -1,29 +1,41 @@
|
|||
from typing import Callable
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import copy_model
|
||||
|
||||
|
||||
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
|
||||
node_map = {}
|
||||
def get_model_instance_from_graph(
|
||||
nodes: list[DataPoint],
|
||||
edges: list[tuple[str, str, str, dict[str, str]]],
|
||||
entity_id: str,
|
||||
):
|
||||
node_map = {node.id: node for node in nodes}
|
||||
|
||||
for node in nodes:
|
||||
node_map[node.id] = node
|
||||
|
||||
for edge in edges:
|
||||
source_node = node_map[edge[0]]
|
||||
target_node = node_map[edge[1]]
|
||||
edge_label = edge[2]
|
||||
edge_properties = edge[3] if len(edge) == 4 else {}
|
||||
for source_node_id, target_node_id, edge_label, edge_properties in edges:
|
||||
source_node = node_map[source_node_id]
|
||||
target_node = node_map[target_node_id]
|
||||
edge_metadata = edge_properties.get("metadata", {})
|
||||
edge_type = edge_metadata.get("type")
|
||||
edge_type = edge_metadata.get("type", "default")
|
||||
|
||||
if edge_type == "list":
|
||||
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
|
||||
NewModel = copy_model(
|
||||
type(source_node),
|
||||
{edge_label: (list[type(target_node)], PydanticUndefined)},
|
||||
)
|
||||
source_node_dict = source_node.model_dump()
|
||||
source_node_edge_label_values = source_node_dict.get(edge_label, [])
|
||||
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]
|
||||
|
||||
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
|
||||
node_map[source_node_id] = NewModel(**source_node_dict)
|
||||
else:
|
||||
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
|
||||
NewModel = copy_model(
|
||||
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
|
||||
)
|
||||
|
||||
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
|
||||
node_map[target_node_id] = NewModel(
|
||||
**source_node.model_dump(), **{edge_label: target_node}
|
||||
)
|
||||
|
||||
return node_map[entity_id]
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list
|
|||
**include_fields
|
||||
}
|
||||
|
||||
return create_model(model.__name__, **final_fields)
|
||||
model = create_model(model.__name__, **final_fields)
|
||||
model.model_rebuild()
|
||||
return model
|
||||
|
||||
def get_own_properties(data_point: DataPoint):
|
||||
properties = {}
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ async def main():
|
|||
|
||||
await render_graph(None, include_nodes = True, include_labels = True)
|
||||
|
||||
search_results = await cognee.search(SearchType.CHUNKS, query = "Student")
|
||||
search_results = await cognee.search(SearchType.CHUNKS, query_text = "Student")
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted chunks are:\n")
|
||||
for result in search_results:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,9 @@
|
|||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils import (
|
||||
get_graph_from_model,
|
||||
get_model_instance_from_graph,
|
||||
)
|
||||
|
||||
|
||||
class CarTypeName(Enum):
|
||||
|
|
@ -47,8 +42,8 @@ class Person(DataPoint):
|
|||
_metadata: dict = dict(index_fields=["name"])
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def graph_outputs():
|
||||
@pytest.fixture(scope="function")
|
||||
def boris():
|
||||
boris = Person(
|
||||
id="boris",
|
||||
name="Boris",
|
||||
|
|
@ -70,11 +65,4 @@ def graph_outputs():
|
|||
"expires_on": "2025-11-06",
|
||||
},
|
||||
)
|
||||
nodes, edges = get_graph_from_model(boris)
|
||||
|
||||
car, person = nodes[0], nodes[1]
|
||||
edge = edges[0]
|
||||
|
||||
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
|
||||
|
||||
return (car, person, edge, parsed_person)
|
||||
return boris
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
from cognee.tests.unit.interfaces.graph.util import (
|
||||
PERSON_NAMES,
|
||||
count_society,
|
||||
create_organization_recursive,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
|
||||
def test_society_nodes_and_edges(recursive_depth):
|
||||
import sys
|
||||
|
||||
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
|
||||
society = create_organization_recursive(
|
||||
"society", "Society", PERSON_NAMES, recursive_depth
|
||||
)
|
||||
|
||||
n_organizations, n_persons = count_society(society)
|
||||
society_counts_total = n_organizations + n_persons
|
||||
|
||||
nodes, edges = get_graph_from_model(society)
|
||||
|
||||
assert (
|
||||
len(nodes) == society_counts_total
|
||||
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"
|
||||
|
||||
assert len(edges) == (
|
||||
len(nodes) - 1
|
||||
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"
|
||||
else:
|
||||
warnings.warn(
|
||||
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
|
||||
)
|
||||
|
|
@ -1,6 +1,19 @@
|
|||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
|
||||
|
||||
EDGE_GROUND_TRUTH = (
|
||||
CAR_SEDAN_EDGE = (
|
||||
"car1",
|
||||
"sedan",
|
||||
"is_type",
|
||||
{
|
||||
"source_node_id": "car1",
|
||||
"target_node_id": "sedan",
|
||||
"relationship_name": "is_type",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
BORIS_CAR_EDGE_GROUND_TRUTH = (
|
||||
"boris",
|
||||
"car1",
|
||||
"owns_car",
|
||||
|
|
@ -12,6 +25,8 @@ EDGE_GROUND_TRUTH = (
|
|||
},
|
||||
)
|
||||
|
||||
CAR_TYPE_GROUND_TRUTH = {"id": "sedan"}
|
||||
|
||||
CAR_GROUND_TRUTH = {
|
||||
"id": "car1",
|
||||
"brand": "Toyota",
|
||||
|
|
@ -33,22 +48,42 @@ PERSON_GROUND_TRUTH = {
|
|||
}
|
||||
|
||||
|
||||
def test_extracted_person(graph_outputs):
|
||||
(_, person, _, _) = graph_outputs
|
||||
|
||||
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
|
||||
def test_extracted_car_type(boris):
|
||||
nodes, _ = get_graph_from_model(boris)
|
||||
assert len(nodes) == 3
|
||||
car_type = nodes[0]
|
||||
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
|
||||
|
||||
|
||||
def test_extracted_car(graph_outputs):
|
||||
(car, _, _, _) = graph_outputs
|
||||
def test_extracted_car(boris):
|
||||
nodes, _ = get_graph_from_model(boris)
|
||||
assert len(nodes) == 3
|
||||
car = nodes[1]
|
||||
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
|
||||
|
||||
|
||||
def test_extracted_edge(graph_outputs):
|
||||
(_, _, edge, _) = graph_outputs
|
||||
def test_extracted_person(boris):
|
||||
nodes, _ = get_graph_from_model(boris)
|
||||
assert len(nodes) == 3
|
||||
person = nodes[2]
|
||||
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
|
||||
|
||||
|
||||
def test_extracted_car_sedan_edge(boris):
|
||||
_, edges = get_graph_from_model(boris)
|
||||
edge = edges[0]
|
||||
|
||||
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
|
||||
for key, ground_truth in CAR_SEDAN_EDGE[3].items():
|
||||
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
|
||||
|
||||
|
||||
def test_extracted_boris_car_edge(boris):
|
||||
_, edges = get_graph_from_model(boris)
|
||||
edge = edges[1]
|
||||
|
||||
assert (
|
||||
EDGE_GROUND_TRUTH[:3] == edge[:3]
|
||||
), f"{EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
|
||||
for key, ground_truth in EDGE_GROUND_TRUTH[3].items():
|
||||
BORIS_CAR_EDGE_GROUND_TRUTH[:3] == edge[:3]
|
||||
), f"{BORIS_CAR_EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
|
||||
for key, ground_truth in BORIS_CAR_EDGE_GROUND_TRUTH[3].items():
|
||||
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.graph.utils import (
|
||||
get_graph_from_model,
|
||||
get_model_instance_from_graph,
|
||||
)
|
||||
from cognee.tests.unit.interfaces.graph.util import (
|
||||
PERSON_NAMES,
|
||||
create_organization_recursive,
|
||||
show_first_difference,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
|
||||
def test_society_nodes_and_edges(recursive_depth):
|
||||
import sys
|
||||
|
||||
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
|
||||
society = create_organization_recursive(
|
||||
"society", "Society", PERSON_NAMES, recursive_depth
|
||||
)
|
||||
nodes, edges = get_graph_from_model(society)
|
||||
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
|
||||
|
||||
assert str(society) == (str(parsed_society)), show_first_difference(
|
||||
str(society), str(parsed_society), "society", "parsed_society"
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
|
||||
)
|
||||
|
|
@ -1,3 +1,7 @@
|
|||
from cognee.modules.graph.utils import (
|
||||
get_graph_from_model,
|
||||
get_model_instance_from_graph,
|
||||
)
|
||||
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
|
||||
|
||||
PARSED_PERSON_GROUND_TRUTH = {
|
||||
|
|
@ -21,8 +25,10 @@ CAR_GROUND_TRUTH = {
|
|||
}
|
||||
|
||||
|
||||
def test_parsed_person(graph_outputs):
|
||||
(_, _, _, parsed_person) = graph_outputs
|
||||
def test_parsed_person(boris):
|
||||
nodes, edges = get_graph_from_model(boris)
|
||||
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
|
||||
|
||||
run_test_against_ground_truth(
|
||||
"parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
import random
|
||||
import string
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
def run_test_against_ground_truth(
|
||||
|
|
@ -21,6 +25,8 @@ def run_test_against_ground_truth(
|
|||
assert (
|
||||
ground_truth2 == getattr(test_target_item, key)[key2]
|
||||
), f"{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }"
|
||||
elif isinstance(ground_truth, list):
|
||||
raise NotImplementedError("Currently not implemented for 'list'")
|
||||
else:
|
||||
assert ground_truth == getattr(
|
||||
test_target_item, key
|
||||
|
|
@ -28,3 +34,117 @@ def run_test_against_ground_truth(
|
|||
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
|
||||
|
||||
assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }"
|
||||
|
||||
|
||||
class Organization(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
members: Optional[list["SocietyPerson"]]
|
||||
|
||||
|
||||
class SocietyPerson(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
memberships: Optional[list[Organization]]
|
||||
|
||||
|
||||
SocietyPerson.model_rebuild()
|
||||
Organization.model_rebuild()
|
||||
|
||||
|
||||
ORGANIZATION_NAMES = [
|
||||
"ChessClub",
|
||||
"RowingClub",
|
||||
"TheatreTroupe",
|
||||
"PoliticalParty",
|
||||
"Charity",
|
||||
"FanClub",
|
||||
"FilmClub",
|
||||
"NeighborhoodGroup",
|
||||
"LocalCouncil",
|
||||
"Band",
|
||||
]
|
||||
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
|
||||
|
||||
|
||||
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
|
||||
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
|
||||
|
||||
if depth < max_depth:
|
||||
memberships = [
|
||||
create_organization_recursive(
|
||||
f"{org_name}-{depth}-{id_suffix}",
|
||||
org_name.lower(),
|
||||
PERSON_NAMES,
|
||||
max_depth,
|
||||
depth + 1,
|
||||
)
|
||||
for org_name in organization_names
|
||||
]
|
||||
else:
|
||||
memberships = None
|
||||
|
||||
return SocietyPerson(id=id, name=f"{name}{depth}", memberships=memberships)
|
||||
|
||||
|
||||
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
|
||||
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
|
||||
|
||||
if depth < max_depth:
|
||||
members = [
|
||||
create_society_person_recursive(
|
||||
f"{member_name}-{depth}-{id_suffix}",
|
||||
member_name.lower(),
|
||||
ORGANIZATION_NAMES,
|
||||
max_depth,
|
||||
depth + 1,
|
||||
)
|
||||
for member_name in member_names
|
||||
]
|
||||
else:
|
||||
members = None
|
||||
|
||||
return Organization(id=id, name=f"{name}{depth}", members=members)
|
||||
|
||||
|
||||
def count_society(obj):
|
||||
if isinstance(obj, SocietyPerson):
|
||||
if obj.memberships is not None:
|
||||
organization_counts, society_person_counts = zip(
|
||||
*[count_society(organization) for organization in obj.memberships]
|
||||
)
|
||||
organization_count = sum(organization_counts)
|
||||
society_person_count = sum(society_person_counts) + 1
|
||||
return (organization_count, society_person_count)
|
||||
else:
|
||||
return (0, 1)
|
||||
if isinstance(obj, Organization):
|
||||
if obj.members is not None:
|
||||
organization_counts, society_person_counts = zip(
|
||||
*[count_society(organization) for organization in obj.members]
|
||||
)
|
||||
organization_count = sum(organization_counts) + 1
|
||||
society_person_count = sum(society_person_counts)
|
||||
return (organization_count, society_person_count)
|
||||
else:
|
||||
return (1, 0)
|
||||
else:
|
||||
raise Exception("Not allowed")
|
||||
|
||||
|
||||
def show_first_difference(str1, str2, str1_name, str2_name, context=30):
|
||||
for i, (c1, c2) in enumerate(zip(str1, str2)):
|
||||
if c1 != c2:
|
||||
start = max(0, i - context)
|
||||
end1 = min(len(str1), i + context + 1)
|
||||
end2 = min(len(str2), i + context + 1)
|
||||
if i > 0:
|
||||
return f"identical: '{str1[start:i-1]}' | {str1_name}: '{str1[i-1:end1]}'... != {str2_name}: '{str2[i-1:end2]}'..."
|
||||
else:
|
||||
return f"{str1_name} and {str2_name} have no overlap in characters"
|
||||
if len(str1) > len(str2):
|
||||
return f"{str2_name} is identical up to the {i}th character, missing afterwards '{str1[i:i+context]}'..."
|
||||
if len(str2) > len(str1):
|
||||
return f"{str1_name} is identical up to the {i}th character, missing afterwards '{str2[i:i+context]}'..."
|
||||
else:
|
||||
return f"{str1_name} and {str2_name} are identical."
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ await cognee.add(text) # Add a new piece of information
|
|||
|
||||
await cognee.cognify() # Use LLMs and cognee to create knowledge
|
||||
|
||||
search_results = await cognee.search("INSIGHTS", {'query': 'Tell me about NLP'}) # Query cognee for the knowledge
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query_text='Tell me about NLP') # Query cognee for the knowledge
|
||||
|
||||
for result_text in search_results:
|
||||
print(result_text)
|
||||
|
|
|
|||
|
|
@ -209,7 +209,7 @@ async def main(enable_steps):
|
|||
if enable_steps.get("search_insights"):
|
||||
search_results = await cognee.search(
|
||||
SearchType.INSIGHTS,
|
||||
{'query': 'Which applicant has the most relevant experience in data science?'}
|
||||
query_text='Which applicant has the most relevant experience in data science?'
|
||||
)
|
||||
print("Search results:")
|
||||
for result_text in search_results:
|
||||
|
|
|
|||
67
profiling/graph_pydantic_conversion/benchmark_function.py
Normal file
67
profiling/graph_pydantic_conversion/benchmark_function.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
import statistics
|
||||
import time
|
||||
import tracemalloc
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
def benchmark_function(func: Callable, *args, num_runs: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Benchmark a function for memory usage and computational performance.
|
||||
|
||||
Args:
|
||||
func: Function to benchmark
|
||||
*args: Arguments to pass to the function
|
||||
num_runs: Number of times to run the benchmark
|
||||
|
||||
Returns:
|
||||
Dictionary containing benchmark metrics
|
||||
"""
|
||||
execution_times = []
|
||||
peak_memory_usages = []
|
||||
cpu_percentages = []
|
||||
|
||||
process = psutil.Process()
|
||||
|
||||
for _ in range(num_runs):
|
||||
# Start memory tracking
|
||||
tracemalloc.start()
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Measure execution time and CPU usage
|
||||
start_time = time.perf_counter()
|
||||
start_cpu_time = process.cpu_times()
|
||||
|
||||
result = func(*args)
|
||||
|
||||
end_cpu_time = process.cpu_times()
|
||||
end_time = time.perf_counter()
|
||||
|
||||
# Calculate metrics
|
||||
execution_time = end_time - start_time
|
||||
cpu_time = (end_cpu_time.user + end_cpu_time.system) - (
|
||||
start_cpu_time.user + start_cpu_time.system
|
||||
)
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
final_memory = process.memory_info().rss
|
||||
memory_used = final_memory - initial_memory
|
||||
|
||||
# Store results
|
||||
execution_times.append(execution_time)
|
||||
peak_memory_usages.append(peak / 1024 / 1024) # Convert to MB
|
||||
cpu_percentages.append((cpu_time / execution_time) * 100)
|
||||
|
||||
tracemalloc.stop()
|
||||
|
||||
analysis = {
|
||||
"mean_execution_time": statistics.mean(execution_times),
|
||||
"mean_peak_memory_mb": statistics.mean(peak_memory_usages),
|
||||
"mean_cpu_percent": statistics.mean(cpu_percentages),
|
||||
"num_runs": num_runs,
|
||||
}
|
||||
|
||||
if num_runs > 1:
|
||||
analysis["std_execution_time"] = statistics.stdev(execution_times)
|
||||
|
||||
return analysis
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
import argparse
|
||||
import time
|
||||
|
||||
from benchmark_function import benchmark_function
|
||||
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
from cognee.tests.unit.interfaces.graph.util import (
|
||||
PERSON_NAMES,
|
||||
create_organization_recursive,
|
||||
)
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark graph model with configurable recursive depth"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive-depth",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Recursive depth for graph generation (default: 3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--runs", type=int, default=5, help="Number of benchmark runs (default: 5)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
society = create_organization_recursive(
|
||||
"society", "Society", PERSON_NAMES, args.recursive_depth
|
||||
)
|
||||
nodes, edges = get_graph_from_model(society)
|
||||
|
||||
results = benchmark_function(get_graph_from_model, society, num_runs=args.runs)
|
||||
print("\nBenchmark Results:")
|
||||
print(
|
||||
f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}"
|
||||
)
|
||||
print(f"Mean Peak Memory: {results['mean_peak_memory_mb']:.2f} MB")
|
||||
print(f"Mean CPU Usage: {results['mean_cpu_percent']:.2f}%")
|
||||
print(f"Mean Execution Time: {results['mean_execution_time']:.4f} seconds")
|
||||
|
||||
if "std_execution_time" in results:
|
||||
print(f"Execution Time Std: {results['std_execution_time']:.4f} seconds")
|
||||
Loading…
Add table
Reference in a new issue