feat: optimize repeated entity extraction (#1682)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> - Added an `edge_text` field to edges that auto-fills from `relationship_type` if not provided. - Containts edges now store descriptions for better embedding - Updated and refactored indexing so that edge_text gets embedded and exposed - Updated retrieval to use the new embeddings - Added a test to verify edge_text exists in the graph with the correct format. ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [x] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
116579a8e1
commit
6223ecf05b
11 changed files with 236 additions and 162 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
|
||||
|
|
@ -18,9 +18,21 @@ class Edge(BaseModel):
|
|||
|
||||
# Mixed usage
|
||||
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])
|
||||
|
||||
# With edge_text for rich embedding representation
|
||||
contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity)
|
||||
"""
|
||||
|
||||
weight: Optional[float] = None
|
||||
weights: Optional[Dict[str, float]] = None
|
||||
relationship_type: Optional[str] = None
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
edge_text: Optional[str] = None
|
||||
|
||||
@field_validator("edge_text", mode="before")
|
||||
@classmethod
|
||||
def ensure_edge_text(cls, v, info):
|
||||
"""Auto-populate edge_text from relationship_type if not explicitly provided."""
|
||||
if v is None and info.data.get("relationship_type"):
|
||||
return info.data["relationship_type"]
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List, Union
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
from cognee.tasks.temporal_graph.models import Event
|
||||
|
|
@ -31,6 +32,6 @@ class DocumentChunk(DataPoint):
|
|||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Union[Entity, Event]] = None
|
||||
contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
|
|
|||
|
|
@ -171,8 +171,10 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
||||
for edge in self.edges:
|
||||
relationship_type = edge.attributes.get("relationship_type")
|
||||
distance = embedding_map.get(relationship_type, None)
|
||||
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
|
||||
"relationship_type"
|
||||
)
|
||||
distance = embedding_map.get(edge_key, None)
|
||||
if distance is not None:
|
||||
edge.attributes["vector_distance"] = distance
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.engine.utils import (
|
||||
|
|
@ -243,10 +244,26 @@ def _process_graph_nodes(
|
|||
ontology_relationships,
|
||||
)
|
||||
|
||||
# Add entity to data chunk
|
||||
if data_chunk.contains is None:
|
||||
data_chunk.contains = []
|
||||
data_chunk.contains.append(entity_node)
|
||||
|
||||
edge_text = "; ".join(
|
||||
[
|
||||
"relationship_name: contains",
|
||||
f"entity_name: {entity_node.name}",
|
||||
f"entity_description: {entity_node.description}",
|
||||
]
|
||||
)
|
||||
|
||||
data_chunk.contains.append(
|
||||
(
|
||||
Edge(
|
||||
relationship_type="contains",
|
||||
edge_text=edge_text,
|
||||
),
|
||||
entity_node,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _process_graph_edges(
|
||||
|
|
|
|||
|
|
@ -1,71 +1,70 @@
|
|||
import string
|
||||
from typing import List
|
||||
from collections import Counter
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
|
||||
|
||||
def _get_top_n_frequent_words(
|
||||
text: str, stop_words: set = None, top_n: int = 3, separator: str = ", "
|
||||
) -> str:
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
if stop_words is None:
|
||||
stop_words = DEFAULT_STOP_WORDS
|
||||
|
||||
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
||||
words = [word for word in words if word and word not in stop_words]
|
||||
|
||||
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
||||
return separator.join(top_words)
|
||||
|
||||
|
||||
def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
"""Creates a title by combining first words with most frequent words from the text."""
|
||||
first_words = text.split()[:first_n_words]
|
||||
top_words = _get_top_n_frequent_words(text, top_n=top_n_words)
|
||||
return f"{' '.join(first_words)}... [{top_words}]"
|
||||
|
||||
|
||||
def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict:
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
|
||||
for edge in retrieved_edges:
|
||||
for node in (edge.node1, edge.node2):
|
||||
if node.id in nodes:
|
||||
continue
|
||||
|
||||
text = node.attributes.get("text")
|
||||
if text:
|
||||
name = _create_title_from_text(text)
|
||||
content = text
|
||||
else:
|
||||
name = node.attributes.get("name", "Unnamed Node")
|
||||
content = node.attributes.get("description", name)
|
||||
|
||||
nodes[node.id] = {"node": node, "name": name, "content": content}
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
|
||||
"""
|
||||
Converts retrieved graph edges into a human-readable string format.
|
||||
"""Converts retrieved graph edges into a human-readable string format."""
|
||||
nodes = _extract_nodes_from_edges(retrieved_edges)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- retrieved_edges (list): A list of edges retrieved from the graph.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A formatted string representation of the nodes and their connections.
|
||||
"""
|
||||
|
||||
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
|
||||
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
if stop_words is None:
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
|
||||
stop_words = DEFAULT_STOP_WORDS
|
||||
|
||||
import string
|
||||
|
||||
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
||||
|
||||
if stop_words:
|
||||
words = [word for word in words if word and word not in stop_words]
|
||||
|
||||
from collections import Counter
|
||||
|
||||
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
||||
|
||||
return separator.join(top_words)
|
||||
|
||||
"""Creates a title, by combining first words with most frequent words from the text."""
|
||||
first_words = text.split()[:first_n_words]
|
||||
top_words = _top_n_words(text, top_n=first_n_words)
|
||||
return f"{' '.join(first_words)}... [{top_words}]"
|
||||
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
for edge in retrieved_edges:
|
||||
for node in (edge.node1, edge.node2):
|
||||
if node.id not in nodes:
|
||||
text = node.attributes.get("text")
|
||||
if text:
|
||||
name = _get_title(text)
|
||||
content = text
|
||||
else:
|
||||
name = node.attributes.get("name", "Unnamed Node")
|
||||
content = node.attributes.get("description", name)
|
||||
nodes[node.id] = {"node": node, "name": name, "content": content}
|
||||
return nodes
|
||||
|
||||
nodes = _get_nodes(retrieved_edges)
|
||||
node_section = "\n".join(
|
||||
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
||||
for info in nodes.values()
|
||||
)
|
||||
connection_section = "\n".join(
|
||||
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
|
||||
for edge in retrieved_edges
|
||||
)
|
||||
|
||||
connections = []
|
||||
for edge in retrieved_edges:
|
||||
source_name = nodes[edge.node1.id]["name"]
|
||||
target_name = nodes[edge.node2.id]["name"]
|
||||
edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
||||
connections.append(f"{source_name} --[{edge_label}]--> {target_name}")
|
||||
|
||||
connection_section = "\n".join(connections)
|
||||
|
||||
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ async def get_memory_fragment(
|
|||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
edge_properties_to_project=["relationship_name", "edge_text"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,47 +8,58 @@ logger = get_logger("index_data_points")
|
|||
|
||||
|
||||
async def index_data_points(data_points: list[DataPoint]):
|
||||
created_indexes = {}
|
||||
index_points = {}
|
||||
"""Index data points in the vector engine by creating embeddings for specified fields.
|
||||
|
||||
Process:
|
||||
1. Groups data points into a nested dict: {type_name: {field_name: [points]}}
|
||||
2. Creates vector indexes for each (type, field) combination on first encounter
|
||||
3. Batches points per (type, field) and creates async indexing tasks
|
||||
4. Executes all indexing tasks in parallel for efficient embedding generation
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint objects to index. Each DataPoint's metadata must
|
||||
contain an 'index_fields' list specifying which fields to embed.
|
||||
|
||||
Returns:
|
||||
The original data_points list.
|
||||
"""
|
||||
data_points_by_type = {}
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
for data_point in data_points:
|
||||
data_point_type = type(data_point)
|
||||
type_name = data_point_type.__name__
|
||||
|
||||
for field_name in data_point.metadata["index_fields"]:
|
||||
if getattr(data_point, field_name, None) is None:
|
||||
continue
|
||||
|
||||
index_name = f"{data_point_type.__name__}_{field_name}"
|
||||
if type_name not in data_points_by_type:
|
||||
data_points_by_type[type_name] = {}
|
||||
|
||||
if index_name not in created_indexes:
|
||||
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||
created_indexes[index_name] = True
|
||||
|
||||
if index_name not in index_points:
|
||||
index_points[index_name] = []
|
||||
if field_name not in data_points_by_type[type_name]:
|
||||
await vector_engine.create_vector_index(type_name, field_name)
|
||||
data_points_by_type[type_name][field_name] = []
|
||||
|
||||
indexed_data_point = data_point.model_copy()
|
||||
indexed_data_point.metadata["index_fields"] = [field_name]
|
||||
index_points[index_name].append(indexed_data_point)
|
||||
data_points_by_type[type_name][field_name].append(indexed_data_point)
|
||||
|
||||
tasks: list[asyncio.Task] = []
|
||||
batch_size = vector_engine.embedding_engine.get_batch_size()
|
||||
|
||||
for index_name_and_field, points in index_points.items():
|
||||
first = index_name_and_field.index("_")
|
||||
index_name = index_name_and_field[:first]
|
||||
field_name = index_name_and_field[first + 1 :]
|
||||
batches = (
|
||||
(type_name, field_name, points[i : i + batch_size])
|
||||
for type_name, fields in data_points_by_type.items()
|
||||
for field_name, points in fields.items()
|
||||
for i in range(0, len(points), batch_size)
|
||||
)
|
||||
|
||||
# Create embedding requests per batch to run in parallel later
|
||||
for i in range(0, len(points), batch_size):
|
||||
batch = points[i : i + batch_size]
|
||||
tasks.append(
|
||||
asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch))
|
||||
)
|
||||
tasks = [
|
||||
asyncio.create_task(vector_engine.index_data_points(type_name, field_name, batch_points))
|
||||
for type_name, field_name, batch_points in batches
|
||||
]
|
||||
|
||||
# Run all embedding requests in parallel
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return data_points
|
||||
|
|
|
|||
|
|
@ -1,17 +1,44 @@
|
|||
import asyncio
|
||||
from collections import Counter
|
||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||
|
||||
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from collections import Counter
|
||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.graph.models.EdgeType import EdgeType
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData
|
||||
from cognee.tasks.storage.index_data_points import index_data_points
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def _get_edge_text(item: dict) -> str:
|
||||
"""Extract edge text for embedding - prefers edge_text field with fallback."""
|
||||
if "edge_text" in item:
|
||||
return item["edge_text"]
|
||||
|
||||
if "relationship_name" in item:
|
||||
return item["relationship_name"]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def create_edge_type_datapoints(edges_data) -> list[EdgeType]:
|
||||
"""Transform raw edge data into EdgeType datapoints."""
|
||||
edge_texts = [
|
||||
_get_edge_text(item)
|
||||
for edge in edges_data
|
||||
for item in edge
|
||||
if isinstance(item, dict) and "relationship_name" in item
|
||||
]
|
||||
|
||||
edge_types = Counter(edge_texts)
|
||||
|
||||
return [
|
||||
EdgeType(id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count)
|
||||
for text, count in edge_types.items()
|
||||
]
|
||||
|
||||
|
||||
async def index_graph_edges(
|
||||
edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None,
|
||||
):
|
||||
|
|
@ -23,24 +50,17 @@ async def index_graph_edges(
|
|||
the `relationship_name` field.
|
||||
|
||||
Steps:
|
||||
1. Initialize the vector engine and graph engine.
|
||||
2. Retrieve graph edge data and count relationship types (`relationship_name`).
|
||||
3. Create vector indexes for `relationship_name` if they don't exist.
|
||||
4. Transform the counted relationships into `EdgeType` objects.
|
||||
5. Index the transformed data points in the vector engine.
|
||||
1. Initialize the graph engine if needed and retrieve edge data.
|
||||
2. Transform edge data into EdgeType datapoints.
|
||||
3. Index the EdgeType datapoints using the standard indexing function.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If initialization of the vector engine or graph engine fails.
|
||||
RuntimeError: If initialization of the graph engine fails.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
created_indexes = {}
|
||||
index_points = {}
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
if edges_data is None:
|
||||
graph_engine = await get_graph_engine()
|
||||
_, edges_data = await graph_engine.get_graph_data()
|
||||
|
|
@ -51,47 +71,7 @@ async def index_graph_edges(
|
|||
logger.error("Failed to initialize engines: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
edge_types = Counter(
|
||||
item.get("relationship_name")
|
||||
for edge in edges_data
|
||||
for item in edge
|
||||
if isinstance(item, dict) and "relationship_name" in item
|
||||
)
|
||||
|
||||
for text, count in edge_types.items():
|
||||
edge = EdgeType(
|
||||
id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count
|
||||
)
|
||||
data_point_type = type(edge)
|
||||
|
||||
for field_name in edge.metadata["index_fields"]:
|
||||
index_name = f"{data_point_type.__name__}.{field_name}"
|
||||
|
||||
if index_name not in created_indexes:
|
||||
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||
created_indexes[index_name] = True
|
||||
|
||||
if index_name not in index_points:
|
||||
index_points[index_name] = []
|
||||
|
||||
indexed_data_point = edge.model_copy()
|
||||
indexed_data_point.metadata["index_fields"] = [field_name]
|
||||
index_points[index_name].append(indexed_data_point)
|
||||
|
||||
# Get maximum batch size for embedding model
|
||||
batch_size = vector_engine.embedding_engine.get_batch_size()
|
||||
tasks: list[asyncio.Task] = []
|
||||
|
||||
for index_name, indexable_points in index_points.items():
|
||||
index_name, field_name = index_name.split(".")
|
||||
|
||||
# Create embedding tasks to run in parallel later
|
||||
for start in range(0, len(indexable_points), batch_size):
|
||||
batch = indexable_points[start : start + batch_size]
|
||||
|
||||
tasks.append(vector_engine.index_data_points(index_name, field_name, batch))
|
||||
|
||||
# Start all embedding tasks and wait for completion
|
||||
await asyncio.gather(*tasks)
|
||||
edge_type_datapoints = create_edge_type_datapoints(edges_data)
|
||||
await index_data_points(edge_type_datapoints)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -52,6 +52,33 @@ async def test_edge_ingestion():
|
|||
|
||||
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
|
||||
|
||||
"Tests edge_text presence and format"
|
||||
contains_edges = [edge for edge in graph[1] if edge[2] == "contains"]
|
||||
assert len(contains_edges) > 0, "Expected at least one contains edge for edge_text verification"
|
||||
|
||||
edge_properties = contains_edges[0][3]
|
||||
assert "edge_text" in edge_properties, "Expected edge_text in edge properties"
|
||||
|
||||
edge_text = edge_properties["edge_text"]
|
||||
assert "relationship_name: contains" in edge_text, (
|
||||
f"Expected 'relationship_name: contains' in edge_text, got: {edge_text}"
|
||||
)
|
||||
assert "entity_name:" in edge_text, f"Expected 'entity_name:' in edge_text, got: {edge_text}"
|
||||
assert "entity_description:" in edge_text, (
|
||||
f"Expected 'entity_description:' in edge_text, got: {edge_text}"
|
||||
)
|
||||
|
||||
all_edge_texts = [
|
||||
edge[3].get("edge_text", "") for edge in contains_edges if "edge_text" in edge[3]
|
||||
]
|
||||
expected_entities = ["dave", "ana", "bob", "dexter", "apples", "cognee"]
|
||||
found_entity = any(
|
||||
any(entity in text.lower() for entity in expected_entities) for text in all_edge_texts
|
||||
)
|
||||
assert found_entity, (
|
||||
f"Expected to find at least one entity name in edge_text: {all_edge_texts[:3]}"
|
||||
)
|
||||
|
||||
"Tests the presence of basic nested edges"
|
||||
for basic_nested_edge in basic_nested_edges:
|
||||
assert edge_type_counts.get(basic_nested_edge, 0) >= 1, (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from cognee.tasks.storage.index_data_points import index_data_points
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class TestDataPoint(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_data_points_calls_vector_engine():
|
||||
"""Test that index_data_points creates vector index and indexes data."""
|
||||
data_points = [TestDataPoint(name="test1")]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100)
|
||||
|
||||
with patch.dict(
|
||||
index_data_points.__globals__,
|
||||
{"get_vector_engine": lambda: mock_vector_engine},
|
||||
):
|
||||
await index_data_points(data_points)
|
||||
|
||||
assert mock_vector_engine.create_vector_index.await_count >= 1
|
||||
assert mock_vector_engine.index_data_points.await_count >= 1
|
||||
|
|
@ -5,8 +5,7 @@ from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_graph_edges_success():
|
||||
"""Test that index_graph_edges uses the index datapoints and creates vector index."""
|
||||
# Create the mocks for the graph and vector engines.
|
||||
"""Test that index_graph_edges retrieves edges and delegates to index_data_points."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.get_graph_data.return_value = (
|
||||
None,
|
||||
|
|
@ -15,26 +14,23 @@ async def test_index_graph_edges_success():
|
|||
[{"relationship_name": "rel2"}],
|
||||
],
|
||||
)
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100)
|
||||
mock_index_data_points = AsyncMock()
|
||||
|
||||
# Patch the globals of the function so that when it does:
|
||||
# vector_engine = get_vector_engine()
|
||||
# graph_engine = await get_graph_engine()
|
||||
# it uses the mocked versions.
|
||||
with patch.dict(
|
||||
index_graph_edges.__globals__,
|
||||
{
|
||||
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
|
||||
"get_vector_engine": lambda: mock_vector_engine,
|
||||
"index_data_points": mock_index_data_points,
|
||||
},
|
||||
):
|
||||
await index_graph_edges()
|
||||
|
||||
# Assertions on the mock calls.
|
||||
mock_graph_engine.get_graph_data.assert_awaited_once()
|
||||
assert mock_vector_engine.create_vector_index.await_count == 1
|
||||
assert mock_vector_engine.index_data_points.await_count == 1
|
||||
mock_index_data_points.assert_awaited_once()
|
||||
|
||||
call_args = mock_index_data_points.call_args[0][0]
|
||||
assert len(call_args) == 2
|
||||
assert all(hasattr(item, "relationship_name") for item in call_args)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -42,20 +38,22 @@ async def test_index_graph_edges_no_relationships():
|
|||
"""Test that index_graph_edges handles empty relationships correctly."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.get_graph_data.return_value = (None, [])
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_index_data_points = AsyncMock()
|
||||
|
||||
with patch.dict(
|
||||
index_graph_edges.__globals__,
|
||||
{
|
||||
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
|
||||
"get_vector_engine": lambda: mock_vector_engine,
|
||||
"index_data_points": mock_index_data_points,
|
||||
},
|
||||
):
|
||||
await index_graph_edges()
|
||||
|
||||
mock_graph_engine.get_graph_data.assert_awaited_once()
|
||||
mock_vector_engine.create_vector_index.assert_not_awaited()
|
||||
mock_vector_engine.index_data_points.assert_not_awaited()
|
||||
mock_index_data_points.assert_awaited_once()
|
||||
|
||||
call_args = mock_index_data_points.call_args[0][0]
|
||||
assert len(call_args) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue