diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index b1c850965..8bfc37ac9 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -1,4 +1,4 @@ -from typing import Union, BinaryIO +from typing import Union, BinaryIO, List, Optional from cognee.modules.users.models import User from cognee.modules.users.methods import get_default_user from cognee.modules.pipelines import run_tasks, Task @@ -16,6 +16,7 @@ async def add( data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_name: str = "main_dataset", user: User = None, + NodeSet: Optional[List[str]] = None, ): # Create tables for databases await create_relational_db_and_tables() @@ -36,7 +37,7 @@ async def add( if user is None: user = await get_default_user() - tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)] + tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, NodeSet)] dataset_id = uuid5(NAMESPACE_OID, dataset_name) pipeline = run_tasks( diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index a5912fecd..e28b91f8f 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -25,6 +25,7 @@ from cognee.tasks.documents import ( from cognee.tasks.graph import extract_graph_from_data from cognee.tasks.storage import add_data_points from cognee.tasks.summarization import summarize_text +from cognee.tasks.node_set import apply_node_set from cognee.modules.chunking.TextChunker import TextChunker logger = get_logger("cognify") @@ -139,6 +140,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's task_config={"batch_size": 10}, ), Task(add_data_points, task_config={"batch_size": 10}), + Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values to DataPoints ] return default_tasks diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index a315f95f1..3abadd51c 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import Optional, Any, Dict +from typing import Optional, Any, Dict, List from uuid import UUID, uuid4 from pydantic import BaseModel, Field @@ -27,6 +27,7 @@ class DataPoint(BaseModel): topological_rank: Optional[int] = 0 metadata: Optional[MetaData] = {"index_fields": []} type: str = Field(default_factory=lambda: DataPoint.__name__) + NodeSet: Optional[List[str]] = None # List of nodes this data point is associated with def __init__(self, **data): super().__init__(**data) diff --git a/cognee/modules/data/models/Data.py b/cognee/modules/data/models/Data.py index bbfdbed32..422e013df 100644 --- a/cognee/modules/data/models/Data.py +++ b/cognee/modules/data/models/Data.py @@ -20,6 +20,7 @@ class Data(Base): owner_id = Column(UUID, index=True) content_hash = Column(String) external_metadata = Column(JSON) + node_set = Column(JSON, nullable=True) # Store NodeSet as JSON list of strings token_count = Column(Integer) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) @@ -44,5 +45,6 @@ class Data(Base): "rawDataLocation": self.raw_data_location, "createdAt": self.created_at.isoformat(), "updatedAt": self.updated_at.isoformat() if self.updated_at else None, + "nodeSet": self.node_set, # "datasets": [dataset.to_json() for dataset in self.datasets] } diff --git a/cognee/modules/visualization/cognee_network_visualization.py b/cognee/modules/visualization/cognee_network_visualization.py index 86ad7ce49..cf4162044 100644 --- a/cognee/modules/visualization/cognee_network_visualization.py +++ b/cognee/modules/visualization/cognee_network_visualization.py @@ -1,4 +1,4 @@ -from cognee.shared.logging_utils import get_logger +import logging import networkx as nx import json import os @@ -6,7 +6,7 @@ import os from cognee.infrastructure.files.storage import LocalStorage -logger = get_logger() +logger = logging.getLogger(__name__) async def cognee_network_visualization(graph_data, destination_file_path: str = None): @@ -20,8 +20,6 @@ async def cognee_network_visualization(graph_data, destination_file_path: str = "EntityType": "#6510f4", "DocumentChunk": "#801212", "TextSummary": "#1077f4", - "TableRow": "#f47710", - "TableType": "#6510f4", "default": "#D3D3D3", } @@ -184,8 +182,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str = """ - html_content = html_template.replace("{nodes}", json.dumps(nodes_list, default=str)) - html_content = html_content.replace("{links}", json.dumps(links_list, default=str)) + html_content = html_template.replace("{nodes}", json.dumps(nodes_list)) + html_content = html_content.replace("{links}", json.dumps(links_list)) if not destination_file_path: home_dir = os.path.expanduser("~") diff --git a/cognee/tasks/ingestion/ingest_data.py b/cognee/tasks/ingestion/ingest_data.py index 78475d106..e345ff408 100644 --- a/cognee/tasks/ingestion/ingest_data.py +++ b/cognee/tasks/ingestion/ingest_data.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Optional import dlt import cognee.modules.ingestion as ingestion @@ -12,9 +12,10 @@ from .save_data_item_to_storage import save_data_item_to_storage from typing import Union, BinaryIO import inspect +import json -async def ingest_data(data: Any, dataset_name: str, user: User): +async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None): destination = get_dlt_destination() pipeline = dlt.pipeline( @@ -43,9 +44,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User): "mime_type": file_metadata["mime_type"], "content_hash": file_metadata["content_hash"], "owner_id": str(user.id), + "node_set": json.dumps(NodeSet) if NodeSet else None, } - async def store_data_to_dataset(data: Any, dataset_name: str, user: User): + async def store_data_to_dataset(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None): if not isinstance(data, list): # Convert data to a list as we work with lists further down. data = [data] @@ -81,6 +83,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User): await session.execute(select(Data).filter(Data.id == data_id)) ).scalar_one_or_none() + ext_metadata = get_external_metadata_dict(data_item) + if NodeSet: + ext_metadata["node_set"] = NodeSet + if data_point is not None: data_point.name = file_metadata["name"] data_point.raw_data_location = file_metadata["file_path"] @@ -88,7 +94,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User): data_point.mime_type = file_metadata["mime_type"] data_point.owner_id = user.id data_point.content_hash = file_metadata["content_hash"] - data_point.external_metadata = (get_external_metadata_dict(data_item),) + data_point.external_metadata = ext_metadata + data_point.node_set = json.dumps(NodeSet) if NodeSet else None await session.merge(data_point) else: data_point = Data( @@ -99,7 +106,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User): mime_type=file_metadata["mime_type"], owner_id=user.id, content_hash=file_metadata["content_hash"], - external_metadata=get_external_metadata_dict(data_item), + external_metadata=ext_metadata, + node_set=json.dumps(NodeSet) if NodeSet else None, token_count=-1, ) @@ -124,7 +132,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User): db_engine = get_relational_engine() - file_paths = await store_data_to_dataset(data, dataset_name, user) + file_paths = await store_data_to_dataset(data, dataset_name, user, NodeSet) # Note: DLT pipeline has its own event loop, therefore objects created in another event loop # can't be used inside the pipeline diff --git a/cognee/tasks/node_set/__init__.py b/cognee/tasks/node_set/__init__.py new file mode 100644 index 000000000..e0582a8d5 --- /dev/null +++ b/cognee/tasks/node_set/__init__.py @@ -0,0 +1 @@ +from .apply_node_set import apply_node_set \ No newline at end of file