feat: dynamic multiple edges in datapoints (#1212)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->
- Improved list handling, removed `.index` logic from
`get_graph_from_model`, transitioned to fully datapoint-oriented
processing
- Streamlined datapoint iteration by introducing `_datapoints_generator`
with nested loops
- Generalized field processing to handle mixed lists: `[DataPoint,
(Edge, DataPoint), (Edge, [DataPoint])]`, allowing dynamic multiple
edges generation
- Small improvements and refactorings
- Added tests to `test_get_graph_from_model_flexible_edges()` covering
weighted edges and dynamic multiple edges
- Created `dynamic_multiple_edges_example.py` demonstrating dynamic
multiple edges

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
lxobr 2025-08-07 14:50:45 +02:00 committed by GitHub
parent c8202c51a7
commit 6dbd8e85a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 269 additions and 104 deletions

View file

@ -4,43 +4,50 @@ from cognee.infrastructure.engine import DataPoint, Edge
from cognee.modules.storage.utils import copy_model
def _extract_field_info(field_value: Any) -> Tuple[str, Any, Optional[Edge]]:
"""Extract field type, actual value, and edge metadata from a field value."""
# Handle tuple[Edge, DataPoint]
if (
isinstance(field_value, tuple)
and len(field_value) == 2
and isinstance(field_value[0], Edge)
and isinstance(field_value[1], DataPoint)
):
return "single_datapoint_with_edge", field_value[1], field_value[0]
# Handle tuple[Edge, list[DataPoint]]
if (
isinstance(field_value, tuple)
and len(field_value) == 2
and isinstance(field_value[0], Edge)
and isinstance(field_value[1], list)
and len(field_value[1]) > 0
and isinstance(field_value[1][0], DataPoint)
):
return "list_datapoint_with_edge", field_value[1], field_value[0]
def _extract_field_data(field_value: Any) -> List[Tuple[Optional[Edge], List[DataPoint]]]:
"""Extract edge metadata and datapoints from a field value."""
# Handle single DataPoint
if isinstance(field_value, DataPoint):
return "single_datapoint", field_value, None
return [(None, [field_value])]
# Handle list of DataPoints
# Handle list - could contain DataPoints, edge tuples, or mixed
if isinstance(field_value, list) and len(field_value) > 0:
result = []
for item in field_value:
# Handle tuple[Edge, DataPoint or list[DataPoint]]
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Edge):
edge, data_value = item
if isinstance(data_value, DataPoint):
result.append((edge, [data_value]))
elif (
isinstance(data_value, list)
and len(data_value) > 0
and isinstance(data_value[0], DataPoint)
):
result.append((edge, data_value))
# Handle single DataPoint in list
elif isinstance(item, DataPoint):
result.append((None, [item]))
return result
# Handle tuple[Edge, DataPoint or list[DataPoint]]
if (
isinstance(field_value, list)
and len(field_value) > 0
and isinstance(field_value[0], DataPoint)
isinstance(field_value, tuple)
and len(field_value) == 2
and isinstance(field_value[0], Edge)
):
return "list_datapoint", field_value, None
edge_metadata, data_value = field_value
if isinstance(data_value, DataPoint):
return [(edge_metadata, [data_value])]
elif (
isinstance(data_value, list)
and len(data_value) > 0
and isinstance(data_value[0], DataPoint)
):
return [(edge_metadata, data_value)]
# Regular property
return "property", field_value, None
# Regular property or empty list
return []
def _create_edge_properties(
@ -80,30 +87,49 @@ def _get_relationship_key(field_name: str, edge_metadata: Optional[Edge]) -> str
def _generate_property_key(data_point_id: str, relationship_key: str, target_id: str) -> str:
"""Generate a unique property key for visited_properties tracking."""
return f"{data_point_id}{relationship_key}{target_id}"
return f"{data_point_id}_{relationship_key}_{target_id}"
def _process_datapoint_field(
data_point_id: str,
field_name: str,
datapoints: List[DataPoint],
edge_metadata: Optional[Edge],
edge_datapoint_pairs: List[Tuple[Optional[Edge], List[DataPoint]]],
visited_properties: Dict[str, bool],
properties_to_visit: set,
excluded_properties: set,
) -> None:
"""Process a field containing DataPoint(s), handling both single and list cases."""
"""Process a field containing DataPoints, always working with lists."""
excluded_properties.add(field_name)
relationship_key = _get_relationship_key(field_name, edge_metadata)
for index, datapoint in enumerate(datapoints):
property_key = _generate_property_key(data_point_id, relationship_key, str(datapoint.id))
if property_key in visited_properties:
for edge_metadata, datapoints in edge_datapoint_pairs:
relationship_key = _get_relationship_key(field_name, edge_metadata)
for datapoint in datapoints:
property_key = _generate_property_key(
data_point_id, relationship_key, str(datapoint.id)
)
if property_key in visited_properties:
continue
# Always use field_name since we're working with lists
properties_to_visit.add(field_name)
def _targets_generator(
data_point: DataPoint,
properties_to_visit: set,
) -> Tuple[DataPoint, str, Optional[Edge]]:
"""Generator that yields (target_datapoint, field_name, edge_metadata) tuples."""
for field_name in properties_to_visit:
field_value = getattr(data_point, field_name)
edge_datapoint_pairs = _extract_field_data(field_value)
if not edge_datapoint_pairs:
continue
# For single datapoint, use field_name; for list, use field_name.index
field_identifier = field_name if len(datapoints) == 1 else f"{field_name}.{index}"
properties_to_visit.add(field_identifier)
for edge_metadata, datapoints in edge_datapoint_pairs:
for target_datapoint in datapoints:
yield target_datapoint, field_name, edge_metadata
async def get_graph_from_model(
@ -143,26 +169,17 @@ async def get_graph_from_model(
if field_name == "metadata":
continue
field_type, actual_value, edge_metadata = _extract_field_info(field_value)
edge_datapoint_pairs = _extract_field_data(field_value)
if field_type == "property":
if not edge_datapoint_pairs:
# Regular property
data_point_properties[field_name] = field_value
elif field_type in ["single_datapoint", "single_datapoint_with_edge"]:
else:
# DataPoint relationship
_process_datapoint_field(
data_point_id,
field_name,
[actual_value],
edge_metadata,
visited_properties,
properties_to_visit,
excluded_properties,
)
elif field_type in ["list_datapoint", "list_datapoint_with_edge"]:
_process_datapoint_field(
data_point_id,
field_name,
actual_value,
edge_metadata,
edge_datapoint_pairs,
visited_properties,
properties_to_visit,
excluded_properties,
@ -176,41 +193,15 @@ async def get_graph_from_model(
nodes.append(SimpleDataPointModel(**data_point_properties))
added_nodes[data_point_id] = True
# Process all relationships
for field_name_with_index in properties_to_visit:
# Parse field name and index
if "." in field_name_with_index:
field_name, index_str = field_name_with_index.split(".")
index = int(index_str)
else:
field_name, index = field_name_with_index, None
# Get field value and extract edge metadata
field_value = getattr(data_point, field_name)
edge_metadata = None
if (
isinstance(field_value, tuple)
and len(field_value) == 2
and isinstance(field_value[0], Edge)
):
edge_metadata, field_value = field_value
# Get specific datapoint - handle both single and list cases
if index is not None:
# List case: extract specific item by index
target_datapoint = field_value[index]
elif isinstance(field_value, list):
# Single datapoint case that was wrapped in a list
target_datapoint = field_value[0]
else:
# True single datapoint case
target_datapoint = field_value
# Process all relationships using generator
for target_datapoint, field_name, edge_metadata in _targets_generator(
data_point, properties_to_visit
):
relationship_name = _get_relationship_key(field_name, edge_metadata)
# Create edge if not already added
edge_key = f"{data_point_id}{target_datapoint.id}{field_name}"
edge_key = f"{data_point_id}_{target_datapoint.id}_{field_name}"
if edge_key not in added_edges:
relationship_name = _get_relationship_key(field_name, edge_metadata)
edge_properties = _create_edge_properties(
data_point.id, target_datapoint.id, relationship_name, edge_metadata
)
@ -218,23 +209,24 @@ async def get_graph_from_model(
added_edges[edge_key] = True
# Mark property as visited - CRITICAL for preventing infinite loops
relationship_key = _get_relationship_key(field_name, edge_metadata)
property_key = _generate_property_key(
data_point_id, relationship_key, str(target_datapoint.id)
data_point_id, relationship_name, str(target_datapoint.id)
)
visited_properties[property_key] = True
# Recursively process target node if not already processed
if str(target_datapoint.id) not in added_nodes:
child_nodes, child_edges = await get_graph_from_model(
target_datapoint,
include_root=True,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
nodes.extend(child_nodes)
edges.extend(child_edges)
if str(target_datapoint.id) in added_nodes:
continue
child_nodes, child_edges = await get_graph_from_model(
target_datapoint,
include_root=True,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
nodes.extend(child_nodes)
edges.extend(child_edges)
return nodes, edges

View file

@ -1,6 +1,6 @@
import pytest
from typing import List
from cognee.infrastructure.engine import DataPoint
from typing import List, Any
from cognee.infrastructure.engine import DataPoint, Edge
from cognee.modules.graph.utils import get_graph_from_model
@ -28,7 +28,20 @@ class Entity(DataPoint):
metadata: dict = {"index_fields": ["name"]}
class Company(DataPoint):
name: str
employees: List[Any] = None # Allow flexible edge system with tuples
metadata: dict = {"index_fields": ["name"]}
class Employee(DataPoint):
name: str
role: str
metadata: dict = {"index_fields": ["name"]}
DocumentChunk.model_rebuild()
Company.model_rebuild()
@pytest.mark.asyncio
@ -50,7 +63,7 @@ async def test_get_graph_from_model_simple_structure():
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
assert len(edges) == 1, f"Expected 1 edges, got {len(edges)}"
edge_key = str(entity.id) + str(entitytype.id) + "is_type"
edge_key = f"{str(entity.id)}_{str(entitytype.id)}_is_type"
assert edge_key in added_edges, f"Edge {edge_key} not found"
@ -149,3 +162,48 @@ async def test_get_graph_from_model_no_contains():
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}"
@pytest.mark.asyncio
async def test_get_graph_from_model_flexible_edges():
"""Tests the new flexible edge system with mixed relationships"""
# Create employees
manager = Employee(name="Manager", role="Manager")
sales1 = Employee(name="Sales1", role="Sales")
sales2 = Employee(name="Sales2", role="Sales")
admin1 = Employee(name="Admin1", role="Admin")
admin2 = Employee(name="Admin2", role="Admin")
# Create company with mixed employee relationships
company = Company(
name="Test Company",
employees=[
# Weighted relationship
(Edge(weight=0.9, relationship_type="manages"), manager),
# Multiple weights relationship
(
Edge(weights={"performance": 0.8, "experience": 0.7}, relationship_type="employs"),
sales1,
),
# Simple relationship
sales2,
# Group relationship
(Edge(weights={"team_efficiency": 0.8}, relationship_type="employs"), [admin1, admin2]),
],
)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(company, added_nodes, added_edges, visited_properties)
# Should have 6 nodes: company + 5 employees
assert len(nodes) == 6, f"Expected 6 nodes, got {len(nodes)}"
# Should have 5 edges: 4 employee relationships
assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}"
# Verify all employees are connected
employee_ids = {str(emp.id) for emp in [manager, sales1, sales2, admin1, admin2]}
edge_target_ids = {str(edge[1]) for edge in edges}
assert employee_ids.issubset(edge_target_ids), "Not all employees are connected"

View file

@ -0,0 +1,115 @@
import asyncio
from os import path
from typing import Any
from pydantic import SkipValidation
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.tasks.storage import add_data_points
import cognee
class Employee(DataPoint):
name: str
role: str
class Company(DataPoint):
name: str
industry: str
employs: SkipValidation[Any] # Mixed list: employees with/without weights
async def main():
# Clear the database for a clean state
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Create employees
michael = Employee(name="Michael", role="Regional Manager")
dwight = Employee(name="Dwight", role="Assistant to the Regional Manager")
jim = Employee(name="Jim", role="Sales Representative")
pam = Employee(name="Pam", role="Receptionist")
kevin = Employee(name="Kevin", role="Accountant")
angela = Employee(name="Angela", role="Senior Accountant")
oscar = Employee(name="Oscar", role="Accountant")
stanley = Employee(name="Stanley", role="Sales Representative")
phyllis = Employee(name="Phyllis", role="Sales Representative")
# Create Dunder Mifflin with mixed employee relationships
dunder_mifflin = Company(
name="Dunder Mifflin Paper Company",
industry="Paper Sales",
employs=[
# Manager with high authority weight
(Edge(weight=0.9, relationship_type="manager"), michael),
# Sales team with performance weights
(
Edge(weights={"sales_performance": 0.8, "loyalty": 0.9}, relationship_type="sales"),
dwight,
),
(
Edge(
weights={"sales_performance": 0.7, "creativity": 0.8}, relationship_type="sales"
),
jim,
),
(
Edge(
weights={"sales_performance": 0.6, "customer_service": 0.9},
relationship_type="sales",
),
phyllis,
),
(
Edge(
weights={"sales_performance": 0.5, "experience": 0.8}, relationship_type="sales"
),
stanley,
),
# Accounting department as a group
(
Edge(
weights={"department_efficiency": 0.8, "team_cohesion": 0.9},
relationship_type="accounting",
),
[oscar, kevin, angela],
),
# Admin staff without weights (simple relationships)
pam,
],
)
all_data_points = [
michael,
dwight,
jim,
pam,
kevin,
angela,
oscar,
stanley,
phyllis,
dunder_mifflin,
]
# Add data points to the graph
await add_data_points(all_data_points)
# Visualize the graph
graph_visualization_path = path.join(path.dirname(__file__), "dunder_mifflin_graph.html")
await visualize_graph(graph_visualization_path)
print("Dynamic multiple edges graph has been created and visualized!")
print(f"Visualization saved to: {graph_visualization_path}")
print("\nTechnical features demonstrated:")
print("- Mixed list support: weighted and unweighted relationships in single field")
print("- Single weight edges with relationship types")
print("- Multiple weight edges with custom metrics")
print("- Group relationships: single edge connecting multiple nodes")
print("- Simple relationships without edge metadata")
print("- Flexible edge extraction from heterogeneous data structures")
if __name__ == "__main__":
asyncio.run(main())