Add NodeSets

This commit is contained in:
vasilije 2025-03-30 11:40:13 +02:00
parent c385e7f189
commit ec68a8cd2d
7 changed files with 28 additions and 15 deletions

View file

@ -1,4 +1,4 @@
from typing import Union, BinaryIO from typing import Union, BinaryIO, List, Optional
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines import run_tasks, Task from cognee.modules.pipelines import run_tasks, Task
@ -16,6 +16,7 @@ async def add(
data: Union[BinaryIO, list[BinaryIO], str, list[str]], data: Union[BinaryIO, list[BinaryIO], str, list[str]],
dataset_name: str = "main_dataset", dataset_name: str = "main_dataset",
user: User = None, user: User = None,
NodeSet: Optional[List[str]] = None,
): ):
# Create tables for databases # Create tables for databases
await create_relational_db_and_tables() await create_relational_db_and_tables()
@ -36,7 +37,7 @@ async def add(
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)] tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, NodeSet)]
dataset_id = uuid5(NAMESPACE_OID, dataset_name) dataset_id = uuid5(NAMESPACE_OID, dataset_name)
pipeline = run_tasks( pipeline = run_tasks(

View file

@ -25,6 +25,7 @@ from cognee.tasks.documents import (
from cognee.tasks.graph import extract_graph_from_data from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text from cognee.tasks.summarization import summarize_text
from cognee.tasks.node_set import apply_node_set
from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.chunking.TextChunker import TextChunker
logger = get_logger("cognify") logger = get_logger("cognify")
@ -139,6 +140,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
task_config={"batch_size": 10}, task_config={"batch_size": 10},
), ),
Task(add_data_points, task_config={"batch_size": 10}), Task(add_data_points, task_config={"batch_size": 10}),
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values to DataPoints
] ]
return default_tasks return default_tasks

View file

@ -1,5 +1,5 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, Any, Dict from typing import Optional, Any, Dict, List
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -27,6 +27,7 @@ class DataPoint(BaseModel):
topological_rank: Optional[int] = 0 topological_rank: Optional[int] = 0
metadata: Optional[MetaData] = {"index_fields": []} metadata: Optional[MetaData] = {"index_fields": []}
type: str = Field(default_factory=lambda: DataPoint.__name__) type: str = Field(default_factory=lambda: DataPoint.__name__)
NodeSet: Optional[List[str]] = None # List of nodes this data point is associated with
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)

View file

@ -20,6 +20,7 @@ class Data(Base):
owner_id = Column(UUID, index=True) owner_id = Column(UUID, index=True)
content_hash = Column(String) content_hash = Column(String)
external_metadata = Column(JSON) external_metadata = Column(JSON)
node_set = Column(JSON, nullable=True) # Store NodeSet as JSON list of strings
token_count = Column(Integer) token_count = Column(Integer)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
@ -44,5 +45,6 @@ class Data(Base):
"rawDataLocation": self.raw_data_location, "rawDataLocation": self.raw_data_location,
"createdAt": self.created_at.isoformat(), "createdAt": self.created_at.isoformat(),
"updatedAt": self.updated_at.isoformat() if self.updated_at else None, "updatedAt": self.updated_at.isoformat() if self.updated_at else None,
"nodeSet": self.node_set,
# "datasets": [dataset.to_json() for dataset in self.datasets] # "datasets": [dataset.to_json() for dataset in self.datasets]
} }

View file

@ -1,4 +1,4 @@
from cognee.shared.logging_utils import get_logger import logging
import networkx as nx import networkx as nx
import json import json
import os import os
@ -6,7 +6,7 @@ import os
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
logger = get_logger() logger = logging.getLogger(__name__)
async def cognee_network_visualization(graph_data, destination_file_path: str = None): async def cognee_network_visualization(graph_data, destination_file_path: str = None):
@ -20,8 +20,6 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
"EntityType": "#6510f4", "EntityType": "#6510f4",
"DocumentChunk": "#801212", "DocumentChunk": "#801212",
"TextSummary": "#1077f4", "TextSummary": "#1077f4",
"TableRow": "#f47710",
"TableType": "#6510f4",
"default": "#D3D3D3", "default": "#D3D3D3",
} }
@ -184,8 +182,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
</html> </html>
""" """
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, default=str)) html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
html_content = html_content.replace("{links}", json.dumps(links_list, default=str)) html_content = html_content.replace("{links}", json.dumps(links_list))
if not destination_file_path: if not destination_file_path:
home_dir = os.path.expanduser("~") home_dir = os.path.expanduser("~")

View file

@ -1,4 +1,4 @@
from typing import Any, List from typing import Any, List, Optional
import dlt import dlt
import cognee.modules.ingestion as ingestion import cognee.modules.ingestion as ingestion
@ -12,9 +12,10 @@ from .save_data_item_to_storage import save_data_item_to_storage
from typing import Union, BinaryIO from typing import Union, BinaryIO
import inspect import inspect
import json
async def ingest_data(data: Any, dataset_name: str, user: User): async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None):
destination = get_dlt_destination() destination = get_dlt_destination()
pipeline = dlt.pipeline( pipeline = dlt.pipeline(
@ -43,9 +44,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
"mime_type": file_metadata["mime_type"], "mime_type": file_metadata["mime_type"],
"content_hash": file_metadata["content_hash"], "content_hash": file_metadata["content_hash"],
"owner_id": str(user.id), "owner_id": str(user.id),
"node_set": json.dumps(NodeSet) if NodeSet else None,
} }
async def store_data_to_dataset(data: Any, dataset_name: str, user: User): async def store_data_to_dataset(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None):
if not isinstance(data, list): if not isinstance(data, list):
# Convert data to a list as we work with lists further down. # Convert data to a list as we work with lists further down.
data = [data] data = [data]
@ -81,6 +83,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
await session.execute(select(Data).filter(Data.id == data_id)) await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none() ).scalar_one_or_none()
ext_metadata = get_external_metadata_dict(data_item)
if NodeSet:
ext_metadata["node_set"] = NodeSet
if data_point is not None: if data_point is not None:
data_point.name = file_metadata["name"] data_point.name = file_metadata["name"]
data_point.raw_data_location = file_metadata["file_path"] data_point.raw_data_location = file_metadata["file_path"]
@ -88,7 +94,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
data_point.mime_type = file_metadata["mime_type"] data_point.mime_type = file_metadata["mime_type"]
data_point.owner_id = user.id data_point.owner_id = user.id
data_point.content_hash = file_metadata["content_hash"] data_point.content_hash = file_metadata["content_hash"]
data_point.external_metadata = (get_external_metadata_dict(data_item),) data_point.external_metadata = ext_metadata
data_point.node_set = json.dumps(NodeSet) if NodeSet else None
await session.merge(data_point) await session.merge(data_point)
else: else:
data_point = Data( data_point = Data(
@ -99,7 +106,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
mime_type=file_metadata["mime_type"], mime_type=file_metadata["mime_type"],
owner_id=user.id, owner_id=user.id,
content_hash=file_metadata["content_hash"], content_hash=file_metadata["content_hash"],
external_metadata=get_external_metadata_dict(data_item), external_metadata=ext_metadata,
node_set=json.dumps(NodeSet) if NodeSet else None,
token_count=-1, token_count=-1,
) )
@ -124,7 +132,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
db_engine = get_relational_engine() db_engine = get_relational_engine()
file_paths = await store_data_to_dataset(data, dataset_name, user) file_paths = await store_data_to_dataset(data, dataset_name, user, NodeSet)
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop # Note: DLT pipeline has its own event loop, therefore objects created in another event loop
# can't be used inside the pipeline # can't be used inside the pipeline

View file

@ -0,0 +1 @@
from .apply_node_set import apply_node_set