feat: dynamic multiple edges in datapoints (#1212)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> - Improved list handling, removed `.index` logic from `get_graph_from_model`, transitioned to fully datapoint-oriented processing - Streamlined datapoint iteration by introducing `_datapoints_generator` with nested loops - Generalized field processing to handle mixed lists: `[DataPoint, (Edge, DataPoint), (Edge, [DataPoint])]`, allowing dynamic multiple edges generation - Small improvements and refactorings - Added tests to `test_get_graph_from_model_flexible_edges()` covering weighted edges and dynamic multiple edges - Created `dynamic_multiple_edges_example.py` demonstrating dynamic multiple edges ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
c8202c51a7
commit
6dbd8e85a1
3 changed files with 269 additions and 104 deletions
|
|
@ -4,43 +4,50 @@ from cognee.infrastructure.engine import DataPoint, Edge
|
|||
from cognee.modules.storage.utils import copy_model
|
||||
|
||||
|
||||
def _extract_field_info(field_value: Any) -> Tuple[str, Any, Optional[Edge]]:
|
||||
"""Extract field type, actual value, and edge metadata from a field value."""
|
||||
|
||||
# Handle tuple[Edge, DataPoint]
|
||||
if (
|
||||
isinstance(field_value, tuple)
|
||||
and len(field_value) == 2
|
||||
and isinstance(field_value[0], Edge)
|
||||
and isinstance(field_value[1], DataPoint)
|
||||
):
|
||||
return "single_datapoint_with_edge", field_value[1], field_value[0]
|
||||
|
||||
# Handle tuple[Edge, list[DataPoint]]
|
||||
if (
|
||||
isinstance(field_value, tuple)
|
||||
and len(field_value) == 2
|
||||
and isinstance(field_value[0], Edge)
|
||||
and isinstance(field_value[1], list)
|
||||
and len(field_value[1]) > 0
|
||||
and isinstance(field_value[1][0], DataPoint)
|
||||
):
|
||||
return "list_datapoint_with_edge", field_value[1], field_value[0]
|
||||
|
||||
def _extract_field_data(field_value: Any) -> List[Tuple[Optional[Edge], List[DataPoint]]]:
|
||||
"""Extract edge metadata and datapoints from a field value."""
|
||||
# Handle single DataPoint
|
||||
if isinstance(field_value, DataPoint):
|
||||
return "single_datapoint", field_value, None
|
||||
return [(None, [field_value])]
|
||||
|
||||
# Handle list of DataPoints
|
||||
# Handle list - could contain DataPoints, edge tuples, or mixed
|
||||
if isinstance(field_value, list) and len(field_value) > 0:
|
||||
result = []
|
||||
for item in field_value:
|
||||
# Handle tuple[Edge, DataPoint or list[DataPoint]]
|
||||
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Edge):
|
||||
edge, data_value = item
|
||||
if isinstance(data_value, DataPoint):
|
||||
result.append((edge, [data_value]))
|
||||
elif (
|
||||
isinstance(data_value, list)
|
||||
and len(data_value) > 0
|
||||
and isinstance(data_value[0], DataPoint)
|
||||
):
|
||||
result.append((edge, data_value))
|
||||
# Handle single DataPoint in list
|
||||
elif isinstance(item, DataPoint):
|
||||
result.append((None, [item]))
|
||||
return result
|
||||
|
||||
# Handle tuple[Edge, DataPoint or list[DataPoint]]
|
||||
if (
|
||||
isinstance(field_value, list)
|
||||
and len(field_value) > 0
|
||||
and isinstance(field_value[0], DataPoint)
|
||||
isinstance(field_value, tuple)
|
||||
and len(field_value) == 2
|
||||
and isinstance(field_value[0], Edge)
|
||||
):
|
||||
return "list_datapoint", field_value, None
|
||||
edge_metadata, data_value = field_value
|
||||
if isinstance(data_value, DataPoint):
|
||||
return [(edge_metadata, [data_value])]
|
||||
elif (
|
||||
isinstance(data_value, list)
|
||||
and len(data_value) > 0
|
||||
and isinstance(data_value[0], DataPoint)
|
||||
):
|
||||
return [(edge_metadata, data_value)]
|
||||
|
||||
# Regular property
|
||||
return "property", field_value, None
|
||||
# Regular property or empty list
|
||||
return []
|
||||
|
||||
|
||||
def _create_edge_properties(
|
||||
|
|
@ -80,30 +87,49 @@ def _get_relationship_key(field_name: str, edge_metadata: Optional[Edge]) -> str
|
|||
|
||||
def _generate_property_key(data_point_id: str, relationship_key: str, target_id: str) -> str:
|
||||
"""Generate a unique property key for visited_properties tracking."""
|
||||
return f"{data_point_id}{relationship_key}{target_id}"
|
||||
return f"{data_point_id}_{relationship_key}_{target_id}"
|
||||
|
||||
|
||||
def _process_datapoint_field(
|
||||
data_point_id: str,
|
||||
field_name: str,
|
||||
datapoints: List[DataPoint],
|
||||
edge_metadata: Optional[Edge],
|
||||
edge_datapoint_pairs: List[Tuple[Optional[Edge], List[DataPoint]]],
|
||||
visited_properties: Dict[str, bool],
|
||||
properties_to_visit: set,
|
||||
excluded_properties: set,
|
||||
) -> None:
|
||||
"""Process a field containing DataPoint(s), handling both single and list cases."""
|
||||
"""Process a field containing DataPoints, always working with lists."""
|
||||
excluded_properties.add(field_name)
|
||||
relationship_key = _get_relationship_key(field_name, edge_metadata)
|
||||
|
||||
for index, datapoint in enumerate(datapoints):
|
||||
property_key = _generate_property_key(data_point_id, relationship_key, str(datapoint.id))
|
||||
if property_key in visited_properties:
|
||||
for edge_metadata, datapoints in edge_datapoint_pairs:
|
||||
relationship_key = _get_relationship_key(field_name, edge_metadata)
|
||||
|
||||
for datapoint in datapoints:
|
||||
property_key = _generate_property_key(
|
||||
data_point_id, relationship_key, str(datapoint.id)
|
||||
)
|
||||
if property_key in visited_properties:
|
||||
continue
|
||||
|
||||
# Always use field_name since we're working with lists
|
||||
properties_to_visit.add(field_name)
|
||||
|
||||
|
||||
def _targets_generator(
|
||||
data_point: DataPoint,
|
||||
properties_to_visit: set,
|
||||
) -> Tuple[DataPoint, str, Optional[Edge]]:
|
||||
"""Generator that yields (target_datapoint, field_name, edge_metadata) tuples."""
|
||||
for field_name in properties_to_visit:
|
||||
field_value = getattr(data_point, field_name)
|
||||
edge_datapoint_pairs = _extract_field_data(field_value)
|
||||
|
||||
if not edge_datapoint_pairs:
|
||||
continue
|
||||
|
||||
# For single datapoint, use field_name; for list, use field_name.index
|
||||
field_identifier = field_name if len(datapoints) == 1 else f"{field_name}.{index}"
|
||||
properties_to_visit.add(field_identifier)
|
||||
for edge_metadata, datapoints in edge_datapoint_pairs:
|
||||
for target_datapoint in datapoints:
|
||||
yield target_datapoint, field_name, edge_metadata
|
||||
|
||||
|
||||
async def get_graph_from_model(
|
||||
|
|
@ -143,26 +169,17 @@ async def get_graph_from_model(
|
|||
if field_name == "metadata":
|
||||
continue
|
||||
|
||||
field_type, actual_value, edge_metadata = _extract_field_info(field_value)
|
||||
edge_datapoint_pairs = _extract_field_data(field_value)
|
||||
|
||||
if field_type == "property":
|
||||
if not edge_datapoint_pairs:
|
||||
# Regular property
|
||||
data_point_properties[field_name] = field_value
|
||||
elif field_type in ["single_datapoint", "single_datapoint_with_edge"]:
|
||||
else:
|
||||
# DataPoint relationship
|
||||
_process_datapoint_field(
|
||||
data_point_id,
|
||||
field_name,
|
||||
[actual_value],
|
||||
edge_metadata,
|
||||
visited_properties,
|
||||
properties_to_visit,
|
||||
excluded_properties,
|
||||
)
|
||||
elif field_type in ["list_datapoint", "list_datapoint_with_edge"]:
|
||||
_process_datapoint_field(
|
||||
data_point_id,
|
||||
field_name,
|
||||
actual_value,
|
||||
edge_metadata,
|
||||
edge_datapoint_pairs,
|
||||
visited_properties,
|
||||
properties_to_visit,
|
||||
excluded_properties,
|
||||
|
|
@ -176,41 +193,15 @@ async def get_graph_from_model(
|
|||
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||
added_nodes[data_point_id] = True
|
||||
|
||||
# Process all relationships
|
||||
for field_name_with_index in properties_to_visit:
|
||||
# Parse field name and index
|
||||
if "." in field_name_with_index:
|
||||
field_name, index_str = field_name_with_index.split(".")
|
||||
index = int(index_str)
|
||||
else:
|
||||
field_name, index = field_name_with_index, None
|
||||
|
||||
# Get field value and extract edge metadata
|
||||
field_value = getattr(data_point, field_name)
|
||||
edge_metadata = None
|
||||
|
||||
if (
|
||||
isinstance(field_value, tuple)
|
||||
and len(field_value) == 2
|
||||
and isinstance(field_value[0], Edge)
|
||||
):
|
||||
edge_metadata, field_value = field_value
|
||||
|
||||
# Get specific datapoint - handle both single and list cases
|
||||
if index is not None:
|
||||
# List case: extract specific item by index
|
||||
target_datapoint = field_value[index]
|
||||
elif isinstance(field_value, list):
|
||||
# Single datapoint case that was wrapped in a list
|
||||
target_datapoint = field_value[0]
|
||||
else:
|
||||
# True single datapoint case
|
||||
target_datapoint = field_value
|
||||
# Process all relationships using generator
|
||||
for target_datapoint, field_name, edge_metadata in _targets_generator(
|
||||
data_point, properties_to_visit
|
||||
):
|
||||
relationship_name = _get_relationship_key(field_name, edge_metadata)
|
||||
|
||||
# Create edge if not already added
|
||||
edge_key = f"{data_point_id}{target_datapoint.id}{field_name}"
|
||||
edge_key = f"{data_point_id}_{target_datapoint.id}_{field_name}"
|
||||
if edge_key not in added_edges:
|
||||
relationship_name = _get_relationship_key(field_name, edge_metadata)
|
||||
edge_properties = _create_edge_properties(
|
||||
data_point.id, target_datapoint.id, relationship_name, edge_metadata
|
||||
)
|
||||
|
|
@ -218,23 +209,24 @@ async def get_graph_from_model(
|
|||
added_edges[edge_key] = True
|
||||
|
||||
# Mark property as visited - CRITICAL for preventing infinite loops
|
||||
relationship_key = _get_relationship_key(field_name, edge_metadata)
|
||||
property_key = _generate_property_key(
|
||||
data_point_id, relationship_key, str(target_datapoint.id)
|
||||
data_point_id, relationship_name, str(target_datapoint.id)
|
||||
)
|
||||
visited_properties[property_key] = True
|
||||
|
||||
# Recursively process target node if not already processed
|
||||
if str(target_datapoint.id) not in added_nodes:
|
||||
child_nodes, child_edges = await get_graph_from_model(
|
||||
target_datapoint,
|
||||
include_root=True,
|
||||
added_nodes=added_nodes,
|
||||
added_edges=added_edges,
|
||||
visited_properties=visited_properties,
|
||||
)
|
||||
nodes.extend(child_nodes)
|
||||
edges.extend(child_edges)
|
||||
if str(target_datapoint.id) in added_nodes:
|
||||
continue
|
||||
|
||||
child_nodes, child_edges = await get_graph_from_model(
|
||||
target_datapoint,
|
||||
include_root=True,
|
||||
added_nodes=added_nodes,
|
||||
added_edges=added_edges,
|
||||
visited_properties=visited_properties,
|
||||
)
|
||||
nodes.extend(child_nodes)
|
||||
edges.extend(child_edges)
|
||||
|
||||
return nodes, edges
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
from typing import List
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from typing import List, Any
|
||||
from cognee.infrastructure.engine import DataPoint, Edge
|
||||
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
||||
|
|
@ -28,7 +28,20 @@ class Entity(DataPoint):
|
|||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
employees: List[Any] = None # Allow flexible edge system with tuples
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class Employee(DataPoint):
|
||||
name: str
|
||||
role: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
DocumentChunk.model_rebuild()
|
||||
Company.model_rebuild()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -50,7 +63,7 @@ async def test_get_graph_from_model_simple_structure():
|
|||
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
|
||||
assert len(edges) == 1, f"Expected 1 edges, got {len(edges)}"
|
||||
|
||||
edge_key = str(entity.id) + str(entitytype.id) + "is_type"
|
||||
edge_key = f"{str(entity.id)}_{str(entitytype.id)}_is_type"
|
||||
assert edge_key in added_edges, f"Edge {edge_key} not found"
|
||||
|
||||
|
||||
|
|
@ -149,3 +162,48 @@ async def test_get_graph_from_model_no_contains():
|
|||
|
||||
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
|
||||
assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_flexible_edges():
|
||||
"""Tests the new flexible edge system with mixed relationships"""
|
||||
# Create employees
|
||||
manager = Employee(name="Manager", role="Manager")
|
||||
sales1 = Employee(name="Sales1", role="Sales")
|
||||
sales2 = Employee(name="Sales2", role="Sales")
|
||||
admin1 = Employee(name="Admin1", role="Admin")
|
||||
admin2 = Employee(name="Admin2", role="Admin")
|
||||
|
||||
# Create company with mixed employee relationships
|
||||
company = Company(
|
||||
name="Test Company",
|
||||
employees=[
|
||||
# Weighted relationship
|
||||
(Edge(weight=0.9, relationship_type="manages"), manager),
|
||||
# Multiple weights relationship
|
||||
(
|
||||
Edge(weights={"performance": 0.8, "experience": 0.7}, relationship_type="employs"),
|
||||
sales1,
|
||||
),
|
||||
# Simple relationship
|
||||
sales2,
|
||||
# Group relationship
|
||||
(Edge(weights={"team_efficiency": 0.8}, relationship_type="employs"), [admin1, admin2]),
|
||||
],
|
||||
)
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(company, added_nodes, added_edges, visited_properties)
|
||||
|
||||
# Should have 6 nodes: company + 5 employees
|
||||
assert len(nodes) == 6, f"Expected 6 nodes, got {len(nodes)}"
|
||||
# Should have 5 edges: 4 employee relationships
|
||||
assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}"
|
||||
|
||||
# Verify all employees are connected
|
||||
employee_ids = {str(emp.id) for emp in [manager, sales1, sales2, admin1, admin2]}
|
||||
edge_target_ids = {str(edge[1]) for edge in edges}
|
||||
assert employee_ids.issubset(edge_target_ids), "Not all employees are connected"
|
||||
|
|
|
|||
115
examples/python/dynamic_multiple_edges_example.py
Normal file
115
examples/python/dynamic_multiple_edges_example.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
import asyncio
|
||||
from os import path
|
||||
from typing import Any
|
||||
from pydantic import SkipValidation
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.tasks.storage import add_data_points
|
||||
import cognee
|
||||
|
||||
|
||||
class Employee(DataPoint):
|
||||
name: str
|
||||
role: str
|
||||
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
industry: str
|
||||
employs: SkipValidation[Any] # Mixed list: employees with/without weights
|
||||
|
||||
|
||||
async def main():
|
||||
# Clear the database for a clean state
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Create employees
|
||||
michael = Employee(name="Michael", role="Regional Manager")
|
||||
dwight = Employee(name="Dwight", role="Assistant to the Regional Manager")
|
||||
jim = Employee(name="Jim", role="Sales Representative")
|
||||
pam = Employee(name="Pam", role="Receptionist")
|
||||
kevin = Employee(name="Kevin", role="Accountant")
|
||||
angela = Employee(name="Angela", role="Senior Accountant")
|
||||
oscar = Employee(name="Oscar", role="Accountant")
|
||||
stanley = Employee(name="Stanley", role="Sales Representative")
|
||||
phyllis = Employee(name="Phyllis", role="Sales Representative")
|
||||
|
||||
# Create Dunder Mifflin with mixed employee relationships
|
||||
dunder_mifflin = Company(
|
||||
name="Dunder Mifflin Paper Company",
|
||||
industry="Paper Sales",
|
||||
employs=[
|
||||
# Manager with high authority weight
|
||||
(Edge(weight=0.9, relationship_type="manager"), michael),
|
||||
# Sales team with performance weights
|
||||
(
|
||||
Edge(weights={"sales_performance": 0.8, "loyalty": 0.9}, relationship_type="sales"),
|
||||
dwight,
|
||||
),
|
||||
(
|
||||
Edge(
|
||||
weights={"sales_performance": 0.7, "creativity": 0.8}, relationship_type="sales"
|
||||
),
|
||||
jim,
|
||||
),
|
||||
(
|
||||
Edge(
|
||||
weights={"sales_performance": 0.6, "customer_service": 0.9},
|
||||
relationship_type="sales",
|
||||
),
|
||||
phyllis,
|
||||
),
|
||||
(
|
||||
Edge(
|
||||
weights={"sales_performance": 0.5, "experience": 0.8}, relationship_type="sales"
|
||||
),
|
||||
stanley,
|
||||
),
|
||||
# Accounting department as a group
|
||||
(
|
||||
Edge(
|
||||
weights={"department_efficiency": 0.8, "team_cohesion": 0.9},
|
||||
relationship_type="accounting",
|
||||
),
|
||||
[oscar, kevin, angela],
|
||||
),
|
||||
# Admin staff without weights (simple relationships)
|
||||
pam,
|
||||
],
|
||||
)
|
||||
|
||||
all_data_points = [
|
||||
michael,
|
||||
dwight,
|
||||
jim,
|
||||
pam,
|
||||
kevin,
|
||||
angela,
|
||||
oscar,
|
||||
stanley,
|
||||
phyllis,
|
||||
dunder_mifflin,
|
||||
]
|
||||
|
||||
# Add data points to the graph
|
||||
await add_data_points(all_data_points)
|
||||
|
||||
# Visualize the graph
|
||||
graph_visualization_path = path.join(path.dirname(__file__), "dunder_mifflin_graph.html")
|
||||
await visualize_graph(graph_visualization_path)
|
||||
|
||||
print("Dynamic multiple edges graph has been created and visualized!")
|
||||
print(f"Visualization saved to: {graph_visualization_path}")
|
||||
print("\nTechnical features demonstrated:")
|
||||
print("- Mixed list support: weighted and unweighted relationships in single field")
|
||||
print("- Single weight edges with relationship types")
|
||||
print("- Multiple weight edges with custom metrics")
|
||||
print("- Group relationships: single edge connecting multiple nodes")
|
||||
print("- Simple relationships without edge metadata")
|
||||
print("- Flexible edge extraction from heterogeneous data structures")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Add table
Reference in a new issue