feat: genarlizing getting entities from triplets

This commit is contained in:
chinu0609 2025-11-03 00:59:04 +05:30
parent f1afd1f0a2
commit 5080e8f8a5
3 changed files with 17 additions and 9 deletions

View file

@ -5,3 +5,4 @@ from .retrieve_existing_edges import retrieve_existing_edges
from .convert_node_to_data_point import convert_node_to_data_point
from .deduplicate_nodes_and_edges import deduplicate_nodes_and_edges
from .resolve_edges_to_text import resolve_edges_to_text
from .get_entity_nodes_from_triplets import get_entity_nodes_from_triplets

View file

@ -0,0 +1,13 @@
def get_entity_nodes_from_triplets(triplets):
entity_nodes = []
seen_ids = set()
for triplet in triplets:
if hasattr(triplet, 'node1') and triplet.node1 and triplet.node1.id not in seen_ids:
entity_nodes.append({"id": str(triplet.node1.id)})
seen_ids.add(triplet.node1.id)
if hasattr(triplet, 'node2') and triplet.node2 and triplet.node2.id not in seen_ids:
entity_nodes.append({"id": str(triplet.node2.id)})
seen_ids.add(triplet.node2.id)
return entity_nodes

View file

@ -22,6 +22,7 @@ from cognee.modules.engine.models.node_set import NodeSet
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
from cognee.modules.graph.utils import get_entity_nodes_from_triplets
logger = get_logger("GraphCompletionRetriever")
@ -139,15 +140,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
return []
# context = await self.resolve_edges_to_text(triplets)
entity_nodes = []
seen_ids = set()
for triplet in triplets:
if hasattr(triplet, 'node1') and triplet.node1 and triplet.node1.id not in seen_ids:
entity_nodes.append({"id": str(triplet.node1.id)})
seen_ids.add(triplet.node1.id)
if hasattr(triplet, 'node2') and triplet.node2 and triplet.node2.id not in seen_ids:
entity_nodes.append({"id": str(triplet.node2.id)})
seen_ids.add(triplet.node2.id)
entity_nodes = get_entity_nodes_from_triplets(triplets)
await update_node_access_timestamps(entity_nodes)
return triplets