feat: Group DataPoints into NodeSets (#680)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: lxobr <122801072+lxobr@users.noreply.github.com> Co-authored-by: Boris <boris@topoteretes.com> Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
This commit is contained in:
parent
8374e402a8
commit
bb7eaa017b
14 changed files with 164 additions and 30 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Union, BinaryIO
|
from typing import Union, BinaryIO, List, Optional
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.pipelines import Task
|
from cognee.modules.pipelines import Task
|
||||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||||
|
|
@ -9,8 +9,9 @@ async def add(
|
||||||
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
||||||
dataset_name: str = "main_dataset",
|
dataset_name: str = "main_dataset",
|
||||||
user: User = None,
|
user: User = None,
|
||||||
|
node_set: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
|
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, node_set)]
|
||||||
|
|
||||||
await cognee_pipeline(
|
await cognee_pipeline(
|
||||||
tasks=tasks, datasets=dataset_name, data=data, user=user, pipeline_name="add_pipeline"
|
tasks=tasks, datasets=dataset_name, data=data, user=user, pipeline_name="add_pipeline"
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
from typing import List, Optional
|
||||||
from chromadb import AsyncHttpClient, Settings
|
from chromadb import AsyncHttpClient, Settings
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.storage.utils import get_own_properties
|
||||||
from cognee.infrastructure.engine.utils import parse_id
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||||
|
|
@ -134,7 +133,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
metadatas = []
|
metadatas = []
|
||||||
for data_point in data_points:
|
for data_point in data_points:
|
||||||
metadata = data_point.model_dump()
|
metadata = get_own_properties(data_point)
|
||||||
metadatas.append(process_data_for_chroma(metadata))
|
metadatas.append(process_data_for_chroma(metadata))
|
||||||
|
|
||||||
await collection.upsert(
|
await collection.upsert(
|
||||||
|
|
|
||||||
|
|
@ -312,6 +312,12 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
models_list = get_args(field_config.annotation)
|
models_list = get_args(field_config.annotation)
|
||||||
if any(hasattr(model, "model_fields") for model in models_list):
|
if any(hasattr(model, "model_fields") for model in models_list):
|
||||||
related_models_fields.append(field_name)
|
related_models_fields.append(field_name)
|
||||||
|
elif models_list and any(get_args(model) is DataPoint for model in models_list):
|
||||||
|
related_models_fields.append(field_name)
|
||||||
|
elif models_list and any(
|
||||||
|
submodel is DataPoint for submodel in get_args(models_list[0])
|
||||||
|
):
|
||||||
|
related_models_fields.append(field_name)
|
||||||
|
|
||||||
elif get_origin(field_config.annotation) == Optional:
|
elif get_origin(field_config.annotation) == Optional:
|
||||||
model = get_args(field_config.annotation)
|
model = get_args(field_config.annotation)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Optional, Any, Dict
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
import pickle
|
import pickle
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
from typing import Optional, Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
# Define metadata type
|
# Define metadata type
|
||||||
|
|
@ -27,6 +26,7 @@ class DataPoint(BaseModel):
|
||||||
topological_rank: Optional[int] = 0
|
topological_rank: Optional[int] = 0
|
||||||
metadata: Optional[MetaData] = {"index_fields": []}
|
metadata: Optional[MetaData] = {"index_fields": []}
|
||||||
type: str = Field(default_factory=lambda: DataPoint.__name__)
|
type: str = Field(default_factory=lambda: DataPoint.__name__)
|
||||||
|
belongs_to_set: Optional[List["DataPoint"]] = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ class Data(Base):
|
||||||
owner_id = Column(UUID, index=True)
|
owner_id = Column(UUID, index=True)
|
||||||
content_hash = Column(String)
|
content_hash = Column(String)
|
||||||
external_metadata = Column(JSON)
|
external_metadata = Column(JSON)
|
||||||
|
node_set = Column(JSON, nullable=True) # Store NodeSet as JSON list of strings
|
||||||
token_count = Column(Integer)
|
token_count = Column(Integer)
|
||||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||||
|
|
@ -44,5 +45,6 @@ class Data(Base):
|
||||||
"rawDataLocation": self.raw_data_location,
|
"rawDataLocation": self.raw_data_location,
|
||||||
"createdAt": self.created_at.isoformat(),
|
"createdAt": self.created_at.isoformat(),
|
||||||
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
||||||
|
"nodeSet": self.node_set,
|
||||||
# "datasets": [dataset.to_json() for dataset in self.datasets]
|
# "datasets": [dataset.to_json() for dataset in self.datasets]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.chunking.Chunker import Chunker
|
from cognee.modules.chunking.Chunker import Chunker
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,3 +2,4 @@ from .Entity import Entity
|
||||||
from .EntityType import EntityType
|
from .EntityType import EntityType
|
||||||
from .TableRow import TableRow
|
from .TableRow import TableRow
|
||||||
from .TableType import TableType
|
from .TableType import TableType
|
||||||
|
from .node_set import NodeSet
|
||||||
|
|
|
||||||
8
cognee/modules/engine/models/node_set.py
Normal file
8
cognee/modules/engine/models/node_set.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
|
||||||
|
class NodeSet(DataPoint):
|
||||||
|
"""NodeSet data point."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
@ -1,18 +1,17 @@
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
import networkx as nx
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
import networkx
|
||||||
|
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
async def cognee_network_visualization(graph_data, destination_file_path: str = None):
|
async def cognee_network_visualization(graph_data, destination_file_path: str = None):
|
||||||
nodes_data, edges_data = graph_data
|
nodes_data, edges_data = graph_data
|
||||||
|
|
||||||
G = nx.DiGraph()
|
G = networkx.DiGraph()
|
||||||
|
|
||||||
nodes_list = []
|
nodes_list = []
|
||||||
color_map = {
|
color_map = {
|
||||||
|
|
@ -184,8 +183,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, default=str))
|
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
|
||||||
html_content = html_content.replace("{links}", json.dumps(links_list, default=str))
|
html_content = html_content.replace("{links}", json.dumps(links_list))
|
||||||
|
|
||||||
if not destination_file_path:
|
if not destination_file_path:
|
||||||
home_dir = os.path.expanduser("~")
|
home_dir = os.path.expanduser("~")
|
||||||
|
|
|
||||||
37
cognee/notebooks/github_analysis_step_by_step.ipynb
Normal file
37
cognee/notebooks/github_analysis_step_by_step.ipynb
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "initial_id",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
""
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
|
|
@ -8,6 +8,8 @@ from cognee.modules.data.processing.document_types import (
|
||||||
TextDocument,
|
TextDocument,
|
||||||
UnstructuredDocument,
|
UnstructuredDocument,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
|
from cognee.modules.engine.utils.generate_node_id import generate_node_id
|
||||||
|
|
||||||
EXTENSION_TO_DOCUMENT_CLASS = {
|
EXTENSION_TO_DOCUMENT_CLASS = {
|
||||||
"pdf": PdfDocument, # Text documents
|
"pdf": PdfDocument, # Text documents
|
||||||
|
|
@ -49,6 +51,29 @@ EXTENSION_TO_DOCUMENT_CLASS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def update_node_set(document):
|
||||||
|
"""Extracts node_set from document's external_metadata."""
|
||||||
|
try:
|
||||||
|
external_metadata = json.loads(document.external_metadata)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(external_metadata, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
if "node_set" not in external_metadata:
|
||||||
|
return
|
||||||
|
|
||||||
|
node_set = external_metadata["node_set"]
|
||||||
|
if not isinstance(node_set, list):
|
||||||
|
return
|
||||||
|
|
||||||
|
document.belongs_to_set = [
|
||||||
|
NodeSet(id=generate_node_id(f"NodeSet:{node_set_name}"), name=node_set_name)
|
||||||
|
for node_set_name in node_set
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def classify_documents(data_documents: list[Data]) -> list[Document]:
|
async def classify_documents(data_documents: list[Data]) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Classifies a list of data items into specific document types based on file extensions.
|
Classifies a list of data items into specific document types based on file extensions.
|
||||||
|
|
@ -67,6 +92,7 @@ async def classify_documents(data_documents: list[Data]) -> list[Document]:
|
||||||
mime_type=data_item.mime_type,
|
mime_type=data_item.mime_type,
|
||||||
external_metadata=json.dumps(data_item.external_metadata, indent=4),
|
external_metadata=json.dumps(data_item.external_metadata, indent=4),
|
||||||
)
|
)
|
||||||
|
update_node_set(document)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ async def extract_chunks_from_documents(
|
||||||
document_token_count = 0
|
document_token_count = 0
|
||||||
for document_chunk in document.read(max_chunk_size=max_chunk_size, chunker_cls=chunker):
|
for document_chunk in document.read(max_chunk_size=max_chunk_size, chunker_cls=chunker):
|
||||||
document_token_count += document_chunk.chunk_size
|
document_token_count += document_chunk.chunk_size
|
||||||
|
document_chunk.belongs_to_set = document.belongs_to_set
|
||||||
yield document_chunk
|
yield document_chunk
|
||||||
|
|
||||||
await update_document_token_count(document.id, document_token_count)
|
await update_document_token_count(document.id, document_token_count)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Any, List
|
|
||||||
|
|
||||||
import dlt
|
import dlt
|
||||||
import s3fs
|
import s3fs
|
||||||
|
import json
|
||||||
|
import inspect
|
||||||
|
from typing import Union, BinaryIO, Any, List, Optional
|
||||||
import cognee.modules.ingestion as ingestion
|
import cognee.modules.ingestion as ingestion
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.data.methods import create_dataset, get_dataset_data, get_datasets_by_name
|
from cognee.modules.data.methods import create_dataset, get_dataset_data, get_datasets_by_name
|
||||||
|
|
@ -12,13 +13,13 @@ from cognee.modules.users.permissions.methods import give_permission_on_document
|
||||||
from .get_dlt_destination import get_dlt_destination
|
from .get_dlt_destination import get_dlt_destination
|
||||||
from .save_data_item_to_storage import save_data_item_to_storage
|
from .save_data_item_to_storage import save_data_item_to_storage
|
||||||
|
|
||||||
from typing import Union, BinaryIO
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from cognee.api.v1.add.config import get_s3_config
|
from cognee.api.v1.add.config import get_s3_config
|
||||||
|
|
||||||
|
|
||||||
async def ingest_data(data: Any, dataset_name: str, user: User):
|
async def ingest_data(
|
||||||
|
data: Any, dataset_name: str, user: User, node_set: Optional[List[str]] = None
|
||||||
|
):
|
||||||
destination = get_dlt_destination()
|
destination = get_dlt_destination()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
|
|
@ -68,9 +69,12 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||||
"mime_type": file_metadata["mime_type"],
|
"mime_type": file_metadata["mime_type"],
|
||||||
"content_hash": file_metadata["content_hash"],
|
"content_hash": file_metadata["content_hash"],
|
||||||
"owner_id": str(user.id),
|
"owner_id": str(user.id),
|
||||||
|
"node_set": json.dumps(node_set) if node_set else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def store_data_to_dataset(data: Any, dataset_name: str, user: User):
|
async def store_data_to_dataset(
|
||||||
|
data: Any, dataset_name: str, user: User, node_set: Optional[List[str]] = None
|
||||||
|
):
|
||||||
if not isinstance(data, list):
|
if not isinstance(data, list):
|
||||||
# Convert data to a list as we work with lists further down.
|
# Convert data to a list as we work with lists further down.
|
||||||
data = [data]
|
data = [data]
|
||||||
|
|
@ -107,6 +111,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||||
await session.execute(select(Data).filter(Data.id == data_id))
|
await session.execute(select(Data).filter(Data.id == data_id))
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
ext_metadata = get_external_metadata_dict(data_item)
|
||||||
|
if node_set:
|
||||||
|
ext_metadata["node_set"] = node_set
|
||||||
|
|
||||||
if data_point is not None:
|
if data_point is not None:
|
||||||
data_point.name = file_metadata["name"]
|
data_point.name = file_metadata["name"]
|
||||||
data_point.raw_data_location = file_metadata["file_path"]
|
data_point.raw_data_location = file_metadata["file_path"]
|
||||||
|
|
@ -114,7 +122,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||||
data_point.mime_type = file_metadata["mime_type"]
|
data_point.mime_type = file_metadata["mime_type"]
|
||||||
data_point.owner_id = user.id
|
data_point.owner_id = user.id
|
||||||
data_point.content_hash = file_metadata["content_hash"]
|
data_point.content_hash = file_metadata["content_hash"]
|
||||||
data_point.external_metadata = (get_external_metadata_dict(data_item),)
|
data_point.external_metadata = ext_metadata
|
||||||
|
data_point.node_set = json.dumps(node_set) if node_set else None
|
||||||
await session.merge(data_point)
|
await session.merge(data_point)
|
||||||
else:
|
else:
|
||||||
data_point = Data(
|
data_point = Data(
|
||||||
|
|
@ -125,7 +134,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||||
mime_type=file_metadata["mime_type"],
|
mime_type=file_metadata["mime_type"],
|
||||||
owner_id=user.id,
|
owner_id=user.id,
|
||||||
content_hash=file_metadata["content_hash"],
|
content_hash=file_metadata["content_hash"],
|
||||||
external_metadata=get_external_metadata_dict(data_item),
|
external_metadata=ext_metadata,
|
||||||
|
node_set=json.dumps(node_set) if node_set else None,
|
||||||
token_count=-1,
|
token_count=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -150,7 +160,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
file_paths = await store_data_to_dataset(data, dataset_name, user)
|
file_paths = await store_data_to_dataset(data, dataset_name, user, node_set)
|
||||||
|
|
||||||
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
||||||
# can't be used inside the pipeline
|
# can't be used inside the pipeline
|
||||||
|
|
|
||||||
44
examples/python/simple_node_set_example.py
Normal file
44
examples/python/simple_node_set_example.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import cognee
|
||||||
|
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||||
|
from cognee.shared.logging_utils import get_logger, ERROR
|
||||||
|
|
||||||
|
text_a = """
|
||||||
|
AI is revolutionizing financial services through intelligent fraud detection
|
||||||
|
and automated customer service platforms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text_b = """
|
||||||
|
Advances in AI are enabling smarter systems that learn and adapt over time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text_c = """
|
||||||
|
MedTech startups have seen significant growth in recent years, driven by innovation
|
||||||
|
in digital health and medical devices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_set_a = ["AI", "FinTech"]
|
||||||
|
node_set_b = ["AI"]
|
||||||
|
node_set_c = ["MedTech"]
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
await cognee.add(text_a, node_set=node_set_a)
|
||||||
|
await cognee.add(text_b, node_set=node_set_b)
|
||||||
|
await cognee.add(text_c, node_set=node_set_c)
|
||||||
|
await cognee.cognify()
|
||||||
|
|
||||||
|
visualization_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), "./.artifacts/graph_visualization.html"
|
||||||
|
)
|
||||||
|
await visualize_graph(visualization_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger = get_logger(level=ERROR)
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Add table
Reference in a new issue