feat: Adds edge centered payload and embedding structure during ingestion (#1853)
<!-- .github/pull_request_template.md -->
## Description
This pull request introduces edge‑centered payloads to the ingestion
process. Payloads are stored in the Triplet_text collection which is
compatible with the triplet_embedding memify pipeline.
Changes in This PR:
- Refactored custom edge handling, from now on they can be passed to the
add_data_points method so the ingestion is centralized and is happening
in one place.
- Added private methods to handle edge centered payload creation inside
the add_data_points.py
- Added unit tests to cover the new functionality
- Added integration tests
- Added e2e tests
Acceptance Criteria and Testing
Scenario 1:
-Set TRIPLET_EMBEDDING env var to True
-Run prune, add, cognify
-Verify the vector DB contains a non empty Triplet_text collection and
the number of triplets are matching with the number of edges in the
graph database
-Use the new triplet_completion search type and confirm it works
correctly.
Scenario 2:
-Set TRIPLET_EMBEDDING env var to True
-Run prune, add, cognify
-Verify the vector DB does not have the Triplet_text collection
-You should receive an error indicating that the Triplet_text is not
available
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Triplet embeddings supported—embeddings created from graph edges plus
connected node text
* Ability to supply custom edges when adding data points
* New configuration toggle to enable/disable triplet embedding
* **Tests**
* Added comprehensive unit and end-to-end tests for edge-centered
payloads and triplet embedding
* New CI job to run the edge-centered payload e2e test
* **Bug Fixes**
* Adjusted server start behavior to surface process output in parent
logs
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Co-authored-by: Pavel Zorin <pazonec@yandex.ru>
This commit is contained in:
parent
49f7c5188c
commit
001fbe699e
9 changed files with 786 additions and 14 deletions
29
.github/workflows/e2e_tests.yml
vendored
29
.github/workflows/e2e_tests.yml
vendored
|
|
@ -412,6 +412,35 @@ jobs:
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/test_feedback_enrichment.py
|
run: uv run python ./cognee/tests/test_feedback_enrichment.py
|
||||||
|
|
||||||
|
test-edge-centered-payload:
|
||||||
|
name: Test Cognify - Edge Centered Payload
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Dependencies already installed
|
||||||
|
run: echo "Dependencies already installed in setup"
|
||||||
|
|
||||||
|
- name: Run Edge Centered Payload Test
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
TRIPLET_EMBEDDING: True
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
run: uv run python ./cognee/tests/test_edge_centered_payload.py
|
||||||
|
|
||||||
run_conversation_sessions_test_redis:
|
run_conversation_sessions_test_redis:
|
||||||
name: Conversation sessions test (Redis)
|
name: Conversation sessions test (Redis)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from pydantic import BaseModel
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
|
|
@ -272,6 +273,9 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
if chunks_per_batch is None:
|
if chunks_per_batch is None:
|
||||||
chunks_per_batch = 100
|
chunks_per_batch = 100
|
||||||
|
|
||||||
|
cognify_config = get_cognify_config()
|
||||||
|
embed_triplets = cognify_config.triplet_embedding
|
||||||
|
|
||||||
default_tasks = [
|
default_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||||
|
|
@ -291,7 +295,11 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
summarize_text,
|
summarize_text,
|
||||||
task_config={"batch_size": chunks_per_batch},
|
task_config={"batch_size": chunks_per_batch},
|
||||||
),
|
),
|
||||||
Task(add_data_points, task_config={"batch_size": chunks_per_batch}),
|
Task(
|
||||||
|
add_data_points,
|
||||||
|
embed_triplets=embed_triplets,
|
||||||
|
task_config={"batch_size": chunks_per_batch},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
return default_tasks
|
return default_tasks
|
||||||
|
|
|
||||||
|
|
@ -8,12 +8,14 @@ import os
|
||||||
class CognifyConfig(BaseSettings):
|
class CognifyConfig(BaseSettings):
|
||||||
classification_model: object = DefaultContentPrediction
|
classification_model: object = DefaultContentPrediction
|
||||||
summarization_model: object = SummarizedContent
|
summarization_model: object = SummarizedContent
|
||||||
|
triplet_embedding: bool = False
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"classification_model": self.classification_model,
|
"classification_model": self.classification_model,
|
||||||
"summarization_model": self.summarization_model,
|
"summarization_model": self.summarization_model,
|
||||||
|
"triplet_embedding": self.triplet_embedding,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,7 @@ import asyncio
|
||||||
from typing import Type, List, Optional
|
from typing import Type, List, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
||||||
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||||
from cognee.tasks.storage import index_graph_edges
|
|
||||||
from cognee.tasks.storage.add_data_points import add_data_points
|
from cognee.tasks.storage.add_data_points import add_data_points
|
||||||
from cognee.modules.ontology.ontology_config import Config
|
from cognee.modules.ontology.ontology_config import Config
|
||||||
from cognee.modules.ontology.get_default_ontology_resolver import (
|
from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||||
|
|
@ -25,6 +23,7 @@ from cognee.tasks.graph.exceptions import (
|
||||||
InvalidChunkGraphInputError,
|
InvalidChunkGraphInputError,
|
||||||
InvalidOntologyAdapterError,
|
InvalidOntologyAdapterError,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
|
|
||||||
|
|
||||||
async def integrate_chunk_graphs(
|
async def integrate_chunk_graphs(
|
||||||
|
|
@ -67,8 +66,6 @@ async def integrate_chunk_graphs(
|
||||||
type(ontology_resolver).__name__ if ontology_resolver else "None"
|
type(ontology_resolver).__name__ if ontology_resolver else "None"
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
|
||||||
|
|
||||||
if graph_model is not KnowledgeGraph:
|
if graph_model is not KnowledgeGraph:
|
||||||
for chunk_index, chunk_graph in enumerate(chunk_graphs):
|
for chunk_index, chunk_graph in enumerate(chunk_graphs):
|
||||||
data_chunks[chunk_index].contains = chunk_graph
|
data_chunks[chunk_index].contains = chunk_graph
|
||||||
|
|
@ -84,12 +81,13 @@ async def integrate_chunk_graphs(
|
||||||
data_chunks, chunk_graphs, ontology_resolver, existing_edges_map
|
data_chunks, chunk_graphs, ontology_resolver, existing_edges_map
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(graph_nodes) > 0:
|
cognify_config = get_cognify_config()
|
||||||
await add_data_points(graph_nodes)
|
embed_triplets = cognify_config.triplet_embedding
|
||||||
|
|
||||||
if len(graph_edges) > 0:
|
if len(graph_nodes) > 0:
|
||||||
await graph_engine.add_edges(graph_edges)
|
await add_data_points(
|
||||||
await index_graph_edges(graph_edges)
|
data_points=graph_nodes, custom_edges=graph_edges, embed_triplets=embed_triplets
|
||||||
|
)
|
||||||
|
|
||||||
return data_chunks
|
return data_chunks
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,23 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List
|
from typing import List, Dict, Optional
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
||||||
from .index_data_points import index_data_points
|
from .index_data_points import index_data_points
|
||||||
from .index_graph_edges import index_graph_edges
|
from .index_graph_edges import index_graph_edges
|
||||||
|
from cognee.modules.engine.models import Triplet
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.tasks.storage.exceptions import (
|
from cognee.tasks.storage.exceptions import (
|
||||||
InvalidDataPointsInAddDataPointsError,
|
InvalidDataPointsInAddDataPointsError,
|
||||||
)
|
)
|
||||||
|
from ...modules.engine.utils import generate_node_id
|
||||||
|
|
||||||
|
logger = get_logger("add_data_points")
|
||||||
|
|
||||||
|
|
||||||
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
async def add_data_points(
|
||||||
|
data_points: List[DataPoint], custom_edges: Optional[List] = None, embed_triplets: bool = False
|
||||||
|
) -> List[DataPoint]:
|
||||||
"""
|
"""
|
||||||
Add a batch of data points to the graph database by extracting nodes and edges,
|
Add a batch of data points to the graph database by extracting nodes and edges,
|
||||||
deduplicating them, and indexing them for retrieval.
|
deduplicating them, and indexing them for retrieval.
|
||||||
|
|
@ -23,6 +30,10 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||||
Args:
|
Args:
|
||||||
data_points (List[DataPoint]):
|
data_points (List[DataPoint]):
|
||||||
A list of data points to process and insert into the graph.
|
A list of data points to process and insert into the graph.
|
||||||
|
custom_edges (List[tuple]): Custom edges between datapoints.
|
||||||
|
embed_triplets (bool):
|
||||||
|
If True, creates and indexes triplet embeddings from the graph structure.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[DataPoint]:
|
List[DataPoint]:
|
||||||
|
|
@ -34,6 +45,7 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||||
- Updates the node index via `index_data_points`.
|
- Updates the node index via `index_data_points`.
|
||||||
- Inserts nodes and edges into the graph engine.
|
- Inserts nodes and edges into the graph engine.
|
||||||
- Optionally updates the edge index via `index_graph_edges`.
|
- Optionally updates the edge index via `index_graph_edges`.
|
||||||
|
- Optionally creates and indexes triplet embeddings if embed_triplets is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(data_points, list):
|
if not isinstance(data_points, list):
|
||||||
|
|
@ -74,4 +86,132 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||||
await graph_engine.add_edges(edges)
|
await graph_engine.add_edges(edges)
|
||||||
await index_graph_edges(edges)
|
await index_graph_edges(edges)
|
||||||
|
|
||||||
|
if isinstance(custom_edges, list) and custom_edges:
|
||||||
|
# This must be handled separately from datapoint edges, created a task in linear to dig deeper but (COG-3488)
|
||||||
|
await graph_engine.add_edges(custom_edges)
|
||||||
|
await index_graph_edges(custom_edges)
|
||||||
|
edges.extend(custom_edges)
|
||||||
|
|
||||||
|
if embed_triplets:
|
||||||
|
triplets = _create_triplets_from_graph(nodes, edges)
|
||||||
|
if triplets:
|
||||||
|
await index_data_points(triplets)
|
||||||
|
logger.info(f"Created and indexed {len(triplets)} triplets from graph structure")
|
||||||
|
|
||||||
return data_points
|
return data_points
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_embeddable_text_from_datapoint(data_point: DataPoint) -> str:
|
||||||
|
"""
|
||||||
|
Extract embeddable text from a DataPoint using its index_fields metadata.
|
||||||
|
Uses the same approach as index_data_points.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- data_point (DataPoint): The data point to extract text from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- str: Concatenated string of all embeddable property values, or empty string if none found.
|
||||||
|
"""
|
||||||
|
if not data_point or not hasattr(data_point, "metadata"):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
index_fields = data_point.metadata.get("index_fields", [])
|
||||||
|
if not index_fields:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
embeddable_values = []
|
||||||
|
for field_name in index_fields:
|
||||||
|
field_value = getattr(data_point, field_name, None)
|
||||||
|
if field_value is not None:
|
||||||
|
field_value = str(field_value).strip()
|
||||||
|
|
||||||
|
if field_value:
|
||||||
|
embeddable_values.append(field_value)
|
||||||
|
|
||||||
|
return " ".join(embeddable_values) if embeddable_values else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _create_triplets_from_graph(nodes: List[DataPoint], edges: List[tuple]) -> List[Triplet]:
|
||||||
|
"""
|
||||||
|
Create Triplet objects from graph nodes and edges.
|
||||||
|
|
||||||
|
This function processes graph edges and their corresponding nodes to create
|
||||||
|
triplet datapoints with embeddable text, similar to the triplet embeddings pipeline.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- nodes (List[DataPoint]): List of graph nodes extracted from data points
|
||||||
|
- edges (List[tuple]): List of edge tuples in format
|
||||||
|
(source_node_id, target_node_id, relationship_name, properties_dict)
|
||||||
|
Note: All edges including those from DocumentChunk.contains are already extracted
|
||||||
|
by get_graph_from_model and included in this list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- List[Triplet]: List of Triplet objects ready for indexing
|
||||||
|
"""
|
||||||
|
node_map: Dict[str, DataPoint] = {}
|
||||||
|
for node in nodes:
|
||||||
|
if hasattr(node, "id"):
|
||||||
|
node_id = str(node.id)
|
||||||
|
if node_id not in node_map:
|
||||||
|
node_map[node_id] = node
|
||||||
|
|
||||||
|
triplets = []
|
||||||
|
skipped_count = 0
|
||||||
|
seen_ids = set()
|
||||||
|
|
||||||
|
for edge_tuple in edges:
|
||||||
|
if len(edge_tuple) < 4:
|
||||||
|
continue
|
||||||
|
|
||||||
|
source_node_id, target_node_id, relationship_name, edge_properties = (
|
||||||
|
edge_tuple[0],
|
||||||
|
edge_tuple[1],
|
||||||
|
edge_tuple[2],
|
||||||
|
edge_tuple[3],
|
||||||
|
)
|
||||||
|
|
||||||
|
source_node = node_map.get(str(source_node_id))
|
||||||
|
target_node = node_map.get(str(target_node_id))
|
||||||
|
|
||||||
|
if not source_node or not target_node or relationship_name is None:
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
source_node_text = _extract_embeddable_text_from_datapoint(source_node)
|
||||||
|
target_node_text = _extract_embeddable_text_from_datapoint(target_node)
|
||||||
|
|
||||||
|
relationship_text = ""
|
||||||
|
if isinstance(edge_properties, dict):
|
||||||
|
edge_text = edge_properties.get("edge_text")
|
||||||
|
if edge_text and isinstance(edge_text, str) and edge_text.strip():
|
||||||
|
relationship_text = edge_text.strip()
|
||||||
|
|
||||||
|
if not relationship_text and relationship_name:
|
||||||
|
relationship_text = relationship_name
|
||||||
|
|
||||||
|
if not source_node_text and not relationship_text and not relationship_name:
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
embeddable_text = f"{source_node_text} -› {relationship_text}-›{target_node_text}".strip()
|
||||||
|
|
||||||
|
triplet_id = generate_node_id(str(source_node_id) + relationship_name + str(target_node_id))
|
||||||
|
|
||||||
|
if triplet_id in seen_ids:
|
||||||
|
continue
|
||||||
|
seen_ids.add(triplet_id)
|
||||||
|
|
||||||
|
triplets.append(
|
||||||
|
Triplet(
|
||||||
|
id=triplet_id,
|
||||||
|
from_node_id=str(source_node_id),
|
||||||
|
to_node_id=str(target_node_id),
|
||||||
|
text=embeddable_text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return triplets
|
||||||
|
|
|
||||||
139
cognee/tests/integration/tasks/test_add_data_points.py
Normal file
139
cognee/tests/integration/tasks/test_add_data_points.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
import pathlib
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.low_level import setup
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.tasks.storage.add_data_points import add_data_points
|
||||||
|
from cognee.tasks.storage.exceptions import InvalidDataPointsInAddDataPointsError
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
industry: str
|
||||||
|
metadata: dict = {"index_fields": ["name", "industry"]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def clean_test_environment():
|
||||||
|
"""Set up a clean test environment for add_data_points tests."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_add_data_points_integration")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_add_data_points_integration")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_data_points_comprehensive(clean_test_environment):
|
||||||
|
"""Comprehensive integration test for add_data_points functionality."""
|
||||||
|
|
||||||
|
person1 = Person(name="Alice", age=30)
|
||||||
|
person2 = Person(name="Bob", age=25)
|
||||||
|
result = await add_data_points([person1, person2])
|
||||||
|
|
||||||
|
assert result == [person1, person2]
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) >= 2
|
||||||
|
|
||||||
|
result_empty = await add_data_points([])
|
||||||
|
assert result_empty == []
|
||||||
|
|
||||||
|
person3 = Person(name="Charlie", age=35)
|
||||||
|
person4 = Person(name="Diana", age=32)
|
||||||
|
custom_edge = (str(person3.id), str(person4.id), "knows", {"edge_text": "friends with"})
|
||||||
|
|
||||||
|
result_custom = await add_data_points([person3, person4], custom_edges=[custom_edge])
|
||||||
|
assert len(result_custom) == 2
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(edges) == 1
|
||||||
|
assert len(nodes) == 4
|
||||||
|
|
||||||
|
class Employee(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_at: Company
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
company = Company(name="TechCorp", industry="Technology")
|
||||||
|
employee = Employee(name="Eve", works_at=company)
|
||||||
|
|
||||||
|
result_rel = await add_data_points([employee])
|
||||||
|
assert len(result_rel) == 1
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) == 6
|
||||||
|
assert len(edges) == 2
|
||||||
|
|
||||||
|
person5 = Person(name="Frank", age=40)
|
||||||
|
person6 = Person(name="Grace", age=38)
|
||||||
|
triplet_edge = (str(person5.id), str(person6.id), "married_to", {"edge_text": "is married to"})
|
||||||
|
|
||||||
|
result_triplet = await add_data_points(
|
||||||
|
[person5, person6], custom_edges=[triplet_edge], embed_triplets=True
|
||||||
|
)
|
||||||
|
assert len(result_triplet) == 2
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) == 8
|
||||||
|
assert len(edges) == 3
|
||||||
|
|
||||||
|
batch1 = [Person(name="Leo", age=25), Person(name="Mia", age=30)]
|
||||||
|
batch2 = [Person(name="Noah", age=35), Person(name="Olivia", age=40)]
|
||||||
|
|
||||||
|
result_batch1 = await add_data_points(batch1)
|
||||||
|
result_batch2 = await add_data_points(batch2)
|
||||||
|
|
||||||
|
assert len(result_batch1) == 2
|
||||||
|
assert len(result_batch2) == 2
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) == 12
|
||||||
|
assert len(edges) == 3
|
||||||
|
|
||||||
|
person7 = Person(name="Paul", age=33)
|
||||||
|
person8 = Person(name="Quinn", age=31)
|
||||||
|
edge1 = (str(person7.id), str(person8.id), "colleague_of", {"edge_text": "works with"})
|
||||||
|
edge2 = (str(person8.id), str(person7.id), "colleague_of", {"edge_text": "works with"})
|
||||||
|
|
||||||
|
result_bi = await add_data_points([person7, person8], custom_edges=[edge1, edge2])
|
||||||
|
assert len(result_bi) == 2
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) == 14
|
||||||
|
assert len(edges) == 5
|
||||||
|
|
||||||
|
person_invalid = Person(name="Invalid", age=50)
|
||||||
|
with pytest.raises(InvalidDataPointsInAddDataPointsError, match="must be a list"):
|
||||||
|
await add_data_points(person_invalid)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidDataPointsInAddDataPointsError, match="must be a DataPoint"):
|
||||||
|
await add_data_points(["not", "datapoints"])
|
||||||
|
|
||||||
|
final_nodes, final_edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(final_nodes) == 14
|
||||||
|
assert len(final_edges) == 5
|
||||||
|
|
@ -25,8 +25,6 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
"--port",
|
"--port",
|
||||||
"8000",
|
"8000",
|
||||||
],
|
],
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
preexec_fn=os.setsid,
|
preexec_fn=os.setsid,
|
||||||
)
|
)
|
||||||
# Give the server some time to start
|
# Give the server some time to start
|
||||||
|
|
|
||||||
170
cognee/tests/test_edge_centered_payload.py
Normal file
170
cognee/tests/test_edge_centered_payload.py
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
"""
|
||||||
|
End-to-end integration test for edge-centered payload and triplet embeddings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import cognee
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.search.types import SearchType
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.ontology_config import Config
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
text_data = """
|
||||||
|
Apple is a technology company that produces the iPhone, iPad, and Mac computers.
|
||||||
|
The company is known for its innovative products and ecosystem integration.
|
||||||
|
|
||||||
|
Microsoft develops the Windows operating system and Office productivity suite.
|
||||||
|
They are also major players in cloud computing with Azure.
|
||||||
|
|
||||||
|
Google created the Android operating system and provides search engine services.
|
||||||
|
The company is a leader in artificial intelligence and machine learning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ontology_content = """<?xml version="1.0"?>
|
||||||
|
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||||
|
xmlns:owl="http://www.w3.org/2002/07/owl#"
|
||||||
|
xmlns:rdfs="http://www.w3.org/2000/01/rdf-schema#"
|
||||||
|
xmlns="http://example.org/tech#"
|
||||||
|
xml:base="http://example.org/tech">
|
||||||
|
|
||||||
|
<owl:Ontology rdf:about="http://example.org/tech"/>
|
||||||
|
|
||||||
|
<!-- Classes -->
|
||||||
|
<owl:Class rdf:ID="Company"/>
|
||||||
|
<owl:Class rdf:ID="TechnologyCompany"/>
|
||||||
|
<owl:Class rdf:ID="Product"/>
|
||||||
|
<owl:Class rdf:ID="Software"/>
|
||||||
|
<owl:Class rdf:ID="Hardware"/>
|
||||||
|
<owl:Class rdf:ID="Service"/>
|
||||||
|
|
||||||
|
<rdf:Description rdf:about="#TechnologyCompany">
|
||||||
|
<rdfs:subClassOf rdf:resource="#Company"/>
|
||||||
|
<rdfs:comment>A company operating in the technology sector.</rdfs:comment>
|
||||||
|
</rdf:Description>
|
||||||
|
|
||||||
|
<rdf:Description rdf:about="#Software">
|
||||||
|
<rdfs:subClassOf rdf:resource="#Product"/>
|
||||||
|
<rdfs:comment>Software products and applications.</rdfs:comment>
|
||||||
|
</rdf:Description>
|
||||||
|
|
||||||
|
<rdf:Description rdf:about="#Hardware">
|
||||||
|
<rdfs:subClassOf rdf:resource="#Product"/>
|
||||||
|
<rdfs:comment>Physical hardware products.</rdfs:comment>
|
||||||
|
</rdf:Description>
|
||||||
|
|
||||||
|
<!-- Individuals -->
|
||||||
|
<TechnologyCompany rdf:ID="apple">
|
||||||
|
<rdfs:label>Apple</rdfs:label>
|
||||||
|
</TechnologyCompany>
|
||||||
|
|
||||||
|
<TechnologyCompany rdf:ID="microsoft">
|
||||||
|
<rdfs:label>Microsoft</rdfs:label>
|
||||||
|
</TechnologyCompany>
|
||||||
|
|
||||||
|
<TechnologyCompany rdf:ID="google">
|
||||||
|
<rdfs:label>Google</rdfs:label>
|
||||||
|
</TechnologyCompany>
|
||||||
|
|
||||||
|
<Hardware rdf:ID="iphone">
|
||||||
|
<rdfs:label>iPhone</rdfs:label>
|
||||||
|
</Hardware>
|
||||||
|
|
||||||
|
<Software rdf:ID="windows">
|
||||||
|
<rdfs:label>Windows</rdfs:label>
|
||||||
|
</Software>
|
||||||
|
|
||||||
|
<Software rdf:ID="android">
|
||||||
|
<rdfs:label>Android</rdfs:label>
|
||||||
|
</Software>
|
||||||
|
|
||||||
|
</rdf:RDF>"""
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
data_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(
|
||||||
|
pathlib.Path(__file__).parent,
|
||||||
|
".data_storage/test_edge_centered_payload",
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(
|
||||||
|
pathlib.Path(__file__).parent,
|
||||||
|
".cognee_system/test_edge_centered_payload",
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
cognee.config.system_root_directory(cognee_directory_path)
|
||||||
|
|
||||||
|
dataset_name = "tech_companies"
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
await cognee.add(data=text_data, dataset_name=dataset_name)
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".owl", delete=False) as f:
|
||||||
|
f.write(ontology_content)
|
||||||
|
ontology_file_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading ontology from: {ontology_file_path}")
|
||||||
|
config: Config = {
|
||||||
|
"ontology_config": {
|
||||||
|
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_file_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await cognee.cognify(datasets=[dataset_name], config=config)
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
nodes_phase2, edges_phase2 = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
triplets_phase2 = await vector_engine.search(
|
||||||
|
query_text="technology", limit=None, collection_name="Triplet_text"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(triplets_phase2) == len(edges_phase2), (
|
||||||
|
f"Triplet embeddings and number of edges do not match. Vector db contains {len(triplets_phase2)} edge triplets while graph db contains {len(edges_phase2)} edges."
|
||||||
|
)
|
||||||
|
|
||||||
|
search_results_phase2 = await cognee.search(
|
||||||
|
query_type=SearchType.TRIPLET_COMPLETION,
|
||||||
|
query_text="What products does Apple make?",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert search_results_phase2 is not None, (
|
||||||
|
"Search should return results for triplet embeddings in simple ontology use case."
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if os.path.exists(ontology_file_path):
|
||||||
|
os.unlink(ontology_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
from cognee.shared.logging_utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(main())
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
loop.close()
|
||||||
288
cognee/tests/unit/tasks/storage/test_add_data_points.py
Normal file
288
cognee/tests/unit/tasks/storage/test_add_data_points.py
Normal file
|
|
@ -0,0 +1,288 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.engine.models import Triplet
|
||||||
|
from cognee.tasks.storage.add_data_points import (
|
||||||
|
add_data_points,
|
||||||
|
InvalidDataPointsInAddDataPointsError,
|
||||||
|
_extract_embeddable_text_from_datapoint,
|
||||||
|
_create_triplets_from_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
adp_module = sys.modules["cognee.tasks.storage.add_data_points"]
|
||||||
|
|
||||||
|
|
||||||
|
class SimplePoint(DataPoint):
|
||||||
|
text: str
|
||||||
|
metadata: dict = {"index_fields": ["text"]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("bad_input", [None, ["not_datapoint"]])
|
||||||
|
async def test_add_data_points_validates_inputs(bad_input):
|
||||||
|
with pytest.raises(InvalidDataPointsInAddDataPointsError):
|
||||||
|
await add_data_points(bad_input)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(adp_module, "index_graph_edges")
|
||||||
|
@patch.object(adp_module, "index_data_points")
|
||||||
|
@patch.object(adp_module, "get_graph_engine")
|
||||||
|
@patch.object(adp_module, "deduplicate_nodes_and_edges")
|
||||||
|
@patch.object(adp_module, "get_graph_from_model")
|
||||||
|
async def test_add_data_points_indexes_nodes_and_edges(
|
||||||
|
mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges
|
||||||
|
):
|
||||||
|
dp1 = SimplePoint(text="first")
|
||||||
|
dp2 = SimplePoint(text="second")
|
||||||
|
|
||||||
|
edge1 = (str(dp1.id), str(dp2.id), "related_to", {"edge_text": "connects"})
|
||||||
|
custom_edges = [(str(dp2.id), str(dp1.id), "custom_edge", {})]
|
||||||
|
|
||||||
|
mock_get_graph.side_effect = [([dp1], [edge1]), ([dp2], [])]
|
||||||
|
mock_dedup.side_effect = lambda n, e: (n, e)
|
||||||
|
graph_engine = AsyncMock()
|
||||||
|
mock_get_engine.return_value = graph_engine
|
||||||
|
|
||||||
|
result = await add_data_points([dp1, dp2], custom_edges=custom_edges)
|
||||||
|
|
||||||
|
assert result == [dp1, dp2]
|
||||||
|
graph_engine.add_nodes.assert_awaited_once()
|
||||||
|
mock_index_nodes.assert_awaited_once()
|
||||||
|
assert graph_engine.add_edges.await_count == 2
|
||||||
|
assert edge1 in graph_engine.add_edges.await_args_list[0].args[0]
|
||||||
|
assert graph_engine.add_edges.await_args_list[1].args[0] == custom_edges
|
||||||
|
assert mock_index_edges.await_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(adp_module, "index_graph_edges")
|
||||||
|
@patch.object(adp_module, "index_data_points")
|
||||||
|
@patch.object(adp_module, "get_graph_engine")
|
||||||
|
@patch.object(adp_module, "deduplicate_nodes_and_edges")
|
||||||
|
@patch.object(adp_module, "get_graph_from_model")
|
||||||
|
async def test_add_data_points_indexes_triplets_when_enabled(
|
||||||
|
mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges
|
||||||
|
):
|
||||||
|
dp1 = SimplePoint(text="source")
|
||||||
|
dp2 = SimplePoint(text="target")
|
||||||
|
|
||||||
|
edge1 = (str(dp1.id), str(dp2.id), "relates", {"edge_text": "describes"})
|
||||||
|
|
||||||
|
mock_get_graph.side_effect = [([dp1], [edge1]), ([dp2], [])]
|
||||||
|
mock_dedup.side_effect = lambda n, e: (n, e)
|
||||||
|
graph_engine = AsyncMock()
|
||||||
|
mock_get_engine.return_value = graph_engine
|
||||||
|
|
||||||
|
await add_data_points([dp1, dp2], embed_triplets=True)
|
||||||
|
|
||||||
|
assert mock_index_nodes.await_count == 2
|
||||||
|
nodes_arg = mock_index_nodes.await_args_list[0].args[0]
|
||||||
|
triplets_arg = mock_index_nodes.await_args_list[1].args[0]
|
||||||
|
assert nodes_arg == [dp1, dp2]
|
||||||
|
assert len(triplets_arg) == 1
|
||||||
|
assert isinstance(triplets_arg[0], Triplet)
|
||||||
|
mock_index_edges.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(adp_module, "index_graph_edges")
|
||||||
|
@patch.object(adp_module, "index_data_points")
|
||||||
|
@patch.object(adp_module, "get_graph_engine")
|
||||||
|
@patch.object(adp_module, "deduplicate_nodes_and_edges")
|
||||||
|
@patch.object(adp_module, "get_graph_from_model")
|
||||||
|
async def test_add_data_points_with_empty_list(
|
||||||
|
mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges
|
||||||
|
):
|
||||||
|
mock_dedup.side_effect = lambda n, e: (n, e)
|
||||||
|
graph_engine = AsyncMock()
|
||||||
|
mock_get_engine.return_value = graph_engine
|
||||||
|
|
||||||
|
result = await add_data_points([])
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
mock_get_graph.assert_not_called()
|
||||||
|
graph_engine.add_nodes.assert_awaited_once_with([])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(adp_module, "index_graph_edges")
|
||||||
|
@patch.object(adp_module, "index_data_points")
|
||||||
|
@patch.object(adp_module, "get_graph_engine")
|
||||||
|
@patch.object(adp_module, "deduplicate_nodes_and_edges")
|
||||||
|
@patch.object(adp_module, "get_graph_from_model")
|
||||||
|
async def test_add_data_points_with_single_datapoint(
|
||||||
|
mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges
|
||||||
|
):
|
||||||
|
dp = SimplePoint(text="single")
|
||||||
|
mock_get_graph.side_effect = [([dp], [])]
|
||||||
|
mock_dedup.side_effect = lambda n, e: (n, e)
|
||||||
|
graph_engine = AsyncMock()
|
||||||
|
mock_get_engine.return_value = graph_engine
|
||||||
|
|
||||||
|
result = await add_data_points([dp])
|
||||||
|
|
||||||
|
assert result == [dp]
|
||||||
|
mock_get_graph.assert_called_once()
|
||||||
|
mock_index_nodes.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_embeddable_text_from_datapoint():
|
||||||
|
dp = SimplePoint(text="hello world")
|
||||||
|
text = _extract_embeddable_text_from_datapoint(dp)
|
||||||
|
assert text == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_embeddable_text_with_multiple_fields():
|
||||||
|
class MultiField(DataPoint):
|
||||||
|
title: str
|
||||||
|
description: str
|
||||||
|
metadata: dict = {"index_fields": ["title", "description"]}
|
||||||
|
|
||||||
|
dp = MultiField(title="Test", description="Description")
|
||||||
|
text = _extract_embeddable_text_from_datapoint(dp)
|
||||||
|
assert text == "Test Description"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_embeddable_text_with_no_index_fields():
|
||||||
|
class NoIndex(DataPoint):
|
||||||
|
text: str
|
||||||
|
metadata: dict = {"index_fields": []}
|
||||||
|
|
||||||
|
dp = NoIndex(text="ignored")
|
||||||
|
text = _extract_embeddable_text_from_datapoint(dp)
|
||||||
|
assert text == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_triplets_from_graph():
|
||||||
|
dp1 = SimplePoint(text="source node")
|
||||||
|
dp2 = SimplePoint(text="target node")
|
||||||
|
edge = (str(dp1.id), str(dp2.id), "connects_to", {"edge_text": "links"})
|
||||||
|
|
||||||
|
triplets = _create_triplets_from_graph([dp1, dp2], [edge])
|
||||||
|
|
||||||
|
assert len(triplets) == 1
|
||||||
|
assert isinstance(triplets[0], Triplet)
|
||||||
|
assert triplets[0].from_node_id == str(dp1.id)
|
||||||
|
assert triplets[0].to_node_id == str(dp2.id)
|
||||||
|
assert "source node" in triplets[0].text
|
||||||
|
assert "target node" in triplets[0].text
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_embeddable_text_with_none_datapoint():
|
||||||
|
text = _extract_embeddable_text_from_datapoint(None)
|
||||||
|
assert text == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_embeddable_text_without_metadata():
|
||||||
|
class NoMetadata(DataPoint):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
dp = NoMetadata(text="test")
|
||||||
|
delattr(dp, "metadata")
|
||||||
|
text = _extract_embeddable_text_from_datapoint(dp)
|
||||||
|
assert text == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_embeddable_text_with_whitespace_only():
|
||||||
|
class WhitespaceField(DataPoint):
|
||||||
|
text: str
|
||||||
|
metadata: dict = {"index_fields": ["text"]}
|
||||||
|
|
||||||
|
dp = WhitespaceField(text=" ")
|
||||||
|
text = _extract_embeddable_text_from_datapoint(dp)
|
||||||
|
assert text == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_triplets_skips_short_edge_tuples():
|
||||||
|
dp = SimplePoint(text="node")
|
||||||
|
incomplete_edge = (str(dp.id), str(dp.id))
|
||||||
|
|
||||||
|
triplets = _create_triplets_from_graph([dp], [incomplete_edge])
|
||||||
|
|
||||||
|
assert len(triplets) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_triplets_skips_missing_source_node():
|
||||||
|
dp1 = SimplePoint(text="target")
|
||||||
|
edge = ("missing_id", str(dp1.id), "relates", {})
|
||||||
|
|
||||||
|
triplets = _create_triplets_from_graph([dp1], [edge])
|
||||||
|
|
||||||
|
assert len(triplets) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_triplets_skips_missing_target_node():
|
||||||
|
dp1 = SimplePoint(text="source")
|
||||||
|
edge = (str(dp1.id), "missing_id", "relates", {})
|
||||||
|
|
||||||
|
triplets = _create_triplets_from_graph([dp1], [edge])
|
||||||
|
|
||||||
|
assert len(triplets) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_triplets_skips_none_relationship():
|
||||||
|
dp1 = SimplePoint(text="source")
|
||||||
|
dp2 = SimplePoint(text="target")
|
||||||
|
edge = (str(dp1.id), str(dp2.id), None, {})
|
||||||
|
|
||||||
|
triplets = _create_triplets_from_graph([dp1, dp2], [edge])
|
||||||
|
|
||||||
|
assert len(triplets) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_triplets_uses_relationship_name_when_no_edge_text():
|
||||||
|
dp1 = SimplePoint(text="source")
|
||||||
|
dp2 = SimplePoint(text="target")
|
||||||
|
edge = (str(dp1.id), str(dp2.id), "connects_to", {})
|
||||||
|
|
||||||
|
triplets = _create_triplets_from_graph([dp1, dp2], [edge])
|
||||||
|
|
||||||
|
assert len(triplets) == 1
|
||||||
|
assert "connects_to" in triplets[0].text
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_triplets_prevents_duplicates():
|
||||||
|
dp1 = SimplePoint(text="source")
|
||||||
|
dp2 = SimplePoint(text="target")
|
||||||
|
edge = (str(dp1.id), str(dp2.id), "relates", {"edge_text": "links"})
|
||||||
|
|
||||||
|
triplets = _create_triplets_from_graph([dp1, dp2], [edge, edge])
|
||||||
|
|
||||||
|
assert len(triplets) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_triplets_skips_nodes_without_id():
|
||||||
|
class NodeNoId:
|
||||||
|
pass
|
||||||
|
|
||||||
|
dp = SimplePoint(text="valid")
|
||||||
|
node_no_id = NodeNoId()
|
||||||
|
edge = (str(dp.id), "some_id", "relates", {})
|
||||||
|
|
||||||
|
triplets = _create_triplets_from_graph([dp, node_no_id], [edge])
|
||||||
|
|
||||||
|
assert len(triplets) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(adp_module, "index_graph_edges")
|
||||||
|
@patch.object(adp_module, "index_data_points")
|
||||||
|
@patch.object(adp_module, "get_graph_engine")
|
||||||
|
@patch.object(adp_module, "deduplicate_nodes_and_edges")
|
||||||
|
@patch.object(adp_module, "get_graph_from_model")
|
||||||
|
async def test_add_data_points_with_empty_custom_edges(
|
||||||
|
mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges
|
||||||
|
):
|
||||||
|
dp = SimplePoint(text="test")
|
||||||
|
mock_get_graph.side_effect = [([dp], [])]
|
||||||
|
mock_dedup.side_effect = lambda n, e: (n, e)
|
||||||
|
graph_engine = AsyncMock()
|
||||||
|
mock_get_engine.return_value = graph_engine
|
||||||
|
|
||||||
|
result = await add_data_points([dp], custom_edges=[])
|
||||||
|
|
||||||
|
assert result == [dp]
|
||||||
|
assert graph_engine.add_edges.await_count == 1
|
||||||
Loading…
Add table
Reference in a new issue