Add NodeSets
This commit is contained in:
parent
c385e7f189
commit
ec68a8cd2d
7 changed files with 28 additions and 15 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, BinaryIO
|
||||
from typing import Union, BinaryIO, List, Optional
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines import run_tasks, Task
|
||||
|
|
@ -16,6 +16,7 @@ async def add(
|
|||
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
||||
dataset_name: str = "main_dataset",
|
||||
user: User = None,
|
||||
NodeSet: Optional[List[str]] = None,
|
||||
):
|
||||
# Create tables for databases
|
||||
await create_relational_db_and_tables()
|
||||
|
|
@ -36,7 +37,7 @@ async def add(
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, NodeSet)]
|
||||
|
||||
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
|
||||
pipeline = run_tasks(
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from cognee.tasks.documents import (
|
|||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.tasks.node_set import apply_node_set
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
|
||||
logger = get_logger("cognify")
|
||||
|
|
@ -139,6 +140,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
task_config={"batch_size": 10},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values to DataPoints
|
||||
]
|
||||
|
||||
return default_tasks
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Optional, Any, Dict
|
||||
from typing import Optional, Any, Dict, List
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -27,6 +27,7 @@ class DataPoint(BaseModel):
|
|||
topological_rank: Optional[int] = 0
|
||||
metadata: Optional[MetaData] = {"index_fields": []}
|
||||
type: str = Field(default_factory=lambda: DataPoint.__name__)
|
||||
NodeSet: Optional[List[str]] = None # List of nodes this data point is associated with
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class Data(Base):
|
|||
owner_id = Column(UUID, index=True)
|
||||
content_hash = Column(String)
|
||||
external_metadata = Column(JSON)
|
||||
node_set = Column(JSON, nullable=True) # Store NodeSet as JSON list of strings
|
||||
token_count = Column(Integer)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
|
@ -44,5 +45,6 @@ class Data(Base):
|
|||
"rawDataLocation": self.raw_data_location,
|
||||
"createdAt": self.created_at.isoformat(),
|
||||
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"nodeSet": self.node_set,
|
||||
# "datasets": [dataset.to_json() for dataset in self.datasets]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
import logging
|
||||
import networkx as nx
|
||||
import json
|
||||
import os
|
||||
|
|
@ -6,7 +6,7 @@ import os
|
|||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def cognee_network_visualization(graph_data, destination_file_path: str = None):
|
||||
|
|
@ -20,8 +20,6 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
"EntityType": "#6510f4",
|
||||
"DocumentChunk": "#801212",
|
||||
"TextSummary": "#1077f4",
|
||||
"TableRow": "#f47710",
|
||||
"TableType": "#6510f4",
|
||||
"default": "#D3D3D3",
|
||||
}
|
||||
|
||||
|
|
@ -184,8 +182,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
</html>
|
||||
"""
|
||||
|
||||
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, default=str))
|
||||
html_content = html_content.replace("{links}", json.dumps(links_list, default=str))
|
||||
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
|
||||
html_content = html_content.replace("{links}", json.dumps(links_list))
|
||||
|
||||
if not destination_file_path:
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import dlt
|
||||
import cognee.modules.ingestion as ingestion
|
||||
|
|
@ -12,9 +12,10 @@ from .save_data_item_to_storage import save_data_item_to_storage
|
|||
|
||||
from typing import Union, BinaryIO
|
||||
import inspect
|
||||
import json
|
||||
|
||||
|
||||
async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||
async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None):
|
||||
destination = get_dlt_destination()
|
||||
|
||||
pipeline = dlt.pipeline(
|
||||
|
|
@ -43,9 +44,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
|||
"mime_type": file_metadata["mime_type"],
|
||||
"content_hash": file_metadata["content_hash"],
|
||||
"owner_id": str(user.id),
|
||||
"node_set": json.dumps(NodeSet) if NodeSet else None,
|
||||
}
|
||||
|
||||
async def store_data_to_dataset(data: Any, dataset_name: str, user: User):
|
||||
async def store_data_to_dataset(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None):
|
||||
if not isinstance(data, list):
|
||||
# Convert data to a list as we work with lists further down.
|
||||
data = [data]
|
||||
|
|
@ -81,6 +83,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
|||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
ext_metadata = get_external_metadata_dict(data_item)
|
||||
if NodeSet:
|
||||
ext_metadata["node_set"] = NodeSet
|
||||
|
||||
if data_point is not None:
|
||||
data_point.name = file_metadata["name"]
|
||||
data_point.raw_data_location = file_metadata["file_path"]
|
||||
|
|
@ -88,7 +94,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
|||
data_point.mime_type = file_metadata["mime_type"]
|
||||
data_point.owner_id = user.id
|
||||
data_point.content_hash = file_metadata["content_hash"]
|
||||
data_point.external_metadata = (get_external_metadata_dict(data_item),)
|
||||
data_point.external_metadata = ext_metadata
|
||||
data_point.node_set = json.dumps(NodeSet) if NodeSet else None
|
||||
await session.merge(data_point)
|
||||
else:
|
||||
data_point = Data(
|
||||
|
|
@ -99,7 +106,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
|||
mime_type=file_metadata["mime_type"],
|
||||
owner_id=user.id,
|
||||
content_hash=file_metadata["content_hash"],
|
||||
external_metadata=get_external_metadata_dict(data_item),
|
||||
external_metadata=ext_metadata,
|
||||
node_set=json.dumps(NodeSet) if NodeSet else None,
|
||||
token_count=-1,
|
||||
)
|
||||
|
||||
|
|
@ -124,7 +132,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
|||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
file_paths = await store_data_to_dataset(data, dataset_name, user)
|
||||
file_paths = await store_data_to_dataset(data, dataset_name, user, NodeSet)
|
||||
|
||||
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
||||
# can't be used inside the pipeline
|
||||
|
|
|
|||
1
cognee/tasks/node_set/__init__.py
Normal file
1
cognee/tasks/node_set/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .apply_node_set import apply_node_set
|
||||
Loading…
Add table
Reference in a new issue