diff --git a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py index 9f0b0bd12..d3a0a8522 100644 --- a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py +++ b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py @@ -22,16 +22,16 @@ def _create_edge_key(source_id: str, target_id: str, relationship_name: str) -> def _process_ontology_nodes( - ontology_nodes: list, - data_chunk: DocumentChunk, - added_nodes_map: dict, - added_ontology_nodes_map: dict + ontology_nodes: list, + data_chunk: DocumentChunk, + added_nodes_map: dict, + added_ontology_nodes_map: dict, ) -> None: """Process and store ontology nodes""" for ontology_node in ontology_nodes: ont_node_id = generate_node_id(ontology_node.name) ont_node_name = generate_node_name(ontology_node.name) - + if ontology_node.category == "classes": ont_node_key = _create_node_key(ont_node_id, "type") if ont_node_key not in added_nodes_map and ont_node_key not in added_ontology_nodes_map: @@ -41,7 +41,7 @@ def _process_ontology_nodes( description=ont_node_name, ontology_valid=True, ) - + elif ontology_node.category == "individuals": ont_node_key = _create_node_key(ont_node_id, "entity") if ont_node_key not in added_nodes_map and ont_node_key not in added_ontology_nodes_map: @@ -55,9 +55,7 @@ def _process_ontology_nodes( def _process_ontology_edges( - ontology_edges: list, - existing_edges_map: dict, - ontology_relationships: list + ontology_edges: list, existing_edges_map: dict, ontology_relationships: list ) -> None: """Process ontology edges and add them if new""" for source, relation, target in ontology_edges: @@ -65,7 +63,7 @@ def _process_ontology_edges( target_node_id = generate_node_id(target) relationship_name = generate_edge_name(relation) edge_key = _create_edge_key(source_node_id, target_node_id, relationship_name) - + if edge_key not in existing_edges_map: ontology_relationships.append( ( @@ -84,41 +82,43 @@ def _process_ontology_edges( def _create_type_node( - node_type: str, - ontology_resolver: OntologyResolver, - added_nodes_map: dict, + node_type: str, + ontology_resolver: OntologyResolver, + added_nodes_map: dict, added_ontology_nodes_map: dict, - name_mapping: dict, - key_mapping: dict, - data_chunk: DocumentChunk, - existing_edges_map: dict, - ontology_relationships: list + name_mapping: dict, + key_mapping: dict, + data_chunk: DocumentChunk, + existing_edges_map: dict, + ontology_relationships: list, ) -> EntityType: """Create or retrieve a type node with ontology validation""" node_id = generate_node_id(node_type) node_name = generate_node_name(node_type) type_node_key = _create_node_key(node_id, "type") - + if type_node_key in added_nodes_map or type_node_key in key_mapping: - return added_nodes_map.get(type_node_key) or added_nodes_map.get(key_mapping.get(type_node_key)) - + return added_nodes_map.get(type_node_key) or added_nodes_map.get( + key_mapping.get(type_node_key) + ) + # Get ontology validation ontology_nodes, ontology_edges, closest_class = ontology_resolver.get_subgraph( node_name=node_name, node_type="classes" ) - + ontology_validated = bool(closest_class) - + if ontology_validated: old_key = type_node_key node_id = generate_node_id(closest_class.name) type_node_key = _create_node_key(node_id, "type") new_node_name = generate_node_name(closest_class.name) - + name_mapping[node_name] = closest_class.name key_mapping[old_key] = type_node_key node_name = new_node_name - + type_node = EntityType( id=node_id, name=node_name, @@ -126,55 +126,57 @@ def _create_type_node( description=node_name, ontology_valid=ontology_validated, ) - + added_nodes_map[type_node_key] = type_node - + # Process ontology nodes and edges _process_ontology_nodes(ontology_nodes, data_chunk, added_nodes_map, added_ontology_nodes_map) _process_ontology_edges(ontology_edges, existing_edges_map, ontology_relationships) - + return type_node def _create_entity_node( - node_id: str, - node_name: str, - node_description: str, - type_node: EntityType, - ontology_resolver: OntologyResolver, - added_nodes_map: dict, + node_id: str, + node_name: str, + node_description: str, + type_node: EntityType, + ontology_resolver: OntologyResolver, + added_nodes_map: dict, added_ontology_nodes_map: dict, - name_mapping: dict, - key_mapping: dict, - data_chunk: DocumentChunk, - existing_edges_map: dict, - ontology_relationships: list + name_mapping: dict, + key_mapping: dict, + data_chunk: DocumentChunk, + existing_edges_map: dict, + ontology_relationships: list, ) -> Entity: """Create or retrieve an entity node with ontology validation""" generated_node_id = generate_node_id(node_id) generated_node_name = generate_node_name(node_name) entity_node_key = _create_node_key(generated_node_id, "entity") - + if entity_node_key in added_nodes_map or entity_node_key in key_mapping: - return added_nodes_map.get(entity_node_key) or added_nodes_map.get(key_mapping.get(entity_node_key)) - + return added_nodes_map.get(entity_node_key) or added_nodes_map.get( + key_mapping.get(entity_node_key) + ) + # Get ontology validation ontology_nodes, ontology_edges, start_ent_ont = ontology_resolver.get_subgraph( node_name=generated_node_name, node_type="individuals" ) - + ontology_validated = bool(start_ent_ont) - + if ontology_validated: old_key = entity_node_key generated_node_id = generate_node_id(start_ent_ont.name) entity_node_key = _create_node_key(generated_node_id, "entity") new_node_name = generate_node_name(start_ent_ont.name) - + name_mapping[generated_node_name] = start_ent_ont.name key_mapping[old_key] = entity_node_key generated_node_name = new_node_name - + entity_node = Entity( id=generated_node_id, name=generated_node_name, @@ -183,42 +185,58 @@ def _create_entity_node( ontology_valid=ontology_validated, belongs_to_set=data_chunk.belongs_to_set, ) - + added_nodes_map[entity_node_key] = entity_node - + # Process ontology nodes and edges _process_ontology_nodes(ontology_nodes, data_chunk, added_nodes_map, added_ontology_nodes_map) _process_ontology_edges(ontology_edges, existing_edges_map, ontology_relationships) - + return entity_node def _process_graph_nodes( - data_chunk: DocumentChunk, - graph: KnowledgeGraph, - ontology_resolver: OntologyResolver, - added_nodes_map: dict, + data_chunk: DocumentChunk, + graph: KnowledgeGraph, + ontology_resolver: OntologyResolver, + added_nodes_map: dict, added_ontology_nodes_map: dict, - name_mapping: dict, - key_mapping: dict, - existing_edges_map: dict, - ontology_relationships: list + name_mapping: dict, + key_mapping: dict, + existing_edges_map: dict, + ontology_relationships: list, ) -> None: """Process nodes in a knowledge graph""" for node in graph.nodes: # Create type node type_node = _create_type_node( - node.type, ontology_resolver, added_nodes_map, added_ontology_nodes_map, - name_mapping, key_mapping, data_chunk, existing_edges_map, ontology_relationships + node.type, + ontology_resolver, + added_nodes_map, + added_ontology_nodes_map, + name_mapping, + key_mapping, + data_chunk, + existing_edges_map, + ontology_relationships, ) - + # Create entity node entity_node = _create_entity_node( - node.id, node.name, node.description, type_node, ontology_resolver, - added_nodes_map, added_ontology_nodes_map, name_mapping, key_mapping, - data_chunk, existing_edges_map, ontology_relationships + node.id, + node.name, + node.description, + type_node, + ontology_resolver, + added_nodes_map, + added_ontology_nodes_map, + name_mapping, + key_mapping, + data_chunk, + existing_edges_map, + ontology_relationships, ) - + # Add entity to data chunk if data_chunk.contains is None: data_chunk.contains = [] @@ -226,22 +244,19 @@ def _process_graph_nodes( def _process_graph_edges( - graph: KnowledgeGraph, - name_mapping: dict, - existing_edges_map: dict, - relationships: list + graph: KnowledgeGraph, name_mapping: dict, existing_edges_map: dict, relationships: list ) -> None: """Process edges in a knowledge graph""" for edge in graph.edges: # Apply name mapping if exists source_id = name_mapping.get(edge.source_node_id, edge.source_node_id) target_id = name_mapping.get(edge.target_node_id, edge.target_node_id) - + source_node_id = generate_node_id(source_id) target_node_id = generate_node_id(target_id) relationship_name = generate_edge_name(edge.relationship_name) edge_key = _create_edge_key(source_node_id, target_node_id, relationship_name) - + if edge_key not in existing_edges_map: relationships.append( ( @@ -270,33 +285,40 @@ def expand_with_nodes_and_edges( """ if existing_edges_map is None: existing_edges_map = {} - + if ontology_resolver is None: ontology_resolver = OntologyResolver() - + added_nodes_map = {} added_ontology_nodes_map = {} relationships = [] ontology_relationships = [] name_mapping = {} key_mapping = {} - + # Process each chunk and its corresponding graph for data_chunk, graph in zip(data_chunks, chunk_graphs): if not graph: continue - + # Process nodes first _process_graph_nodes( - data_chunk, graph, ontology_resolver, added_nodes_map, added_ontology_nodes_map, - name_mapping, key_mapping, existing_edges_map, ontology_relationships + data_chunk, + graph, + ontology_resolver, + added_nodes_map, + added_ontology_nodes_map, + name_mapping, + key_mapping, + existing_edges_map, + ontology_relationships, ) - + # Then process edges _process_graph_edges(graph, name_mapping, existing_edges_map, relationships) - + # Return combined results graph_nodes = list(added_ontology_nodes_map.values()) graph_edges = relationships + ontology_relationships - + return graph_nodes, graph_edges