refactor: add edge_type_id generation in add_edge instead of graph projection
This commit is contained in:
parent
01a6382552
commit
099d78ccfc
2 changed files with 26 additions and 11 deletions
|
|
@ -45,6 +45,12 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
self.edges.append(edge)
|
||||
|
||||
edge_text = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
||||
edge.attributes["edge_type_id"] = (
|
||||
generate_edge_id(edge_id=edge_text) if edge_text else None
|
||||
) # Update edge with generated edge_type_id
|
||||
|
||||
edge.node1.add_skeleton_edge(edge)
|
||||
edge.node2.add_skeleton_edge(edge)
|
||||
key = edge.get_distance_key()
|
||||
|
|
@ -206,10 +212,6 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
key: properties.get(key) for key in edge_properties_to_project
|
||||
}
|
||||
edge_attributes["relationship_type"] = relationship_type
|
||||
edge_text = properties.get("edge_text") or properties.get("relationship_name")
|
||||
edge_attributes["edge_type_id"] = (
|
||||
generate_edge_id(edge_id=edge_text) if edge_text else None
|
||||
)
|
||||
|
||||
edge = Edge(
|
||||
source_node,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||
|
|
@ -379,7 +380,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
|||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
MockScoredResult(generate_edge_id("CONNECTS_TO"), 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
|
@ -404,8 +405,9 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
|||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_1_text = "CONNECTS_TO"
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
|
@ -431,8 +433,9 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
|
|||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_text = "KNOWS"
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
||||
MockScoredResult(generate_edge_id(edge_text), 0.85, payload={"text": edge_text}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
|
@ -457,8 +460,9 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
|
|||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_text = "SOME_OTHER_EDGE"
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
||||
MockScoredResult(generate_edge_id(edge_text), 0.92, payload={"text": edge_text}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
|
@ -511,9 +515,15 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
|
|||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_1_text = "A"
|
||||
edge_2_text = "B"
|
||||
edge_distances = [
|
||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
|
||||
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
|
||||
[
|
||||
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
|
||||
], # query 0
|
||||
[
|
||||
MockScoredResult(generate_edge_id(edge_2_text), 0.2, payload={"text": edge_2_text})
|
||||
], # query 1
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
|
|
@ -541,8 +551,11 @@ async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(se
|
|||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_1_text = "A"
|
||||
edge_distances = [
|
||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped
|
||||
[
|
||||
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
|
||||
], # query 0: only edge1 mapped
|
||||
[], # query 1: no edges mapped
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue