diff --git a/.github/workflows/basic_tests.yml b/.github/workflows/basic_tests.yml index 13ea829da..456bc8c88 100644 --- a/.github/workflows/basic_tests.yml +++ b/.github/workflows/basic_tests.yml @@ -184,3 +184,5 @@ jobs: - name: Run Graph Tests run: poetry run python ./examples/python/code_graph_example.py --repo_path ./cognee/tasks/graph + + diff --git a/.github/workflows/graph_db_tests.yml b/.github/workflows/graph_db_tests.yml index 9dc60b11d..94ff639ba 100644 --- a/.github/workflows/graph_db_tests.yml +++ b/.github/workflows/graph_db_tests.yml @@ -50,6 +50,20 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: poetry run python ./cognee/tests/test_kuzu.py + - name: Run Weighted Edges Tests with Kuzu + env: + ENV: 'dev' + GRAPH_DATABASE_PROVIDER: "kuzu" + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v + run-neo4j-tests: name: Neo4j Tests runs-on: ubuntu-22.04 @@ -83,3 +97,20 @@ jobs: GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }} GRAPH_DATABASE_USERNAME: "neo4j" run: poetry run python ./cognee/tests/test_neo4j.py + + - name: Run Weighted Edges Tests with Neo4j + env: + ENV: 'dev' + GRAPH_DATABASE_PROVIDER: "neo4j" + GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }} + GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }} + GRAPH_DATABASE_USERNAME: "neo4j" + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v diff --git a/.github/workflows/weighted_edges_tests.yml b/.github/workflows/weighted_edges_tests.yml new file mode 100644 index 000000000..d33bd1ceb --- /dev/null +++ b/.github/workflows/weighted_edges_tests.yml @@ -0,0 +1,166 @@ +name: Weighted Edges Tests + +on: + push: + branches: [ main, weighted_edges ] + paths: + - 'cognee/modules/graph/utils/get_graph_from_model.py' + - 'cognee/infrastructure/engine/models/Edge.py' + - 'cognee/tests/unit/interfaces/graph/test_weighted_edges.py' + - 'examples/python/weighted_edges_example.py' + - '.github/workflows/weighted_edges_tests.yml' + pull_request: + branches: [ main ] + paths: + - 'cognee/modules/graph/utils/get_graph_from_model.py' + - 'cognee/infrastructure/engine/models/Edge.py' + - 'cognee/tests/unit/interfaces/graph/test_weighted_edges.py' + - 'examples/python/weighted_edges_example.py' + - '.github/workflows/weighted_edges_tests.yml' + +env: + RUNTIME__LOG_LEVEL: ERROR + ENV: 'dev' + +jobs: + test-weighted-edges-functionality: + name: Test Weighted Edges Core Functionality + runs-on: ubuntu-22.04 + strategy: + matrix: + python-version: ['3.11', '3.12'] + env: + LLM_PROVIDER: openai + LLM_MODEL: gpt-4o-mini + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + + steps: + - name: Check out repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: ${{ matrix.python-version }} + + - name: Run Weighted Edges Unit Tests + run: | + poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short + + - name: Run Standard Graph Tests (Regression) + run: | + poetry run pytest cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py -v --tb=short + + test-with-different-databases: + name: Test Weighted Edges with Different Graph Databases + runs-on: ubuntu-22.04 + strategy: + matrix: + database: ['kuzu', 'neo4j'] + include: + - database: kuzu + install_extra: "" + graph_db_provider: "kuzu" + - database: neo4j + install_extra: "-E neo4j" + graph_db_provider: "neo4j" + env: + LLM_PROVIDER: openai + LLM_MODEL: gpt-4o-mini + LLM_ENDPOINT: https://api.openai.com/v1/ + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: "2024-02-01" + EMBEDDING_PROVIDER: openai + EMBEDDING_MODEL: text-embedding-3-small + EMBEDDING_ENDPOINT: https://api.openai.com/v1/ + EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }} + EMBEDDING_API_VERSION: "2024-02-01" + steps: + - name: Check out repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11' + + - name: Install Database Dependencies + run: | + poetry install ${{ matrix.install_extra }} + + - name: Run Weighted Edges Tests + env: + GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }} + run: | + poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short + + test-examples: + name: Test Weighted Edges Examples + runs-on: ubuntu-22.04 + env: + LLM_PROVIDER: openai + LLM_MODEL: gpt-4o-mini + LLM_ENDPOINT: https://api.openai.com/v1/ + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: "2024-02-01" + EMBEDDING_PROVIDER: openai + EMBEDDING_MODEL: text-embedding-3-small + EMBEDDING_ENDPOINT: https://api.openai.com/v1/ + EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }} + EMBEDDING_API_VERSION: "2024-02-01" + steps: + - name: Check out repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11' + + - name: Test Weighted Edges Example + run: | + poetry run python examples/python/weighted_edges_example.py + + - name: Verify Visualization File Created + run: | + if [ -f "examples/python/weighted_graph_visualization.html" ]; then + echo "✅ Visualization file created successfully" + ls -la examples/python/weighted_graph_visualization.html + else + echo "❌ Visualization file not found" + exit 1 + fi + + + + code-quality: + name: Code Quality for Weighted Edges + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11' + + - name: Run Linting on Weighted Edges Files + uses: astral-sh/ruff-action@v2 + with: + args: "check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py" + + - name: Run Formatting Check on Weighted Edges Files + uses: astral-sh/ruff-action@v2 + with: + args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py" + + \ No newline at end of file diff --git a/cognee/infrastructure/engine/__init__.py b/cognee/infrastructure/engine/__init__.py index cc084d142..e1328a655 100644 --- a/cognee/infrastructure/engine/__init__.py +++ b/cognee/infrastructure/engine/__init__.py @@ -1,2 +1,3 @@ from .models.DataPoint import DataPoint from .models.ExtendableDataPoint import ExtendableDataPoint +from .models.Edge import Edge diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 3b5f7424a..812380eaa 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -1,6 +1,6 @@ import pickle from uuid import UUID, uuid4 -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from datetime import datetime, timezone from typing_extensions import TypedDict from typing import Optional, Any, Dict, List @@ -34,6 +34,8 @@ class DataPoint(BaseModel): - from_dict """ + model_config = ConfigDict(arbitrary_types_allowed=True) + id: UUID = Field(default_factory=uuid4) created_at: int = Field( default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) diff --git a/cognee/infrastructure/engine/models/Edge.py b/cognee/infrastructure/engine/models/Edge.py new file mode 100644 index 000000000..5ad9c84dd --- /dev/null +++ b/cognee/infrastructure/engine/models/Edge.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel +from typing import Optional, Any, Dict + + +class Edge(BaseModel): + """ + Represents edge metadata for relationships between DataPoints. + + This class is used to define edge properties like weight when creating + relationships between DataPoints using tuple syntax: + + Example: + # Single weight (backward compatible) + has_items: (Edge(weight=0.5), list[Item]) + + # Multiple weights + has_items: (Edge(weights={"strength": 0.8, "confidence": 0.9, "importance": 0.7}), list[Item]) + + # Mixed usage + has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item]) + """ + + weight: Optional[float] = None + weights: Optional[Dict[str, float]] = None + relationship_type: Optional[str] = None + properties: Optional[Dict[str, Any]] = None diff --git a/cognee/infrastructure/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py index 3d423da6e..c1cc40e38 100644 --- a/cognee/infrastructure/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -1,5 +1,6 @@ """Adapter for Generic API LLM provider API""" +import logging import litellm import instructor from typing import Type diff --git a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py index d3a0a8522..c1f55d4fc 100644 --- a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py +++ b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py @@ -281,7 +281,40 @@ def expand_with_nodes_and_edges( existing_edges_map: Optional[dict[str, bool]] = None, ): """ - Expand data chunks with nodes and edges, validating against ontology. + + - LLM generated docstring + Expand knowledge graphs with validated nodes and edges, integrating ontology information. + + This function processes document chunks and their associated knowledge graphs to create + a comprehensive graph structure with entity nodes, entity type nodes, and their relationships. + It validates entities against an ontology resolver and adds ontology-derived nodes and edges + to enhance the knowledge representation. + + Args: + data_chunks (list[DocumentChunk]): List of document chunks that contain the source data. + Each chunk should have metadata about what entities it contains. + chunk_graphs (list[KnowledgeGraph]): List of knowledge graphs corresponding to each + data chunk. Each graph contains nodes (entities) and edges (relationships) extracted + from the chunk content. + ontology_resolver (OntologyResolver, optional): Resolver for validating entities and + types against an ontology. If None, a default OntologyResolver is created. + Defaults to None. + existing_edges_map (dict[str, bool], optional): Mapping of existing edge keys to prevent + duplicate edge creation. Keys are formatted as "{source_id}_{target_id}_{relation}". + If None, an empty dictionary is created. Defaults to None. + + Returns: + tuple[list, list]: A tuple containing: + - graph_nodes (list): Combined list of data chunks and ontology nodes (EntityType and Entity objects) + - graph_edges (list): List of edge tuples in format (source_id, target_id, relationship_name, properties) + + Note: + - Entity nodes are created for each entity found in the knowledge graphs + - EntityType nodes are created for each unique entity type + - Ontology validation is performed to map entities to canonical ontology terms + - Duplicate nodes and edges are prevented using internal mapping and the existing_edges_map + - The function modifies data_chunks in-place by adding entities to their 'contains' attribute + """ if existing_edges_map is None: existing_edges_map = {} diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index a6d50f41c..6fcca4572 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -1,134 +1,256 @@ from datetime import datetime, timezone -from cognee.infrastructure.engine import DataPoint +from typing import Tuple, List, Any, Dict, Optional +from cognee.infrastructure.engine import DataPoint, Edge from cognee.modules.storage.utils import copy_model +def _extract_field_info(field_value: Any) -> Tuple[str, Any, Optional[Edge]]: + """Extract field type, actual value, and edge metadata from a field value.""" + + # Handle tuple[Edge, DataPoint] + if ( + isinstance(field_value, tuple) + and len(field_value) == 2 + and isinstance(field_value[0], Edge) + and isinstance(field_value[1], DataPoint) + ): + return "single_datapoint_with_edge", field_value[1], field_value[0] + + # Handle tuple[Edge, list[DataPoint]] + if ( + isinstance(field_value, tuple) + and len(field_value) == 2 + and isinstance(field_value[0], Edge) + and isinstance(field_value[1], list) + and len(field_value[1]) > 0 + and isinstance(field_value[1][0], DataPoint) + ): + return "list_datapoint_with_edge", field_value[1], field_value[0] + + # Handle single DataPoint + if isinstance(field_value, DataPoint): + return "single_datapoint", field_value, None + + # Handle list of DataPoints + if ( + isinstance(field_value, list) + and len(field_value) > 0 + and isinstance(field_value[0], DataPoint) + ): + return "list_datapoint", field_value, None + + # Regular property + return "property", field_value, None + + +def _create_edge_properties( + source_id: str, target_id: str, relationship_name: str, edge_metadata: Optional[Edge] +) -> Dict[str, Any]: + """Create edge properties dictionary with metadata if present.""" + properties = { + "source_node_id": source_id, + "target_node_id": target_id, + "relationship_name": relationship_name, + "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + } + + if edge_metadata: + # Add edge metadata + edge_data = edge_metadata.model_dump(exclude_none=True) + properties.update(edge_data) + + # Add individual weights as separate fields for easier querying + if edge_metadata.weights is not None: + for weight_name, weight_value in edge_metadata.weights.items(): + properties[f"weight_{weight_name}"] = weight_value + + return properties + + +def _get_relationship_key(field_name: str, edge_metadata: Optional[Edge]) -> str: + """Extract relationship key from edge metadata or use field name as fallback.""" + if ( + edge_metadata + and hasattr(edge_metadata, "relationship_type") + and edge_metadata.relationship_type + ): + return edge_metadata.relationship_type + return field_name + + +def _generate_property_key(data_point_id: str, relationship_key: str, target_id: str) -> str: + """Generate a unique property key for visited_properties tracking.""" + return f"{data_point_id}{relationship_key}{target_id}" + + +def _process_datapoint_field( + data_point_id: str, + field_name: str, + datapoints: List[DataPoint], + edge_metadata: Optional[Edge], + visited_properties: Dict[str, bool], + properties_to_visit: set, + excluded_properties: set, +) -> None: + """Process a field containing DataPoint(s), handling both single and list cases.""" + excluded_properties.add(field_name) + relationship_key = _get_relationship_key(field_name, edge_metadata) + + for index, datapoint in enumerate(datapoints): + property_key = _generate_property_key(data_point_id, relationship_key, str(datapoint.id)) + if property_key in visited_properties: + continue + + # For single datapoint, use field_name; for list, use field_name.index + field_identifier = field_name if len(datapoints) == 1 else f"{field_name}.{index}" + properties_to_visit.add(field_identifier) + + async def get_graph_from_model( data_point: DataPoint, - added_nodes: dict, - added_edges: dict, - visited_properties: dict = None, - include_root=True, -): + added_nodes: Dict[str, bool], + added_edges: Dict[str, bool], + visited_properties: Optional[Dict[str, bool]] = None, + include_root: bool = True, +) -> Tuple[List[DataPoint], List[Tuple[str, str, str, Dict[str, Any]]]]: + """ + Extract graph representation from a DataPoint model. + + Args: + data_point: The DataPoint to extract graph from + added_nodes: Dictionary tracking already processed nodes + added_edges: Dictionary tracking already processed edges + visited_properties: Dictionary tracking visited properties to avoid cycles + include_root: Whether to include the root node in results + + Returns: + Tuple of (nodes, edges) extracted from the model + """ if str(data_point.id) in added_nodes: return [], [] nodes = [] edges = [] visited_properties = visited_properties or {} + data_point_id = str(data_point.id) - data_point_properties = { - "type": type(data_point).__name__, - } + data_point_properties = {"type": type(data_point).__name__} excluded_properties = set() properties_to_visit = set() + # Analyze all fields to categorize them as properties or relationships for field_name, field_value in data_point: if field_name == "metadata": continue - if isinstance(field_value, DataPoint): - excluded_properties.add(field_name) + field_type, actual_value, edge_metadata = _extract_field_info(field_value) - property_key = str(data_point.id) + field_name + str(field_value.id) + if field_type == "property": + data_point_properties[field_name] = field_value + elif field_type in ["single_datapoint", "single_datapoint_with_edge"]: + _process_datapoint_field( + data_point_id, + field_name, + [actual_value], + edge_metadata, + visited_properties, + properties_to_visit, + excluded_properties, + ) + elif field_type in ["list_datapoint", "list_datapoint_with_edge"]: + _process_datapoint_field( + data_point_id, + field_name, + actual_value, + edge_metadata, + visited_properties, + properties_to_visit, + excluded_properties, + ) - if property_key in visited_properties: - continue - - visited_properties[property_key] = True - properties_to_visit.add(field_name) - - continue - - if ( - isinstance(field_value, list) - and len(field_value) > 0 - and isinstance(field_value[0], DataPoint) - ): - excluded_properties.add(field_name) - - for index, item in enumerate(field_value): - property_key = str(data_point.id) + field_name + str(item.id) - - if property_key in visited_properties: - continue - - visited_properties[property_key] = True - properties_to_visit.add(f"{field_name}.{index}") - - continue - - data_point_properties[field_name] = field_value - - if include_root and str(data_point.id) not in added_nodes: + # Create node for current DataPoint if needed + if include_root and data_point_id not in added_nodes: SimpleDataPointModel = copy_model( - type(data_point), - exclude_fields=list(excluded_properties), + type(data_point), exclude_fields=list(excluded_properties) ) nodes.append(SimpleDataPointModel(**data_point_properties)) - added_nodes[str(data_point.id)] = True + added_nodes[data_point_id] = True - for field_name in properties_to_visit: - index = None - - if "." in field_name: - field_name, index = field_name.split(".") + # Process all relationships + for field_name_with_index in properties_to_visit: + # Parse field name and index + if "." in field_name_with_index: + field_name, index_str = field_name_with_index.split(".") + index = int(index_str) + else: + field_name, index = field_name_with_index, None + # Get field value and extract edge metadata field_value = getattr(data_point, field_name) + edge_metadata = None + if ( + isinstance(field_value, tuple) + and len(field_value) == 2 + and isinstance(field_value[0], Edge) + ): + edge_metadata, field_value = field_value + + # Get specific datapoint - handle both single and list cases if index is not None: - field_value = field_value[int(index)] + # List case: extract specific item by index + target_datapoint = field_value[index] + elif isinstance(field_value, list): + # Single datapoint case that was wrapped in a list + target_datapoint = field_value[0] + else: + # True single datapoint case + target_datapoint = field_value - edge_key = str(data_point.id) + str(field_value.id) + field_name - - if str(edge_key) not in added_edges: - edges.append( - ( - data_point.id, - field_value.id, - field_name, - { - "source_node_id": data_point.id, - "target_node_id": field_value.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - }, - ) + # Create edge if not already added + edge_key = f"{data_point_id}{target_datapoint.id}{field_name}" + if edge_key not in added_edges: + relationship_name = _get_relationship_key(field_name, edge_metadata) + edge_properties = _create_edge_properties( + data_point.id, target_datapoint.id, relationship_name, edge_metadata ) - added_edges[str(edge_key)] = True + edges.append((data_point.id, target_datapoint.id, relationship_name, edge_properties)) + added_edges[edge_key] = True - if str(field_value.id) in added_nodes: - continue - - property_key = str(data_point.id) + field_name + str(field_value.id) + # Mark property as visited - CRITICAL for preventing infinite loops + relationship_key = _get_relationship_key(field_name, edge_metadata) + property_key = _generate_property_key( + data_point_id, relationship_key, str(target_datapoint.id) + ) visited_properties[property_key] = True - property_nodes, property_edges = await get_graph_from_model( - field_value, - include_root=True, - added_nodes=added_nodes, - added_edges=added_edges, - visited_properties=visited_properties, - ) - - for node in property_nodes: - nodes.append(node) - - for edge in property_edges: - edges.append(edge) + # Recursively process target node if not already processed + if str(target_datapoint.id) not in added_nodes: + child_nodes, child_edges = await get_graph_from_model( + target_datapoint, + include_root=True, + added_nodes=added_nodes, + added_edges=added_edges, + visited_properties=visited_properties, + ) + nodes.extend(child_nodes) + edges.extend(child_edges) return nodes, edges -def get_own_property_nodes(property_nodes, property_edges): - own_properties = [] +def get_own_property_nodes( + property_nodes: List[DataPoint], property_edges: List[Tuple[str, str, str, Dict[str, Any]]] +) -> List[DataPoint]: + """ + Filter nodes to return only those that are not destinations of any edges. - destination_nodes = [str(property_edge[1]) for property_edge in property_edges] + Args: + property_nodes: List of all nodes + property_edges: List of all edges - for node in property_nodes: - if str(node.id) in destination_nodes: - continue - - own_properties.append(node) - - return own_properties + Returns: + List of nodes that are not edge destinations + """ + destination_node_ids = {str(edge[1]) for edge in property_edges} + return [node for node in property_nodes if str(node.id) not in destination_node_ids] diff --git a/cognee/modules/graph/utils/retrieve_existing_edges.py b/cognee/modules/graph/utils/retrieve_existing_edges.py index 13c3490a8..20cb30a26 100644 --- a/cognee/modules/graph/utils/retrieve_existing_edges.py +++ b/cognee/modules/graph/utils/retrieve_existing_edges.py @@ -8,6 +8,39 @@ async def retrieve_existing_edges( data_chunks: list[DataPoint], chunk_graphs: list[KnowledgeGraph], ) -> dict[str, bool]: + """ + - LLM generated docstring + Retrieve existing edges from the graph database to prevent duplicate edge creation. + + This function checks which edges already exist in the graph database by querying + for various types of relationships including structural edges (exists_in, mentioned_in, is_a) + and content-derived edges from the knowledge graphs. It returns a mapping that can be + used to avoid creating duplicate edges during graph expansion. + + Args: + data_chunks (list[DataPoint]): List of data point objects that serve as containers + for the entities. Each data chunk represents a source document or data segment. + chunk_graphs (list[KnowledgeGraph]): List of knowledge graphs corresponding to each + data chunk. Each graph contains nodes (entities) and edges (relationships) that + were extracted from the chunk content. + graph_engine (GraphDBInterface): Interface to the graph database that will be queried + to check for existing edges. Must implement the has_edges() method. + + Returns: + dict[str, bool]: A mapping of edge keys to boolean values indicating existence. + Edge keys are formatted as concatenated strings: "{source_id}{target_id}{relationship_name}". + All values in the returned dictionary are True (indicating the edge exists). + + Note: + - The function generates several types of edges for checking: + * Type node edges: (chunk_id, type_node_id, "exists_in") + * Entity node edges: (chunk_id, entity_node_id, "mentioned_in") + * Type-entity edges: (entity_node_id, type_node_id, "is_a") + * Graph node edges: extracted from the knowledge graph relationships + - Uses generate_node_id() to ensure consistent node ID formatting + - Prevents processing the same node multiple times using a processed_nodes tracker + - The returned mapping can be used with expand_with_nodes_and_edges() to avoid duplicates + """ processed_nodes = {} type_node_edges = [] entity_node_edges = [] diff --git a/cognee/modules/storage/utils/__init__.py b/cognee/modules/storage/utils/__init__.py index 1008d159a..f240798f8 100644 --- a/cognee/modules/storage/utils/__init__.py +++ b/cognee/modules/storage/utils/__init__.py @@ -3,7 +3,7 @@ from uuid import UUID from decimal import Decimal from datetime import datetime from pydantic_core import PydanticUndefined -from pydantic import create_model +from pydantic import create_model, ConfigDict, BaseModel from cognee.infrastructure.engine import DataPoint @@ -29,9 +29,15 @@ def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list final_fields = {**fields, **include_fields} - model = create_model(model.__name__, **final_fields) - model.model_rebuild() - return model + # Create a base class with the same configuration as DataPoint + class ConfiguredBase(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + # Create the model inheriting from the configured base + new_model = create_model(model.__name__, __base__=ConfiguredBase, **final_fields) + + new_model.model_rebuild() + return new_model def get_own_properties(data_point: DataPoint): diff --git a/cognee/modules/visualization/cognee_network_visualization.py b/cognee/modules/visualization/cognee_network_visualization.py index 83a0f237e..0328c5c3e 100644 --- a/cognee/modules/visualization/cognee_network_visualization.py +++ b/cognee/modules/visualization/cognee_network_visualization.py @@ -53,7 +53,40 @@ async def cognee_network_visualization(graph_data, destination_file_path: str = target = str(target) G.add_edge(source, target) edge_labels[(source, target)] = relation - links_list.append({"source": source, "target": target, "relation": relation}) + + # Extract edge metadata including all weights + all_weights = {} + primary_weight = None + + if edge_info: + # Single weight (backward compatibility) + if "weight" in edge_info: + all_weights["default"] = edge_info["weight"] + primary_weight = edge_info["weight"] + + # Multiple weights + if "weights" in edge_info and isinstance(edge_info["weights"], dict): + all_weights.update(edge_info["weights"]) + # Use the first weight as primary for visual thickness if no default weight + if primary_weight is None and edge_info["weights"]: + primary_weight = next(iter(edge_info["weights"].values())) + + # Individual weight fields (weight_strength, weight_confidence, etc.) + for key, value in edge_info.items(): + if key.startswith("weight_") and isinstance(value, (int, float)): + weight_name = key[7:] # Remove "weight_" prefix + all_weights[weight_name] = value + + link_data = { + "source": source, + "target": target, + "relation": relation, + "weight": primary_weight, # Primary weight for backward compatibility + "all_weights": all_weights, # All weights for display + "relationship_type": edge_info.get("relationship_type") if edge_info else None, + "edge_info": edge_info if edge_info else {}, + } + links_list.append(link_data) html_template = """ @@ -66,13 +99,33 @@ async def cognee_network_visualization(graph_data, destination_file_path: str = svg { width: 100vw; height: 100vh; display: block; } .links line { stroke: rgba(255, 255, 255, 0.4); stroke-width: 2px; } + .links line.weighted { stroke: rgba(255, 215, 0, 0.7); } + .links line.multi-weighted { stroke: rgba(0, 255, 127, 0.8); } .nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); } .node-label { font-size: 5px; font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; } .edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; } + + .tooltip { + position: absolute; + text-align: left; + padding: 8px; + font-size: 12px; + background: rgba(0, 0, 0, 0.9); + color: white; + border: 1px solid rgba(255, 255, 255, 0.3); + border-radius: 4px; + pointer-events: none; + opacity: 0; + transition: opacity 0.2s; + z-index: 1000; + max-width: 300px; + word-wrap: break-word; + } +
- + diff --git a/cognee/tests/unit/interfaces/graph/test_weighted_edges.py b/cognee/tests/unit/interfaces/graph/test_weighted_edges.py new file mode 100644 index 000000000..0b83d2b8d --- /dev/null +++ b/cognee/tests/unit/interfaces/graph/test_weighted_edges.py @@ -0,0 +1,391 @@ +import pytest +from typing import List, Any +from pydantic import SkipValidation +from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine.models.Edge import Edge +from cognee.modules.graph.utils import get_graph_from_model + + +class Product(DataPoint): + name: str + description: str + metadata: dict = {"index_fields": ["name"]} + + +class Category(DataPoint): + name: str + description: str + products: List[Product] = [] + metadata: dict = {"index_fields": ["name"]} + + +class User(DataPoint): + name: str + email: str + # Weighted relationships + purchased_products: SkipValidation[Any] = None # (Edge, list[Product]) + favorite_categories: SkipValidation[Any] = None # (Edge, list[Category]) + follows: SkipValidation[Any] = None # (Edge, list["User"]) + metadata: dict = {"index_fields": ["name", "email"]} + + +class Company(DataPoint): + name: str + description: str + employees: SkipValidation[Any] = None # (Edge, list[User]) + partners: SkipValidation[Any] = None # (Edge, list["Company"]) + metadata: dict = {"index_fields": ["name"]} + + +@pytest.mark.asyncio +async def test_single_weight_edge(): + """Test get_graph_from_model with single weight edges (backward compatible)""" + + product1 = Product(name="Laptop", description="Gaming laptop") + product2 = Product(name="Mouse", description="Wireless mouse") + + user = User( + name="John Doe", + email="john@example.com", + purchased_products=(Edge(weight=0.8, relationship_type="purchased"), [product1, product2]), + ) + + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model(user, added_nodes, added_edges, visited_properties) + + # Should have user + 2 products = 3 nodes + assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}" + # Should have 2 edges (user -> product1, user -> product2) + assert len(edges) == 2, f"Expected 2 edges, got {len(edges)}" + + # Check edge properties contain weight + for edge in edges: + source_id, target_id, relationship_name, edge_properties = edge + assert "weight" in edge_properties, "Edge should contain weight property" + assert edge_properties["weight"] == 0.8, ( + f"Expected weight 0.8, got {edge_properties['weight']}" + ) + assert edge_properties["relationship_name"] == "purchased" + + +@pytest.mark.asyncio +async def test_multiple_weights_edge(): + """Test get_graph_from_model with multiple weights on edges""" + + category1 = Category(name="Electronics", description="Electronic products") + category2 = Category(name="Gaming", description="Gaming products") + + user = User( + name="Alice Smith", + email="alice@example.com", + favorite_categories=( + Edge( + weights={ + "interest_level": 0.9, + "time_spent": 0.7, + "purchase_frequency": 0.8, + "expertise": 0.6, + }, + relationship_type="interested_in", + ), + [category1, category2], + ), + ) + + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model(user, added_nodes, added_edges, visited_properties) + + # Should have user + 2 categories = 3 nodes + assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}" + # Should have 2 edges + assert len(edges) == 2, f"Expected 2 edges, got {len(edges)}" + + # Check edge properties contain multiple weights + for edge in edges: + source_id, target_id, relationship_name, edge_properties = edge + assert edge_properties["relationship_name"] == "interested_in" + + # Check individual weight fields + assert "weight_interest_level" in edge_properties + assert "weight_time_spent" in edge_properties + assert "weight_purchase_frequency" in edge_properties + assert "weight_expertise" in edge_properties + + assert edge_properties["weight_interest_level"] == 0.9 + assert edge_properties["weight_time_spent"] == 0.7 + assert edge_properties["weight_purchase_frequency"] == 0.8 + assert edge_properties["weight_expertise"] == 0.6 + + +@pytest.mark.asyncio +async def test_mixed_single_and_multiple_weights(): + """Test get_graph_from_model with both single weight and multiple weights on same edge""" + + product = Product(name="Smartphone", description="Latest smartphone") + + user = User( + name="Bob Wilson", + email="bob@example.com", + purchased_products=( + Edge( + weight=0.7, # Single weight (backward compatible) + weights={"satisfaction": 0.9, "value_for_money": 0.6}, # Multiple weights + relationship_type="owns", + ), + [product], + ), + ) + + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model(user, added_nodes, added_edges, visited_properties) + + assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}" + assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}" + + edge_properties = edges[0][3] + + # Should have both single weight and multiple weights + assert "weight" in edge_properties, "Should have backward compatible weight field" + assert edge_properties["weight"] == 0.7 + + assert "weight_satisfaction" in edge_properties + assert "weight_value_for_money" in edge_properties + assert edge_properties["weight_satisfaction"] == 0.9 + assert edge_properties["weight_value_for_money"] == 0.6 + + +@pytest.mark.asyncio +async def test_complex_weighted_relationships(): + """Test complex scenario with multiple entities and various weighted relationships""" + + # Create products and categories + product1 = Product(name="Gaming Chair", description="Ergonomic gaming chair") + product2 = Product(name="Mechanical Keyboard", description="RGB mechanical keyboard") + + category = Category(name="Gaming Accessories", description="Gaming accessories category") + category.products = [product1, product2] + + # Create users with different weighted relationships + user1 = User( + name="Gamer Pro", + email="gamerpro@example.com", + purchased_products=( + Edge( + weights={ + "satisfaction": 0.95, + "frequency_of_use": 0.9, + "recommendation_likelihood": 0.8, + }, + relationship_type="purchased", + ), + [product1, product2], + ), + favorite_categories=(Edge(weight=0.9, relationship_type="follows"), [category]), + ) + + user2 = User( + name="Casual User", + email="casual@example.com", + purchased_products=(Edge(weight=0.6, relationship_type="purchased"), [product1]), + ) + + # Create weighted user relationships + user1.follows = ( + Edge( + weights={ + "friendship_level": 0.7, + "shared_interests": 0.8, + "communication_frequency": 0.5, + }, + relationship_type="follows", + ), + [user2], + ) + + added_nodes = {} + added_edges = {} + visited_properties = {} + + # Process user1 (which should process all connected nodes) + nodes, edges = await get_graph_from_model(user1, added_nodes, added_edges, visited_properties) + + # Should have: user1, user2, 2 products, 1 category = 5 nodes + assert len(nodes) == 5, f"Expected 5 nodes, got {len(nodes)}" + + # Should have multiple edges with different weight configurations + assert len(edges) > 0, "Should have edges" + + # Verify that different edge types are created correctly + edge_types = set() + weighted_edges = 0 + multi_weighted_edges = 0 + + for edge in edges: + source_id, target_id, relationship_name, edge_properties = edge + edge_types.add(relationship_name) + + if "weight" in edge_properties: + weighted_edges += 1 + + # Count edges with multiple weights + multi_weight_fields = [k for k in edge_properties.keys() if k.startswith("weight_")] + if len(multi_weight_fields) > 1: + multi_weighted_edges += 1 + + assert "purchased" in edge_types + assert "follows" in edge_types + assert weighted_edges > 0, "Should have edges with weights" + assert multi_weighted_edges > 0, "Should have edges with multiple weights" + + +@pytest.mark.asyncio +async def test_company_hierarchy_with_weights(): + """Test hierarchical company structure with weighted relationships""" + + # Create users + ceo = User(name="CEO", email="ceo@company.com") + manager = User(name="Manager", email="manager@company.com") + developer = User(name="Developer", email="dev@company.com") + + # Create companies with weighted employee relationships + startup = Company( + name="Tech Startup", + description="Innovative tech startup", + employees=( + Edge( + weights={"seniority": 0.9, "performance": 0.8, "leadership": 0.95}, + relationship_type="employs", + ), + [ceo, manager, developer], + ), + ) + + corporation = Company(name="Big Corp", description="Large corporation") + + # Create partnership with weights + startup.partners = ( + Edge( + weights={"trust_level": 0.7, "business_value": 0.8, "strategic_importance": 0.6}, + relationship_type="partners_with", + ), + [corporation], + ) + + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model(startup, added_nodes, added_edges, visited_properties) + + # Should have: startup, corporation, 3 users = 5 nodes + assert len(nodes) == 5, f"Expected 5 nodes, got {len(nodes)}" + + # Verify weighted relationships are properly stored + partnership_edges = [e for e in edges if e[2] == "partners_with"] + employee_edges = [e for e in edges if e[2] == "employs"] + + assert len(partnership_edges) == 1, "Should have one partnership edge" + assert len(employee_edges) == 3, "Should have three employee edges" + + # Check partnership edge weights + partnership_props = partnership_edges[0][3] + assert "weight_trust_level" in partnership_props + assert "weight_business_value" in partnership_props + assert "weight_strategic_importance" in partnership_props + + # Check employee edge weights + for edge in employee_edges: + props = edge[3] + assert "weight_seniority" in props + assert "weight_performance" in props + assert "weight_leadership" in props + + +@pytest.mark.asyncio +async def test_edge_metadata_preservation(): + """Test that all edge metadata is preserved correctly in weighted edges""" + + product = Product(name="Test Product", description="A test product") + + user = User( + name="Test User", + email="test@example.com", + purchased_products=( + Edge(weight=0.8, weights={"quality": 0.9, "price": 0.7}, relationship_type="purchased"), + [product], + ), + ) + + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model(user, added_nodes, added_edges, visited_properties) + + assert len(edges) == 1, "Should have exactly one edge" + + edge_properties = edges[0][3] + + # Check all required metadata is present + assert "source_node_id" in edge_properties + assert "target_node_id" in edge_properties + assert "relationship_name" in edge_properties + assert "updated_at" in edge_properties + + # Check relationship type + assert edge_properties["relationship_name"] == "purchased" + + # Check weights are properly stored + assert "weight" in edge_properties + assert edge_properties["weight"] == 0.8 + + assert "weight_quality" in edge_properties + assert edge_properties["weight_quality"] == 0.9 + + assert "weight_price" in edge_properties + assert edge_properties["weight_price"] == 0.7 + + +@pytest.mark.asyncio +async def test_no_weights_edge(): + """Test that edges without weights still work correctly""" + + product = Product(name="Simple Product", description="No weights product") + + user = User( + name="Simple User", + email="simple@example.com", + purchased_products=(Edge(relationship_type="purchased"), [product]), + ) + + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model(user, added_nodes, added_edges, visited_properties) + + assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}" + assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}" + + edge_properties = edges[0][3] + + # Should have basic metadata but no weights + assert "source_node_id" in edge_properties + assert "target_node_id" in edge_properties + assert "relationship_name" in edge_properties + assert "updated_at" in edge_properties + assert edge_properties["relationship_name"] == "purchased" + + # Should not have weight fields + assert "weight" not in edge_properties + weight_fields = [k for k in edge_properties.keys() if k.startswith("weight_")] + assert len(weight_fields) == 0, f"Should have no weight fields, but found: {weight_fields}" diff --git a/examples/python/weighted_edges_example.py b/examples/python/weighted_edges_example.py new file mode 100644 index 000000000..7372b9d05 --- /dev/null +++ b/examples/python/weighted_edges_example.py @@ -0,0 +1,137 @@ +import asyncio +from os import path +from typing import Any +from pydantic import SkipValidation +from cognee.api.v1.visualize.visualize import visualize_graph +from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine.models.Edge import Edge +from cognee.tasks.storage import add_data_points +import cognee + + +class Clothes(DataPoint): + name: str + description: str + + +class Object(DataPoint): + name: str + description: str + has_clothes: list[Clothes] + + +class Person(DataPoint): + name: str + description: str + has_items: SkipValidation[Any] # (Edge, list[Clothes]) + has_objects: SkipValidation[Any] # (Edge, list[Object]) + knows: SkipValidation[Any] # (Edge, list["Person"]) + + +async def main(): + # Clear the database for a clean state + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + # Create clothes items + item1 = Clothes(name="Shirt", description="A blue shirt") + item2 = Clothes(name="Pants", description="Black pants") + item3 = Clothes(name="Jacket", description="Leather jacket") + + # Create object with simple relationship to clothes + object1 = Object( + name="Closet", description="A wooden closet", has_clothes=[item1, item2, item3] + ) + + # Create people with various weighted relationships + person1 = Person( + name="John", + description="A software engineer", + # Single weight (backward compatible) + has_items=(Edge(weight=0.8, relationship_type="owns"), [item1, item2]), + # Simple relationship without weights + has_objects=(Edge(relationship_type="stores_in"), [object1]), + knows=[], + ) + + person2 = Person( + name="Alice", + description="A designer", + # Multiple weights on edge + has_items=( + Edge( + weights={ + "ownership": 0.9, + "frequency_of_use": 0.7, + "emotional_attachment": 0.8, + "monetary_value": 0.6, + }, + relationship_type="owns", + ), + [item3], + ), + has_objects=(Edge(relationship_type="uses"), [object1]), + knows=[], + ) + + person3 = Person( + name="Bob", + description="A friend", + # Mixed: single weight + multiple weights + has_items=( + Edge( + weight=0.5, # Default weight + weights={"trust_level": 0.9, "communication_frequency": 0.6}, + relationship_type="borrows", + ), + [item1], + ), + has_objects=[], + knows=[], + ) + + # Create relationships between people with multiple weights + person1.knows = ( + Edge( + weights={ + "friendship_strength": 0.9, + "trust_level": 0.8, + "years_known": 0.7, + "shared_interests": 0.6, + }, + relationship_type="friend", + ), + [person2, person3], + ) + + person2.knows = ( + Edge( + weights={"professional_collaboration": 0.8, "personal_friendship": 0.6}, + relationship_type="colleague", + ), + [person1], + ) + + all_data_points = [item1, item2, item3, object1, person1, person2, person3] + + # Add data points to the graph + await add_data_points(all_data_points) + + # Visualize the graph + graph_visualization_path = path.join( + path.dirname(__file__), "weighted_graph_visualization.html" + ) + await visualize_graph(graph_visualization_path) + + print("Graph with multiple weighted edges has been created and visualized!") + print(f"Visualization saved to: {graph_visualization_path}") + print("\nFeatures demonstrated:") + print("- Single weight edges (backward compatible)") + print("- Multiple weights on single edges") + print("- Mixed single + multiple weights") + print("- Hover over edges to see all weight information") + print("- Different visual styling for single vs. multiple weighted edges") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/python/weighted_graph_visualization.html b/examples/python/weighted_graph_visualization.html new file mode 100644 index 000000000..2e7f67e31 --- /dev/null +++ b/examples/python/weighted_graph_visualization.html @@ -0,0 +1,212 @@ + + + + + + + + + + +
+ + + + + + + + \ No newline at end of file