feat: Weighted edges (#1068)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
Co-authored-by: Igor Ilic <igorilic03@gmail.com>
This commit is contained in:
Vasilije 2025-07-14 21:26:25 +02:00 committed by GitHub
parent f68fd59b95
commit 4bcb893a54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 1383 additions and 103 deletions

View file

@ -184,3 +184,5 @@ jobs:
- name: Run Graph Tests
run: poetry run python ./examples/python/code_graph_example.py --repo_path ./cognee/tasks/graph

View file

@ -50,6 +50,20 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: poetry run python ./cognee/tests/test_kuzu.py
- name: Run Weighted Edges Tests with Kuzu
env:
ENV: 'dev'
GRAPH_DATABASE_PROVIDER: "kuzu"
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v
run-neo4j-tests:
name: Neo4j Tests
runs-on: ubuntu-22.04
@ -83,3 +97,20 @@ jobs:
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_USERNAME: "neo4j"
run: poetry run python ./cognee/tests/test_neo4j.py
- name: Run Weighted Edges Tests with Neo4j
env:
ENV: 'dev'
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_USERNAME: "neo4j"
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v

View file

@ -0,0 +1,166 @@
name: Weighted Edges Tests
on:
push:
branches: [ main, weighted_edges ]
paths:
- 'cognee/modules/graph/utils/get_graph_from_model.py'
- 'cognee/infrastructure/engine/models/Edge.py'
- 'cognee/tests/unit/interfaces/graph/test_weighted_edges.py'
- 'examples/python/weighted_edges_example.py'
- '.github/workflows/weighted_edges_tests.yml'
pull_request:
branches: [ main ]
paths:
- 'cognee/modules/graph/utils/get_graph_from_model.py'
- 'cognee/infrastructure/engine/models/Edge.py'
- 'cognee/tests/unit/interfaces/graph/test_weighted_edges.py'
- 'examples/python/weighted_edges_example.py'
- '.github/workflows/weighted_edges_tests.yml'
env:
RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs:
test-weighted-edges-functionality:
name: Test Weighted Edges Core Functionality
runs-on: ubuntu-22.04
strategy:
matrix:
python-version: ['3.11', '3.12']
env:
LLM_PROVIDER: openai
LLM_MODEL: gpt-4o-mini
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
steps:
- name: Check out repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: ${{ matrix.python-version }}
- name: Run Weighted Edges Unit Tests
run: |
poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short
- name: Run Standard Graph Tests (Regression)
run: |
poetry run pytest cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py -v --tb=short
test-with-different-databases:
name: Test Weighted Edges with Different Graph Databases
runs-on: ubuntu-22.04
strategy:
matrix:
database: ['kuzu', 'neo4j']
include:
- database: kuzu
install_extra: ""
graph_db_provider: "kuzu"
- database: neo4j
install_extra: "-E neo4j"
graph_db_provider: "neo4j"
env:
LLM_PROVIDER: openai
LLM_MODEL: gpt-4o-mini
LLM_ENDPOINT: https://api.openai.com/v1/
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: "2024-02-01"
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: text-embedding-3-small
EMBEDDING_ENDPOINT: https://api.openai.com/v1/
EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }}
EMBEDDING_API_VERSION: "2024-02-01"
steps:
- name: Check out repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11'
- name: Install Database Dependencies
run: |
poetry install ${{ matrix.install_extra }}
- name: Run Weighted Edges Tests
env:
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
run: |
poetry run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short
test-examples:
name: Test Weighted Edges Examples
runs-on: ubuntu-22.04
env:
LLM_PROVIDER: openai
LLM_MODEL: gpt-4o-mini
LLM_ENDPOINT: https://api.openai.com/v1/
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: "2024-02-01"
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: text-embedding-3-small
EMBEDDING_ENDPOINT: https://api.openai.com/v1/
EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }}
EMBEDDING_API_VERSION: "2024-02-01"
steps:
- name: Check out repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11'
- name: Test Weighted Edges Example
run: |
poetry run python examples/python/weighted_edges_example.py
- name: Verify Visualization File Created
run: |
if [ -f "examples/python/weighted_graph_visualization.html" ]; then
echo "✅ Visualization file created successfully"
ls -la examples/python/weighted_graph_visualization.html
else
echo "❌ Visualization file not found"
exit 1
fi
code-quality:
name: Code Quality for Weighted Edges
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11'
- name: Run Linting on Weighted Edges Files
uses: astral-sh/ruff-action@v2
with:
args: "check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"
- name: Run Formatting Check on Weighted Edges Files
uses: astral-sh/ruff-action@v2
with:
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"

View file

@ -1,2 +1,3 @@
from .models.DataPoint import DataPoint
from .models.ExtendableDataPoint import ExtendableDataPoint
from .models.Edge import Edge

View file

@ -1,6 +1,6 @@
import pickle
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from datetime import datetime, timezone
from typing_extensions import TypedDict
from typing import Optional, Any, Dict, List
@ -34,6 +34,8 @@ class DataPoint(BaseModel):
- from_dict
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
id: UUID = Field(default_factory=uuid4)
created_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)

View file

@ -0,0 +1,26 @@
from pydantic import BaseModel
from typing import Optional, Any, Dict
class Edge(BaseModel):
"""
Represents edge metadata for relationships between DataPoints.
This class is used to define edge properties like weight when creating
relationships between DataPoints using tuple syntax:
Example:
# Single weight (backward compatible)
has_items: (Edge(weight=0.5), list[Item])
# Multiple weights
has_items: (Edge(weights={"strength": 0.8, "confidence": 0.9, "importance": 0.7}), list[Item])
# Mixed usage
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])
"""
weight: Optional[float] = None
weights: Optional[Dict[str, float]] = None
relationship_type: Optional[str] = None
properties: Optional[Dict[str, Any]] = None

View file

@ -1,5 +1,6 @@
"""Adapter for Generic API LLM provider API"""
import logging
import litellm
import instructor
from typing import Type

View file

@ -281,7 +281,40 @@ def expand_with_nodes_and_edges(
existing_edges_map: Optional[dict[str, bool]] = None,
):
"""
Expand data chunks with nodes and edges, validating against ontology.
- LLM generated docstring
Expand knowledge graphs with validated nodes and edges, integrating ontology information.
This function processes document chunks and their associated knowledge graphs to create
a comprehensive graph structure with entity nodes, entity type nodes, and their relationships.
It validates entities against an ontology resolver and adds ontology-derived nodes and edges
to enhance the knowledge representation.
Args:
data_chunks (list[DocumentChunk]): List of document chunks that contain the source data.
Each chunk should have metadata about what entities it contains.
chunk_graphs (list[KnowledgeGraph]): List of knowledge graphs corresponding to each
data chunk. Each graph contains nodes (entities) and edges (relationships) extracted
from the chunk content.
ontology_resolver (OntologyResolver, optional): Resolver for validating entities and
types against an ontology. If None, a default OntologyResolver is created.
Defaults to None.
existing_edges_map (dict[str, bool], optional): Mapping of existing edge keys to prevent
duplicate edge creation. Keys are formatted as "{source_id}_{target_id}_{relation}".
If None, an empty dictionary is created. Defaults to None.
Returns:
tuple[list, list]: A tuple containing:
- graph_nodes (list): Combined list of data chunks and ontology nodes (EntityType and Entity objects)
- graph_edges (list): List of edge tuples in format (source_id, target_id, relationship_name, properties)
Note:
- Entity nodes are created for each entity found in the knowledge graphs
- EntityType nodes are created for each unique entity type
- Ontology validation is performed to map entities to canonical ontology terms
- Duplicate nodes and edges are prevented using internal mapping and the existing_edges_map
- The function modifies data_chunks in-place by adding entities to their 'contains' attribute
"""
if existing_edges_map is None:
existing_edges_map = {}

View file

@ -1,134 +1,256 @@
from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint
from typing import Tuple, List, Any, Dict, Optional
from cognee.infrastructure.engine import DataPoint, Edge
from cognee.modules.storage.utils import copy_model
def _extract_field_info(field_value: Any) -> Tuple[str, Any, Optional[Edge]]:
"""Extract field type, actual value, and edge metadata from a field value."""
# Handle tuple[Edge, DataPoint]
if (
isinstance(field_value, tuple)
and len(field_value) == 2
and isinstance(field_value[0], Edge)
and isinstance(field_value[1], DataPoint)
):
return "single_datapoint_with_edge", field_value[1], field_value[0]
# Handle tuple[Edge, list[DataPoint]]
if (
isinstance(field_value, tuple)
and len(field_value) == 2
and isinstance(field_value[0], Edge)
and isinstance(field_value[1], list)
and len(field_value[1]) > 0
and isinstance(field_value[1][0], DataPoint)
):
return "list_datapoint_with_edge", field_value[1], field_value[0]
# Handle single DataPoint
if isinstance(field_value, DataPoint):
return "single_datapoint", field_value, None
# Handle list of DataPoints
if (
isinstance(field_value, list)
and len(field_value) > 0
and isinstance(field_value[0], DataPoint)
):
return "list_datapoint", field_value, None
# Regular property
return "property", field_value, None
def _create_edge_properties(
source_id: str, target_id: str, relationship_name: str, edge_metadata: Optional[Edge]
) -> Dict[str, Any]:
"""Create edge properties dictionary with metadata if present."""
properties = {
"source_node_id": source_id,
"target_node_id": target_id,
"relationship_name": relationship_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}
if edge_metadata:
# Add edge metadata
edge_data = edge_metadata.model_dump(exclude_none=True)
properties.update(edge_data)
# Add individual weights as separate fields for easier querying
if edge_metadata.weights is not None:
for weight_name, weight_value in edge_metadata.weights.items():
properties[f"weight_{weight_name}"] = weight_value
return properties
def _get_relationship_key(field_name: str, edge_metadata: Optional[Edge]) -> str:
"""Extract relationship key from edge metadata or use field name as fallback."""
if (
edge_metadata
and hasattr(edge_metadata, "relationship_type")
and edge_metadata.relationship_type
):
return edge_metadata.relationship_type
return field_name
def _generate_property_key(data_point_id: str, relationship_key: str, target_id: str) -> str:
"""Generate a unique property key for visited_properties tracking."""
return f"{data_point_id}{relationship_key}{target_id}"
def _process_datapoint_field(
data_point_id: str,
field_name: str,
datapoints: List[DataPoint],
edge_metadata: Optional[Edge],
visited_properties: Dict[str, bool],
properties_to_visit: set,
excluded_properties: set,
) -> None:
"""Process a field containing DataPoint(s), handling both single and list cases."""
excluded_properties.add(field_name)
relationship_key = _get_relationship_key(field_name, edge_metadata)
for index, datapoint in enumerate(datapoints):
property_key = _generate_property_key(data_point_id, relationship_key, str(datapoint.id))
if property_key in visited_properties:
continue
# For single datapoint, use field_name; for list, use field_name.index
field_identifier = field_name if len(datapoints) == 1 else f"{field_name}.{index}"
properties_to_visit.add(field_identifier)
async def get_graph_from_model(
data_point: DataPoint,
added_nodes: dict,
added_edges: dict,
visited_properties: dict = None,
include_root=True,
):
added_nodes: Dict[str, bool],
added_edges: Dict[str, bool],
visited_properties: Optional[Dict[str, bool]] = None,
include_root: bool = True,
) -> Tuple[List[DataPoint], List[Tuple[str, str, str, Dict[str, Any]]]]:
"""
Extract graph representation from a DataPoint model.
Args:
data_point: The DataPoint to extract graph from
added_nodes: Dictionary tracking already processed nodes
added_edges: Dictionary tracking already processed edges
visited_properties: Dictionary tracking visited properties to avoid cycles
include_root: Whether to include the root node in results
Returns:
Tuple of (nodes, edges) extracted from the model
"""
if str(data_point.id) in added_nodes:
return [], []
nodes = []
edges = []
visited_properties = visited_properties or {}
data_point_id = str(data_point.id)
data_point_properties = {
"type": type(data_point).__name__,
}
data_point_properties = {"type": type(data_point).__name__}
excluded_properties = set()
properties_to_visit = set()
# Analyze all fields to categorize them as properties or relationships
for field_name, field_value in data_point:
if field_name == "metadata":
continue
if isinstance(field_value, DataPoint):
excluded_properties.add(field_name)
field_type, actual_value, edge_metadata = _extract_field_info(field_value)
property_key = str(data_point.id) + field_name + str(field_value.id)
if field_type == "property":
data_point_properties[field_name] = field_value
elif field_type in ["single_datapoint", "single_datapoint_with_edge"]:
_process_datapoint_field(
data_point_id,
field_name,
[actual_value],
edge_metadata,
visited_properties,
properties_to_visit,
excluded_properties,
)
elif field_type in ["list_datapoint", "list_datapoint_with_edge"]:
_process_datapoint_field(
data_point_id,
field_name,
actual_value,
edge_metadata,
visited_properties,
properties_to_visit,
excluded_properties,
)
if property_key in visited_properties:
continue
visited_properties[property_key] = True
properties_to_visit.add(field_name)
continue
if (
isinstance(field_value, list)
and len(field_value) > 0
and isinstance(field_value[0], DataPoint)
):
excluded_properties.add(field_name)
for index, item in enumerate(field_value):
property_key = str(data_point.id) + field_name + str(item.id)
if property_key in visited_properties:
continue
visited_properties[property_key] = True
properties_to_visit.add(f"{field_name}.{index}")
continue
data_point_properties[field_name] = field_value
if include_root and str(data_point.id) not in added_nodes:
# Create node for current DataPoint if needed
if include_root and data_point_id not in added_nodes:
SimpleDataPointModel = copy_model(
type(data_point),
exclude_fields=list(excluded_properties),
type(data_point), exclude_fields=list(excluded_properties)
)
nodes.append(SimpleDataPointModel(**data_point_properties))
added_nodes[str(data_point.id)] = True
added_nodes[data_point_id] = True
for field_name in properties_to_visit:
index = None
if "." in field_name:
field_name, index = field_name.split(".")
# Process all relationships
for field_name_with_index in properties_to_visit:
# Parse field name and index
if "." in field_name_with_index:
field_name, index_str = field_name_with_index.split(".")
index = int(index_str)
else:
field_name, index = field_name_with_index, None
# Get field value and extract edge metadata
field_value = getattr(data_point, field_name)
edge_metadata = None
if (
isinstance(field_value, tuple)
and len(field_value) == 2
and isinstance(field_value[0], Edge)
):
edge_metadata, field_value = field_value
# Get specific datapoint - handle both single and list cases
if index is not None:
field_value = field_value[int(index)]
# List case: extract specific item by index
target_datapoint = field_value[index]
elif isinstance(field_value, list):
# Single datapoint case that was wrapped in a list
target_datapoint = field_value[0]
else:
# True single datapoint case
target_datapoint = field_value
edge_key = str(data_point.id) + str(field_value.id) + field_name
if str(edge_key) not in added_edges:
edges.append(
(
data_point.id,
field_value.id,
field_name,
{
"source_node_id": data_point.id,
"target_node_id": field_value.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
},
)
# Create edge if not already added
edge_key = f"{data_point_id}{target_datapoint.id}{field_name}"
if edge_key not in added_edges:
relationship_name = _get_relationship_key(field_name, edge_metadata)
edge_properties = _create_edge_properties(
data_point.id, target_datapoint.id, relationship_name, edge_metadata
)
added_edges[str(edge_key)] = True
edges.append((data_point.id, target_datapoint.id, relationship_name, edge_properties))
added_edges[edge_key] = True
if str(field_value.id) in added_nodes:
continue
property_key = str(data_point.id) + field_name + str(field_value.id)
# Mark property as visited - CRITICAL for preventing infinite loops
relationship_key = _get_relationship_key(field_name, edge_metadata)
property_key = _generate_property_key(
data_point_id, relationship_key, str(target_datapoint.id)
)
visited_properties[property_key] = True
property_nodes, property_edges = await get_graph_from_model(
field_value,
include_root=True,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
for node in property_nodes:
nodes.append(node)
for edge in property_edges:
edges.append(edge)
# Recursively process target node if not already processed
if str(target_datapoint.id) not in added_nodes:
child_nodes, child_edges = await get_graph_from_model(
target_datapoint,
include_root=True,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
nodes.extend(child_nodes)
edges.extend(child_edges)
return nodes, edges
def get_own_property_nodes(property_nodes, property_edges):
own_properties = []
def get_own_property_nodes(
property_nodes: List[DataPoint], property_edges: List[Tuple[str, str, str, Dict[str, Any]]]
) -> List[DataPoint]:
"""
Filter nodes to return only those that are not destinations of any edges.
destination_nodes = [str(property_edge[1]) for property_edge in property_edges]
Args:
property_nodes: List of all nodes
property_edges: List of all edges
for node in property_nodes:
if str(node.id) in destination_nodes:
continue
own_properties.append(node)
return own_properties
Returns:
List of nodes that are not edge destinations
"""
destination_node_ids = {str(edge[1]) for edge in property_edges}
return [node for node in property_nodes if str(node.id) not in destination_node_ids]

View file

@ -8,6 +8,39 @@ async def retrieve_existing_edges(
data_chunks: list[DataPoint],
chunk_graphs: list[KnowledgeGraph],
) -> dict[str, bool]:
"""
- LLM generated docstring
Retrieve existing edges from the graph database to prevent duplicate edge creation.
This function checks which edges already exist in the graph database by querying
for various types of relationships including structural edges (exists_in, mentioned_in, is_a)
and content-derived edges from the knowledge graphs. It returns a mapping that can be
used to avoid creating duplicate edges during graph expansion.
Args:
data_chunks (list[DataPoint]): List of data point objects that serve as containers
for the entities. Each data chunk represents a source document or data segment.
chunk_graphs (list[KnowledgeGraph]): List of knowledge graphs corresponding to each
data chunk. Each graph contains nodes (entities) and edges (relationships) that
were extracted from the chunk content.
graph_engine (GraphDBInterface): Interface to the graph database that will be queried
to check for existing edges. Must implement the has_edges() method.
Returns:
dict[str, bool]: A mapping of edge keys to boolean values indicating existence.
Edge keys are formatted as concatenated strings: "{source_id}{target_id}{relationship_name}".
All values in the returned dictionary are True (indicating the edge exists).
Note:
- The function generates several types of edges for checking:
* Type node edges: (chunk_id, type_node_id, "exists_in")
* Entity node edges: (chunk_id, entity_node_id, "mentioned_in")
* Type-entity edges: (entity_node_id, type_node_id, "is_a")
* Graph node edges: extracted from the knowledge graph relationships
- Uses generate_node_id() to ensure consistent node ID formatting
- Prevents processing the same node multiple times using a processed_nodes tracker
- The returned mapping can be used with expand_with_nodes_and_edges() to avoid duplicates
"""
processed_nodes = {}
type_node_edges = []
entity_node_edges = []

View file

@ -3,7 +3,7 @@ from uuid import UUID
from decimal import Decimal
from datetime import datetime
from pydantic_core import PydanticUndefined
from pydantic import create_model
from pydantic import create_model, ConfigDict, BaseModel
from cognee.infrastructure.engine import DataPoint
@ -29,9 +29,15 @@ def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list
final_fields = {**fields, **include_fields}
model = create_model(model.__name__, **final_fields)
model.model_rebuild()
return model
# Create a base class with the same configuration as DataPoint
class ConfiguredBase(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
# Create the model inheriting from the configured base
new_model = create_model(model.__name__, __base__=ConfiguredBase, **final_fields)
new_model.model_rebuild()
return new_model
def get_own_properties(data_point: DataPoint):

View file

@ -53,7 +53,40 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
target = str(target)
G.add_edge(source, target)
edge_labels[(source, target)] = relation
links_list.append({"source": source, "target": target, "relation": relation})
# Extract edge metadata including all weights
all_weights = {}
primary_weight = None
if edge_info:
# Single weight (backward compatibility)
if "weight" in edge_info:
all_weights["default"] = edge_info["weight"]
primary_weight = edge_info["weight"]
# Multiple weights
if "weights" in edge_info and isinstance(edge_info["weights"], dict):
all_weights.update(edge_info["weights"])
# Use the first weight as primary for visual thickness if no default weight
if primary_weight is None and edge_info["weights"]:
primary_weight = next(iter(edge_info["weights"].values()))
# Individual weight fields (weight_strength, weight_confidence, etc.)
for key, value in edge_info.items():
if key.startswith("weight_") and isinstance(value, (int, float)):
weight_name = key[7:] # Remove "weight_" prefix
all_weights[weight_name] = value
link_data = {
"source": source,
"target": target,
"relation": relation,
"weight": primary_weight, # Primary weight for backward compatibility
"all_weights": all_weights, # All weights for display
"relationship_type": edge_info.get("relationship_type") if edge_info else None,
"edge_info": edge_info if edge_info else {},
}
links_list.append(link_data)
html_template = """
<!DOCTYPE html>
@ -66,13 +99,33 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
svg { width: 100vw; height: 100vh; display: block; }
.links line { stroke: rgba(255, 255, 255, 0.4); stroke-width: 2px; }
.links line.weighted { stroke: rgba(255, 215, 0, 0.7); }
.links line.multi-weighted { stroke: rgba(0, 255, 127, 0.8); }
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
.node-label { font-size: 5px; font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
.tooltip {
position: absolute;
text-align: left;
padding: 8px;
font-size: 12px;
background: rgba(0, 0, 0, 0.9);
color: white;
border: 1px solid rgba(255, 255, 255, 0.3);
border-radius: 4px;
pointer-events: none;
opacity: 0;
transition: opacity 0.2s;
z-index: 1000;
max-width: 300px;
word-wrap: break-word;
}
</style>
</head>
<body>
<svg></svg>
<div class="tooltip" id="tooltip"></div>
<script>
var nodes = {nodes};
var links = {links};
@ -82,6 +135,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
height = window.innerHeight;
var container = svg.append("g");
var tooltip = d3.select("#tooltip");
var simulation = d3.forceSimulation(nodes)
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
@ -95,7 +149,58 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
.selectAll("line")
.data(links)
.enter().append("line")
.attr("stroke-width", 2);
.attr("stroke-width", d => {
if (d.weight) return Math.max(2, d.weight * 5);
if (d.all_weights && Object.keys(d.all_weights).length > 0) {
var avgWeight = Object.values(d.all_weights).reduce((a, b) => a + b, 0) / Object.values(d.all_weights).length;
return Math.max(2, avgWeight * 5);
}
return 2;
})
.attr("class", d => {
if (d.all_weights && Object.keys(d.all_weights).length > 1) return "multi-weighted";
if (d.weight || (d.all_weights && Object.keys(d.all_weights).length > 0)) return "weighted";
return "";
})
.on("mouseover", function(d) {
// Create tooltip content for edge
var content = "<strong>Edge Information</strong><br/>";
content += "Relationship: " + d.relation + "<br/>";
// Show all weights
if (d.all_weights && Object.keys(d.all_weights).length > 0) {
content += "<strong>Weights:</strong><br/>";
Object.keys(d.all_weights).forEach(function(weightName) {
content += "&nbsp;&nbsp;" + weightName + ": " + d.all_weights[weightName] + "<br/>";
});
} else if (d.weight !== null && d.weight !== undefined) {
content += "Weight: " + d.weight + "<br/>";
}
if (d.relationship_type) {
content += "Type: " + d.relationship_type + "<br/>";
}
// Add other edge properties
if (d.edge_info) {
Object.keys(d.edge_info).forEach(function(key) {
if (key !== 'weight' && key !== 'weights' && key !== 'relationship_type' &&
key !== 'source_node_id' && key !== 'target_node_id' &&
key !== 'relationship_name' && key !== 'updated_at' &&
!key.startsWith('weight_')) {
content += key + ": " + d.edge_info[key] + "<br/>";
}
});
}
tooltip.html(content)
.style("left", (d3.event.pageX + 10) + "px")
.style("top", (d3.event.pageY - 10) + "px")
.style("opacity", 1);
})
.on("mouseout", function(d) {
tooltip.style("opacity", 0);
});
var edgeLabels = container.append("g")
.attr("class", "edge-labels")
@ -103,7 +208,19 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
.data(links)
.enter().append("text")
.attr("class", "edge-label")
.text(d => d.relation);
.text(d => {
var label = d.relation;
if (d.all_weights && Object.keys(d.all_weights).length > 1) {
// Show count of weights for multiple weights
label += " (" + Object.keys(d.all_weights).length + " weights)";
} else if (d.weight) {
label += " (" + d.weight + ")";
} else if (d.all_weights && Object.keys(d.all_weights).length === 1) {
var singleWeight = Object.values(d.all_weights)[0];
label += " (" + singleWeight + ")";
}
return label;
});
var nodeGroup = container.append("g")
.attr("class", "nodes")
@ -178,7 +295,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
</script>
<svg style="position: fixed; bottom: 10px; right: 10px; width: 150px; height: auto; z-index: 9999;" viewBox="0 0 158 44" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.7496 4.92654C7.83308 4.92654 4.8585 7.94279 4.8585 11.3612V14.9304C4.8585 18.3488 7.83308 21.3651 11.7496 21.3651C13.6831 21.3651 15.0217 20.8121 16.9551 19.3543C18.0458 18.5499 19.5331 18.8013 20.3263 19.9072C21.1195 21.0132 20.8717 22.5213 19.781 23.3257C17.3518 25.0851 15.0217 26.2414 11.7 26.2414C5.35425 26.2414 0 21.2646 0 14.9304V11.3612C0 4.97681 5.35425 0.0502739 11.7 0.0502739C15.0217 0.0502739 17.3518 1.2065 19.781 2.96598C20.8717 3.77032 21.1195 5.27843 20.3263 6.38439C19.5331 7.49035 18.0458 7.69144 16.9551 6.93737C15.0217 5.52979 13.6831 4.92654 11.7496 4.92654ZM35.5463 4.92654C31.7289 4.92654 28.6552 8.04333 28.6552 11.8639V14.478C28.6552 18.2986 31.7289 21.4154 35.5463 21.4154C39.3141 21.4154 42.3878 18.2986 42.3878 14.478V11.8639C42.3878 8.04333 39.3141 4.92654 35.5463 4.92654ZM23.7967 11.8639C23.7967 5.32871 29.0518 0 35.5463 0C42.0408 0 47.2463 5.32871 47.2463 11.8639V14.478C47.2463 21.0132 42.0408 26.3419 35.5463 26.3419C29.0518 26.3419 23.7967 21.0635 23.7967 14.478V11.8639ZM63.3091 5.07736C59.4917 5.07736 56.418 8.19415 56.418 12.0147C56.418 15.8353 59.4917 18.9521 63.3091 18.9521C67.1265 18.9521 70.1506 15.8856 70.1506 12.0147C70.1506 8.14388 67.0769 5.07736 63.3091 5.07736ZM51.5595 11.9645C51.5595 5.42925 56.8146 0.150814 63.3091 0.150814C66.0854 0.150814 68.5642 1.10596 70.5968 2.71463L72.4311 0.904876C73.3731 -0.0502693 74.9099 -0.0502693 75.8519 0.904876C76.7938 1.86002 76.7938 3.41841 75.8519 4.37356L73.7201 6.53521C74.5629 8.19414 75.0587 10.0542 75.0587 12.0147C75.0587 18.4997 69.8532 23.8284 63.3587 23.8284C63.3091 23.8284 63.2099 23.8284 63.1603 23.8284H58.0044C57.1616 23.8284 56.4675 24.5322 56.4675 25.3868C56.4675 26.2414 57.1616 26.9452 58.0044 26.9452H64.6476H66.7794C68.5146 26.9452 70.3489 27.4479 71.7866 28.6041C73.2739 29.8106 74.2159 31.5701 74.4142 33.7317C74.7116 37.6026 72.0345 40.2166 69.8532 41.0713L63.8048 43.7859C62.5654 44.3389 61.1277 43.7859 60.6319 42.5291C60.0866 41.2723 60.6319 39.8648 61.8714 39.3118L68.0188 36.5972C68.0684 36.5972 68.118 36.5469 68.1675 36.5469C68.4154 36.4463 68.8616 36.1447 69.2087 35.6923C69.5061 35.2398 69.7044 34.7371 69.6548 34.1339C69.6053 33.229 69.2582 32.7263 68.8616 32.4247C68.4154 32.0728 67.7214 31.8214 66.8786 31.8214H58.2027C58.1531 31.8214 58.1531 31.8214 58.1035 31.8214H58.054C54.534 31.8214 51.6586 28.956 51.6586 25.3868C51.6586 23.0743 52.8485 21.0635 54.6828 19.9072C52.6997 17.7959 51.5595 15.031 51.5595 11.9645ZM90.8736 5.07736C87.0562 5.07736 83.9824 8.19415 83.9824 12.0147V23.9289C83.9824 25.2862 82.8917 26.3922 81.5532 26.3922C80.2146 26.3922 79.1239 25.2862 79.1239 23.9289V11.9645C79.1239 5.42925 84.379 0.150814 90.824 0.150814C97.2689 0.150814 102.524 5.42925 102.524 11.9645V23.8786C102.524 25.2359 101.433 26.3419 100.095 26.3419C98.7562 26.3419 97.6655 25.2359 97.6655 23.8786V11.9645C97.7647 8.14387 94.6414 5.07736 90.8736 5.07736ZM119.43 5.07736C115.513 5.07736 112.39 8.24441 112.39 12.065V14.5785C112.39 18.4494 115.513 21.5662 119.43 21.5662C120.768 21.5662 122.057 21.164 123.098 20.5105C124.238 19.8067 125.726 20.1586 126.42 21.3148C127.114 22.4711 126.767 23.9792 125.627 24.683C123.842 25.7889 121.71 26.4425 119.43 26.4425C112.885 26.4425 107.581 21.1137 107.581 14.5785V12.065C107.581 5.47952 112.935 0.201088 119.43 0.201088C125.032 0.201088 129.692 4.07194 130.931 9.3001L131.427 11.3612L121.115 15.584C119.876 16.0867 118.488 15.4834 117.942 14.2266C117.447 12.9699 118.041 11.5623 119.281 11.0596L125.478 8.54604C124.238 6.43466 122.008 5.07736 119.43 5.07736ZM146.003 5.07736C142.086 5.07736 138.963 8.24441 138.963 12.065V14.5785C138.963 18.4494 142.086 21.5662 146.003 21.5662C147.341 21.5662 148.63 21.164 149.671 20.5105C150.217 20.1586 150.663 19.8067 151.109 19.304C152.001 18.2986 153.538 18.2483 154.53 19.2034C155.521 20.1083 155.571 21.6667 154.629 22.6721C153.935 23.4262 153.092 24.13 152.2 24.683C150.415 25.7889 148.283 26.4425 146.003 26.4425C139.458 26.4425 134.154 21.1137 134.154 14.5785V12.065C134.154 5.47952 139.508 0.201088 146.003 0.201088C151.605 0.201088 156.265 4.07194 157.504 9.3001L158 11.3612L147.688 15.584C146.449 16.0867 145.061 15.4834 144.515 14.2266C144.019 12.9699 144.614 11.5623 145.854 11.0596L152.051 8.54604C150.762 6.43466 148.58 5.07736 146.003 5.07736Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.7496 4.92654C7.83308 4.92654 4.8585 7.94279 4.8585 11.3612V14.9304C4.8585 18.3488 7.83308 21.3651 11.7496 21.3651C13.6831 21.3651 15.0217 20.8121 16.9551 19.3543C18.0458 18.5499 19.5331 18.8013 20.3263 19.9072C21.1195 21.0132 20.8717 22.5213 19.781 23.3257C17.3518 25.0851 15.0217 26.2414 11.7 26.2414C5.35425 26.2414 0 21.2646 0 14.9304V11.3612C0 4.97681 5.35425 0.0502739 11.7 0.0502739C15.0217 0.0502739 17.3518 1.2065 19.781 2.96598C20.8717 3.77032 21.1195 5.27843 20.3263 6.38439C19.5331 7.49035 18.0458 7.69144 16.9551 6.93737C15.0217 5.52979 13.6831 4.92654 11.7496 4.92654ZM35.5463 4.92654C31.7289 4.92654 28.6552 8.04333 28.6552 11.8639V14.478C28.6552 18.2986 31.7289 21.4154 35.5463 21.4154C39.3141 21.4154 42.3878 18.2986 42.3878 14.478V11.8639C42.3878 8.04333 39.3141 4.92654 35.5463 4.92654ZM23.7967 11.8639C23.7967 5.32871 29.0518 0 35.5463 0C42.0408 0 47.2463 5.32871 47.2463 11.8639V14.478C47.2463 21.0132 42.0408 26.3419 35.5463 26.3419C29.0518 26.3419 23.7967 21.0635 23.7967 14.478V11.8639ZM63.3091 5.07736C59.4917 5.07736 56.418 8.19415 56.418 12.0147C56.418 15.8353 59.4917 18.9521 63.3091 18.9521C67.1265 18.9521 70.1506 15.8856 70.1506 12.0147C70.1506 8.14388 67.0769 5.07736 63.3091 5.07736ZM51.5595 11.9645C51.5595 5.42925 56.8146 0.150814 63.3091 0.150814C66.0854 0.150814 68.5642 1.10596 70.5968 2.71463L72.4311 0.904876C73.3731 -0.0502693 74.9099 -0.0502693 75.8519 0.904876C76.7938 1.86002 76.7938 3.41841 75.8519 4.37356L73.7201 6.53521C74.5629 8.19414 75.0587 10.0542 75.0587 12.0147C75.0587 18.4997 69.8532 23.8284 63.3587 23.8284C63.3091 23.8284 63.2099 23.8284 63.1603 23.8284H58.0044C57.1616 23.8284 56.4675 24.5322 56.4675 25.3868C56.4675 26.2414 57.1616 26.9452 58.0044 26.9452H64.6476H66.7794C68.5146 26.9452 70.3489 27.4479 71.7866 28.6041C73.2739 29.8106 74.2159 31.5701 74.4142 33.7317C74.7116 37.6026 72.0345 40.2166 69.8532 41.0713L63.8048 43.7859C62.5654 44.3389 61.1277 43.7859 60.6319 42.5291C60.0866 41.2723 60.6319 39.8648 61.8714 39.3118L68.0188 36.5972C68.0684 36.5972 68.118 36.5469 68.1675 36.5469C68.4154 36.4463 68.8616 36.1447 69.2087 35.6923C69.5061 35.2398 69.7044 34.7371 69.6548 34.1339C69.6053 33.229 69.2582 32.7263 68.8616 32.4247C68.4154 32.0728 67.7214 31.8214 66.8786 31.8214H58.2027C58.1531 31.8214 58.1531 31.8214 58.1035 31.8214H58.054C54.534 31.8214 51.6586 28.956 51.6586 25.3868C51.6586 23.0743 52.8485 21.0635 54.6828 19.9072C52.6997 17.7959 51.5595 15.031 51.5595 11.9645ZM90.8736 5.07736C87.0562 5.07736 83.9824 8.19415 83.9824 12.0147V23.9289C83.9824 25.2862 82.8917 26.3922 81.5532 26.3922C80.2146 26.3922 79.1239 25.2862 79.1239 23.9289V11.9645C79.1239 5.42925 84.379 0.150814 90.824 0.150814C97.2689 0.150814 102.524 5.42925 102.524 11.9645V23.8786C102.524 25.2359 101.433 26.3419 100.095 26.3419C98.7562 26.3419 97.6655 25.2359 97.6655 23.8786V11.9645C97.7647 8.14387 94.6414 5.07736 90.8736 5.07736ZM119.43 5.07736C115.513 5.07736 112.39 8.24441 112.39 12.065V14.5785C112.39 18.4494 115.513 21.5662 119.43 21.5662C120.768 21.5662 122.057 21.164 123.098 20.5105C124.238 19.8067 125.726 20.1586 126.42 21.3148C127.114 22.4711 126.767 23.9792 125.627 24.683C123.842 25.7889 121.71 26.4425 119.43 26.4425C112.885 26.4425 107.581 21.1137 107.581 14.5785V12.065C107.581 5.47952 112.935 0.201088 119.43 0.201088C125.032 0.201088 129.692 4.07194 130.931 9.3001L131.427 11.3612L121.115 15.584C119.876 16.0867 118.488 15.4834 117.942 14.2266C117.447 12.9699 118.041 11.5623 119.281 11.0596L125.478 8.54604C124.238 6.43466 122.008 5.07736 119.43 5.07736ZM146.003 5.07736C142.086 5.07736 138.963 8.24441 138.963 12.065V14.5785C138.963 18.4494 142.086 21.5662 146.003 21.5662C147.341 21.5662 148.630 21.164 149.671 20.5105C150.217 20.1586 150.663 19.8067 151.109 19.304C152.001 18.2986 153.538 18.2483 154.53 19.2034C155.521 20.1083 155.571 21.6667 154.629 22.6721C153.935 23.4262 153.092 24.13 152.2 24.683C150.415 25.7889 148.283 26.4425 146.003 26.4425C139.458 26.4425 134.154 21.1137 134.154 14.5785V12.065C134.154 5.47952 139.508 0.201088 146.003 0.201088C151.605 0.201088 156.265 4.07194 157.504 9.3001L158 11.3612L147.688 15.584C146.449 16.0867 145.061 15.4834 144.515 14.2266C144.019 12.9699 144.614 11.5623 145.854 11.0596L152.051 8.54604C150.762 6.43466 148.58 5.07736 146.003 5.07736Z" fill="white"/>
</svg>
</body>
</html>

View file

@ -0,0 +1,391 @@
import pytest
from typing import List, Any
from pydantic import SkipValidation
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.graph.utils import get_graph_from_model
class Product(DataPoint):
name: str
description: str
metadata: dict = {"index_fields": ["name"]}
class Category(DataPoint):
name: str
description: str
products: List[Product] = []
metadata: dict = {"index_fields": ["name"]}
class User(DataPoint):
name: str
email: str
# Weighted relationships
purchased_products: SkipValidation[Any] = None # (Edge, list[Product])
favorite_categories: SkipValidation[Any] = None # (Edge, list[Category])
follows: SkipValidation[Any] = None # (Edge, list["User"])
metadata: dict = {"index_fields": ["name", "email"]}
class Company(DataPoint):
name: str
description: str
employees: SkipValidation[Any] = None # (Edge, list[User])
partners: SkipValidation[Any] = None # (Edge, list["Company"])
metadata: dict = {"index_fields": ["name"]}
@pytest.mark.asyncio
async def test_single_weight_edge():
"""Test get_graph_from_model with single weight edges (backward compatible)"""
product1 = Product(name="Laptop", description="Gaming laptop")
product2 = Product(name="Mouse", description="Wireless mouse")
user = User(
name="John Doe",
email="john@example.com",
purchased_products=(Edge(weight=0.8, relationship_type="purchased"), [product1, product2]),
)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(user, added_nodes, added_edges, visited_properties)
# Should have user + 2 products = 3 nodes
assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}"
# Should have 2 edges (user -> product1, user -> product2)
assert len(edges) == 2, f"Expected 2 edges, got {len(edges)}"
# Check edge properties contain weight
for edge in edges:
source_id, target_id, relationship_name, edge_properties = edge
assert "weight" in edge_properties, "Edge should contain weight property"
assert edge_properties["weight"] == 0.8, (
f"Expected weight 0.8, got {edge_properties['weight']}"
)
assert edge_properties["relationship_name"] == "purchased"
@pytest.mark.asyncio
async def test_multiple_weights_edge():
"""Test get_graph_from_model with multiple weights on edges"""
category1 = Category(name="Electronics", description="Electronic products")
category2 = Category(name="Gaming", description="Gaming products")
user = User(
name="Alice Smith",
email="alice@example.com",
favorite_categories=(
Edge(
weights={
"interest_level": 0.9,
"time_spent": 0.7,
"purchase_frequency": 0.8,
"expertise": 0.6,
},
relationship_type="interested_in",
),
[category1, category2],
),
)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(user, added_nodes, added_edges, visited_properties)
# Should have user + 2 categories = 3 nodes
assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}"
# Should have 2 edges
assert len(edges) == 2, f"Expected 2 edges, got {len(edges)}"
# Check edge properties contain multiple weights
for edge in edges:
source_id, target_id, relationship_name, edge_properties = edge
assert edge_properties["relationship_name"] == "interested_in"
# Check individual weight fields
assert "weight_interest_level" in edge_properties
assert "weight_time_spent" in edge_properties
assert "weight_purchase_frequency" in edge_properties
assert "weight_expertise" in edge_properties
assert edge_properties["weight_interest_level"] == 0.9
assert edge_properties["weight_time_spent"] == 0.7
assert edge_properties["weight_purchase_frequency"] == 0.8
assert edge_properties["weight_expertise"] == 0.6
@pytest.mark.asyncio
async def test_mixed_single_and_multiple_weights():
"""Test get_graph_from_model with both single weight and multiple weights on same edge"""
product = Product(name="Smartphone", description="Latest smartphone")
user = User(
name="Bob Wilson",
email="bob@example.com",
purchased_products=(
Edge(
weight=0.7, # Single weight (backward compatible)
weights={"satisfaction": 0.9, "value_for_money": 0.6}, # Multiple weights
relationship_type="owns",
),
[product],
),
)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(user, added_nodes, added_edges, visited_properties)
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}"
edge_properties = edges[0][3]
# Should have both single weight and multiple weights
assert "weight" in edge_properties, "Should have backward compatible weight field"
assert edge_properties["weight"] == 0.7
assert "weight_satisfaction" in edge_properties
assert "weight_value_for_money" in edge_properties
assert edge_properties["weight_satisfaction"] == 0.9
assert edge_properties["weight_value_for_money"] == 0.6
@pytest.mark.asyncio
async def test_complex_weighted_relationships():
"""Test complex scenario with multiple entities and various weighted relationships"""
# Create products and categories
product1 = Product(name="Gaming Chair", description="Ergonomic gaming chair")
product2 = Product(name="Mechanical Keyboard", description="RGB mechanical keyboard")
category = Category(name="Gaming Accessories", description="Gaming accessories category")
category.products = [product1, product2]
# Create users with different weighted relationships
user1 = User(
name="Gamer Pro",
email="gamerpro@example.com",
purchased_products=(
Edge(
weights={
"satisfaction": 0.95,
"frequency_of_use": 0.9,
"recommendation_likelihood": 0.8,
},
relationship_type="purchased",
),
[product1, product2],
),
favorite_categories=(Edge(weight=0.9, relationship_type="follows"), [category]),
)
user2 = User(
name="Casual User",
email="casual@example.com",
purchased_products=(Edge(weight=0.6, relationship_type="purchased"), [product1]),
)
# Create weighted user relationships
user1.follows = (
Edge(
weights={
"friendship_level": 0.7,
"shared_interests": 0.8,
"communication_frequency": 0.5,
},
relationship_type="follows",
),
[user2],
)
added_nodes = {}
added_edges = {}
visited_properties = {}
# Process user1 (which should process all connected nodes)
nodes, edges = await get_graph_from_model(user1, added_nodes, added_edges, visited_properties)
# Should have: user1, user2, 2 products, 1 category = 5 nodes
assert len(nodes) == 5, f"Expected 5 nodes, got {len(nodes)}"
# Should have multiple edges with different weight configurations
assert len(edges) > 0, "Should have edges"
# Verify that different edge types are created correctly
edge_types = set()
weighted_edges = 0
multi_weighted_edges = 0
for edge in edges:
source_id, target_id, relationship_name, edge_properties = edge
edge_types.add(relationship_name)
if "weight" in edge_properties:
weighted_edges += 1
# Count edges with multiple weights
multi_weight_fields = [k for k in edge_properties.keys() if k.startswith("weight_")]
if len(multi_weight_fields) > 1:
multi_weighted_edges += 1
assert "purchased" in edge_types
assert "follows" in edge_types
assert weighted_edges > 0, "Should have edges with weights"
assert multi_weighted_edges > 0, "Should have edges with multiple weights"
@pytest.mark.asyncio
async def test_company_hierarchy_with_weights():
"""Test hierarchical company structure with weighted relationships"""
# Create users
ceo = User(name="CEO", email="ceo@company.com")
manager = User(name="Manager", email="manager@company.com")
developer = User(name="Developer", email="dev@company.com")
# Create companies with weighted employee relationships
startup = Company(
name="Tech Startup",
description="Innovative tech startup",
employees=(
Edge(
weights={"seniority": 0.9, "performance": 0.8, "leadership": 0.95},
relationship_type="employs",
),
[ceo, manager, developer],
),
)
corporation = Company(name="Big Corp", description="Large corporation")
# Create partnership with weights
startup.partners = (
Edge(
weights={"trust_level": 0.7, "business_value": 0.8, "strategic_importance": 0.6},
relationship_type="partners_with",
),
[corporation],
)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(startup, added_nodes, added_edges, visited_properties)
# Should have: startup, corporation, 3 users = 5 nodes
assert len(nodes) == 5, f"Expected 5 nodes, got {len(nodes)}"
# Verify weighted relationships are properly stored
partnership_edges = [e for e in edges if e[2] == "partners_with"]
employee_edges = [e for e in edges if e[2] == "employs"]
assert len(partnership_edges) == 1, "Should have one partnership edge"
assert len(employee_edges) == 3, "Should have three employee edges"
# Check partnership edge weights
partnership_props = partnership_edges[0][3]
assert "weight_trust_level" in partnership_props
assert "weight_business_value" in partnership_props
assert "weight_strategic_importance" in partnership_props
# Check employee edge weights
for edge in employee_edges:
props = edge[3]
assert "weight_seniority" in props
assert "weight_performance" in props
assert "weight_leadership" in props
@pytest.mark.asyncio
async def test_edge_metadata_preservation():
"""Test that all edge metadata is preserved correctly in weighted edges"""
product = Product(name="Test Product", description="A test product")
user = User(
name="Test User",
email="test@example.com",
purchased_products=(
Edge(weight=0.8, weights={"quality": 0.9, "price": 0.7}, relationship_type="purchased"),
[product],
),
)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(user, added_nodes, added_edges, visited_properties)
assert len(edges) == 1, "Should have exactly one edge"
edge_properties = edges[0][3]
# Check all required metadata is present
assert "source_node_id" in edge_properties
assert "target_node_id" in edge_properties
assert "relationship_name" in edge_properties
assert "updated_at" in edge_properties
# Check relationship type
assert edge_properties["relationship_name"] == "purchased"
# Check weights are properly stored
assert "weight" in edge_properties
assert edge_properties["weight"] == 0.8
assert "weight_quality" in edge_properties
assert edge_properties["weight_quality"] == 0.9
assert "weight_price" in edge_properties
assert edge_properties["weight_price"] == 0.7
@pytest.mark.asyncio
async def test_no_weights_edge():
"""Test that edges without weights still work correctly"""
product = Product(name="Simple Product", description="No weights product")
user = User(
name="Simple User",
email="simple@example.com",
purchased_products=(Edge(relationship_type="purchased"), [product]),
)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(user, added_nodes, added_edges, visited_properties)
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}"
edge_properties = edges[0][3]
# Should have basic metadata but no weights
assert "source_node_id" in edge_properties
assert "target_node_id" in edge_properties
assert "relationship_name" in edge_properties
assert "updated_at" in edge_properties
assert edge_properties["relationship_name"] == "purchased"
# Should not have weight fields
assert "weight" not in edge_properties
weight_fields = [k for k in edge_properties.keys() if k.startswith("weight_")]
assert len(weight_fields) == 0, f"Should have no weight fields, but found: {weight_fields}"

View file

@ -0,0 +1,137 @@
import asyncio
from os import path
from typing import Any
from pydantic import SkipValidation
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.tasks.storage import add_data_points
import cognee
class Clothes(DataPoint):
name: str
description: str
class Object(DataPoint):
name: str
description: str
has_clothes: list[Clothes]
class Person(DataPoint):
name: str
description: str
has_items: SkipValidation[Any] # (Edge, list[Clothes])
has_objects: SkipValidation[Any] # (Edge, list[Object])
knows: SkipValidation[Any] # (Edge, list["Person"])
async def main():
# Clear the database for a clean state
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Create clothes items
item1 = Clothes(name="Shirt", description="A blue shirt")
item2 = Clothes(name="Pants", description="Black pants")
item3 = Clothes(name="Jacket", description="Leather jacket")
# Create object with simple relationship to clothes
object1 = Object(
name="Closet", description="A wooden closet", has_clothes=[item1, item2, item3]
)
# Create people with various weighted relationships
person1 = Person(
name="John",
description="A software engineer",
# Single weight (backward compatible)
has_items=(Edge(weight=0.8, relationship_type="owns"), [item1, item2]),
# Simple relationship without weights
has_objects=(Edge(relationship_type="stores_in"), [object1]),
knows=[],
)
person2 = Person(
name="Alice",
description="A designer",
# Multiple weights on edge
has_items=(
Edge(
weights={
"ownership": 0.9,
"frequency_of_use": 0.7,
"emotional_attachment": 0.8,
"monetary_value": 0.6,
},
relationship_type="owns",
),
[item3],
),
has_objects=(Edge(relationship_type="uses"), [object1]),
knows=[],
)
person3 = Person(
name="Bob",
description="A friend",
# Mixed: single weight + multiple weights
has_items=(
Edge(
weight=0.5, # Default weight
weights={"trust_level": 0.9, "communication_frequency": 0.6},
relationship_type="borrows",
),
[item1],
),
has_objects=[],
knows=[],
)
# Create relationships between people with multiple weights
person1.knows = (
Edge(
weights={
"friendship_strength": 0.9,
"trust_level": 0.8,
"years_known": 0.7,
"shared_interests": 0.6,
},
relationship_type="friend",
),
[person2, person3],
)
person2.knows = (
Edge(
weights={"professional_collaboration": 0.8, "personal_friendship": 0.6},
relationship_type="colleague",
),
[person1],
)
all_data_points = [item1, item2, item3, object1, person1, person2, person3]
# Add data points to the graph
await add_data_points(all_data_points)
# Visualize the graph
graph_visualization_path = path.join(
path.dirname(__file__), "weighted_graph_visualization.html"
)
await visualize_graph(graph_visualization_path)
print("Graph with multiple weighted edges has been created and visualized!")
print(f"Visualization saved to: {graph_visualization_path}")
print("\nFeatures demonstrated:")
print("- Single weight edges (backward compatible)")
print("- Multiple weights on single edges")
print("- Mixed single + multiple weights")
print("- Hover over edges to see all weight information")
print("- Different visual styling for single vs. multiple weighted edges")
if __name__ == "__main__":
asyncio.run(main())

File diff suppressed because one or more lines are too long