diff --git a/.github/workflows/basic_tests.yml b/.github/workflows/basic_tests.yml index 13ea829da..456bc8c88 100644 --- a/.github/workflows/basic_tests.yml +++ b/.github/workflows/basic_tests.yml @@ -184,3 +184,5 @@ jobs: - name: Run Graph Tests run: poetry run python ./examples/python/code_graph_example.py --repo_path ./cognee/tasks/graph + + diff --git a/.github/workflows/graph_db_tests.yml b/.github/workflows/graph_db_tests.yml index 9dc60b11d..94ff639ba 100644 --- a/.github/workflows/graph_db_tests.yml +++ b/.github/workflows/graph_db_tests.yml @@ -50,6 +50,20 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: poetry run python ./cognee/tests/test_kuzu.py + - name: Run Weighted Edges Tests with Kuzu + env: + ENV: 'dev' + GRAPH_DATABASE_PROVIDER: "kuzu" + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v + run-neo4j-tests: name: Neo4j Tests runs-on: ubuntu-22.04 @@ -83,3 +97,20 @@ jobs: GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }} GRAPH_DATABASE_USERNAME: "neo4j" run: poetry run python ./cognee/tests/test_neo4j.py + + - name: Run Weighted Edges Tests with Neo4j + env: + ENV: 'dev' + GRAPH_DATABASE_PROVIDER: "neo4j" + GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }} + GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }} + GRAPH_DATABASE_USERNAME: "neo4j" + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v diff --git a/.github/workflows/weighted_edges_tests.yml b/.github/workflows/weighted_edges_tests.yml new file mode 100644 index 000000000..d33bd1ceb --- /dev/null +++ b/.github/workflows/weighted_edges_tests.yml @@ -0,0 +1,166 @@ +name: Weighted Edges Tests + +on: + push: + branches: [ main, weighted_edges ] + paths: + - 'cognee/modules/graph/utils/get_graph_from_model.py' + - 'cognee/infrastructure/engine/models/Edge.py' + - 'cognee/tests/unit/interfaces/graph/test_weighted_edges.py' + - 'examples/python/weighted_edges_example.py' + - '.github/workflows/weighted_edges_tests.yml' + pull_request: + branches: [ main ] + paths: + - 'cognee/modules/graph/utils/get_graph_from_model.py' + - 'cognee/infrastructure/engine/models/Edge.py' + - 'cognee/tests/unit/interfaces/graph/test_weighted_edges.py' + - 'examples/python/weighted_edges_example.py' + - '.github/workflows/weighted_edges_tests.yml' + +env: + RUNTIME__LOG_LEVEL: ERROR + ENV: 'dev' + +jobs: + test-weighted-edges-functionality: + name: Test Weighted Edges Core Functionality + runs-on: ubuntu-22.04 + strategy: + matrix: + python-version: ['3.11', '3.12'] + env: + LLM_PROVIDER: openai + LLM_MODEL: gpt-4o-mini + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + + steps: + - name: Check out repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: ${{ matrix.python-version }} + + - name: Run Weighted Edges Unit Tests + run: | + poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short + + - name: Run Standard Graph Tests (Regression) + run: | + poetry run pytest cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py -v --tb=short + + test-with-different-databases: + name: Test Weighted Edges with Different Graph Databases + runs-on: ubuntu-22.04 + strategy: + matrix: + database: ['kuzu', 'neo4j'] + include: + - database: kuzu + install_extra: "" + graph_db_provider: "kuzu" + - database: neo4j + install_extra: "-E neo4j" + graph_db_provider: "neo4j" + env: + LLM_PROVIDER: openai + LLM_MODEL: gpt-4o-mini + LLM_ENDPOINT: https://api.openai.com/v1/ + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: "2024-02-01" + EMBEDDING_PROVIDER: openai + EMBEDDING_MODEL: text-embedding-3-small + EMBEDDING_ENDPOINT: https://api.openai.com/v1/ + EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }} + EMBEDDING_API_VERSION: "2024-02-01" + steps: + - name: Check out repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11' + + - name: Install Database Dependencies + run: | + poetry install ${{ matrix.install_extra }} + + - name: Run Weighted Edges Tests + env: + GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }} + run: | + poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short + + test-examples: + name: Test Weighted Edges Examples + runs-on: ubuntu-22.04 + env: + LLM_PROVIDER: openai + LLM_MODEL: gpt-4o-mini + LLM_ENDPOINT: https://api.openai.com/v1/ + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: "2024-02-01" + EMBEDDING_PROVIDER: openai + EMBEDDING_MODEL: text-embedding-3-small + EMBEDDING_ENDPOINT: https://api.openai.com/v1/ + EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }} + EMBEDDING_API_VERSION: "2024-02-01" + steps: + - name: Check out repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11' + + - name: Test Weighted Edges Example + run: | + poetry run python examples/python/weighted_edges_example.py + + - name: Verify Visualization File Created + run: | + if [ -f "examples/python/weighted_graph_visualization.html" ]; then + echo "✅ Visualization file created successfully" + ls -la examples/python/weighted_graph_visualization.html + else + echo "❌ Visualization file not found" + exit 1 + fi + + + + code-quality: + name: Code Quality for Weighted Edges + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11' + + - name: Run Linting on Weighted Edges Files + uses: astral-sh/ruff-action@v2 + with: + args: "check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py" + + - name: Run Formatting Check on Weighted Edges Files + uses: astral-sh/ruff-action@v2 + with: + args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py" + + \ No newline at end of file diff --git a/cognee/infrastructure/engine/__init__.py b/cognee/infrastructure/engine/__init__.py index cc084d142..e1328a655 100644 --- a/cognee/infrastructure/engine/__init__.py +++ b/cognee/infrastructure/engine/__init__.py @@ -1,2 +1,3 @@ from .models.DataPoint import DataPoint from .models.ExtendableDataPoint import ExtendableDataPoint +from .models.Edge import Edge diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 3b5f7424a..812380eaa 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -1,6 +1,6 @@ import pickle from uuid import UUID, uuid4 -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from datetime import datetime, timezone from typing_extensions import TypedDict from typing import Optional, Any, Dict, List @@ -34,6 +34,8 @@ class DataPoint(BaseModel): - from_dict """ + model_config = ConfigDict(arbitrary_types_allowed=True) + id: UUID = Field(default_factory=uuid4) created_at: int = Field( default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) diff --git a/cognee/infrastructure/engine/models/Edge.py b/cognee/infrastructure/engine/models/Edge.py new file mode 100644 index 000000000..5ad9c84dd --- /dev/null +++ b/cognee/infrastructure/engine/models/Edge.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel +from typing import Optional, Any, Dict + + +class Edge(BaseModel): + """ + Represents edge metadata for relationships between DataPoints. + + This class is used to define edge properties like weight when creating + relationships between DataPoints using tuple syntax: + + Example: + # Single weight (backward compatible) + has_items: (Edge(weight=0.5), list[Item]) + + # Multiple weights + has_items: (Edge(weights={"strength": 0.8, "confidence": 0.9, "importance": 0.7}), list[Item]) + + # Mixed usage + has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item]) + """ + + weight: Optional[float] = None + weights: Optional[Dict[str, float]] = None + relationship_type: Optional[str] = None + properties: Optional[Dict[str, Any]] = None diff --git a/cognee/infrastructure/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py index 3d423da6e..c1cc40e38 100644 --- a/cognee/infrastructure/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -1,5 +1,6 @@ """Adapter for Generic API LLM provider API""" +import logging import litellm import instructor from typing import Type diff --git a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py index d3a0a8522..c1f55d4fc 100644 --- a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py +++ b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py @@ -281,7 +281,40 @@ def expand_with_nodes_and_edges( existing_edges_map: Optional[dict[str, bool]] = None, ): """ - Expand data chunks with nodes and edges, validating against ontology. + + - LLM generated docstring + Expand knowledge graphs with validated nodes and edges, integrating ontology information. + + This function processes document chunks and their associated knowledge graphs to create + a comprehensive graph structure with entity nodes, entity type nodes, and their relationships. + It validates entities against an ontology resolver and adds ontology-derived nodes and edges + to enhance the knowledge representation. + + Args: + data_chunks (list[DocumentChunk]): List of document chunks that contain the source data. + Each chunk should have metadata about what entities it contains. + chunk_graphs (list[KnowledgeGraph]): List of knowledge graphs corresponding to each + data chunk. Each graph contains nodes (entities) and edges (relationships) extracted + from the chunk content. + ontology_resolver (OntologyResolver, optional): Resolver for validating entities and + types against an ontology. If None, a default OntologyResolver is created. + Defaults to None. + existing_edges_map (dict[str, bool], optional): Mapping of existing edge keys to prevent + duplicate edge creation. Keys are formatted as "{source_id}_{target_id}_{relation}". + If None, an empty dictionary is created. Defaults to None. + + Returns: + tuple[list, list]: A tuple containing: + - graph_nodes (list): Combined list of data chunks and ontology nodes (EntityType and Entity objects) + - graph_edges (list): List of edge tuples in format (source_id, target_id, relationship_name, properties) + + Note: + - Entity nodes are created for each entity found in the knowledge graphs + - EntityType nodes are created for each unique entity type + - Ontology validation is performed to map entities to canonical ontology terms + - Duplicate nodes and edges are prevented using internal mapping and the existing_edges_map + - The function modifies data_chunks in-place by adding entities to their 'contains' attribute + """ if existing_edges_map is None: existing_edges_map = {} diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index a6d50f41c..6fcca4572 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -1,134 +1,256 @@ from datetime import datetime, timezone -from cognee.infrastructure.engine import DataPoint +from typing import Tuple, List, Any, Dict, Optional +from cognee.infrastructure.engine import DataPoint, Edge from cognee.modules.storage.utils import copy_model +def _extract_field_info(field_value: Any) -> Tuple[str, Any, Optional[Edge]]: + """Extract field type, actual value, and edge metadata from a field value.""" + + # Handle tuple[Edge, DataPoint] + if ( + isinstance(field_value, tuple) + and len(field_value) == 2 + and isinstance(field_value[0], Edge) + and isinstance(field_value[1], DataPoint) + ): + return "single_datapoint_with_edge", field_value[1], field_value[0] + + # Handle tuple[Edge, list[DataPoint]] + if ( + isinstance(field_value, tuple) + and len(field_value) == 2 + and isinstance(field_value[0], Edge) + and isinstance(field_value[1], list) + and len(field_value[1]) > 0 + and isinstance(field_value[1][0], DataPoint) + ): + return "list_datapoint_with_edge", field_value[1], field_value[0] + + # Handle single DataPoint + if isinstance(field_value, DataPoint): + return "single_datapoint", field_value, None + + # Handle list of DataPoints + if ( + isinstance(field_value, list) + and len(field_value) > 0 + and isinstance(field_value[0], DataPoint) + ): + return "list_datapoint", field_value, None + + # Regular property + return "property", field_value, None + + +def _create_edge_properties( + source_id: str, target_id: str, relationship_name: str, edge_metadata: Optional[Edge] +) -> Dict[str, Any]: + """Create edge properties dictionary with metadata if present.""" + properties = { + "source_node_id": source_id, + "target_node_id": target_id, + "relationship_name": relationship_name, + "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + } + + if edge_metadata: + # Add edge metadata + edge_data = edge_metadata.model_dump(exclude_none=True) + properties.update(edge_data) + + # Add individual weights as separate fields for easier querying + if edge_metadata.weights is not None: + for weight_name, weight_value in edge_metadata.weights.items(): + properties[f"weight_{weight_name}"] = weight_value + + return properties + + +def _get_relationship_key(field_name: str, edge_metadata: Optional[Edge]) -> str: + """Extract relationship key from edge metadata or use field name as fallback.""" + if ( + edge_metadata + and hasattr(edge_metadata, "relationship_type") + and edge_metadata.relationship_type + ): + return edge_metadata.relationship_type + return field_name + + +def _generate_property_key(data_point_id: str, relationship_key: str, target_id: str) -> str: + """Generate a unique property key for visited_properties tracking.""" + return f"{data_point_id}{relationship_key}{target_id}" + + +def _process_datapoint_field( + data_point_id: str, + field_name: str, + datapoints: List[DataPoint], + edge_metadata: Optional[Edge], + visited_properties: Dict[str, bool], + properties_to_visit: set, + excluded_properties: set, +) -> None: + """Process a field containing DataPoint(s), handling both single and list cases.""" + excluded_properties.add(field_name) + relationship_key = _get_relationship_key(field_name, edge_metadata) + + for index, datapoint in enumerate(datapoints): + property_key = _generate_property_key(data_point_id, relationship_key, str(datapoint.id)) + if property_key in visited_properties: + continue + + # For single datapoint, use field_name; for list, use field_name.index + field_identifier = field_name if len(datapoints) == 1 else f"{field_name}.{index}" + properties_to_visit.add(field_identifier) + + async def get_graph_from_model( data_point: DataPoint, - added_nodes: dict, - added_edges: dict, - visited_properties: dict = None, - include_root=True, -): + added_nodes: Dict[str, bool], + added_edges: Dict[str, bool], + visited_properties: Optional[Dict[str, bool]] = None, + include_root: bool = True, +) -> Tuple[List[DataPoint], List[Tuple[str, str, str, Dict[str, Any]]]]: + """ + Extract graph representation from a DataPoint model. + + Args: + data_point: The DataPoint to extract graph from + added_nodes: Dictionary tracking already processed nodes + added_edges: Dictionary tracking already processed edges + visited_properties: Dictionary tracking visited properties to avoid cycles + include_root: Whether to include the root node in results + + Returns: + Tuple of (nodes, edges) extracted from the model + """ if str(data_point.id) in added_nodes: return [], [] nodes = [] edges = [] visited_properties = visited_properties or {} + data_point_id = str(data_point.id) - data_point_properties = { - "type": type(data_point).__name__, - } + data_point_properties = {"type": type(data_point).__name__} excluded_properties = set() properties_to_visit = set() + # Analyze all fields to categorize them as properties or relationships for field_name, field_value in data_point: if field_name == "metadata": continue - if isinstance(field_value, DataPoint): - excluded_properties.add(field_name) + field_type, actual_value, edge_metadata = _extract_field_info(field_value) - property_key = str(data_point.id) + field_name + str(field_value.id) + if field_type == "property": + data_point_properties[field_name] = field_value + elif field_type in ["single_datapoint", "single_datapoint_with_edge"]: + _process_datapoint_field( + data_point_id, + field_name, + [actual_value], + edge_metadata, + visited_properties, + properties_to_visit, + excluded_properties, + ) + elif field_type in ["list_datapoint", "list_datapoint_with_edge"]: + _process_datapoint_field( + data_point_id, + field_name, + actual_value, + edge_metadata, + visited_properties, + properties_to_visit, + excluded_properties, + ) - if property_key in visited_properties: - continue - - visited_properties[property_key] = True - properties_to_visit.add(field_name) - - continue - - if ( - isinstance(field_value, list) - and len(field_value) > 0 - and isinstance(field_value[0], DataPoint) - ): - excluded_properties.add(field_name) - - for index, item in enumerate(field_value): - property_key = str(data_point.id) + field_name + str(item.id) - - if property_key in visited_properties: - continue - - visited_properties[property_key] = True - properties_to_visit.add(f"{field_name}.{index}") - - continue - - data_point_properties[field_name] = field_value - - if include_root and str(data_point.id) not in added_nodes: + # Create node for current DataPoint if needed + if include_root and data_point_id not in added_nodes: SimpleDataPointModel = copy_model( - type(data_point), - exclude_fields=list(excluded_properties), + type(data_point), exclude_fields=list(excluded_properties) ) nodes.append(SimpleDataPointModel(**data_point_properties)) - added_nodes[str(data_point.id)] = True + added_nodes[data_point_id] = True - for field_name in properties_to_visit: - index = None - - if "." in field_name: - field_name, index = field_name.split(".") + # Process all relationships + for field_name_with_index in properties_to_visit: + # Parse field name and index + if "." in field_name_with_index: + field_name, index_str = field_name_with_index.split(".") + index = int(index_str) + else: + field_name, index = field_name_with_index, None + # Get field value and extract edge metadata field_value = getattr(data_point, field_name) + edge_metadata = None + if ( + isinstance(field_value, tuple) + and len(field_value) == 2 + and isinstance(field_value[0], Edge) + ): + edge_metadata, field_value = field_value + + # Get specific datapoint - handle both single and list cases if index is not None: - field_value = field_value[int(index)] + # List case: extract specific item by index + target_datapoint = field_value[index] + elif isinstance(field_value, list): + # Single datapoint case that was wrapped in a list + target_datapoint = field_value[0] + else: + # True single datapoint case + target_datapoint = field_value - edge_key = str(data_point.id) + str(field_value.id) + field_name - - if str(edge_key) not in added_edges: - edges.append( - ( - data_point.id, - field_value.id, - field_name, - { - "source_node_id": data_point.id, - "target_node_id": field_value.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - }, - ) + # Create edge if not already added + edge_key = f"{data_point_id}{target_datapoint.id}{field_name}" + if edge_key not in added_edges: + relationship_name = _get_relationship_key(field_name, edge_metadata) + edge_properties = _create_edge_properties( + data_point.id, target_datapoint.id, relationship_name, edge_metadata ) - added_edges[str(edge_key)] = True + edges.append((data_point.id, target_datapoint.id, relationship_name, edge_properties)) + added_edges[edge_key] = True - if str(field_value.id) in added_nodes: - continue - - property_key = str(data_point.id) + field_name + str(field_value.id) + # Mark property as visited - CRITICAL for preventing infinite loops + relationship_key = _get_relationship_key(field_name, edge_metadata) + property_key = _generate_property_key( + data_point_id, relationship_key, str(target_datapoint.id) + ) visited_properties[property_key] = True - property_nodes, property_edges = await get_graph_from_model( - field_value, - include_root=True, - added_nodes=added_nodes, - added_edges=added_edges, - visited_properties=visited_properties, - ) - - for node in property_nodes: - nodes.append(node) - - for edge in property_edges: - edges.append(edge) + # Recursively process target node if not already processed + if str(target_datapoint.id) not in added_nodes: + child_nodes, child_edges = await get_graph_from_model( + target_datapoint, + include_root=True, + added_nodes=added_nodes, + added_edges=added_edges, + visited_properties=visited_properties, + ) + nodes.extend(child_nodes) + edges.extend(child_edges) return nodes, edges -def get_own_property_nodes(property_nodes, property_edges): - own_properties = [] +def get_own_property_nodes( + property_nodes: List[DataPoint], property_edges: List[Tuple[str, str, str, Dict[str, Any]]] +) -> List[DataPoint]: + """ + Filter nodes to return only those that are not destinations of any edges. - destination_nodes = [str(property_edge[1]) for property_edge in property_edges] + Args: + property_nodes: List of all nodes + property_edges: List of all edges - for node in property_nodes: - if str(node.id) in destination_nodes: - continue - - own_properties.append(node) - - return own_properties + Returns: + List of nodes that are not edge destinations + """ + destination_node_ids = {str(edge[1]) for edge in property_edges} + return [node for node in property_nodes if str(node.id) not in destination_node_ids] diff --git a/cognee/modules/graph/utils/retrieve_existing_edges.py b/cognee/modules/graph/utils/retrieve_existing_edges.py index 13c3490a8..20cb30a26 100644 --- a/cognee/modules/graph/utils/retrieve_existing_edges.py +++ b/cognee/modules/graph/utils/retrieve_existing_edges.py @@ -8,6 +8,39 @@ async def retrieve_existing_edges( data_chunks: list[DataPoint], chunk_graphs: list[KnowledgeGraph], ) -> dict[str, bool]: + """ + - LLM generated docstring + Retrieve existing edges from the graph database to prevent duplicate edge creation. + + This function checks which edges already exist in the graph database by querying + for various types of relationships including structural edges (exists_in, mentioned_in, is_a) + and content-derived edges from the knowledge graphs. It returns a mapping that can be + used to avoid creating duplicate edges during graph expansion. + + Args: + data_chunks (list[DataPoint]): List of data point objects that serve as containers + for the entities. Each data chunk represents a source document or data segment. + chunk_graphs (list[KnowledgeGraph]): List of knowledge graphs corresponding to each + data chunk. Each graph contains nodes (entities) and edges (relationships) that + were extracted from the chunk content. + graph_engine (GraphDBInterface): Interface to the graph database that will be queried + to check for existing edges. Must implement the has_edges() method. + + Returns: + dict[str, bool]: A mapping of edge keys to boolean values indicating existence. + Edge keys are formatted as concatenated strings: "{source_id}{target_id}{relationship_name}". + All values in the returned dictionary are True (indicating the edge exists). + + Note: + - The function generates several types of edges for checking: + * Type node edges: (chunk_id, type_node_id, "exists_in") + * Entity node edges: (chunk_id, entity_node_id, "mentioned_in") + * Type-entity edges: (entity_node_id, type_node_id, "is_a") + * Graph node edges: extracted from the knowledge graph relationships + - Uses generate_node_id() to ensure consistent node ID formatting + - Prevents processing the same node multiple times using a processed_nodes tracker + - The returned mapping can be used with expand_with_nodes_and_edges() to avoid duplicates + """ processed_nodes = {} type_node_edges = [] entity_node_edges = [] diff --git a/cognee/modules/storage/utils/__init__.py b/cognee/modules/storage/utils/__init__.py index 1008d159a..f240798f8 100644 --- a/cognee/modules/storage/utils/__init__.py +++ b/cognee/modules/storage/utils/__init__.py @@ -3,7 +3,7 @@ from uuid import UUID from decimal import Decimal from datetime import datetime from pydantic_core import PydanticUndefined -from pydantic import create_model +from pydantic import create_model, ConfigDict, BaseModel from cognee.infrastructure.engine import DataPoint @@ -29,9 +29,15 @@ def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list final_fields = {**fields, **include_fields} - model = create_model(model.__name__, **final_fields) - model.model_rebuild() - return model + # Create a base class with the same configuration as DataPoint + class ConfiguredBase(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + # Create the model inheriting from the configured base + new_model = create_model(model.__name__, __base__=ConfiguredBase, **final_fields) + + new_model.model_rebuild() + return new_model def get_own_properties(data_point: DataPoint): diff --git a/cognee/modules/visualization/cognee_network_visualization.py b/cognee/modules/visualization/cognee_network_visualization.py index 83a0f237e..0328c5c3e 100644 --- a/cognee/modules/visualization/cognee_network_visualization.py +++ b/cognee/modules/visualization/cognee_network_visualization.py @@ -53,7 +53,40 @@ async def cognee_network_visualization(graph_data, destination_file_path: str = target = str(target) G.add_edge(source, target) edge_labels[(source, target)] = relation - links_list.append({"source": source, "target": target, "relation": relation}) + + # Extract edge metadata including all weights + all_weights = {} + primary_weight = None + + if edge_info: + # Single weight (backward compatibility) + if "weight" in edge_info: + all_weights["default"] = edge_info["weight"] + primary_weight = edge_info["weight"] + + # Multiple weights + if "weights" in edge_info and isinstance(edge_info["weights"], dict): + all_weights.update(edge_info["weights"]) + # Use the first weight as primary for visual thickness if no default weight + if primary_weight is None and edge_info["weights"]: + primary_weight = next(iter(edge_info["weights"].values())) + + # Individual weight fields (weight_strength, weight_confidence, etc.) + for key, value in edge_info.items(): + if key.startswith("weight_") and isinstance(value, (int, float)): + weight_name = key[7:] # Remove "weight_" prefix + all_weights[weight_name] = value + + link_data = { + "source": source, + "target": target, + "relation": relation, + "weight": primary_weight, # Primary weight for backward compatibility + "all_weights": all_weights, # All weights for display + "relationship_type": edge_info.get("relationship_type") if edge_info else None, + "edge_info": edge_info if edge_info else {}, + } + links_list.append(link_data) html_template = """ @@ -66,13 +99,33 @@ async def cognee_network_visualization(graph_data, destination_file_path: str = svg { width: 100vw; height: 100vh; display: block; } .links line { stroke: rgba(255, 255, 255, 0.4); stroke-width: 2px; } + .links line.weighted { stroke: rgba(255, 215, 0, 0.7); } + .links line.multi-weighted { stroke: rgba(0, 255, 127, 0.8); } .nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); } .node-label { font-size: 5px; font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; } .edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; } + + .tooltip { + position: absolute; + text-align: left; + padding: 8px; + font-size: 12px; + background: rgba(0, 0, 0, 0.9); + color: white; + border: 1px solid rgba(255, 255, 255, 0.3); + border-radius: 4px; + pointer-events: none; + opacity: 0; + transition: opacity 0.2s; + z-index: 1000; + max-width: 300px; + word-wrap: break-word; + }
+