From 628f192b8d05d81ab604dec25a3d2a4cbdbf7cf2 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 10:27:06 +0100 Subject: [PATCH 01/21] Remove added_nodes and added_edges default dicts --- cognee/modules/graph/utils/get_graph_from_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 29137ddc7..810be7ce8 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -2,7 +2,13 @@ from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model -def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}): +def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = None, added_edges = None): + + if not added_nodes: + added_nodes = {} + if not added_edges: + added_edges = {} + nodes = [] edges = [] From 0ea011ccd7bec863822d27bad4d1f6df8ab5d1d7 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 10:27:27 +0100 Subject: [PATCH 02/21] Adapt graph interfaces tests to debugged get_graph_from_model --- .../tests/unit/interfaces/graph/conftest.py | 18 +----- .../graph/get_graph_from_model_test.py | 59 +++++++++++++++---- .../get_model_instance_from_graph_test.py | 10 +++- cognee/tests/unit/interfaces/graph/util.py | 2 + 4 files changed, 60 insertions(+), 29 deletions(-) diff --git a/cognee/tests/unit/interfaces/graph/conftest.py b/cognee/tests/unit/interfaces/graph/conftest.py index 9a784bb53..45f977bd6 100644 --- a/cognee/tests/unit/interfaces/graph/conftest.py +++ b/cognee/tests/unit/interfaces/graph/conftest.py @@ -1,14 +1,9 @@ -from datetime import datetime, timezone from enum import Enum from typing import Optional import pytest from cognee.infrastructure.engine import DataPoint -from cognee.modules.graph.utils import ( - get_graph_from_model, - get_model_instance_from_graph, -) class CarTypeName(Enum): @@ -47,8 +42,8 @@ class Person(DataPoint): _metadata: dict = dict(index_fields=["name"]) -@pytest.fixture(scope="session") -def graph_outputs(): +@pytest.fixture(scope="function") +def boris(): boris = Person( id="boris", name="Boris", @@ -70,11 +65,4 @@ def graph_outputs(): "expires_on": "2025-11-06", }, ) - nodes, edges = get_graph_from_model(boris) - - car, person = nodes[0], nodes[1] - edge = edges[0] - - parsed_person = get_model_instance_from_graph(nodes, edges, "boris") - - return (car, person, edge, parsed_person) + return boris diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py index 17dd69a0e..e56a2dff2 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py @@ -1,6 +1,19 @@ +from cognee.modules.graph.utils import get_graph_from_model from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth -EDGE_GROUND_TRUTH = ( +CAR_SEDAN_EDGE = ( + "car1", + "sedan", + "is_type", + { + "source_node_id": "car1", + "target_node_id": "sedan", + "relationship_name": "is_type", + }, +) + + +BORIS_CAR_EDGE_GROUND_TRUTH = ( "boris", "car1", "owns_car", @@ -12,6 +25,8 @@ EDGE_GROUND_TRUTH = ( }, ) +CAR_TYPE_GROUND_TRUTH = {"id": "sedan"} + CAR_GROUND_TRUTH = { "id": "car1", "brand": "Toyota", @@ -33,22 +48,42 @@ PERSON_GROUND_TRUTH = { } -def test_extracted_person(graph_outputs): - (_, person, _, _) = graph_outputs - - run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH) +def test_extracted_car_type(boris): + nodes, _ = get_graph_from_model(boris) + assert len(nodes) == 3 + car_type = nodes[0] + run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH) -def test_extracted_car(graph_outputs): - (car, _, _, _) = graph_outputs +def test_extracted_car(boris): + nodes, _ = get_graph_from_model(boris) + assert len(nodes) == 3 + car = nodes[1] run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH) -def test_extracted_edge(graph_outputs): - (_, _, edge, _) = graph_outputs +def test_extracted_person(boris): + nodes, _ = get_graph_from_model(boris) + assert len(nodes) == 3 + person = nodes[2] + run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH) + + +def test_extracted_car_sedan_edge(boris): + _, edges = get_graph_from_model(boris) + edge = edges[0] + + assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }" + for key, ground_truth in CAR_SEDAN_EDGE[3].items(): + assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }" + + +def test_extracted_boris_car_edge(boris): + _, edges = get_graph_from_model(boris) + edge = edges[1] assert ( - EDGE_GROUND_TRUTH[:3] == edge[:3] - ), f"{EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }" - for key, ground_truth in EDGE_GROUND_TRUTH[3].items(): + BORIS_CAR_EDGE_GROUND_TRUTH[:3] == edge[:3] + ), f"{BORIS_CAR_EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }" + for key, ground_truth in BORIS_CAR_EDGE_GROUND_TRUTH[3].items(): assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }" diff --git a/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_test.py b/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_test.py index 98ba501bd..f1aa7736d 100644 --- a/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_test.py +++ b/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_test.py @@ -1,3 +1,7 @@ +from cognee.modules.graph.utils import ( + get_graph_from_model, + get_model_instance_from_graph, +) from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth PARSED_PERSON_GROUND_TRUTH = { @@ -21,8 +25,10 @@ CAR_GROUND_TRUTH = { } -def test_parsed_person(graph_outputs): - (_, _, _, parsed_person) = graph_outputs +def test_parsed_person(boris): + nodes, edges = get_graph_from_model(boris) + parsed_person = get_model_instance_from_graph(nodes, edges, "boris") + run_test_against_ground_truth( "parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH ) diff --git a/cognee/tests/unit/interfaces/graph/util.py b/cognee/tests/unit/interfaces/graph/util.py index 764eafa11..e5da0201c 100644 --- a/cognee/tests/unit/interfaces/graph/util.py +++ b/cognee/tests/unit/interfaces/graph/util.py @@ -21,6 +21,8 @@ def run_test_against_ground_truth( assert ( ground_truth2 == getattr(test_target_item, key)[key2] ), f"{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }" + elif isinstance(ground_truth, list): + raise NotImplementedError("Currently not implemented for 'list'") else: assert ground_truth == getattr( test_target_item, key From 7be613e2fc01b0f77f227763280a954a67608f9a Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 11:57:26 +0100 Subject: [PATCH 03/21] WIP nested pydantic structures --- .../tests/unit/interfaces/graph/conftest.py | 105 +++++++++++++++++- 1 file changed, 103 insertions(+), 2 deletions(-) diff --git a/cognee/tests/unit/interfaces/graph/conftest.py b/cognee/tests/unit/interfaces/graph/conftest.py index 45f977bd6..83d64523d 100644 --- a/cognee/tests/unit/interfaces/graph/conftest.py +++ b/cognee/tests/unit/interfaces/graph/conftest.py @@ -1,3 +1,5 @@ +import random +import string from enum import Enum from typing import Optional @@ -36,8 +38,8 @@ class Car(DataPoint): class Person(DataPoint): id: str name: str - age: int - owns_car: list[Car] + age: Optional[int] + owns_car: Optional[list[Car]] driving_license: Optional[dict] _metadata: dict = dict(index_fields=["name"]) @@ -66,3 +68,102 @@ def boris(): }, ) return boris + + +class Organization(DataPoint): + id: str + name: str + members: Optional[list["SocietyPerson"]] + + +class SocietyPerson(DataPoint): + id: str + name: str + memberships: Optional[list[Organization]] + + +Organization.model_rebuild() +SocietyPerson.model_rebuild() + + +ORGANIZATION_NAMES = [ + "ChessClub", + "RowingClub", + "TheatreTroupe", + "PoliticalParty", + "Charity", + "FanClub", + "FilmClub", + "NeighborhoodGroup", + "LocalCouncil", + "Band", +] +PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"] + + +def create_society_person_recursive(id, name, organization_names, max_depth, depth=0): + if depth < max_depth: + memberships = [ + create_organization_recursive( + org_name, org_name.lower(), PERSON_NAMES, max_depth, depth + 1 + ) + for org_name in organization_names + ] + else: + memberships = None + + id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + return SocietyPerson( + id=f"{id}{depth}-{id_suffix}", name=f"{name}{depth}", memberships=memberships + ) + + +def create_organization_recursive(id, name, member_names, max_depth, depth=0): + if depth < max_depth: + members = [ + create_society_person_recursive( + member_name, + member_name.lower(), + ORGANIZATION_NAMES, + max_depth, + depth + 1, + ) + for member_name in member_names + ] + else: + members = None + + id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + return Organization( + id=f"{id}{depth}-{id_suffix}", name=f"{id}{name}", members=members + ) + + +def count_society(obj): + if isinstance(obj, SocietyPerson): + if obj.memberships is not None: + organization_counts, society_person_counts = zip( + *[count_society(organization) for organization in obj.memberships] + ) + organization_count = sum(organization_counts) + society_person_count = sum(society_person_counts) + 1 + return (organization_count, society_person_count) + else: + return (0, 1) + if isinstance(obj, Organization): + if obj.members is not None: + organization_counts, society_person_counts = zip( + *[count_society(organization) for organization in obj.members] + ) + organization_count = sum(organization_counts) + 1 + society_person_count = sum(society_person_counts) + return (organization_count, society_person_count) + else: + return (1, 0) + else: + return (0, 0) + + +@pytest.fixture(scope="function") +def society(): + society = create_organization_recursive("society", "Society", PERSON_NAMES, 4) From 2c0fce32d33a0684545325ad7c56791d62e0861b Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 13:38:33 +0100 Subject: [PATCH 04/21] WIP get_graph_from_model --- .../graph/utils/get_graph_from_model.py | 86 ++++++++----------- 1 file changed, 34 insertions(+), 52 deletions(-) diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 810be7ce8..7b05d2046 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -2,6 +2,37 @@ from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model +def add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added_nodes, added_edges): + + property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) + + for node in property_nodes: + if str(node.id) not in added_nodes: + nodes.append(node) + added_nodes[str(node.id)] = True + + for edge in property_edges: + edge_key = str(edge[0]) + str(edge[1]) + edge[2] + + if str(edge_key) not in added_edges: + edges.append(edge) + added_edges[str(edge_key)] = True + + for property_node in get_own_properties(property_nodes, property_edges): + edge_key = str(data_point.id) + str(property_node.id) + field_name + + if str(edge_key) not in added_edges: + edges.append((data_point.id, property_node.id, field_name, { + "source_node_id": data_point.id, + "target_node_id": property_node.id, + "relationship_name": field_name, + "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + })) + added_edges[str(edge_key)] = True + + return(nodes, edges, added_nodes, added_edges) + + def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = None, added_edges = None): if not added_nodes: @@ -22,65 +53,16 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes if isinstance(field_value, DataPoint): excluded_properties.add(field_name) - property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) + nodes, edges, added_nodes, added_edges = add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added_nodes, added_edges) - for node in property_nodes: - if str(node.id) not in added_nodes: - nodes.append(node) - added_nodes[str(node.id)] = True - - for edge in property_edges: - edge_key = str(edge[0]) + str(edge[1]) + edge[2] - - if str(edge_key) not in added_edges: - edges.append(edge) - added_edges[str(edge_key)] = True - - for property_node in get_own_properties(property_nodes, property_edges): - edge_key = str(data_point.id) + str(property_node.id) + field_name - - if str(edge_key) not in added_edges: - edges.append((data_point.id, property_node.id, field_name, { - "source_node_id": data_point.id, - "target_node_id": property_node.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - })) - added_edges[str(edge_key)] = True - continue if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): excluded_properties.add(field_name) for item in field_value: - property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges) + nodes, edges, added_nodes, added_edges = add_nodes_and_edges(data_point, field_name, item, nodes, edges, added_nodes, added_edges) + edges = [(*edge[:3],{**edge[3], "metadata": {"type": "list"}}) for edge in edges] - for node in property_nodes: - if str(node.id) not in added_nodes: - nodes.append(node) - added_nodes[str(node.id)] = True - - for edge in property_edges: - edge_key = str(edge[0]) + str(edge[1]) + edge[2] - - if str(edge_key) not in added_edges: - edges.append(edge) - added_edges[edge_key] = True - - for property_node in get_own_properties(property_nodes, property_edges): - edge_key = str(data_point.id) + str(property_node.id) + field_name - - if str(edge_key) not in added_edges: - edges.append((data_point.id, property_node.id, field_name, { - "source_node_id": data_point.id, - "target_node_id": property_node.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - "metadata": { - "type": "list" - }, - })) - added_edges[edge_key] = True continue data_point_properties[field_name] = field_value From 05ea3575208c5df5a4d2f41b20d4ed7e9fcda578 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 13:43:13 +0100 Subject: [PATCH 05/21] Refactor get_graph_from_model --- .../graph/utils/get_graph_from_model.py | 136 ++++++++++-------- 1 file changed, 79 insertions(+), 57 deletions(-) diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 7b05d2046..d1e14c878 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -2,9 +2,70 @@ from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model -def add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added_nodes, added_edges): - property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) +def get_graph_from_model( + data_point: DataPoint, include_root=True, added_nodes=None, added_edges=None +): + + if not added_nodes: + added_nodes = {} + if not added_edges: + added_edges = {} + + nodes = [] + edges = [] + + data_point_properties = {} + excluded_properties = set() + + for field_name, field_value in data_point: + if field_name == "_metadata": + continue + elif isinstance(field_value, DataPoint): + excluded_properties.add(field_name) + nodes, edges, added_nodes, added_edges = add_nodes_and_edges( + data_point, field_name, field_value, nodes, edges, added_nodes, added_edges + ) + + elif ( + isinstance(field_value, list) + and len(field_value) > 0 + and isinstance(field_value[0], DataPoint) + ): + excluded_properties.add(field_name) + + for item in field_value: + nodes, edges, added_nodes, added_edges = add_nodes_and_edges( + data_point, field_name, item, nodes, edges, added_nodes, added_edges + ) + edges = [ + (*edge[:3], {**edge[3], "metadata": {"type": "list"}}) + for edge in edges + ] + else: + data_point_properties[field_name] = field_value + + SimpleDataPointModel = copy_model( + type(data_point), + include_fields={ + "_metadata": (dict, data_point._metadata), + }, + exclude_fields=excluded_properties, + ) + + if include_root: + nodes.append(SimpleDataPointModel(**data_point_properties)) + + return nodes, edges + + +def add_nodes_and_edges( + data_point, field_name, field_value, nodes, edges, added_nodes, added_edges +): + + property_nodes, property_edges = get_graph_from_model( + field_value, True, added_nodes, added_edges + ) for node in property_nodes: if str(node.id) not in added_nodes: @@ -22,63 +83,24 @@ def add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added edge_key = str(data_point.id) + str(property_node.id) + field_name if str(edge_key) not in added_edges: - edges.append((data_point.id, property_node.id, field_name, { - "source_node_id": data_point.id, - "target_node_id": property_node.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - })) + edges.append( + ( + data_point.id, + property_node.id, + field_name, + { + "source_node_id": data_point.id, + "target_node_id": property_node.id, + "relationship_name": field_name, + "updated_at": datetime.now(timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S" + ), + }, + ) + ) added_edges[str(edge_key)] = True - - return(nodes, edges, added_nodes, added_edges) - -def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = None, added_edges = None): - - if not added_nodes: - added_nodes = {} - if not added_edges: - added_edges = {} - - nodes = [] - edges = [] - - data_point_properties = {} - excluded_properties = set() - - for field_name, field_value in data_point: - if field_name == "_metadata": - continue - - if isinstance(field_value, DataPoint): - excluded_properties.add(field_name) - - nodes, edges, added_nodes, added_edges = add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added_nodes, added_edges) - - - if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): - excluded_properties.add(field_name) - - for item in field_value: - nodes, edges, added_nodes, added_edges = add_nodes_and_edges(data_point, field_name, item, nodes, edges, added_nodes, added_edges) - edges = [(*edge[:3],{**edge[3], "metadata": {"type": "list"}}) for edge in edges] - - continue - - data_point_properties[field_name] = field_value - - SimpleDataPointModel = copy_model( - type(data_point), - include_fields = { - "_metadata": (dict, data_point._metadata), - }, - exclude_fields = excluded_properties, - ) - - if include_root: - nodes.append(SimpleDataPointModel(**data_point_properties)) - - return nodes, edges + return (nodes, edges, added_nodes, added_edges) def get_own_properties(property_nodes, property_edges): From a5860700a7255ecbf702411f136d8efe2651e5d2 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 14:00:59 +0100 Subject: [PATCH 06/21] Remove include_root parameter --- cognee/modules/graph/utils/get_graph_from_model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index d1e14c878..def6e4840 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -4,7 +4,7 @@ from cognee.modules.storage.utils import copy_model def get_graph_from_model( - data_point: DataPoint, include_root=True, added_nodes=None, added_edges=None + data_point: DataPoint, added_nodes=None, added_edges=None ): if not added_nodes: @@ -53,8 +53,7 @@ def get_graph_from_model( exclude_fields=excluded_properties, ) - if include_root: - nodes.append(SimpleDataPointModel(**data_point_properties)) + nodes.append(SimpleDataPointModel(**data_point_properties)) return nodes, edges @@ -64,7 +63,7 @@ def add_nodes_and_edges( ): property_nodes, property_edges = get_graph_from_model( - field_value, True, added_nodes, added_edges + field_value, added_nodes, added_edges ) for node in property_nodes: From 3c8a52f4b00a91c7ddc3fa195b3fc818681d15da Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 14:47:36 +0100 Subject: [PATCH 07/21] Fix inconsistent state between nodes and added_nodes and edges and added_edges --- cognee/modules/graph/utils/get_graph_from_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index def6e4840..bd6480ba0 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -63,7 +63,7 @@ def add_nodes_and_edges( ): property_nodes, property_edges = get_graph_from_model( - field_value, added_nodes, added_edges + field_value, dict(added_nodes), dict(added_edges) ) for node in property_nodes: From afae70f3b59cee1d13eaab152162556fbdd22278 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 15:10:42 +0100 Subject: [PATCH 08/21] Add get_graph_from_model_generative_test --- .../tests/unit/interfaces/graph/conftest.py | 101 ----------------- .../get_graph_from_model_generative_test.py | 28 +++++ cognee/tests/unit/interfaces/graph/util.py | 102 +++++++++++++++++- 3 files changed, 129 insertions(+), 102 deletions(-) create mode 100644 cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py diff --git a/cognee/tests/unit/interfaces/graph/conftest.py b/cognee/tests/unit/interfaces/graph/conftest.py index 83d64523d..8084fb2bf 100644 --- a/cognee/tests/unit/interfaces/graph/conftest.py +++ b/cognee/tests/unit/interfaces/graph/conftest.py @@ -1,5 +1,3 @@ -import random -import string from enum import Enum from typing import Optional @@ -68,102 +66,3 @@ def boris(): }, ) return boris - - -class Organization(DataPoint): - id: str - name: str - members: Optional[list["SocietyPerson"]] - - -class SocietyPerson(DataPoint): - id: str - name: str - memberships: Optional[list[Organization]] - - -Organization.model_rebuild() -SocietyPerson.model_rebuild() - - -ORGANIZATION_NAMES = [ - "ChessClub", - "RowingClub", - "TheatreTroupe", - "PoliticalParty", - "Charity", - "FanClub", - "FilmClub", - "NeighborhoodGroup", - "LocalCouncil", - "Band", -] -PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"] - - -def create_society_person_recursive(id, name, organization_names, max_depth, depth=0): - if depth < max_depth: - memberships = [ - create_organization_recursive( - org_name, org_name.lower(), PERSON_NAMES, max_depth, depth + 1 - ) - for org_name in organization_names - ] - else: - memberships = None - - id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) - return SocietyPerson( - id=f"{id}{depth}-{id_suffix}", name=f"{name}{depth}", memberships=memberships - ) - - -def create_organization_recursive(id, name, member_names, max_depth, depth=0): - if depth < max_depth: - members = [ - create_society_person_recursive( - member_name, - member_name.lower(), - ORGANIZATION_NAMES, - max_depth, - depth + 1, - ) - for member_name in member_names - ] - else: - members = None - - id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) - return Organization( - id=f"{id}{depth}-{id_suffix}", name=f"{id}{name}", members=members - ) - - -def count_society(obj): - if isinstance(obj, SocietyPerson): - if obj.memberships is not None: - organization_counts, society_person_counts = zip( - *[count_society(organization) for organization in obj.memberships] - ) - organization_count = sum(organization_counts) - society_person_count = sum(society_person_counts) + 1 - return (organization_count, society_person_count) - else: - return (0, 1) - if isinstance(obj, Organization): - if obj.members is not None: - organization_counts, society_person_counts = zip( - *[count_society(organization) for organization in obj.members] - ) - organization_count = sum(organization_counts) + 1 - society_person_count = sum(society_person_counts) - return (organization_count, society_person_count) - else: - return (1, 0) - else: - return (0, 0) - - -@pytest.fixture(scope="function") -def society(): - society = create_organization_recursive("society", "Society", PERSON_NAMES, 4) diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py new file mode 100644 index 000000000..2a5816ac3 --- /dev/null +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py @@ -0,0 +1,28 @@ +import pytest + +from cognee.modules.graph.utils import get_graph_from_model +from cognee.tests.unit.interfaces.graph.util import ( + PERSON_NAMES, + count_society, + create_organization_recursive, +) + + +@pytest.mark.parametrize("recursive_depth", [1, 2, 3]) +def test_extracted_car_type(recursive_depth): + society = create_organization_recursive( + "society", "Society", PERSON_NAMES, recursive_depth + ) + + n_organizations, n_persons = count_society(society) + society_counts_total = n_organizations + n_persons + + nodes, edges = get_graph_from_model(society) + + assert ( + len(nodes) == society_counts_total + ), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found" + + assert len(edges) == ( + len(nodes) - 1 + ), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node" diff --git a/cognee/tests/unit/interfaces/graph/util.py b/cognee/tests/unit/interfaces/graph/util.py index e5da0201c..3bdd55fe7 100644 --- a/cognee/tests/unit/interfaces/graph/util.py +++ b/cognee/tests/unit/interfaces/graph/util.py @@ -1,5 +1,9 @@ +import random +import string from datetime import datetime, timezone -from typing import Any, Dict +from typing import Any, Dict, Optional + +from cognee.infrastructure.engine import DataPoint def run_test_against_ground_truth( @@ -30,3 +34,99 @@ def run_test_against_ground_truth( time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at") assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }" + + +class Organization(DataPoint): + id: str + name: str + members: Optional[list["SocietyPerson"]] + + +class SocietyPerson(DataPoint): + id: str + name: str + memberships: Optional[list[Organization]] + + +Organization.model_rebuild() +SocietyPerson.model_rebuild() + + +ORGANIZATION_NAMES = [ + "ChessClub", + "RowingClub", + "TheatreTroupe", + "PoliticalParty", + "Charity", + "FanClub", + "FilmClub", + "NeighborhoodGroup", + "LocalCouncil", + "Band", +] +PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"] + + +def create_society_person_recursive(id, name, organization_names, max_depth, depth=0): + id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + + if depth < max_depth: + memberships = [ + create_organization_recursive( + f"{org_name}-{depth}-{id_suffix}", + org_name.lower(), + PERSON_NAMES, + max_depth, + depth + 1, + ) + for org_name in organization_names + ] + else: + memberships = None + + return SocietyPerson(id=id, name=f"{name}{depth}", memberships=memberships) + + +def create_organization_recursive(id, name, member_names, max_depth, depth=0): + id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + + if depth < max_depth: + members = [ + create_society_person_recursive( + f"{member_name}-{depth}-{id_suffix}", + member_name.lower(), + ORGANIZATION_NAMES, + max_depth, + depth + 1, + ) + for member_name in member_names + ] + else: + members = None + + return Organization(id=id, name=f"{name}{depth}", members=members) + + +def count_society(obj): + if isinstance(obj, SocietyPerson): + if obj.memberships is not None: + organization_counts, society_person_counts = zip( + *[count_society(organization) for organization in obj.memberships] + ) + organization_count = sum(organization_counts) + society_person_count = sum(society_person_counts) + 1 + return (organization_count, society_person_count) + else: + return (0, 1) + if isinstance(obj, Organization): + if obj.members is not None: + organization_counts, society_person_counts = zip( + *[count_society(organization) for organization in obj.members] + ) + organization_count = sum(organization_counts) + 1 + society_person_count = sum(society_person_counts) + return (organization_count, society_person_count) + else: + return (1, 0) + else: + raise Exception("Not allowed") From 5a464bfca78d0a9f6d7059a261a839e1d85ff020 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 15:57:50 +0100 Subject: [PATCH 09/21] Refactor get_model_instance_from_graph --- .../utils/get_model_instance_from_graph.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/cognee/modules/graph/utils/get_model_instance_from_graph.py b/cognee/modules/graph/utils/get_model_instance_from_graph.py index 82cdfa150..bdd0dface 100644 --- a/cognee/modules/graph/utils/get_model_instance_from_graph.py +++ b/cognee/modules/graph/utils/get_model_instance_from_graph.py @@ -2,28 +2,35 @@ from pydantic_core import PydanticUndefined from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model +def merge_dicts(dict1, dict2, agg_fn): + merged_dict = {} + for key, value in dict1.items(): + if key in dict2: + merged_dict[key] = agg_fn(value, dict2[key]) + else: + merged_dict[key] = value -def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str): - node_map = {} + for key, value in dict2.items(): + if key not in merged_dict: + merged_dict[key] = value + return merged_dict - for node in nodes: - node_map[node.id] = node +def get_model_instance_from_graph(nodes: list[DataPoint], edges: list[tuple[str, str, str, dict[str, str]]], entity_id: str): + node_map = {node.id: node for node in nodes} - for edge in edges: - source_node = node_map[edge[0]] - target_node = node_map[edge[1]] - edge_label = edge[2] - edge_properties = edge[3] if len(edge) == 4 else {} + for source_node_id, target_node_id, edge_label, edge_properties in edges: + source_node = node_map[source_node_id] + target_node = node_map[target_node_id] edge_metadata = edge_properties.get("metadata", {}) edge_type = edge_metadata.get("type") if edge_type == "list": NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) }) - - node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] }) + new_model_dict = merge_dicts(source_node.model_dump(), { edge_label: [target_node] }, lambda a, b: a + b) + node_map[source_node_id] = NewModel(**new_model_dict) else: NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) }) - node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node }) + node_map[target_node_id] = NewModel(**source_node.model_dump(), **{ edge_label: target_node }) return node_map[entity_id] From 370b59b39a70086b7a341e63340a8cc1cb23610e Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 15:58:03 +0100 Subject: [PATCH 10/21] Add get_graph_from_model_generative_test --- .../get_graph_from_model_generative_test.py | 2 +- ...del_instance_from_graph_generative_test.py | 24 +++++++++++++++++++ cognee/tests/unit/interfaces/graph/util.py | 19 +++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py index 2a5816ac3..dee4f5042 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py @@ -9,7 +9,7 @@ from cognee.tests.unit.interfaces.graph.util import ( @pytest.mark.parametrize("recursive_depth", [1, 2, 3]) -def test_extracted_car_type(recursive_depth): +def test_society_nodes_and_edges(recursive_depth): society = create_organization_recursive( "society", "Society", PERSON_NAMES, recursive_depth ) diff --git a/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py b/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py new file mode 100644 index 000000000..10578216b --- /dev/null +++ b/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py @@ -0,0 +1,24 @@ +import pytest + +from cognee.modules.graph.utils import ( + get_graph_from_model, + get_model_instance_from_graph, +) +from cognee.tests.unit.interfaces.graph.util import ( + PERSON_NAMES, + create_organization_recursive, + show_first_difference, +) + + +@pytest.mark.parametrize("recursive_depth", [1, 2, 3]) +def test_society_nodes_and_edges(recursive_depth): + society = create_organization_recursive( + "society", "Society", PERSON_NAMES, recursive_depth + ) + nodes, edges = get_graph_from_model(society) + parsed_society = get_model_instance_from_graph(nodes, edges, "society") + + assert str(society) == (str(parsed_society)), show_first_difference( + str(society), str(parsed_society), "society", "parsed_society" + ) diff --git a/cognee/tests/unit/interfaces/graph/util.py b/cognee/tests/unit/interfaces/graph/util.py index 3bdd55fe7..4a60c94fa 100644 --- a/cognee/tests/unit/interfaces/graph/util.py +++ b/cognee/tests/unit/interfaces/graph/util.py @@ -130,3 +130,22 @@ def count_society(obj): return (1, 0) else: raise Exception("Not allowed") + + +def show_first_difference(str1, str2, str1_name, str2_name, context=30): + """Shows where two strings first diverge, with surrounding context.""" + for i, (c1, c2) in enumerate(zip(str1, str2)): + if c1 != c2: + start = max(0, i - context) + end1 = min(len(str1), i + context + 1) + end2 = min(len(str2), i + context + 1) + if i > 0: + return f"identical: '{str1[start:i-1]}' | {str1_name}: '{str1[i-1:end1]}'... != {str2_name}: '{str2[i-1:end2]}'..." + else: + return f"{str1_name} and {str2_name} have no overlap in characters" + if len(str1) > len(str2): + return f"{str2_name} is identical up to the {i}th character, missing afterwards '{str1[i:i+context]}'..." + if len(str2) > len(str1): + return f"{str1_name} is identical up to the {i}th character, missing afterwards '{str2[i:i+context]}'..." + else: + return f"{str1_name} and {str2_name} are identical." From f3f0bca9bdc679627681e2c0b0008fcc4a5dad7a Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 16:03:53 +0100 Subject: [PATCH 11/21] Revert making Person attributes optional --- cognee/tests/unit/interfaces/graph/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cognee/tests/unit/interfaces/graph/conftest.py b/cognee/tests/unit/interfaces/graph/conftest.py index 8084fb2bf..45f977bd6 100644 --- a/cognee/tests/unit/interfaces/graph/conftest.py +++ b/cognee/tests/unit/interfaces/graph/conftest.py @@ -36,8 +36,8 @@ class Car(DataPoint): class Person(DataPoint): id: str name: str - age: Optional[int] - owns_car: Optional[list[Car]] + age: int + owns_car: list[Car] driving_license: Optional[dict] _metadata: dict = dict(index_fields=["name"]) From a1f72727bcfa93c217c1a38400d6553ddc222ec6 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 16:17:33 +0100 Subject: [PATCH 12/21] Revert model_rebuild order --- cognee/tests/unit/interfaces/graph/util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cognee/tests/unit/interfaces/graph/util.py b/cognee/tests/unit/interfaces/graph/util.py index 4a60c94fa..c06023cc2 100644 --- a/cognee/tests/unit/interfaces/graph/util.py +++ b/cognee/tests/unit/interfaces/graph/util.py @@ -47,9 +47,8 @@ class SocietyPerson(DataPoint): name: str memberships: Optional[list[Organization]] - -Organization.model_rebuild() SocietyPerson.model_rebuild() +Organization.model_rebuild() ORGANIZATION_NAMES = [ From 148eb4ed9bf44ceb56350bd890b8935aba16f949 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 16:42:14 +0100 Subject: [PATCH 13/21] Add profile_graph_pydantic_conversion.py --- .../benchmark_function.py | 64 +++++++++++++++++++ .../profile_graph_pydantic_conversion.py | 37 +++++++++++ 2 files changed, 101 insertions(+) create mode 100644 profiling/graph_pydantic_conversion/benchmark_function.py create mode 100644 profiling/graph_pydantic_conversion/profile_graph_pydantic_conversion.py diff --git a/profiling/graph_pydantic_conversion/benchmark_function.py b/profiling/graph_pydantic_conversion/benchmark_function.py new file mode 100644 index 000000000..95c483584 --- /dev/null +++ b/profiling/graph_pydantic_conversion/benchmark_function.py @@ -0,0 +1,64 @@ +import time +import psutil +import tracemalloc +import statistics +from typing import Callable, Any, Dict + + +def benchmark_function(func: Callable, *args, num_runs: int = 5) -> Dict[str, Any]: + """ + Benchmark a function for memory usage and computational performance. + + Args: + func: Function to benchmark + *args: Arguments to pass to the function + num_runs: Number of times to run the benchmark + + Returns: + Dictionary containing benchmark metrics + """ + execution_times = [] + peak_memory_usages = [] + cpu_percentages = [] + + process = psutil.Process() + + for _ in range(num_runs): + # Start memory tracking + tracemalloc.start() + initial_memory = process.memory_info().rss + + # Measure execution time and CPU usage + start_time = time.perf_counter() + start_cpu_time = process.cpu_times() + + result = func(*args) + + end_cpu_time = process.cpu_times() + end_time = time.perf_counter() + + # Calculate metrics + execution_time = end_time - start_time + cpu_time = (end_cpu_time.user + end_cpu_time.system) - (start_cpu_time.user + start_cpu_time.system) + current, peak = tracemalloc.get_traced_memory() + final_memory = process.memory_info().rss + memory_used = final_memory - initial_memory + + # Store results + execution_times.append(execution_time) + peak_memory_usages.append(peak / 1024 / 1024) # Convert to MB + cpu_percentages.append((cpu_time / execution_time) * 100) + + tracemalloc.stop() + + analysis = { + "mean_execution_time": statistics.mean(execution_times), + "mean_peak_memory_mb": statistics.mean(peak_memory_usages), + "mean_cpu_percent": statistics.mean(cpu_percentages), + "num_runs": num_runs + } + + if num_runs > 1: + analysis["std_execution_time"] = statistics.stdev(execution_times) + + return analysis \ No newline at end of file diff --git a/profiling/graph_pydantic_conversion/profile_graph_pydantic_conversion.py b/profiling/graph_pydantic_conversion/profile_graph_pydantic_conversion.py new file mode 100644 index 000000000..75803b996 --- /dev/null +++ b/profiling/graph_pydantic_conversion/profile_graph_pydantic_conversion.py @@ -0,0 +1,37 @@ +import time +import argparse + +from benchmark_function import benchmark_function +from cognee.modules.graph.utils import get_graph_from_model + +from cognee.tests.unit.interfaces.graph.util import ( + PERSON_NAMES, + create_organization_recursive, +) + + + +# Example usage: +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Benchmark graph model with configurable recursive depth') + parser.add_argument('--recursive-depth', type=int, default=3, + help='Recursive depth for graph generation (default: 3)') + parser.add_argument('--runs', type=int, default=5, + help='Number of benchmark runs (default: 5)') + args = parser.parse_args() + + + society = create_organization_recursive( + "society", "Society", PERSON_NAMES, args.recursive_depth + ) + nodes, edges = get_graph_from_model(society) + + results = benchmark_function(get_graph_from_model, society, num_runs=args.runs) + print("\nBenchmark Results:") + print(f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}") + print(f"Mean Peak Memory: {results['mean_peak_memory_mb']:.2f} MB") + print(f"Mean CPU Usage: {results['mean_cpu_percent']:.2f}%") + print(f"Mean Execution Time: {results['mean_execution_time']:.4f} seconds") + + if 'std_execution_time' in results: + print(f"Execution Time Std: {results['std_execution_time']:.4f} seconds") From 5b420ebccc4eb3e15022e7b4904ac0164937d524 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 16:44:30 +0100 Subject: [PATCH 14/21] Autoformat graph pydantic conversion code --- .../benchmark_function.py | 39 ++++++++++--------- .../profile_graph_pydantic_conversion.py | 32 ++++++++------- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/profiling/graph_pydantic_conversion/benchmark_function.py b/profiling/graph_pydantic_conversion/benchmark_function.py index 95c483584..58990cc31 100644 --- a/profiling/graph_pydantic_conversion/benchmark_function.py +++ b/profiling/graph_pydantic_conversion/benchmark_function.py @@ -1,64 +1,67 @@ -import time -import psutil -import tracemalloc import statistics -from typing import Callable, Any, Dict +import time +import tracemalloc +from typing import Any, Callable, Dict + +import psutil def benchmark_function(func: Callable, *args, num_runs: int = 5) -> Dict[str, Any]: """ Benchmark a function for memory usage and computational performance. - + Args: func: Function to benchmark *args: Arguments to pass to the function num_runs: Number of times to run the benchmark - + Returns: Dictionary containing benchmark metrics """ execution_times = [] peak_memory_usages = [] cpu_percentages = [] - + process = psutil.Process() - + for _ in range(num_runs): # Start memory tracking tracemalloc.start() initial_memory = process.memory_info().rss - + # Measure execution time and CPU usage start_time = time.perf_counter() start_cpu_time = process.cpu_times() - + result = func(*args) - + end_cpu_time = process.cpu_times() end_time = time.perf_counter() - + # Calculate metrics execution_time = end_time - start_time - cpu_time = (end_cpu_time.user + end_cpu_time.system) - (start_cpu_time.user + start_cpu_time.system) + cpu_time = (end_cpu_time.user + end_cpu_time.system) - ( + start_cpu_time.user + start_cpu_time.system + ) current, peak = tracemalloc.get_traced_memory() final_memory = process.memory_info().rss memory_used = final_memory - initial_memory - + # Store results execution_times.append(execution_time) peak_memory_usages.append(peak / 1024 / 1024) # Convert to MB cpu_percentages.append((cpu_time / execution_time) * 100) - + tracemalloc.stop() - + analysis = { "mean_execution_time": statistics.mean(execution_times), "mean_peak_memory_mb": statistics.mean(peak_memory_usages), "mean_cpu_percent": statistics.mean(cpu_percentages), - "num_runs": num_runs + "num_runs": num_runs, } if num_runs > 1: analysis["std_execution_time"] = statistics.stdev(execution_times) - return analysis \ No newline at end of file + return analysis diff --git a/profiling/graph_pydantic_conversion/profile_graph_pydantic_conversion.py b/profiling/graph_pydantic_conversion/profile_graph_pydantic_conversion.py index 75803b996..664186c28 100644 --- a/profiling/graph_pydantic_conversion/profile_graph_pydantic_conversion.py +++ b/profiling/graph_pydantic_conversion/profile_graph_pydantic_conversion.py @@ -1,37 +1,43 @@ -import time import argparse +import time from benchmark_function import benchmark_function -from cognee.modules.graph.utils import get_graph_from_model +from cognee.modules.graph.utils import get_graph_from_model from cognee.tests.unit.interfaces.graph.util import ( PERSON_NAMES, create_organization_recursive, ) - - # Example usage: if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Benchmark graph model with configurable recursive depth') - parser.add_argument('--recursive-depth', type=int, default=3, - help='Recursive depth for graph generation (default: 3)') - parser.add_argument('--runs', type=int, default=5, - help='Number of benchmark runs (default: 5)') + parser = argparse.ArgumentParser( + description="Benchmark graph model with configurable recursive depth" + ) + parser.add_argument( + "--recursive-depth", + type=int, + default=3, + help="Recursive depth for graph generation (default: 3)", + ) + parser.add_argument( + "--runs", type=int, default=5, help="Number of benchmark runs (default: 5)" + ) args = parser.parse_args() - society = create_organization_recursive( "society", "Society", PERSON_NAMES, args.recursive_depth ) nodes, edges = get_graph_from_model(society) - + results = benchmark_function(get_graph_from_model, society, num_runs=args.runs) print("\nBenchmark Results:") - print(f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}") + print( + f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}" + ) print(f"Mean Peak Memory: {results['mean_peak_memory_mb']:.2f} MB") print(f"Mean CPU Usage: {results['mean_cpu_percent']:.2f}%") print(f"Mean Execution Time: {results['mean_execution_time']:.4f} seconds") - if 'std_execution_time' in results: + if "std_execution_time" in results: print(f"Execution Time Std: {results['std_execution_time']:.4f} seconds") From a3342918d971d3c9a3d46dd923d0c0a5d166a4b4 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 16:53:32 +0100 Subject: [PATCH 15/21] Apply cosmetic changes and autoformat --- .../graph/utils/get_graph_from_model.py | 18 ++++++---- .../utils/get_model_instance_from_graph.py | 34 +++++++++++++++---- cognee/tests/unit/interfaces/graph/util.py | 1 - 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index bd6480ba0..770e63d05 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -1,11 +1,10 @@ from datetime import datetime, timezone + from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model -def get_graph_from_model( - data_point: DataPoint, added_nodes=None, added_edges=None -): +def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None): if not added_nodes: added_nodes = {} @@ -24,7 +23,13 @@ def get_graph_from_model( elif isinstance(field_value, DataPoint): excluded_properties.add(field_name) nodes, edges, added_nodes, added_edges = add_nodes_and_edges( - data_point, field_name, field_value, nodes, edges, added_nodes, added_edges + data_point, + field_name, + field_value, + nodes, + edges, + added_nodes, + added_edges, ) elif ( @@ -35,12 +40,13 @@ def get_graph_from_model( excluded_properties.add(field_name) for item in field_value: + n_edges_before = len(edges) nodes, edges, added_nodes, added_edges = add_nodes_and_edges( data_point, field_name, item, nodes, edges, added_nodes, added_edges ) - edges = [ + edges = edges[:n_edges_before] + [ (*edge[:3], {**edge[3], "metadata": {"type": "list"}}) - for edge in edges + for edge in edges[n_edges_before:] ] else: data_point_properties[field_name] = field_value diff --git a/cognee/modules/graph/utils/get_model_instance_from_graph.py b/cognee/modules/graph/utils/get_model_instance_from_graph.py index bdd0dface..87146111c 100644 --- a/cognee/modules/graph/utils/get_model_instance_from_graph.py +++ b/cognee/modules/graph/utils/get_model_instance_from_graph.py @@ -1,8 +1,12 @@ +from typing import Callable + from pydantic_core import PydanticUndefined + from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model -def merge_dicts(dict1, dict2, agg_fn): + +def merge_dicts(dict1: dict, dict2: dict, agg_fn: Callable) -> dict: merged_dict = {} for key, value in dict1.items(): if key in dict2: @@ -15,22 +19,38 @@ def merge_dicts(dict1, dict2, agg_fn): merged_dict[key] = value return merged_dict -def get_model_instance_from_graph(nodes: list[DataPoint], edges: list[tuple[str, str, str, dict[str, str]]], entity_id: str): + +def get_model_instance_from_graph( + nodes: list[DataPoint], + edges: list[tuple[str, str, str, dict[str, str]]], + entity_id: str, +): node_map = {node.id: node for node in nodes} for source_node_id, target_node_id, edge_label, edge_properties in edges: source_node = node_map[source_node_id] target_node = node_map[target_node_id] edge_metadata = edge_properties.get("metadata", {}) - edge_type = edge_metadata.get("type") + edge_type = edge_metadata.get("type", "default") if edge_type == "list": - NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) }) - new_model_dict = merge_dicts(source_node.model_dump(), { edge_label: [target_node] }, lambda a, b: a + b) + NewModel = copy_model( + type(source_node), + {edge_label: (list[type(target_node)], PydanticUndefined)}, + ) + new_model_dict = merge_dicts( + source_node.model_dump(), + {edge_label: [target_node]}, + lambda a, b: a + b, + ) node_map[source_node_id] = NewModel(**new_model_dict) else: - NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) }) + NewModel = copy_model( + type(source_node), {edge_label: (type(target_node), PydanticUndefined)} + ) - node_map[target_node_id] = NewModel(**source_node.model_dump(), **{ edge_label: target_node }) + node_map[target_node_id] = NewModel( + **source_node.model_dump(), **{edge_label: target_node} + ) return node_map[entity_id] diff --git a/cognee/tests/unit/interfaces/graph/util.py b/cognee/tests/unit/interfaces/graph/util.py index c06023cc2..c8909d40d 100644 --- a/cognee/tests/unit/interfaces/graph/util.py +++ b/cognee/tests/unit/interfaces/graph/util.py @@ -132,7 +132,6 @@ def count_society(obj): def show_first_difference(str1, str2, str1_name, str2_name, context=30): - """Shows where two strings first diverge, with surrounding context.""" for i, (c1, c2) in enumerate(zip(str1, str2)): if c1 != c2: start = max(0, i - context) From 8a2cf2075a148a2d9f8d8c9420187952b41897cb Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 17:57:03 +0100 Subject: [PATCH 16/21] Add model_rebuild --- cognee/modules/storage/utils/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cognee/modules/storage/utils/__init__.py b/cognee/modules/storage/utils/__init__.py index 7073e6470..a399e8a82 100644 --- a/cognee/modules/storage/utils/__init__.py +++ b/cognee/modules/storage/utils/__init__.py @@ -29,7 +29,9 @@ def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list **include_fields } - return create_model(model.__name__, **final_fields) + model = create_model(model.__name__, **final_fields) + model.model_rebuild() + return model def get_own_properties(data_point: DataPoint): properties = {} From 103eb13c7717092c79c7c805ea2a585efa4386e9 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Mon, 18 Nov 2024 11:23:22 +0100 Subject: [PATCH 17/21] Skip recursive pydantic tests for Python 3.9 and 3.10 --- .../get_graph_from_model_generative_test.py | 33 ++++++++++++------- ...del_instance_from_graph_generative_test.py | 25 +++++++++----- cognee/tests/unit/interfaces/graph/util.py | 1 + 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py index dee4f5042..dec751f89 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py @@ -1,3 +1,5 @@ +import warnings + import pytest from cognee.modules.graph.utils import get_graph_from_model @@ -10,19 +12,26 @@ from cognee.tests.unit.interfaces.graph.util import ( @pytest.mark.parametrize("recursive_depth", [1, 2, 3]) def test_society_nodes_and_edges(recursive_depth): - society = create_organization_recursive( - "society", "Society", PERSON_NAMES, recursive_depth - ) + import sys - n_organizations, n_persons = count_society(society) - society_counts_total = n_organizations + n_persons + if sys.version_info[0] == 3 and sys.version_info[1] >= 11: + society = create_organization_recursive( + "society", "Society", PERSON_NAMES, recursive_depth + ) - nodes, edges = get_graph_from_model(society) + n_organizations, n_persons = count_society(society) + society_counts_total = n_organizations + n_persons - assert ( - len(nodes) == society_counts_total - ), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found" + nodes, edges = get_graph_from_model(society) - assert len(edges) == ( - len(nodes) - 1 - ), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node" + assert ( + len(nodes) == society_counts_total + ), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found" + + assert len(edges) == ( + len(nodes) - 1 + ), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node" + else: + warnings.warn( + "The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11" + ) diff --git a/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py b/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py index 10578216b..dd5e19469 100644 --- a/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py +++ b/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py @@ -1,3 +1,5 @@ +import warnings + import pytest from cognee.modules.graph.utils import ( @@ -13,12 +15,19 @@ from cognee.tests.unit.interfaces.graph.util import ( @pytest.mark.parametrize("recursive_depth", [1, 2, 3]) def test_society_nodes_and_edges(recursive_depth): - society = create_organization_recursive( - "society", "Society", PERSON_NAMES, recursive_depth - ) - nodes, edges = get_graph_from_model(society) - parsed_society = get_model_instance_from_graph(nodes, edges, "society") + import sys - assert str(society) == (str(parsed_society)), show_first_difference( - str(society), str(parsed_society), "society", "parsed_society" - ) + if sys.version_info[0] == 3 and sys.version_info[1] >= 11: + society = create_organization_recursive( + "society", "Society", PERSON_NAMES, recursive_depth + ) + nodes, edges = get_graph_from_model(society) + parsed_society = get_model_instance_from_graph(nodes, edges, "society") + + assert str(society) == (str(parsed_society)), show_first_difference( + str(society), str(parsed_society), "society", "parsed_society" + ) + else: + warnings.warn( + "The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11" + ) diff --git a/cognee/tests/unit/interfaces/graph/util.py b/cognee/tests/unit/interfaces/graph/util.py index c8909d40d..a20bdb3e4 100644 --- a/cognee/tests/unit/interfaces/graph/util.py +++ b/cognee/tests/unit/interfaces/graph/util.py @@ -47,6 +47,7 @@ class SocietyPerson(DataPoint): name: str memberships: Optional[list[Organization]] + SocietyPerson.model_rebuild() Organization.model_rebuild() From 7a2fc617a875bcfe9a7f9394ef64d9b3c8785127 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Mon, 18 Nov 2024 14:00:14 +0100 Subject: [PATCH 18/21] Rename remaining 'query' keyword args in cognee.search to 'query_text' --- cognee/tests/test_code_generation.py | 2 +- docs/quickstart.md | 2 +- examples/python/dynamic_steps_example.py | 2 +- examples/python/simple_example.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cognee/tests/test_code_generation.py b/cognee/tests/test_code_generation.py index aad59ace8..a21925585 100755 --- a/cognee/tests/test_code_generation.py +++ b/cognee/tests/test_code_generation.py @@ -26,7 +26,7 @@ async def main(): await render_graph(None, include_nodes = True, include_labels = True) - search_results = await cognee.search(SearchType.CHUNKS, query = "Student") + search_results = await cognee.search(SearchType.CHUNKS, query_text = "Student") assert len(search_results) != 0, "The search results list is empty." print("\n\nExtracted chunks are:\n") for result in search_results: diff --git a/docs/quickstart.md b/docs/quickstart.md index 828fb249d..35c46e4f9 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -59,7 +59,7 @@ await cognee.add(text) # Add a new piece of information await cognee.cognify() # Use LLMs and cognee to create knowledge -search_results = await cognee.search("INSIGHTS", {'query': 'Tell me about NLP'}) # Query cognee for the knowledge +search_results = await cognee.search("INSIGHTS", {'query_text': 'Tell me about NLP'}) # Query cognee for the knowledge for result_text in search_results: print(result_text) diff --git a/examples/python/dynamic_steps_example.py b/examples/python/dynamic_steps_example.py index 309aea82c..194450599 100644 --- a/examples/python/dynamic_steps_example.py +++ b/examples/python/dynamic_steps_example.py @@ -209,7 +209,7 @@ async def main(enable_steps): if enable_steps.get("search_insights"): search_results = await cognee.search( SearchType.INSIGHTS, - {'query': 'Which applicant has the most relevant experience in data science?'} + query_text='Which applicant has the most relevant experience in data science?' ) print("Search results:") for result_text in search_results: diff --git a/examples/python/simple_example.py b/examples/python/simple_example.py index 4e0e61834..707705056 100644 --- a/examples/python/simple_example.py +++ b/examples/python/simple_example.py @@ -27,7 +27,7 @@ async def main(): # Query cognee for insights on the added text search_results = await cognee.search( - SearchType.INSIGHTS, query='Tell me about NLP' + SearchType.INSIGHTS, query_text='Tell me about NLP' ) # Display search results From bb404f51f9c472c2fcbeaae1f7328d0685d2f772 Mon Sep 17 00:00:00 2001 From: Boris Date: Mon, 18 Nov 2024 17:50:15 +0100 Subject: [PATCH 19/21] Update docs/quickstart.md --- docs/quickstart.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/quickstart.md b/docs/quickstart.md index 35c46e4f9..0cdc2645c 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -59,7 +59,7 @@ await cognee.add(text) # Add a new piece of information await cognee.cognify() # Use LLMs and cognee to create knowledge -search_results = await cognee.search("INSIGHTS", {'query_text': 'Tell me about NLP'}) # Query cognee for the knowledge +search_results = await cognee.search(SearchType.INSIGHTS, query_text='Tell me about NLP') # Query cognee for the knowledge for result_text in search_results: print(result_text) From 722c7b081a53263c1f616266a4e1d6149496e52f Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Mon, 18 Nov 2024 21:43:10 +0100 Subject: [PATCH 20/21] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 28d5858a0..41d0ac7cd 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![GitHub commits](https://badgen.net/github/commits/topoteretes/cognee)](https://GitHub.com/topoteretes/cognee/commit/) [![Github tag](https://badgen.net/github/tag/topoteretes/cognee)](https://github.com/topoteretes/cognee/tags/) [![Downloads](https://static.pepy.tech/badge/cognee)](https://pepy.tech/project/cognee) -[![GitHub license](https://badgen.net/github/license/topoteretes/cognee)](https://github.com/topoteretes/cognee/blob/master/LICENSE) + We build for developers who need a reliable, production-ready data layer for AI applications From b18f748c9ec4fe4e431a1fb786f25ca2b08e40da Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Tue, 19 Nov 2024 10:56:21 +0100 Subject: [PATCH 21/21] Merge dicts directly --- .../utils/get_model_instance_from_graph.py | 25 ++++--------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/cognee/modules/graph/utils/get_model_instance_from_graph.py b/cognee/modules/graph/utils/get_model_instance_from_graph.py index 87146111c..16658d743 100644 --- a/cognee/modules/graph/utils/get_model_instance_from_graph.py +++ b/cognee/modules/graph/utils/get_model_instance_from_graph.py @@ -6,20 +6,6 @@ from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model -def merge_dicts(dict1: dict, dict2: dict, agg_fn: Callable) -> dict: - merged_dict = {} - for key, value in dict1.items(): - if key in dict2: - merged_dict[key] = agg_fn(value, dict2[key]) - else: - merged_dict[key] = value - - for key, value in dict2.items(): - if key not in merged_dict: - merged_dict[key] = value - return merged_dict - - def get_model_instance_from_graph( nodes: list[DataPoint], edges: list[tuple[str, str, str, dict[str, str]]], @@ -38,12 +24,11 @@ def get_model_instance_from_graph( type(source_node), {edge_label: (list[type(target_node)], PydanticUndefined)}, ) - new_model_dict = merge_dicts( - source_node.model_dump(), - {edge_label: [target_node]}, - lambda a, b: a + b, - ) - node_map[source_node_id] = NewModel(**new_model_dict) + source_node_dict = source_node.model_dump() + source_node_edge_label_values = source_node_dict.get(edge_label, []) + source_node_dict[edge_label] = source_node_edge_label_values + [target_node] + + node_map[source_node_id] = NewModel(**source_node_dict) else: NewModel = copy_model( type(source_node), {edge_label: (type(target_node), PydanticUndefined)}