diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 6fcca4572..5497207a4 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -4,43 +4,50 @@ from cognee.infrastructure.engine import DataPoint, Edge from cognee.modules.storage.utils import copy_model -def _extract_field_info(field_value: Any) -> Tuple[str, Any, Optional[Edge]]: - """Extract field type, actual value, and edge metadata from a field value.""" - - # Handle tuple[Edge, DataPoint] - if ( - isinstance(field_value, tuple) - and len(field_value) == 2 - and isinstance(field_value[0], Edge) - and isinstance(field_value[1], DataPoint) - ): - return "single_datapoint_with_edge", field_value[1], field_value[0] - - # Handle tuple[Edge, list[DataPoint]] - if ( - isinstance(field_value, tuple) - and len(field_value) == 2 - and isinstance(field_value[0], Edge) - and isinstance(field_value[1], list) - and len(field_value[1]) > 0 - and isinstance(field_value[1][0], DataPoint) - ): - return "list_datapoint_with_edge", field_value[1], field_value[0] - +def _extract_field_data(field_value: Any) -> List[Tuple[Optional[Edge], List[DataPoint]]]: + """Extract edge metadata and datapoints from a field value.""" # Handle single DataPoint if isinstance(field_value, DataPoint): - return "single_datapoint", field_value, None + return [(None, [field_value])] - # Handle list of DataPoints + # Handle list - could contain DataPoints, edge tuples, or mixed + if isinstance(field_value, list) and len(field_value) > 0: + result = [] + for item in field_value: + # Handle tuple[Edge, DataPoint or list[DataPoint]] + if isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Edge): + edge, data_value = item + if isinstance(data_value, DataPoint): + result.append((edge, [data_value])) + elif ( + isinstance(data_value, list) + and len(data_value) > 0 + and isinstance(data_value[0], DataPoint) + ): + result.append((edge, data_value)) + # Handle single DataPoint in list + elif isinstance(item, DataPoint): + result.append((None, [item])) + return result + + # Handle tuple[Edge, DataPoint or list[DataPoint]] if ( - isinstance(field_value, list) - and len(field_value) > 0 - and isinstance(field_value[0], DataPoint) + isinstance(field_value, tuple) + and len(field_value) == 2 + and isinstance(field_value[0], Edge) ): - return "list_datapoint", field_value, None + edge_metadata, data_value = field_value + if isinstance(data_value, DataPoint): + return [(edge_metadata, [data_value])] + elif ( + isinstance(data_value, list) + and len(data_value) > 0 + and isinstance(data_value[0], DataPoint) + ): + return [(edge_metadata, data_value)] - # Regular property - return "property", field_value, None + # Regular property or empty list + return [] def _create_edge_properties( @@ -80,30 +87,49 @@ def _get_relationship_key(field_name: str, edge_metadata: Optional[Edge]) -> str def _generate_property_key(data_point_id: str, relationship_key: str, target_id: str) -> str: """Generate a unique property key for visited_properties tracking.""" - return f"{data_point_id}{relationship_key}{target_id}" + return f"{data_point_id}_{relationship_key}_{target_id}" def _process_datapoint_field( data_point_id: str, field_name: str, - datapoints: List[DataPoint], - edge_metadata: Optional[Edge], + edge_datapoint_pairs: List[Tuple[Optional[Edge], List[DataPoint]]], visited_properties: Dict[str, bool], properties_to_visit: set, excluded_properties: set, ) -> None: - """Process a field containing DataPoint(s), handling both single and list cases.""" + """Process a field containing DataPoints, always working with lists.""" excluded_properties.add(field_name) - relationship_key = _get_relationship_key(field_name, edge_metadata) - for index, datapoint in enumerate(datapoints): - property_key = _generate_property_key(data_point_id, relationship_key, str(datapoint.id)) - if property_key in visited_properties: + for edge_metadata, datapoints in edge_datapoint_pairs: + relationship_key = _get_relationship_key(field_name, edge_metadata) + + for datapoint in datapoints: + property_key = _generate_property_key( + data_point_id, relationship_key, str(datapoint.id) + ) + if property_key in visited_properties: + continue + + # Always use field_name since we're working with lists + properties_to_visit.add(field_name) + + +def _targets_generator( + data_point: DataPoint, + properties_to_visit: set, +) -> Tuple[DataPoint, str, Optional[Edge]]: + """Generator that yields (target_datapoint, field_name, edge_metadata) tuples.""" + for field_name in properties_to_visit: + field_value = getattr(data_point, field_name) + edge_datapoint_pairs = _extract_field_data(field_value) + + if not edge_datapoint_pairs: continue - # For single datapoint, use field_name; for list, use field_name.index - field_identifier = field_name if len(datapoints) == 1 else f"{field_name}.{index}" - properties_to_visit.add(field_identifier) + for edge_metadata, datapoints in edge_datapoint_pairs: + for target_datapoint in datapoints: + yield target_datapoint, field_name, edge_metadata async def get_graph_from_model( @@ -143,26 +169,17 @@ async def get_graph_from_model( if field_name == "metadata": continue - field_type, actual_value, edge_metadata = _extract_field_info(field_value) + edge_datapoint_pairs = _extract_field_data(field_value) - if field_type == "property": + if not edge_datapoint_pairs: + # Regular property data_point_properties[field_name] = field_value - elif field_type in ["single_datapoint", "single_datapoint_with_edge"]: + else: + # DataPoint relationship _process_datapoint_field( data_point_id, field_name, - [actual_value], - edge_metadata, - visited_properties, - properties_to_visit, - excluded_properties, - ) - elif field_type in ["list_datapoint", "list_datapoint_with_edge"]: - _process_datapoint_field( - data_point_id, - field_name, - actual_value, - edge_metadata, + edge_datapoint_pairs, visited_properties, properties_to_visit, excluded_properties, @@ -176,41 +193,15 @@ async def get_graph_from_model( nodes.append(SimpleDataPointModel(**data_point_properties)) added_nodes[data_point_id] = True - # Process all relationships - for field_name_with_index in properties_to_visit: - # Parse field name and index - if "." in field_name_with_index: - field_name, index_str = field_name_with_index.split(".") - index = int(index_str) - else: - field_name, index = field_name_with_index, None - - # Get field value and extract edge metadata - field_value = getattr(data_point, field_name) - edge_metadata = None - - if ( - isinstance(field_value, tuple) - and len(field_value) == 2 - and isinstance(field_value[0], Edge) - ): - edge_metadata, field_value = field_value - - # Get specific datapoint - handle both single and list cases - if index is not None: - # List case: extract specific item by index - target_datapoint = field_value[index] - elif isinstance(field_value, list): - # Single datapoint case that was wrapped in a list - target_datapoint = field_value[0] - else: - # True single datapoint case - target_datapoint = field_value + # Process all relationships using generator + for target_datapoint, field_name, edge_metadata in _targets_generator( + data_point, properties_to_visit + ): + relationship_name = _get_relationship_key(field_name, edge_metadata) # Create edge if not already added - edge_key = f"{data_point_id}{target_datapoint.id}{field_name}" + edge_key = f"{data_point_id}_{target_datapoint.id}_{field_name}" if edge_key not in added_edges: - relationship_name = _get_relationship_key(field_name, edge_metadata) edge_properties = _create_edge_properties( data_point.id, target_datapoint.id, relationship_name, edge_metadata ) @@ -218,23 +209,24 @@ async def get_graph_from_model( added_edges[edge_key] = True # Mark property as visited - CRITICAL for preventing infinite loops - relationship_key = _get_relationship_key(field_name, edge_metadata) property_key = _generate_property_key( - data_point_id, relationship_key, str(target_datapoint.id) + data_point_id, relationship_name, str(target_datapoint.id) ) visited_properties[property_key] = True # Recursively process target node if not already processed - if str(target_datapoint.id) not in added_nodes: - child_nodes, child_edges = await get_graph_from_model( - target_datapoint, - include_root=True, - added_nodes=added_nodes, - added_edges=added_edges, - visited_properties=visited_properties, - ) - nodes.extend(child_nodes) - edges.extend(child_edges) + if str(target_datapoint.id) in added_nodes: + continue + + child_nodes, child_edges = await get_graph_from_model( + target_datapoint, + include_root=True, + added_nodes=added_nodes, + added_edges=added_edges, + visited_properties=visited_properties, + ) + nodes.extend(child_nodes) + edges.extend(child_edges) return nodes, edges diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py index 99bb66ccf..d8cfa4782 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py @@ -1,6 +1,6 @@ import pytest -from typing import List -from cognee.infrastructure.engine import DataPoint +from typing import List, Any +from cognee.infrastructure.engine import DataPoint, Edge from cognee.modules.graph.utils import get_graph_from_model @@ -28,7 +28,20 @@ class Entity(DataPoint): metadata: dict = {"index_fields": ["name"]} +class Company(DataPoint): + name: str + employees: List[Any] = None # Allow flexible edge system with tuples + metadata: dict = {"index_fields": ["name"]} + + +class Employee(DataPoint): + name: str + role: str + metadata: dict = {"index_fields": ["name"]} + + DocumentChunk.model_rebuild() +Company.model_rebuild() @pytest.mark.asyncio @@ -50,7 +63,7 @@ async def test_get_graph_from_model_simple_structure(): assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}" assert len(edges) == 1, f"Expected 1 edges, got {len(edges)}" - edge_key = str(entity.id) + str(entitytype.id) + "is_type" + edge_key = f"{str(entity.id)}_{str(entitytype.id)}_is_type" assert edge_key in added_edges, f"Edge {edge_key} not found" @@ -149,3 +162,48 @@ async def test_get_graph_from_model_no_contains(): assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}" assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}" + + +@pytest.mark.asyncio +async def test_get_graph_from_model_flexible_edges(): + """Tests the new flexible edge system with mixed relationships""" + # Create employees + manager = Employee(name="Manager", role="Manager") + sales1 = Employee(name="Sales1", role="Sales") + sales2 = Employee(name="Sales2", role="Sales") + admin1 = Employee(name="Admin1", role="Admin") + admin2 = Employee(name="Admin2", role="Admin") + + # Create company with mixed employee relationships + company = Company( + name="Test Company", + employees=[ + # Weighted relationship + (Edge(weight=0.9, relationship_type="manages"), manager), + # Multiple weights relationship + ( + Edge(weights={"performance": 0.8, "experience": 0.7}, relationship_type="employs"), + sales1, + ), + # Simple relationship + sales2, + # Group relationship + (Edge(weights={"team_efficiency": 0.8}, relationship_type="employs"), [admin1, admin2]), + ], + ) + + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model(company, added_nodes, added_edges, visited_properties) + + # Should have 6 nodes: company + 5 employees + assert len(nodes) == 6, f"Expected 6 nodes, got {len(nodes)}" + # Should have 5 edges: 4 employee relationships + assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}" + + # Verify all employees are connected + employee_ids = {str(emp.id) for emp in [manager, sales1, sales2, admin1, admin2]} + edge_target_ids = {str(edge[1]) for edge in edges} + assert employee_ids.issubset(edge_target_ids), "Not all employees are connected" diff --git a/examples/python/dynamic_multiple_edges_example.py b/examples/python/dynamic_multiple_edges_example.py new file mode 100644 index 000000000..4ae538122 --- /dev/null +++ b/examples/python/dynamic_multiple_edges_example.py @@ -0,0 +1,115 @@ +import asyncio +from os import path +from typing import Any +from pydantic import SkipValidation +from cognee.api.v1.visualize.visualize import visualize_graph +from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine.models.Edge import Edge +from cognee.tasks.storage import add_data_points +import cognee + + +class Employee(DataPoint): + name: str + role: str + + +class Company(DataPoint): + name: str + industry: str + employs: SkipValidation[Any] # Mixed list: employees with/without weights + + +async def main(): + # Clear the database for a clean state + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + # Create employees + michael = Employee(name="Michael", role="Regional Manager") + dwight = Employee(name="Dwight", role="Assistant to the Regional Manager") + jim = Employee(name="Jim", role="Sales Representative") + pam = Employee(name="Pam", role="Receptionist") + kevin = Employee(name="Kevin", role="Accountant") + angela = Employee(name="Angela", role="Senior Accountant") + oscar = Employee(name="Oscar", role="Accountant") + stanley = Employee(name="Stanley", role="Sales Representative") + phyllis = Employee(name="Phyllis", role="Sales Representative") + + # Create Dunder Mifflin with mixed employee relationships + dunder_mifflin = Company( + name="Dunder Mifflin Paper Company", + industry="Paper Sales", + employs=[ + # Manager with high authority weight + (Edge(weight=0.9, relationship_type="manager"), michael), + # Sales team with performance weights + ( + Edge(weights={"sales_performance": 0.8, "loyalty": 0.9}, relationship_type="sales"), + dwight, + ), + ( + Edge( + weights={"sales_performance": 0.7, "creativity": 0.8}, relationship_type="sales" + ), + jim, + ), + ( + Edge( + weights={"sales_performance": 0.6, "customer_service": 0.9}, + relationship_type="sales", + ), + phyllis, + ), + ( + Edge( + weights={"sales_performance": 0.5, "experience": 0.8}, relationship_type="sales" + ), + stanley, + ), + # Accounting department as a group + ( + Edge( + weights={"department_efficiency": 0.8, "team_cohesion": 0.9}, + relationship_type="accounting", + ), + [oscar, kevin, angela], + ), + # Admin staff without weights (simple relationships) + pam, + ], + ) + + all_data_points = [ + michael, + dwight, + jim, + pam, + kevin, + angela, + oscar, + stanley, + phyllis, + dunder_mifflin, + ] + + # Add data points to the graph + await add_data_points(all_data_points) + + # Visualize the graph + graph_visualization_path = path.join(path.dirname(__file__), "dunder_mifflin_graph.html") + await visualize_graph(graph_visualization_path) + + print("Dynamic multiple edges graph has been created and visualized!") + print(f"Visualization saved to: {graph_visualization_path}") + print("\nTechnical features demonstrated:") + print("- Mixed list support: weighted and unweighted relationships in single field") + print("- Single weight edges with relationship types") + print("- Multiple weight edges with custom metrics") + print("- Group relationships: single edge connecting multiple nodes") + print("- Simple relationships without edge metadata") + print("- Flexible edge extraction from heterogeneous data structures") + + +if __name__ == "__main__": + asyncio.run(main())