diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 3dea2548c..676699c2a 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -412,6 +412,35 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_feedback_enrichment.py + test-edge-centered-payload: + name: Test Cognify - Edge Centered Payload + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Dependencies already installed + run: echo "Dependencies already installed in setup" + + - name: Run Edge Centered Payload Test + env: + ENV: 'dev' + TRIPLET_EMBEDDING: True + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_edge_centered_payload.py + run_conversation_sessions_test_redis: name: Conversation sessions test (Redis) runs-on: ubuntu-latest diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 0fa345176..8a7c97050 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from typing import Union, Optional from uuid import UUID +from cognee.modules.cognify.config import get_cognify_config from cognee.modules.ontology.ontology_env_config import get_ontology_env_config from cognee.shared.logging_utils import get_logger from cognee.shared.data_models import KnowledgeGraph @@ -272,6 +273,9 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's if chunks_per_batch is None: chunks_per_batch = 100 + cognify_config = get_cognify_config() + embed_triplets = cognify_config.triplet_embedding + default_tasks = [ Task(classify_documents), Task(check_permissions_on_dataset, user=user, permissions=["write"]), @@ -291,7 +295,11 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's summarize_text, task_config={"batch_size": chunks_per_batch}, ), - Task(add_data_points, task_config={"batch_size": chunks_per_batch}), + Task( + add_data_points, + embed_triplets=embed_triplets, + task_config={"batch_size": chunks_per_batch}, + ), ] return default_tasks diff --git a/cognee/modules/cognify/config.py b/cognee/modules/cognify/config.py index 4ba0f4bd6..ec03225e8 100644 --- a/cognee/modules/cognify/config.py +++ b/cognee/modules/cognify/config.py @@ -8,12 +8,14 @@ import os class CognifyConfig(BaseSettings): classification_model: object = DefaultContentPrediction summarization_model: object = SummarizedContent + triplet_embedding: bool = False model_config = SettingsConfigDict(env_file=".env", extra="allow") def to_dict(self) -> dict: return { "classification_model": self.classification_model, "summarization_model": self.summarization_model, + "triplet_embedding": self.triplet_embedding, } diff --git a/cognee/tasks/graph/extract_graph_from_data.py b/cognee/tasks/graph/extract_graph_from_data.py index 49b51af2d..2d1eca17e 100644 --- a/cognee/tasks/graph/extract_graph_from_data.py +++ b/cognee/tasks/graph/extract_graph_from_data.py @@ -2,9 +2,7 @@ import asyncio from typing import Type, List, Optional from pydantic import BaseModel -from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.ontology.ontology_env_config import get_ontology_env_config -from cognee.tasks.storage import index_graph_edges from cognee.tasks.storage.add_data_points import add_data_points from cognee.modules.ontology.ontology_config import Config from cognee.modules.ontology.get_default_ontology_resolver import ( @@ -25,6 +23,7 @@ from cognee.tasks.graph.exceptions import ( InvalidChunkGraphInputError, InvalidOntologyAdapterError, ) +from cognee.modules.cognify.config import get_cognify_config async def integrate_chunk_graphs( @@ -67,8 +66,6 @@ async def integrate_chunk_graphs( type(ontology_resolver).__name__ if ontology_resolver else "None" ) - graph_engine = await get_graph_engine() - if graph_model is not KnowledgeGraph: for chunk_index, chunk_graph in enumerate(chunk_graphs): data_chunks[chunk_index].contains = chunk_graph @@ -84,12 +81,13 @@ async def integrate_chunk_graphs( data_chunks, chunk_graphs, ontology_resolver, existing_edges_map ) - if len(graph_nodes) > 0: - await add_data_points(graph_nodes) + cognify_config = get_cognify_config() + embed_triplets = cognify_config.triplet_embedding - if len(graph_edges) > 0: - await graph_engine.add_edges(graph_edges) - await index_graph_edges(graph_edges) + if len(graph_nodes) > 0: + await add_data_points( + data_points=graph_nodes, custom_edges=graph_edges, embed_triplets=embed_triplets + ) return data_chunks diff --git a/cognee/tasks/storage/add_data_points.py b/cognee/tasks/storage/add_data_points.py index ad1693e82..ea731fd27 100644 --- a/cognee/tasks/storage/add_data_points.py +++ b/cognee/tasks/storage/add_data_points.py @@ -1,16 +1,23 @@ import asyncio -from typing import List +from typing import List, Dict, Optional from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model from .index_data_points import index_data_points from .index_graph_edges import index_graph_edges +from cognee.modules.engine.models import Triplet +from cognee.shared.logging_utils import get_logger from cognee.tasks.storage.exceptions import ( InvalidDataPointsInAddDataPointsError, ) +from ...modules.engine.utils import generate_node_id + +logger = get_logger("add_data_points") -async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: +async def add_data_points( + data_points: List[DataPoint], custom_edges: Optional[List] = None, embed_triplets: bool = False +) -> List[DataPoint]: """ Add a batch of data points to the graph database by extracting nodes and edges, deduplicating them, and indexing them for retrieval. @@ -23,6 +30,10 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: Args: data_points (List[DataPoint]): A list of data points to process and insert into the graph. + custom_edges (List[tuple]): Custom edges between datapoints. + embed_triplets (bool): + If True, creates and indexes triplet embeddings from the graph structure. + Defaults to False. Returns: List[DataPoint]: @@ -34,6 +45,7 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: - Updates the node index via `index_data_points`. - Inserts nodes and edges into the graph engine. - Optionally updates the edge index via `index_graph_edges`. + - Optionally creates and indexes triplet embeddings if embed_triplets is True. """ if not isinstance(data_points, list): @@ -74,4 +86,132 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: await graph_engine.add_edges(edges) await index_graph_edges(edges) + if isinstance(custom_edges, list) and custom_edges: + # This must be handled separately from datapoint edges, created a task in linear to dig deeper but (COG-3488) + await graph_engine.add_edges(custom_edges) + await index_graph_edges(custom_edges) + edges.extend(custom_edges) + + if embed_triplets: + triplets = _create_triplets_from_graph(nodes, edges) + if triplets: + await index_data_points(triplets) + logger.info(f"Created and indexed {len(triplets)} triplets from graph structure") + return data_points + + +def _extract_embeddable_text_from_datapoint(data_point: DataPoint) -> str: + """ + Extract embeddable text from a DataPoint using its index_fields metadata. + Uses the same approach as index_data_points. + + Parameters: + ----------- + - data_point (DataPoint): The data point to extract text from. + + Returns: + -------- + - str: Concatenated string of all embeddable property values, or empty string if none found. + """ + if not data_point or not hasattr(data_point, "metadata"): + return "" + + index_fields = data_point.metadata.get("index_fields", []) + if not index_fields: + return "" + + embeddable_values = [] + for field_name in index_fields: + field_value = getattr(data_point, field_name, None) + if field_value is not None: + field_value = str(field_value).strip() + + if field_value: + embeddable_values.append(field_value) + + return " ".join(embeddable_values) if embeddable_values else "" + + +def _create_triplets_from_graph(nodes: List[DataPoint], edges: List[tuple]) -> List[Triplet]: + """ + Create Triplet objects from graph nodes and edges. + + This function processes graph edges and their corresponding nodes to create + triplet datapoints with embeddable text, similar to the triplet embeddings pipeline. + + Parameters: + ----------- + - nodes (List[DataPoint]): List of graph nodes extracted from data points + - edges (List[tuple]): List of edge tuples in format + (source_node_id, target_node_id, relationship_name, properties_dict) + Note: All edges including those from DocumentChunk.contains are already extracted + by get_graph_from_model and included in this list. + + Returns: + -------- + - List[Triplet]: List of Triplet objects ready for indexing + """ + node_map: Dict[str, DataPoint] = {} + for node in nodes: + if hasattr(node, "id"): + node_id = str(node.id) + if node_id not in node_map: + node_map[node_id] = node + + triplets = [] + skipped_count = 0 + seen_ids = set() + + for edge_tuple in edges: + if len(edge_tuple) < 4: + continue + + source_node_id, target_node_id, relationship_name, edge_properties = ( + edge_tuple[0], + edge_tuple[1], + edge_tuple[2], + edge_tuple[3], + ) + + source_node = node_map.get(str(source_node_id)) + target_node = node_map.get(str(target_node_id)) + + if not source_node or not target_node or relationship_name is None: + skipped_count += 1 + continue + + source_node_text = _extract_embeddable_text_from_datapoint(source_node) + target_node_text = _extract_embeddable_text_from_datapoint(target_node) + + relationship_text = "" + if isinstance(edge_properties, dict): + edge_text = edge_properties.get("edge_text") + if edge_text and isinstance(edge_text, str) and edge_text.strip(): + relationship_text = edge_text.strip() + + if not relationship_text and relationship_name: + relationship_text = relationship_name + + if not source_node_text and not relationship_text and not relationship_name: + skipped_count += 1 + continue + + embeddable_text = f"{source_node_text} -› {relationship_text}-›{target_node_text}".strip() + + triplet_id = generate_node_id(str(source_node_id) + relationship_name + str(target_node_id)) + + if triplet_id in seen_ids: + continue + seen_ids.add(triplet_id) + + triplets.append( + Triplet( + id=triplet_id, + from_node_id=str(source_node_id), + to_node_id=str(target_node_id), + text=embeddable_text, + ) + ) + + return triplets diff --git a/cognee/tests/integration/tasks/test_add_data_points.py b/cognee/tests/integration/tasks/test_add_data_points.py new file mode 100644 index 000000000..7b6c9a683 --- /dev/null +++ b/cognee/tests/integration/tasks/test_add_data_points.py @@ -0,0 +1,139 @@ +import pathlib +import pytest +import pytest_asyncio + +import cognee +from cognee.low_level import setup +from cognee.infrastructure.engine import DataPoint +from cognee.tasks.storage.add_data_points import add_data_points +from cognee.tasks.storage.exceptions import InvalidDataPointsInAddDataPointsError +from cognee.infrastructure.databases.graph import get_graph_engine + + +class Person(DataPoint): + name: str + age: int + metadata: dict = {"index_fields": ["name"]} + + +class Company(DataPoint): + name: str + industry: str + metadata: dict = {"index_fields": ["name", "industry"]} + + +@pytest_asyncio.fixture +async def clean_test_environment(): + """Set up a clean test environment for add_data_points tests.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_add_data_points_integration") + data_directory_path = str(base_dir / ".data_storage/test_add_data_points_integration") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_add_data_points_comprehensive(clean_test_environment): + """Comprehensive integration test for add_data_points functionality.""" + + person1 = Person(name="Alice", age=30) + person2 = Person(name="Bob", age=25) + result = await add_data_points([person1, person2]) + + assert result == [person1, person2] + assert len(result) == 2 + + graph_engine = await get_graph_engine() + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) >= 2 + + result_empty = await add_data_points([]) + assert result_empty == [] + + person3 = Person(name="Charlie", age=35) + person4 = Person(name="Diana", age=32) + custom_edge = (str(person3.id), str(person4.id), "knows", {"edge_text": "friends with"}) + + result_custom = await add_data_points([person3, person4], custom_edges=[custom_edge]) + assert len(result_custom) == 2 + + nodes, edges = await graph_engine.get_graph_data() + assert len(edges) == 1 + assert len(nodes) == 4 + + class Employee(DataPoint): + name: str + works_at: Company + metadata: dict = {"index_fields": ["name"]} + + company = Company(name="TechCorp", industry="Technology") + employee = Employee(name="Eve", works_at=company) + + result_rel = await add_data_points([employee]) + assert len(result_rel) == 1 + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 6 + assert len(edges) == 2 + + person5 = Person(name="Frank", age=40) + person6 = Person(name="Grace", age=38) + triplet_edge = (str(person5.id), str(person6.id), "married_to", {"edge_text": "is married to"}) + + result_triplet = await add_data_points( + [person5, person6], custom_edges=[triplet_edge], embed_triplets=True + ) + assert len(result_triplet) == 2 + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 8 + assert len(edges) == 3 + + batch1 = [Person(name="Leo", age=25), Person(name="Mia", age=30)] + batch2 = [Person(name="Noah", age=35), Person(name="Olivia", age=40)] + + result_batch1 = await add_data_points(batch1) + result_batch2 = await add_data_points(batch2) + + assert len(result_batch1) == 2 + assert len(result_batch2) == 2 + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 12 + assert len(edges) == 3 + + person7 = Person(name="Paul", age=33) + person8 = Person(name="Quinn", age=31) + edge1 = (str(person7.id), str(person8.id), "colleague_of", {"edge_text": "works with"}) + edge2 = (str(person8.id), str(person7.id), "colleague_of", {"edge_text": "works with"}) + + result_bi = await add_data_points([person7, person8], custom_edges=[edge1, edge2]) + assert len(result_bi) == 2 + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 14 + assert len(edges) == 5 + + person_invalid = Person(name="Invalid", age=50) + with pytest.raises(InvalidDataPointsInAddDataPointsError, match="must be a list"): + await add_data_points(person_invalid) + + with pytest.raises(InvalidDataPointsInAddDataPointsError, match="must be a DataPoint"): + await add_data_points(["not", "datapoints"]) + + final_nodes, final_edges = await graph_engine.get_graph_data() + assert len(final_nodes) == 14 + assert len(final_edges) == 5 diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index ddffe53a4..fece88240 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -25,8 +25,6 @@ class TestCogneeServerStart(unittest.TestCase): "--port", "8000", ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, preexec_fn=os.setsid, ) # Give the server some time to start diff --git a/cognee/tests/test_edge_centered_payload.py b/cognee/tests/test_edge_centered_payload.py new file mode 100644 index 000000000..3d76e93ff --- /dev/null +++ b/cognee/tests/test_edge_centered_payload.py @@ -0,0 +1,170 @@ +""" +End-to-end integration test for edge-centered payload and triplet embeddings. + +""" + +import os +import pathlib +import cognee +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.search.types import SearchType +from cognee.shared.logging_utils import get_logger +from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver +from cognee.modules.ontology.ontology_config import Config + +logger = get_logger() + +text_data = """ +Apple is a technology company that produces the iPhone, iPad, and Mac computers. +The company is known for its innovative products and ecosystem integration. + +Microsoft develops the Windows operating system and Office productivity suite. +They are also major players in cloud computing with Azure. + +Google created the Android operating system and provides search engine services. +The company is a leader in artificial intelligence and machine learning. +""" + +ontology_content = """ + + + + + + + + + + + + + + + A company operating in the technology sector. + + + + + Software products and applications. + + + + + Physical hardware products. + + + + + Apple + + + + Microsoft + + + + Google + + + + iPhone + + + + Windows + + + + Android + + +""" + + +async def main(): + data_directory_path = str( + pathlib.Path( + os.path.join( + pathlib.Path(__file__).parent, + ".data_storage/test_edge_centered_payload", + ) + ).resolve() + ) + cognee_directory_path = str( + pathlib.Path( + os.path.join( + pathlib.Path(__file__).parent, + ".cognee_system/test_edge_centered_payload", + ) + ).resolve() + ) + + cognee.config.data_root_directory(data_directory_path) + cognee.config.system_root_directory(cognee_directory_path) + + dataset_name = "tech_companies" + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + await cognee.add(data=text_data, dataset_name=dataset_name) + + import tempfile + + with tempfile.NamedTemporaryFile(mode="w", suffix=".owl", delete=False) as f: + f.write(ontology_content) + ontology_file_path = f.name + + try: + logger.info(f"Loading ontology from: {ontology_file_path}") + config: Config = { + "ontology_config": { + "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_file_path) + } + } + + await cognee.cognify(datasets=[dataset_name], config=config) + graph_engine = await get_graph_engine() + nodes_phase2, edges_phase2 = await graph_engine.get_graph_data() + + vector_engine = get_vector_engine() + triplets_phase2 = await vector_engine.search( + query_text="technology", limit=None, collection_name="Triplet_text" + ) + + assert len(triplets_phase2) == len(edges_phase2), ( + f"Triplet embeddings and number of edges do not match. Vector db contains {len(triplets_phase2)} edge triplets while graph db contains {len(edges_phase2)} edges." + ) + + search_results_phase2 = await cognee.search( + query_type=SearchType.TRIPLET_COMPLETION, + query_text="What products does Apple make?", + ) + + assert search_results_phase2 is not None, ( + "Search should return results for triplet embeddings in simple ontology use case." + ) + + finally: + if os.path.exists(ontology_file_path): + os.unlink(ontology_file_path) + + +if __name__ == "__main__": + import asyncio + from cognee.shared.logging_utils import setup_logging + + setup_logging() + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() diff --git a/cognee/tests/unit/tasks/storage/test_add_data_points.py b/cognee/tests/unit/tasks/storage/test_add_data_points.py new file mode 100644 index 000000000..90d33158e --- /dev/null +++ b/cognee/tests/unit/tasks/storage/test_add_data_points.py @@ -0,0 +1,288 @@ +import pytest +from unittest.mock import AsyncMock, patch +import sys + +from cognee.infrastructure.engine import DataPoint +from cognee.modules.engine.models import Triplet +from cognee.tasks.storage.add_data_points import ( + add_data_points, + InvalidDataPointsInAddDataPointsError, + _extract_embeddable_text_from_datapoint, + _create_triplets_from_graph, +) + +adp_module = sys.modules["cognee.tasks.storage.add_data_points"] + + +class SimplePoint(DataPoint): + text: str + metadata: dict = {"index_fields": ["text"]} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("bad_input", [None, ["not_datapoint"]]) +async def test_add_data_points_validates_inputs(bad_input): + with pytest.raises(InvalidDataPointsInAddDataPointsError): + await add_data_points(bad_input) + + +@pytest.mark.asyncio +@patch.object(adp_module, "index_graph_edges") +@patch.object(adp_module, "index_data_points") +@patch.object(adp_module, "get_graph_engine") +@patch.object(adp_module, "deduplicate_nodes_and_edges") +@patch.object(adp_module, "get_graph_from_model") +async def test_add_data_points_indexes_nodes_and_edges( + mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges +): + dp1 = SimplePoint(text="first") + dp2 = SimplePoint(text="second") + + edge1 = (str(dp1.id), str(dp2.id), "related_to", {"edge_text": "connects"}) + custom_edges = [(str(dp2.id), str(dp1.id), "custom_edge", {})] + + mock_get_graph.side_effect = [([dp1], [edge1]), ([dp2], [])] + mock_dedup.side_effect = lambda n, e: (n, e) + graph_engine = AsyncMock() + mock_get_engine.return_value = graph_engine + + result = await add_data_points([dp1, dp2], custom_edges=custom_edges) + + assert result == [dp1, dp2] + graph_engine.add_nodes.assert_awaited_once() + mock_index_nodes.assert_awaited_once() + assert graph_engine.add_edges.await_count == 2 + assert edge1 in graph_engine.add_edges.await_args_list[0].args[0] + assert graph_engine.add_edges.await_args_list[1].args[0] == custom_edges + assert mock_index_edges.await_count == 2 + + +@pytest.mark.asyncio +@patch.object(adp_module, "index_graph_edges") +@patch.object(adp_module, "index_data_points") +@patch.object(adp_module, "get_graph_engine") +@patch.object(adp_module, "deduplicate_nodes_and_edges") +@patch.object(adp_module, "get_graph_from_model") +async def test_add_data_points_indexes_triplets_when_enabled( + mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges +): + dp1 = SimplePoint(text="source") + dp2 = SimplePoint(text="target") + + edge1 = (str(dp1.id), str(dp2.id), "relates", {"edge_text": "describes"}) + + mock_get_graph.side_effect = [([dp1], [edge1]), ([dp2], [])] + mock_dedup.side_effect = lambda n, e: (n, e) + graph_engine = AsyncMock() + mock_get_engine.return_value = graph_engine + + await add_data_points([dp1, dp2], embed_triplets=True) + + assert mock_index_nodes.await_count == 2 + nodes_arg = mock_index_nodes.await_args_list[0].args[0] + triplets_arg = mock_index_nodes.await_args_list[1].args[0] + assert nodes_arg == [dp1, dp2] + assert len(triplets_arg) == 1 + assert isinstance(triplets_arg[0], Triplet) + mock_index_edges.assert_awaited_once() + + +@pytest.mark.asyncio +@patch.object(adp_module, "index_graph_edges") +@patch.object(adp_module, "index_data_points") +@patch.object(adp_module, "get_graph_engine") +@patch.object(adp_module, "deduplicate_nodes_and_edges") +@patch.object(adp_module, "get_graph_from_model") +async def test_add_data_points_with_empty_list( + mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges +): + mock_dedup.side_effect = lambda n, e: (n, e) + graph_engine = AsyncMock() + mock_get_engine.return_value = graph_engine + + result = await add_data_points([]) + + assert result == [] + mock_get_graph.assert_not_called() + graph_engine.add_nodes.assert_awaited_once_with([]) + + +@pytest.mark.asyncio +@patch.object(adp_module, "index_graph_edges") +@patch.object(adp_module, "index_data_points") +@patch.object(adp_module, "get_graph_engine") +@patch.object(adp_module, "deduplicate_nodes_and_edges") +@patch.object(adp_module, "get_graph_from_model") +async def test_add_data_points_with_single_datapoint( + mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges +): + dp = SimplePoint(text="single") + mock_get_graph.side_effect = [([dp], [])] + mock_dedup.side_effect = lambda n, e: (n, e) + graph_engine = AsyncMock() + mock_get_engine.return_value = graph_engine + + result = await add_data_points([dp]) + + assert result == [dp] + mock_get_graph.assert_called_once() + mock_index_nodes.assert_awaited_once() + + +def test_extract_embeddable_text_from_datapoint(): + dp = SimplePoint(text="hello world") + text = _extract_embeddable_text_from_datapoint(dp) + assert text == "hello world" + + +def test_extract_embeddable_text_with_multiple_fields(): + class MultiField(DataPoint): + title: str + description: str + metadata: dict = {"index_fields": ["title", "description"]} + + dp = MultiField(title="Test", description="Description") + text = _extract_embeddable_text_from_datapoint(dp) + assert text == "Test Description" + + +def test_extract_embeddable_text_with_no_index_fields(): + class NoIndex(DataPoint): + text: str + metadata: dict = {"index_fields": []} + + dp = NoIndex(text="ignored") + text = _extract_embeddable_text_from_datapoint(dp) + assert text == "" + + +def test_create_triplets_from_graph(): + dp1 = SimplePoint(text="source node") + dp2 = SimplePoint(text="target node") + edge = (str(dp1.id), str(dp2.id), "connects_to", {"edge_text": "links"}) + + triplets = _create_triplets_from_graph([dp1, dp2], [edge]) + + assert len(triplets) == 1 + assert isinstance(triplets[0], Triplet) + assert triplets[0].from_node_id == str(dp1.id) + assert triplets[0].to_node_id == str(dp2.id) + assert "source node" in triplets[0].text + assert "target node" in triplets[0].text + + +def test_extract_embeddable_text_with_none_datapoint(): + text = _extract_embeddable_text_from_datapoint(None) + assert text == "" + + +def test_extract_embeddable_text_without_metadata(): + class NoMetadata(DataPoint): + text: str + + dp = NoMetadata(text="test") + delattr(dp, "metadata") + text = _extract_embeddable_text_from_datapoint(dp) + assert text == "" + + +def test_extract_embeddable_text_with_whitespace_only(): + class WhitespaceField(DataPoint): + text: str + metadata: dict = {"index_fields": ["text"]} + + dp = WhitespaceField(text=" ") + text = _extract_embeddable_text_from_datapoint(dp) + assert text == "" + + +def test_create_triplets_skips_short_edge_tuples(): + dp = SimplePoint(text="node") + incomplete_edge = (str(dp.id), str(dp.id)) + + triplets = _create_triplets_from_graph([dp], [incomplete_edge]) + + assert len(triplets) == 0 + + +def test_create_triplets_skips_missing_source_node(): + dp1 = SimplePoint(text="target") + edge = ("missing_id", str(dp1.id), "relates", {}) + + triplets = _create_triplets_from_graph([dp1], [edge]) + + assert len(triplets) == 0 + + +def test_create_triplets_skips_missing_target_node(): + dp1 = SimplePoint(text="source") + edge = (str(dp1.id), "missing_id", "relates", {}) + + triplets = _create_triplets_from_graph([dp1], [edge]) + + assert len(triplets) == 0 + + +def test_create_triplets_skips_none_relationship(): + dp1 = SimplePoint(text="source") + dp2 = SimplePoint(text="target") + edge = (str(dp1.id), str(dp2.id), None, {}) + + triplets = _create_triplets_from_graph([dp1, dp2], [edge]) + + assert len(triplets) == 0 + + +def test_create_triplets_uses_relationship_name_when_no_edge_text(): + dp1 = SimplePoint(text="source") + dp2 = SimplePoint(text="target") + edge = (str(dp1.id), str(dp2.id), "connects_to", {}) + + triplets = _create_triplets_from_graph([dp1, dp2], [edge]) + + assert len(triplets) == 1 + assert "connects_to" in triplets[0].text + + +def test_create_triplets_prevents_duplicates(): + dp1 = SimplePoint(text="source") + dp2 = SimplePoint(text="target") + edge = (str(dp1.id), str(dp2.id), "relates", {"edge_text": "links"}) + + triplets = _create_triplets_from_graph([dp1, dp2], [edge, edge]) + + assert len(triplets) == 1 + + +def test_create_triplets_skips_nodes_without_id(): + class NodeNoId: + pass + + dp = SimplePoint(text="valid") + node_no_id = NodeNoId() + edge = (str(dp.id), "some_id", "relates", {}) + + triplets = _create_triplets_from_graph([dp, node_no_id], [edge]) + + assert len(triplets) == 0 + + +@pytest.mark.asyncio +@patch.object(adp_module, "index_graph_edges") +@patch.object(adp_module, "index_data_points") +@patch.object(adp_module, "get_graph_engine") +@patch.object(adp_module, "deduplicate_nodes_and_edges") +@patch.object(adp_module, "get_graph_from_model") +async def test_add_data_points_with_empty_custom_edges( + mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges +): + dp = SimplePoint(text="test") + mock_get_graph.side_effect = [([dp], [])] + mock_dedup.side_effect = lambda n, e: (n, e) + graph_engine = AsyncMock() + mock_get_engine.return_value = graph_engine + + result = await add_data_points([dp], custom_edges=[]) + + assert result == [dp] + assert graph_engine.add_edges.await_count == 1