Add nodesets datapoints (#755)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
lxobr 2025-04-17 17:10:42 +02:00 committed by GitHub
parent 40142b4789
commit b2a53b4124
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 81 additions and 3 deletions

View file

@ -140,7 +140,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values and create set nodes
# Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values and create set nodes
]
return default_tasks

View file

@ -27,7 +27,9 @@ class DataPoint(BaseModel):
topological_rank: Optional[int] = 0
metadata: Optional[MetaData] = {"index_fields": []}
type: str = Field(default_factory=lambda: DataPoint.__name__)
NodeSet: Optional[List[str]] = None # List of nodes this data point is associated with
belongs_to_set: Optional[List["DataPoint"]] = (
None # List of nodesets this data point belongs to
)
def __init__(self, **data):
super().__init__(**data)

View file

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, List
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.Chunker import Chunker

View file

@ -0,0 +1,8 @@
from cognee.infrastructure.engine import DataPoint
class NodeSet(DataPoint):
"""NodeSet data point."""
name: str
metadata: dict = {"index_fields": ["name"]}

View file

@ -8,6 +8,11 @@ from cognee.modules.data.processing.document_types import (
TextDocument,
UnstructuredDocument,
)
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.engine.utils.generate_node_id import generate_node_id
from typing import List, Optional
import uuid
EXTENSION_TO_DOCUMENT_CLASS = {
"pdf": PdfDocument, # Text documents
@ -49,6 +54,29 @@ EXTENSION_TO_DOCUMENT_CLASS = {
}
def update_node_set(document):
"""Extracts node_set from document's external_metadata."""
try:
external_metadata = json.loads(document.external_metadata)
except json.JSONDecodeError:
return
if not isinstance(external_metadata, dict):
return
if "node_set" not in external_metadata:
return
node_set = external_metadata["node_set"]
if not isinstance(node_set, list):
return
document.belongs_to_set = [
NodeSet(id=generate_node_id(f"NodeSet:{node_set_name}"), name=node_set_name)
for node_set_name in node_set
]
async def classify_documents(data_documents: list[Data]) -> list[Document]:
"""
Classifies a list of data items into specific document types based on file extensions.
@ -67,6 +95,7 @@ async def classify_documents(data_documents: list[Data]) -> list[Document]:
mime_type=data_item.mime_type,
external_metadata=json.dumps(data_item.external_metadata, indent=4),
)
update_node_set(document)
documents.append(document)
return documents

View file

@ -40,6 +40,7 @@ async def extract_chunks_from_documents(
document_token_count = 0
for document_chunk in document.read(max_chunk_size=max_chunk_size, chunker_cls=chunker):
document_token_count += document_chunk.chunk_size
document_chunk.belongs_to_set = document.belongs_to_set
yield document_chunk
await update_document_token_count(document.id, document_token_count)

View file

@ -0,0 +1,38 @@
import asyncio
import cognee
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.api.v1.search import SearchType
text_a = """
AI is revolutionizing financial services through intelligent fraud detection
and automated customer service platforms.
"""
text_b = """
Advances in AI are enabling smarter systems that learn and adapt over time.
"""
text_c = """
MedTech startups have seen significant growth in recent years, driven by innovation
in digital health and medical devices.
"""
node_set_a = ["AI", "FinTech"]
node_set_b = ["AI"]
node_set_c = ["MedTech"]
async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(text_a, node_set=node_set_a)
await cognee.add(text_b, node_set=node_set_b)
await cognee.add(text_c, node_set=node_set_c)
await cognee.cognify()
if __name__ == "__main__":
logger = get_logger(level=ERROR)
loop = asyncio.new_event_loop()
asyncio.run(main())