refactor: add edge_type_id generation in add_edge instead of graph projection
This commit is contained in:
parent
01a6382552
commit
099d78ccfc
2 changed files with 26 additions and 11 deletions
|
|
@ -45,6 +45,12 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
|
|
||||||
def add_edge(self, edge: Edge) -> None:
|
def add_edge(self, edge: Edge) -> None:
|
||||||
self.edges.append(edge)
|
self.edges.append(edge)
|
||||||
|
|
||||||
|
edge_text = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
||||||
|
edge.attributes["edge_type_id"] = (
|
||||||
|
generate_edge_id(edge_id=edge_text) if edge_text else None
|
||||||
|
) # Update edge with generated edge_type_id
|
||||||
|
|
||||||
edge.node1.add_skeleton_edge(edge)
|
edge.node1.add_skeleton_edge(edge)
|
||||||
edge.node2.add_skeleton_edge(edge)
|
edge.node2.add_skeleton_edge(edge)
|
||||||
key = edge.get_distance_key()
|
key = edge.get_distance_key()
|
||||||
|
|
@ -206,10 +212,6 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
key: properties.get(key) for key in edge_properties_to_project
|
key: properties.get(key) for key in edge_properties_to_project
|
||||||
}
|
}
|
||||||
edge_attributes["relationship_type"] = relationship_type
|
edge_attributes["relationship_type"] = relationship_type
|
||||||
edge_text = properties.get("edge_text") or properties.get("relationship_name")
|
|
||||||
edge_attributes["edge_type_id"] = (
|
|
||||||
generate_edge_id(edge_id=edge_text) if edge_text else None
|
|
||||||
)
|
|
||||||
|
|
||||||
edge = Edge(
|
edge = Edge(
|
||||||
source_node,
|
source_node,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||||
|
|
@ -379,7 +380,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
|
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
MockScoredResult(generate_edge_id("CONNECTS_TO"), 0.92, payload={"text": "CONNECTS_TO"}),
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||||
|
|
@ -404,8 +405,9 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
||||||
graph.add_edge(edge1)
|
graph.add_edge(edge1)
|
||||||
graph.add_edge(edge2)
|
graph.add_edge(edge2)
|
||||||
|
|
||||||
|
edge_1_text = "CONNECTS_TO"
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text}),
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||||
|
|
@ -431,8 +433,9 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
|
||||||
)
|
)
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
|
|
||||||
|
edge_text = "KNOWS"
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
MockScoredResult(generate_edge_id(edge_text), 0.85, payload={"text": edge_text}),
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||||
|
|
@ -457,8 +460,9 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
|
||||||
)
|
)
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
|
|
||||||
|
edge_text = "SOME_OTHER_EDGE"
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
MockScoredResult(generate_edge_id(edge_text), 0.92, payload={"text": edge_text}),
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||||
|
|
@ -511,9 +515,15 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
|
||||||
graph.add_edge(edge1)
|
graph.add_edge(edge1)
|
||||||
graph.add_edge(edge2)
|
graph.add_edge(edge2)
|
||||||
|
|
||||||
|
edge_1_text = "A"
|
||||||
|
edge_2_text = "B"
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
|
[
|
||||||
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
|
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
|
||||||
|
], # query 0
|
||||||
|
[
|
||||||
|
MockScoredResult(generate_edge_id(edge_2_text), 0.2, payload={"text": edge_2_text})
|
||||||
|
], # query 1
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_edges(
|
await graph.map_vector_distances_to_graph_edges(
|
||||||
|
|
@ -541,8 +551,11 @@ async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(se
|
||||||
graph.add_edge(edge1)
|
graph.add_edge(edge1)
|
||||||
graph.add_edge(edge2)
|
graph.add_edge(edge2)
|
||||||
|
|
||||||
|
edge_1_text = "A"
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped
|
[
|
||||||
|
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
|
||||||
|
], # query 0: only edge1 mapped
|
||||||
[], # query 1: no edges mapped
|
[], # query 1: no edges mapped
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue