Add NodeSets

This commit is contained in:
vasilije 2025-03-30 11:40:13 +02:00
parent c385e7f189
commit ec68a8cd2d
7 changed files with 28 additions and 15 deletions

View file

@ -1,4 +1,4 @@
from typing import Union, BinaryIO
from typing import Union, BinaryIO, List, Optional
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines import run_tasks, Task
@ -16,6 +16,7 @@ async def add(
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
dataset_name: str = "main_dataset",
user: User = None,
NodeSet: Optional[List[str]] = None,
):
# Create tables for databases
await create_relational_db_and_tables()
@ -36,7 +37,7 @@ async def add(
if user is None:
user = await get_default_user()
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, NodeSet)]
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
pipeline = run_tasks(

View file

@ -25,6 +25,7 @@ from cognee.tasks.documents import (
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text
from cognee.tasks.node_set import apply_node_set
from cognee.modules.chunking.TextChunker import TextChunker
logger = get_logger("cognify")
@ -139,6 +140,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values to DataPoints
]
return default_tasks

View file

@ -1,5 +1,5 @@
from datetime import datetime, timezone
from typing import Optional, Any, Dict
from typing import Optional, Any, Dict, List
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
@ -27,6 +27,7 @@ class DataPoint(BaseModel):
topological_rank: Optional[int] = 0
metadata: Optional[MetaData] = {"index_fields": []}
type: str = Field(default_factory=lambda: DataPoint.__name__)
NodeSet: Optional[List[str]] = None # List of nodes this data point is associated with
def __init__(self, **data):
super().__init__(**data)

View file

@ -20,6 +20,7 @@ class Data(Base):
owner_id = Column(UUID, index=True)
content_hash = Column(String)
external_metadata = Column(JSON)
node_set = Column(JSON, nullable=True) # Store NodeSet as JSON list of strings
token_count = Column(Integer)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
@ -44,5 +45,6 @@ class Data(Base):
"rawDataLocation": self.raw_data_location,
"createdAt": self.created_at.isoformat(),
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
"nodeSet": self.node_set,
# "datasets": [dataset.to_json() for dataset in self.datasets]
}

View file

@ -1,4 +1,4 @@
from cognee.shared.logging_utils import get_logger
import logging
import networkx as nx
import json
import os
@ -6,7 +6,7 @@ import os
from cognee.infrastructure.files.storage import LocalStorage
logger = get_logger()
logger = logging.getLogger(__name__)
async def cognee_network_visualization(graph_data, destination_file_path: str = None):
@ -20,8 +20,6 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
"EntityType": "#6510f4",
"DocumentChunk": "#801212",
"TextSummary": "#1077f4",
"TableRow": "#f47710",
"TableType": "#6510f4",
"default": "#D3D3D3",
}
@ -184,8 +182,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
</html>
"""
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, default=str))
html_content = html_content.replace("{links}", json.dumps(links_list, default=str))
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
html_content = html_content.replace("{links}", json.dumps(links_list))
if not destination_file_path:
home_dir = os.path.expanduser("~")

View file

@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any, List, Optional
import dlt
import cognee.modules.ingestion as ingestion
@ -12,9 +12,10 @@ from .save_data_item_to_storage import save_data_item_to_storage
from typing import Union, BinaryIO
import inspect
import json
async def ingest_data(data: Any, dataset_name: str, user: User):
async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None):
destination = get_dlt_destination()
pipeline = dlt.pipeline(
@ -43,9 +44,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
"mime_type": file_metadata["mime_type"],
"content_hash": file_metadata["content_hash"],
"owner_id": str(user.id),
"node_set": json.dumps(NodeSet) if NodeSet else None,
}
async def store_data_to_dataset(data: Any, dataset_name: str, user: User):
async def store_data_to_dataset(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None):
if not isinstance(data, list):
# Convert data to a list as we work with lists further down.
data = [data]
@ -81,6 +83,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
ext_metadata = get_external_metadata_dict(data_item)
if NodeSet:
ext_metadata["node_set"] = NodeSet
if data_point is not None:
data_point.name = file_metadata["name"]
data_point.raw_data_location = file_metadata["file_path"]
@ -88,7 +94,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
data_point.mime_type = file_metadata["mime_type"]
data_point.owner_id = user.id
data_point.content_hash = file_metadata["content_hash"]
data_point.external_metadata = (get_external_metadata_dict(data_item),)
data_point.external_metadata = ext_metadata
data_point.node_set = json.dumps(NodeSet) if NodeSet else None
await session.merge(data_point)
else:
data_point = Data(
@ -99,7 +106,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
mime_type=file_metadata["mime_type"],
owner_id=user.id,
content_hash=file_metadata["content_hash"],
external_metadata=get_external_metadata_dict(data_item),
external_metadata=ext_metadata,
node_set=json.dumps(NodeSet) if NodeSet else None,
token_count=-1,
)
@ -124,7 +132,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
db_engine = get_relational_engine()
file_paths = await store_data_to_dataset(data, dataset_name, user)
file_paths = await store_data_to_dataset(data, dataset_name, user, NodeSet)
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
# can't be used inside the pipeline

View file

@ -0,0 +1 @@
from .apply_node_set import apply_node_set