Add NodeSets
This commit is contained in:
parent
c385e7f189
commit
ec68a8cd2d
7 changed files with 28 additions and 15 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Union, BinaryIO
|
from typing import Union, BinaryIO, List, Optional
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.pipelines import run_tasks, Task
|
from cognee.modules.pipelines import run_tasks, Task
|
||||||
|
|
@ -16,6 +16,7 @@ async def add(
|
||||||
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
||||||
dataset_name: str = "main_dataset",
|
dataset_name: str = "main_dataset",
|
||||||
user: User = None,
|
user: User = None,
|
||||||
|
NodeSet: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
# Create tables for databases
|
# Create tables for databases
|
||||||
await create_relational_db_and_tables()
|
await create_relational_db_and_tables()
|
||||||
|
|
@ -36,7 +37,7 @@ async def add(
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
|
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, NodeSet)]
|
||||||
|
|
||||||
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
|
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
|
||||||
pipeline = run_tasks(
|
pipeline = run_tasks(
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from cognee.tasks.documents import (
|
||||||
from cognee.tasks.graph import extract_graph_from_data
|
from cognee.tasks.graph import extract_graph_from_data
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.tasks.summarization import summarize_text
|
from cognee.tasks.summarization import summarize_text
|
||||||
|
from cognee.tasks.node_set import apply_node_set
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
|
|
||||||
logger = get_logger("cognify")
|
logger = get_logger("cognify")
|
||||||
|
|
@ -139,6 +140,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
task_config={"batch_size": 10},
|
task_config={"batch_size": 10},
|
||||||
),
|
),
|
||||||
Task(add_data_points, task_config={"batch_size": 10}),
|
Task(add_data_points, task_config={"batch_size": 10}),
|
||||||
|
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values to DataPoints
|
||||||
]
|
]
|
||||||
|
|
||||||
return default_tasks
|
return default_tasks
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional, Any, Dict
|
from typing import Optional, Any, Dict, List
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
@ -27,6 +27,7 @@ class DataPoint(BaseModel):
|
||||||
topological_rank: Optional[int] = 0
|
topological_rank: Optional[int] = 0
|
||||||
metadata: Optional[MetaData] = {"index_fields": []}
|
metadata: Optional[MetaData] = {"index_fields": []}
|
||||||
type: str = Field(default_factory=lambda: DataPoint.__name__)
|
type: str = Field(default_factory=lambda: DataPoint.__name__)
|
||||||
|
NodeSet: Optional[List[str]] = None # List of nodes this data point is associated with
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ class Data(Base):
|
||||||
owner_id = Column(UUID, index=True)
|
owner_id = Column(UUID, index=True)
|
||||||
content_hash = Column(String)
|
content_hash = Column(String)
|
||||||
external_metadata = Column(JSON)
|
external_metadata = Column(JSON)
|
||||||
|
node_set = Column(JSON, nullable=True) # Store NodeSet as JSON list of strings
|
||||||
token_count = Column(Integer)
|
token_count = Column(Integer)
|
||||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||||
|
|
@ -44,5 +45,6 @@ class Data(Base):
|
||||||
"rawDataLocation": self.raw_data_location,
|
"rawDataLocation": self.raw_data_location,
|
||||||
"createdAt": self.created_at.isoformat(),
|
"createdAt": self.created_at.isoformat(),
|
||||||
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
||||||
|
"nodeSet": self.node_set,
|
||||||
# "datasets": [dataset.to_json() for dataset in self.datasets]
|
# "datasets": [dataset.to_json() for dataset in self.datasets]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from cognee.shared.logging_utils import get_logger
|
import logging
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
@ -6,7 +6,7 @@ import os
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger()
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def cognee_network_visualization(graph_data, destination_file_path: str = None):
|
async def cognee_network_visualization(graph_data, destination_file_path: str = None):
|
||||||
|
|
@ -20,8 +20,6 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
||||||
"EntityType": "#6510f4",
|
"EntityType": "#6510f4",
|
||||||
"DocumentChunk": "#801212",
|
"DocumentChunk": "#801212",
|
||||||
"TextSummary": "#1077f4",
|
"TextSummary": "#1077f4",
|
||||||
"TableRow": "#f47710",
|
|
||||||
"TableType": "#6510f4",
|
|
||||||
"default": "#D3D3D3",
|
"default": "#D3D3D3",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -184,8 +182,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, default=str))
|
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
|
||||||
html_content = html_content.replace("{links}", json.dumps(links_list, default=str))
|
html_content = html_content.replace("{links}", json.dumps(links_list))
|
||||||
|
|
||||||
if not destination_file_path:
|
if not destination_file_path:
|
||||||
home_dir = os.path.expanduser("~")
|
home_dir = os.path.expanduser("~")
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, List
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import dlt
|
import dlt
|
||||||
import cognee.modules.ingestion as ingestion
|
import cognee.modules.ingestion as ingestion
|
||||||
|
|
@ -12,9 +12,10 @@ from .save_data_item_to_storage import save_data_item_to_storage
|
||||||
|
|
||||||
from typing import Union, BinaryIO
|
from typing import Union, BinaryIO
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
async def ingest_data(data: Any, dataset_name: str, user: User):
|
async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None):
|
||||||
destination = get_dlt_destination()
|
destination = get_dlt_destination()
|
||||||
|
|
||||||
pipeline = dlt.pipeline(
|
pipeline = dlt.pipeline(
|
||||||
|
|
@ -43,9 +44,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||||
"mime_type": file_metadata["mime_type"],
|
"mime_type": file_metadata["mime_type"],
|
||||||
"content_hash": file_metadata["content_hash"],
|
"content_hash": file_metadata["content_hash"],
|
||||||
"owner_id": str(user.id),
|
"owner_id": str(user.id),
|
||||||
|
"node_set": json.dumps(NodeSet) if NodeSet else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def store_data_to_dataset(data: Any, dataset_name: str, user: User):
|
async def store_data_to_dataset(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None):
|
||||||
if not isinstance(data, list):
|
if not isinstance(data, list):
|
||||||
# Convert data to a list as we work with lists further down.
|
# Convert data to a list as we work with lists further down.
|
||||||
data = [data]
|
data = [data]
|
||||||
|
|
@ -81,6 +83,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||||
await session.execute(select(Data).filter(Data.id == data_id))
|
await session.execute(select(Data).filter(Data.id == data_id))
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
ext_metadata = get_external_metadata_dict(data_item)
|
||||||
|
if NodeSet:
|
||||||
|
ext_metadata["node_set"] = NodeSet
|
||||||
|
|
||||||
if data_point is not None:
|
if data_point is not None:
|
||||||
data_point.name = file_metadata["name"]
|
data_point.name = file_metadata["name"]
|
||||||
data_point.raw_data_location = file_metadata["file_path"]
|
data_point.raw_data_location = file_metadata["file_path"]
|
||||||
|
|
@ -88,7 +94,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||||
data_point.mime_type = file_metadata["mime_type"]
|
data_point.mime_type = file_metadata["mime_type"]
|
||||||
data_point.owner_id = user.id
|
data_point.owner_id = user.id
|
||||||
data_point.content_hash = file_metadata["content_hash"]
|
data_point.content_hash = file_metadata["content_hash"]
|
||||||
data_point.external_metadata = (get_external_metadata_dict(data_item),)
|
data_point.external_metadata = ext_metadata
|
||||||
|
data_point.node_set = json.dumps(NodeSet) if NodeSet else None
|
||||||
await session.merge(data_point)
|
await session.merge(data_point)
|
||||||
else:
|
else:
|
||||||
data_point = Data(
|
data_point = Data(
|
||||||
|
|
@ -99,7 +106,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||||
mime_type=file_metadata["mime_type"],
|
mime_type=file_metadata["mime_type"],
|
||||||
owner_id=user.id,
|
owner_id=user.id,
|
||||||
content_hash=file_metadata["content_hash"],
|
content_hash=file_metadata["content_hash"],
|
||||||
external_metadata=get_external_metadata_dict(data_item),
|
external_metadata=ext_metadata,
|
||||||
|
node_set=json.dumps(NodeSet) if NodeSet else None,
|
||||||
token_count=-1,
|
token_count=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -124,7 +132,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
file_paths = await store_data_to_dataset(data, dataset_name, user)
|
file_paths = await store_data_to_dataset(data, dataset_name, user, NodeSet)
|
||||||
|
|
||||||
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
||||||
# can't be used inside the pipeline
|
# can't be used inside the pipeline
|
||||||
|
|
|
||||||
1
cognee/tasks/node_set/__init__.py
Normal file
1
cognee/tasks/node_set/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .apply_node_set import apply_node_set
|
||||||
Loading…
Add table
Reference in a new issue