diff --git a/cognee/modules/graph/utils/__init__.py b/cognee/modules/graph/utils/__init__.py index ebc648495..4c0b29d47 100644 --- a/cognee/modules/graph/utils/__init__.py +++ b/cognee/modules/graph/utils/__init__.py @@ -5,3 +5,4 @@ from .retrieve_existing_edges import retrieve_existing_edges from .convert_node_to_data_point import convert_node_to_data_point from .deduplicate_nodes_and_edges import deduplicate_nodes_and_edges from .resolve_edges_to_text import resolve_edges_to_text +from .get_entity_nodes_from_triplets import get_entity_nodes_from_triplets diff --git a/cognee/modules/graph/utils/get_entity_nodes_from_triplets.py b/cognee/modules/graph/utils/get_entity_nodes_from_triplets.py new file mode 100644 index 000000000..598a36854 --- /dev/null +++ b/cognee/modules/graph/utils/get_entity_nodes_from_triplets.py @@ -0,0 +1,13 @@ + +def get_entity_nodes_from_triplets(triplets): + entity_nodes = [] + seen_ids = set() + for triplet in triplets: + if hasattr(triplet, 'node1') and triplet.node1 and triplet.node1.id not in seen_ids: + entity_nodes.append({"id": str(triplet.node1.id)}) + seen_ids.add(triplet.node1.id) + if hasattr(triplet, 'node2') and triplet.node2 and triplet.node2.id not in seen_ids: + entity_nodes.append({"id": str(triplet.node2.id)}) + seen_ids.add(triplet.node2.id) + + return entity_nodes diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index ac7e45e3c..122cc943f 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -22,6 +22,7 @@ from cognee.modules.engine.models.node_set import NodeSet from cognee.infrastructure.databases.graph import get_graph_engine from cognee.context_global_variables import session_user from cognee.infrastructure.databases.cache.config import CacheConfig +from cognee.modules.graph.utils import get_entity_nodes_from_triplets logger = get_logger("GraphCompletionRetriever") @@ -139,15 +140,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): return [] # context = await self.resolve_edges_to_text(triplets) - entity_nodes = [] - seen_ids = set() - for triplet in triplets: - if hasattr(triplet, 'node1') and triplet.node1 and triplet.node1.id not in seen_ids: - entity_nodes.append({"id": str(triplet.node1.id)}) - seen_ids.add(triplet.node1.id) - if hasattr(triplet, 'node2') and triplet.node2 and triplet.node2.id not in seen_ids: - entity_nodes.append({"id": str(triplet.node2.id)}) - seen_ids.add(triplet.node2.id) + + entity_nodes = get_entity_nodes_from_triplets(triplets) await update_node_access_timestamps(entity_nodes) return triplets