diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index b23bf8e00..50cc88605 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -410,6 +410,38 @@ class Neo4jAdapter(GraphDBInterface): return await self.query(query, params) + def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]: + """ + Flatten edge properties to handle nested dictionaries like weights. + + Neo4j doesn't support nested dictionaries as property values, so we need to + flatten the 'weights' dictionary into individual properties with prefixes. + + Args: + properties: Dictionary of edge properties that may contain nested dicts + + Returns: + Flattened properties dictionary suitable for Neo4j storage + """ + flattened = {} + + for key, value in properties.items(): + if key == "weights" and isinstance(value, dict): + # Flatten weights dictionary into individual properties + for weight_name, weight_value in value.items(): + flattened[f"weight_{weight_name}"] = weight_value + elif isinstance(value, dict): + # For other nested dictionaries, serialize as JSON string + flattened[f"{key}_json"] = json.dumps(value, cls=JSONEncoder) + elif isinstance(value, list): + # For lists, serialize as JSON string + flattened[f"{key}_json"] = json.dumps(value, cls=JSONEncoder) + else: + # Keep primitive types as-is + flattened[key] = value + + return flattened + @record_graph_changes @override_distributed(queued_add_edges) async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None: @@ -448,11 +480,13 @@ class Neo4jAdapter(GraphDBInterface): "from_node": str(edge[0]), "to_node": str(edge[1]), "relationship_name": edge[2], - "properties": { - **(edge[3] if edge[3] else {}), - "source_node_id": str(edge[0]), - "target_node_id": str(edge[1]), - }, + "properties": self._flatten_edge_properties( + { + **(edge[3] if edge[3] else {}), + "source_node_id": str(edge[0]), + "target_node_id": str(edge[1]), + } + ), } for edge in edges ] diff --git a/examples/python/weighted_graph_visualization.html b/examples/python/weighted_graph_visualization.html index 2e7f67e31..89920a780 100644 --- a/examples/python/weighted_graph_visualization.html +++ b/examples/python/weighted_graph_visualization.html @@ -37,8 +37,8 @@