feat: Group DataPoints into NodeSets (#680)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: lxobr <122801072+lxobr@users.noreply.github.com>
Co-authored-by: Boris <boris@topoteretes.com>
Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
This commit is contained in:
Vasilije 2025-04-19 20:21:04 +02:00 committed by GitHub
parent 8374e402a8
commit bb7eaa017b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 164 additions and 30 deletions

View file

@ -1,4 +1,4 @@
from typing import Union, BinaryIO
from typing import Union, BinaryIO, List, Optional
from cognee.modules.users.models import User
from cognee.modules.pipelines import Task
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
@ -9,8 +9,9 @@ async def add(
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
dataset_name: str = "main_dataset",
user: User = None,
node_set: Optional[List[str]] = None,
):
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, node_set)]
await cognee_pipeline(
tasks=tasks, datasets=dataset_name, data=data, user=user, pipeline_name="add_pipeline"

View file

@ -1,12 +1,11 @@
from cognee.shared.logging_utils import get_logger
from typing import Dict, List, Optional, Any
import os
import json
from uuid import UUID
from typing import List, Optional
from chromadb import AsyncHttpClient, Settings
from cognee.exceptions import InvalidValueError
from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
@ -134,7 +133,7 @@ class ChromaDBAdapter(VectorDBInterface):
metadatas = []
for data_point in data_points:
metadata = data_point.model_dump()
metadata = get_own_properties(data_point)
metadatas.append(process_data_for_chroma(metadata))
await collection.upsert(

View file

@ -312,6 +312,12 @@ class LanceDBAdapter(VectorDBInterface):
models_list = get_args(field_config.annotation)
if any(hasattr(model, "model_fields") for model in models_list):
related_models_fields.append(field_name)
elif models_list and any(get_args(model) is DataPoint for model in models_list):
related_models_fields.append(field_name)
elif models_list and any(
submodel is DataPoint for submodel in get_args(models_list[0])
):
related_models_fields.append(field_name)
elif get_origin(field_config.annotation) == Optional:
model = get_args(field_config.annotation)

View file

@ -1,10 +1,9 @@
from datetime import datetime, timezone
from typing import Optional, Any, Dict
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import pickle
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
from datetime import datetime, timezone
from typing_extensions import TypedDict
from typing import Optional, Any, Dict, List
# Define metadata type
@ -27,6 +26,7 @@ class DataPoint(BaseModel):
topological_rank: Optional[int] = 0
metadata: Optional[MetaData] = {"index_fields": []}
type: str = Field(default_factory=lambda: DataPoint.__name__)
belongs_to_set: Optional[List["DataPoint"]] = None
def __init__(self, **data):
super().__init__(**data)

View file

@ -20,6 +20,7 @@ class Data(Base):
owner_id = Column(UUID, index=True)
content_hash = Column(String)
external_metadata = Column(JSON)
node_set = Column(JSON, nullable=True) # Store NodeSet as JSON list of strings
token_count = Column(Integer)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
@ -44,5 +45,6 @@ class Data(Base):
"rawDataLocation": self.raw_data_location,
"createdAt": self.created_at.isoformat(),
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
"nodeSet": self.node_set,
# "datasets": [dataset.to_json() for dataset in self.datasets]
}

View file

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, List
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.Chunker import Chunker

View file

@ -2,3 +2,4 @@ from .Entity import Entity
from .EntityType import EntityType
from .TableRow import TableRow
from .TableType import TableType
from .node_set import NodeSet

View file

@ -0,0 +1,8 @@
from cognee.infrastructure.engine import DataPoint
class NodeSet(DataPoint):
"""NodeSet data point."""
name: str
metadata: dict = {"index_fields": ["name"]}

View file

@ -1,18 +1,17 @@
from cognee.shared.logging_utils import get_logger
import networkx as nx
import json
import os
import json
import networkx
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.files.storage import LocalStorage
logger = get_logger()
async def cognee_network_visualization(graph_data, destination_file_path: str = None):
nodes_data, edges_data = graph_data
G = nx.DiGraph()
G = networkx.DiGraph()
nodes_list = []
color_map = {
@ -184,8 +183,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
</html>
"""
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, default=str))
html_content = html_content.replace("{links}", json.dumps(links_list, default=str))
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
html_content = html_content.replace("{links}", json.dumps(links_list))
if not destination_file_path:
home_dir = os.path.expanduser("~")

View file

@ -0,0 +1,37 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "initial_id",
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
""
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -8,6 +8,8 @@ from cognee.modules.data.processing.document_types import (
TextDocument,
UnstructuredDocument,
)
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.engine.utils.generate_node_id import generate_node_id
EXTENSION_TO_DOCUMENT_CLASS = {
"pdf": PdfDocument, # Text documents
@ -49,6 +51,29 @@ EXTENSION_TO_DOCUMENT_CLASS = {
}
def update_node_set(document):
"""Extracts node_set from document's external_metadata."""
try:
external_metadata = json.loads(document.external_metadata)
except json.JSONDecodeError:
return
if not isinstance(external_metadata, dict):
return
if "node_set" not in external_metadata:
return
node_set = external_metadata["node_set"]
if not isinstance(node_set, list):
return
document.belongs_to_set = [
NodeSet(id=generate_node_id(f"NodeSet:{node_set_name}"), name=node_set_name)
for node_set_name in node_set
]
async def classify_documents(data_documents: list[Data]) -> list[Document]:
"""
Classifies a list of data items into specific document types based on file extensions.
@ -67,6 +92,7 @@ async def classify_documents(data_documents: list[Data]) -> list[Document]:
mime_type=data_item.mime_type,
external_metadata=json.dumps(data_item.external_metadata, indent=4),
)
update_node_set(document)
documents.append(document)
return documents

View file

@ -40,6 +40,7 @@ async def extract_chunks_from_documents(
document_token_count = 0
for document_chunk in document.read(max_chunk_size=max_chunk_size, chunker_cls=chunker):
document_token_count += document_chunk.chunk_size
document_chunk.belongs_to_set = document.belongs_to_set
yield document_chunk
await update_document_token_count(document.id, document_token_count)

View file

@ -1,7 +1,8 @@
from typing import Any, List
import dlt
import s3fs
import json
import inspect
from typing import Union, BinaryIO, Any, List, Optional
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.methods import create_dataset, get_dataset_data, get_datasets_by_name
@ -12,13 +13,13 @@ from cognee.modules.users.permissions.methods import give_permission_on_document
from .get_dlt_destination import get_dlt_destination
from .save_data_item_to_storage import save_data_item_to_storage
from typing import Union, BinaryIO
import inspect
from cognee.api.v1.add.config import get_s3_config
async def ingest_data(data: Any, dataset_name: str, user: User):
async def ingest_data(
data: Any, dataset_name: str, user: User, node_set: Optional[List[str]] = None
):
destination = get_dlt_destination()
if not user:
@ -68,9 +69,12 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
"mime_type": file_metadata["mime_type"],
"content_hash": file_metadata["content_hash"],
"owner_id": str(user.id),
"node_set": json.dumps(node_set) if node_set else None,
}
async def store_data_to_dataset(data: Any, dataset_name: str, user: User):
async def store_data_to_dataset(
data: Any, dataset_name: str, user: User, node_set: Optional[List[str]] = None
):
if not isinstance(data, list):
# Convert data to a list as we work with lists further down.
data = [data]
@ -107,6 +111,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
ext_metadata = get_external_metadata_dict(data_item)
if node_set:
ext_metadata["node_set"] = node_set
if data_point is not None:
data_point.name = file_metadata["name"]
data_point.raw_data_location = file_metadata["file_path"]
@ -114,7 +122,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
data_point.mime_type = file_metadata["mime_type"]
data_point.owner_id = user.id
data_point.content_hash = file_metadata["content_hash"]
data_point.external_metadata = (get_external_metadata_dict(data_item),)
data_point.external_metadata = ext_metadata
data_point.node_set = json.dumps(node_set) if node_set else None
await session.merge(data_point)
else:
data_point = Data(
@ -125,7 +134,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
mime_type=file_metadata["mime_type"],
owner_id=user.id,
content_hash=file_metadata["content_hash"],
external_metadata=get_external_metadata_dict(data_item),
external_metadata=ext_metadata,
node_set=json.dumps(node_set) if node_set else None,
token_count=-1,
)
@ -150,7 +160,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
db_engine = get_relational_engine()
file_paths = await store_data_to_dataset(data, dataset_name, user)
file_paths = await store_data_to_dataset(data, dataset_name, user, node_set)
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
# can't be used inside the pipeline

View file

@ -0,0 +1,44 @@
import os
import asyncio
import cognee
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.shared.logging_utils import get_logger, ERROR
text_a = """
AI is revolutionizing financial services through intelligent fraud detection
and automated customer service platforms.
"""
text_b = """
Advances in AI are enabling smarter systems that learn and adapt over time.
"""
text_c = """
MedTech startups have seen significant growth in recent years, driven by innovation
in digital health and medical devices.
"""
node_set_a = ["AI", "FinTech"]
node_set_b = ["AI"]
node_set_c = ["MedTech"]
async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(text_a, node_set=node_set_a)
await cognee.add(text_b, node_set=node_set_b)
await cognee.add(text_c, node_set=node_set_c)
await cognee.cognify()
visualization_path = os.path.join(
os.path.dirname(__file__), "./.artifacts/graph_visualization.html"
)
await visualize_graph(visualization_path)
if __name__ == "__main__":
logger = get_logger(level=ERROR)
loop = asyncio.new_event_loop()
asyncio.run(main())