Move tasks
This commit is contained in:
parent
82ac9fc26a
commit
1087a7edda
11 changed files with 629 additions and 11 deletions
|
|
@ -27,6 +27,15 @@ from cognee.modules.users.methods import get_default_user
|
|||
from cognee.modules.users.permissions.methods import check_permissions_on_documents
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
|
||||
from cognee.tasks.chunk_extract_summary.chunk_extract_summary import chunk_extract_summary_task
|
||||
from cognee.tasks.chunk_naive_llm_classifier.chunk_naive_llm_classifier import chunk_naive_llm_classifier_task
|
||||
from cognee.tasks.chunk_remove_disconnected.chunk_remove_disconnected import chunk_remove_disconnected_task
|
||||
from cognee.tasks.chunk_to_graph_decomposition.chunk_to_graph_decomposition import chunk_to_graph_decomposition_task
|
||||
from cognee.tasks.chunk_to_vector_graphstore.chunk_to_vector_graphstore import chunk_to_vector_graphstore_task
|
||||
from cognee.tasks.chunk_update_check.chunk_update_check import chunk_update_check_task
|
||||
from cognee.tasks.graph_decomposition_to_graph_nodes.graph_decomposition_to_graph_nodes import \
|
||||
graph_decomposition_to_graph_nodes_task
|
||||
from cognee.tasks.source_documents_to_chunks.source_documents_to_chunks import source_documents_to_chunks
|
||||
|
||||
logger = logging.getLogger("cognify.v2")
|
||||
|
||||
|
|
@ -100,26 +109,26 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
|||
root_node_id = "ROOT"
|
||||
|
||||
tasks = [
|
||||
Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
|
||||
Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||
Task(source_documents_to_chunks, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(chunk_to_graph_decomposition_task, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
|
||||
Task(graph_decomposition_to_graph_nodes_task, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(chunk_update_check_task, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||
Task(
|
||||
save_data_chunks,
|
||||
chunk_to_vector_graphstore_task,
|
||||
collection_name = "chunks",
|
||||
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
|
||||
run_tasks_parallel([
|
||||
Task(
|
||||
summarize_text_chunks,
|
||||
chunk_extract_summary_task,
|
||||
summarization_model = cognee_config.summarization_model,
|
||||
collection_name = "chunk_summaries",
|
||||
), # Summarize the document chunks
|
||||
Task(
|
||||
classify_text_chunks,
|
||||
chunk_naive_llm_classifier_task,
|
||||
classification_model = cognee_config.classification_model,
|
||||
),
|
||||
]),
|
||||
Task(remove_obsolete_chunks), # Remove the obsolete document chunks.
|
||||
Task(chunk_remove_disconnected_task), # Remove the obsolete document chunks.
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, documents)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
|
|||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from ..data.extraction.extract_categories import extract_categories
|
||||
|
||||
async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
|
||||
async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
from ...processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from .add_model_class_to_graph import add_model_class_to_graph
|
||||
|
||||
async def establish_graph_topology(data_chunks: list[DocumentChunk], topology_model: Type[BaseModel]):
|
||||
async def chunk_to_graph_decomposition(data_chunks: list[DocumentChunk], topology_model: Type[BaseModel]):
|
||||
if topology_model == KnowledgeGraph:
|
||||
return data_chunks
|
||||
|
||||
|
|
|
|||
38
cognee/tasks/chunk_extract_summary/chunk_extract_summary.py
Normal file
38
cognee/tasks/chunk_extract_summary/chunk_extract_summary.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
|
||||
import asyncio
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
|
||||
from cognee.modules.data.extraction.data_summary.models.TextSummary import TextSummary
|
||||
from cognee.modules.data.extraction.extract_summary import extract_summary
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
|
||||
|
||||
async def chunk_extract_summary_task(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel], collection_name: str = "summaries"):
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
chunk_summaries = await asyncio.gather(
|
||||
*[extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
await vector_engine.create_collection(collection_name, payload_schema = TextSummary)
|
||||
|
||||
await vector_engine.create_data_points(
|
||||
collection_name,
|
||||
[
|
||||
DataPoint[TextSummary](
|
||||
id = str(chunk.chunk_id),
|
||||
payload = dict(
|
||||
chunk_id = str(chunk.chunk_id),
|
||||
document_id = str(chunk.document_id),
|
||||
text = chunk_summaries[chunk_index].summary,
|
||||
),
|
||||
embed_field = "text",
|
||||
) for (chunk_index, chunk) in enumerate(data_chunks)
|
||||
],
|
||||
)
|
||||
|
||||
return data_chunks
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
import asyncio
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
|
||||
from cognee.modules.data.extraction.extract_categories import extract_categories
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
|
||||
|
||||
async def chunk_naive_llm_classifier_task(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
chunk_classifications = await asyncio.gather(
|
||||
*[extract_categories(chunk.text, classification_model) for chunk in data_chunks],
|
||||
)
|
||||
|
||||
classification_data_points = []
|
||||
|
||||
for chunk_index, chunk in enumerate(data_chunks):
|
||||
chunk_classification = chunk_classifications[chunk_index]
|
||||
classification_data_points.append(uuid5(NAMESPACE_OID, chunk_classification.label.type))
|
||||
classification_data_points.append(uuid5(NAMESPACE_OID, chunk_classification.label.type))
|
||||
|
||||
for classification_subclass in chunk_classification.label.subclass:
|
||||
classification_data_points.append(uuid5(NAMESPACE_OID, classification_subclass.value))
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
class Keyword(BaseModel):
|
||||
uuid: str
|
||||
text: str
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
|
||||
collection_name = "classification"
|
||||
|
||||
if await vector_engine.has_collection(collection_name):
|
||||
existing_data_points = await vector_engine.retrieve(
|
||||
collection_name,
|
||||
list(set(classification_data_points)),
|
||||
) if len(classification_data_points) > 0 else []
|
||||
|
||||
existing_points_map = {point.id: True for point in existing_data_points}
|
||||
else:
|
||||
existing_points_map = {}
|
||||
await vector_engine.create_collection(collection_name, payload_schema=Keyword)
|
||||
|
||||
data_points = []
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
for (chunk_index, data_chunk) in enumerate(data_chunks):
|
||||
chunk_classification = chunk_classifications[chunk_index]
|
||||
classification_type_label = chunk_classification.label.type
|
||||
classification_type_id = uuid5(NAMESPACE_OID, classification_type_label)
|
||||
|
||||
if classification_type_id not in existing_points_map:
|
||||
data_points.append(
|
||||
DataPoint[Keyword](
|
||||
id=str(classification_type_id),
|
||||
payload=Keyword.parse_obj({
|
||||
"uuid": str(classification_type_id),
|
||||
"text": classification_type_label,
|
||||
"chunk_id": str(data_chunk.chunk_id),
|
||||
"document_id": str(data_chunk.document_id),
|
||||
}),
|
||||
embed_field="text",
|
||||
)
|
||||
)
|
||||
|
||||
nodes.append((
|
||||
str(classification_type_id),
|
||||
dict(
|
||||
id=str(classification_type_id),
|
||||
name=classification_type_label,
|
||||
type=classification_type_label,
|
||||
)
|
||||
))
|
||||
existing_points_map[classification_type_id] = True
|
||||
|
||||
edges.append((
|
||||
str(data_chunk.chunk_id),
|
||||
str(classification_type_id),
|
||||
"is_media_type",
|
||||
dict(
|
||||
relationship_name="is_media_type",
|
||||
source_node_id=str(data_chunk.chunk_id),
|
||||
target_node_id=str(classification_type_id),
|
||||
),
|
||||
))
|
||||
|
||||
for classification_subclass in chunk_classification.label.subclass:
|
||||
classification_subtype_label = classification_subclass.value
|
||||
classification_subtype_id = uuid5(NAMESPACE_OID, classification_subtype_label)
|
||||
|
||||
if classification_subtype_id not in existing_points_map:
|
||||
data_points.append(
|
||||
DataPoint[Keyword](
|
||||
id=str(classification_subtype_id),
|
||||
payload=Keyword.parse_obj({
|
||||
"uuid": str(classification_subtype_id),
|
||||
"text": classification_subtype_label,
|
||||
"chunk_id": str(data_chunk.chunk_id),
|
||||
"document_id": str(data_chunk.document_id),
|
||||
}),
|
||||
embed_field="text",
|
||||
)
|
||||
)
|
||||
|
||||
nodes.append((
|
||||
str(classification_subtype_id),
|
||||
dict(
|
||||
id=str(classification_subtype_id),
|
||||
name=classification_subtype_label,
|
||||
type=classification_subtype_label,
|
||||
)
|
||||
))
|
||||
edges.append((
|
||||
str(classification_subtype_id),
|
||||
str(classification_type_id),
|
||||
"is_subtype_of",
|
||||
dict(
|
||||
relationship_name="contains",
|
||||
source_node_id=str(classification_type_id),
|
||||
target_node_id=str(classification_subtype_id),
|
||||
),
|
||||
))
|
||||
|
||||
existing_points_map[classification_subtype_id] = True
|
||||
|
||||
edges.append((
|
||||
str(data_chunk.chunk_id),
|
||||
str(classification_subtype_id),
|
||||
"is_classified_as",
|
||||
dict(
|
||||
relationship_name="is_classified_as",
|
||||
source_node_id=str(data_chunk.chunk_id),
|
||||
target_node_id=str(classification_subtype_id),
|
||||
),
|
||||
))
|
||||
|
||||
if len(nodes) > 0 or len(edges) > 0:
|
||||
await vector_engine.create_data_points(collection_name, data_points)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
await graph_engine.add_nodes(nodes)
|
||||
await graph_engine.add_edges(edges)
|
||||
|
||||
return data_chunks
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
|
||||
|
||||
# from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
|
||||
async def chunk_remove_disconnected_task(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
document_ids = set((data_chunk.document_id for data_chunk in data_chunks))
|
||||
|
||||
obsolete_chunk_ids = []
|
||||
|
||||
for document_id in document_ids:
|
||||
chunk_ids = await graph_engine.get_successor_ids(document_id, edge_label = "has_chunk")
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
previous_chunks = await graph_engine.get_predecessor_ids(chunk_id, edge_label = "next_chunk")
|
||||
|
||||
if len(previous_chunks) == 0:
|
||||
obsolete_chunk_ids.append(chunk_id)
|
||||
|
||||
if len(obsolete_chunk_ids) > 0:
|
||||
await graph_engine.delete_nodes(obsolete_chunk_ids)
|
||||
|
||||
disconnected_nodes = await graph_engine.get_disconnected_nodes()
|
||||
if len(disconnected_nodes) > 0:
|
||||
await graph_engine.delete_nodes(disconnected_nodes)
|
||||
|
||||
return data_chunks
|
||||
|
|
@ -7,7 +7,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
from cognee.modules.data.extraction.knowledge_graph.add_model_class_to_graph import add_model_class_to_graph
|
||||
|
||||
|
||||
async def establish_graph_topology(data_chunks: list[DocumentChunk], topology_model: Type[BaseModel]):
|
||||
async def chunk_to_graph_decomposition_task(data_chunks: list[DocumentChunk], topology_model: Type[BaseModel]):
|
||||
if topology_model == KnowledgeGraph:
|
||||
return data_chunks
|
||||
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
|
||||
async def chunk_to_vector_graphstore_task(data_chunks: list[DocumentChunk], collection_name: str):
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
# Remove and unlink existing chunks
|
||||
if await vector_engine.has_collection(collection_name):
|
||||
existing_chunks = [DocumentChunk.parse_obj(chunk.payload) for chunk in (await vector_engine.retrieve(
|
||||
collection_name,
|
||||
[str(chunk.chunk_id) for chunk in data_chunks],
|
||||
))]
|
||||
|
||||
if len(existing_chunks) > 0:
|
||||
await vector_engine.delete_data_points(collection_name, [str(chunk.chunk_id) for chunk in existing_chunks])
|
||||
|
||||
await graph_engine.remove_connection_to_successors_of([chunk.chunk_id for chunk in existing_chunks], "next_chunk")
|
||||
await graph_engine.remove_connection_to_predecessors_of([chunk.chunk_id for chunk in existing_chunks], "has_chunk")
|
||||
else:
|
||||
await vector_engine.create_collection(collection_name, payload_schema = DocumentChunk)
|
||||
|
||||
# Add to vector storage
|
||||
await vector_engine.create_data_points(
|
||||
collection_name,
|
||||
[
|
||||
DataPoint[DocumentChunk](
|
||||
id = str(chunk.chunk_id),
|
||||
payload = chunk,
|
||||
embed_field = "text",
|
||||
) for chunk in data_chunks
|
||||
],
|
||||
)
|
||||
|
||||
# Add to graph storage
|
||||
chunk_nodes = []
|
||||
chunk_edges = []
|
||||
|
||||
for chunk in data_chunks:
|
||||
chunk_nodes.append((
|
||||
str(chunk.chunk_id),
|
||||
dict(
|
||||
id = str(chunk.chunk_id),
|
||||
chunk_id = str(chunk.chunk_id),
|
||||
document_id = str(chunk.document_id),
|
||||
word_count = chunk.word_count,
|
||||
chunk_index = chunk.chunk_index,
|
||||
cut_type = chunk.cut_type,
|
||||
pages = chunk.pages,
|
||||
)
|
||||
))
|
||||
|
||||
chunk_edges.append((
|
||||
str(chunk.document_id),
|
||||
str(chunk.chunk_id),
|
||||
"has_chunk",
|
||||
dict(
|
||||
relationship_name = "has_chunk",
|
||||
source_node_id = str(chunk.document_id),
|
||||
target_node_id = str(chunk.chunk_id),
|
||||
),
|
||||
))
|
||||
|
||||
previous_chunk_id = get_previous_chunk_id(data_chunks, chunk)
|
||||
|
||||
if previous_chunk_id is not None:
|
||||
chunk_edges.append((
|
||||
str(previous_chunk_id),
|
||||
str(chunk.chunk_id),
|
||||
"next_chunk",
|
||||
dict(
|
||||
relationship_name = "next_chunk",
|
||||
source_node_id = str(previous_chunk_id),
|
||||
target_node_id = str(chunk.chunk_id),
|
||||
),
|
||||
))
|
||||
|
||||
await graph_engine.add_nodes(chunk_nodes)
|
||||
await graph_engine.add_edges(chunk_edges)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
||||
def get_previous_chunk_id(document_chunks: list[DocumentChunk], current_chunk: DocumentChunk) -> DocumentChunk:
|
||||
if current_chunk.chunk_index == 0:
|
||||
return current_chunk.document_id
|
||||
|
||||
for chunk in document_chunks:
|
||||
if str(chunk.document_id) == str(current_chunk.document_id) \
|
||||
and chunk.chunk_index == current_chunk.chunk_index - 1:
|
||||
return chunk.chunk_id
|
||||
|
||||
return None
|
||||
26
cognee/tasks/chunk_update_check/chunk_update_check.py
Normal file
26
cognee/tasks/chunk_update_check/chunk_update_check.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
|
||||
|
||||
async def chunk_update_check_task(data_chunks: list[DocumentChunk], collection_name: str) -> list[DocumentChunk]:
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
if not await vector_engine.has_collection(collection_name):
|
||||
# If collection doesn't exist, all data_chunks are new
|
||||
return data_chunks
|
||||
|
||||
existing_chunks = await vector_engine.retrieve(
|
||||
collection_name,
|
||||
[str(chunk.chunk_id) for chunk in data_chunks],
|
||||
)
|
||||
|
||||
existing_chunks_map = {chunk.id: chunk.payload for chunk in existing_chunks}
|
||||
|
||||
affected_data_chunks = []
|
||||
|
||||
for chunk in data_chunks:
|
||||
if chunk.chunk_id not in existing_chunks_map or \
|
||||
chunk.text != existing_chunks_map[chunk.chunk_id]["text"]:
|
||||
affected_data_chunks.append(chunk)
|
||||
|
||||
return affected_data_chunks
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
import json
|
||||
import asyncio
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from datetime import datetime, timezone
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
|
||||
from cognee.modules.data.extraction.knowledge_graph.extract_content_graph import extract_content_graph
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
|
||||
|
||||
class EntityNode(BaseModel):
|
||||
uuid: str
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
async def graph_decomposition_to_graph_nodes_task(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
has_collection = await vector_engine.has_collection(collection_name)
|
||||
|
||||
if not has_collection:
|
||||
await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
|
||||
|
||||
processed_nodes = {}
|
||||
type_node_edges = []
|
||||
entity_node_edges = []
|
||||
type_entity_edges = []
|
||||
|
||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
||||
chunk_graph = chunk_graphs[chunk_index]
|
||||
for node in chunk_graph.nodes:
|
||||
type_node_id = generate_node_id(node.type)
|
||||
entity_node_id = generate_node_id(node.id)
|
||||
|
||||
if type_node_id not in processed_nodes:
|
||||
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
|
||||
processed_nodes[type_node_id] = True
|
||||
|
||||
if entity_node_id not in processed_nodes:
|
||||
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
|
||||
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
|
||||
processed_nodes[entity_node_id] = True
|
||||
|
||||
graph_node_edges = [
|
||||
(edge.source_node_id, edge.target_node_id, edge.relationship_name) \
|
||||
for edge in chunk_graph.edges
|
||||
]
|
||||
|
||||
existing_edges = await graph_engine.has_edges([
|
||||
*type_node_edges,
|
||||
*entity_node_edges,
|
||||
*type_entity_edges,
|
||||
*graph_node_edges,
|
||||
])
|
||||
|
||||
existing_edges_map = {}
|
||||
existing_nodes_map = {}
|
||||
|
||||
for edge in existing_edges:
|
||||
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
|
||||
existing_nodes_map[edge[0]] = True
|
||||
|
||||
graph_nodes = []
|
||||
graph_edges = []
|
||||
data_points = []
|
||||
|
||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
||||
graph = chunk_graphs[chunk_index]
|
||||
if graph is None:
|
||||
continue
|
||||
|
||||
for node in graph.nodes:
|
||||
node_id = generate_node_id(node.id)
|
||||
node_name = generate_name(node.name)
|
||||
|
||||
type_node_id = generate_node_id(node.type)
|
||||
type_node_name = generate_name(node.type)
|
||||
|
||||
if node_id not in existing_nodes_map:
|
||||
node_data = dict(
|
||||
uuid = node_id,
|
||||
name = node_name,
|
||||
type = node_name,
|
||||
description = node.description,
|
||||
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
graph_nodes.append((
|
||||
node_id,
|
||||
dict(
|
||||
**node_data,
|
||||
properties = json.dumps(node.properties),
|
||||
)
|
||||
))
|
||||
|
||||
data_points.append(DataPoint[EntityNode](
|
||||
id = str(uuid5(NAMESPACE_OID, node_id)),
|
||||
payload = node_data,
|
||||
embed_field = "name",
|
||||
))
|
||||
|
||||
existing_nodes_map[node_id] = True
|
||||
|
||||
edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
|
||||
|
||||
if edge_key not in existing_edges_map:
|
||||
graph_edges.append((
|
||||
str(chunk.chunk_id),
|
||||
node_id,
|
||||
"contains_entity",
|
||||
dict(
|
||||
relationship_name = "contains_entity",
|
||||
source_node_id = str(chunk.chunk_id),
|
||||
target_node_id = node_id,
|
||||
),
|
||||
))
|
||||
|
||||
# Add relationship between entity type and entity itself: "Jake is Person"
|
||||
graph_edges.append((
|
||||
node_id,
|
||||
type_node_id,
|
||||
"is_entity_type",
|
||||
dict(
|
||||
relationship_name = "is_entity_type",
|
||||
source_node_id = type_node_id,
|
||||
target_node_id = node_id,
|
||||
),
|
||||
))
|
||||
|
||||
existing_edges_map[edge_key] = True
|
||||
|
||||
if type_node_id not in existing_nodes_map:
|
||||
type_node_data = dict(
|
||||
uuid = type_node_id,
|
||||
name = type_node_name,
|
||||
type = type_node_id,
|
||||
description = type_node_name,
|
||||
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
graph_nodes.append((type_node_id, dict(
|
||||
**type_node_data,
|
||||
properties = json.dumps(node.properties)
|
||||
)))
|
||||
|
||||
data_points.append(DataPoint[EntityNode](
|
||||
id = str(uuid5(NAMESPACE_OID, type_node_id)),
|
||||
payload = type_node_data,
|
||||
embed_field = "name",
|
||||
))
|
||||
|
||||
existing_nodes_map[type_node_id] = True
|
||||
|
||||
edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
|
||||
|
||||
if edge_key not in existing_edges_map:
|
||||
graph_edges.append((
|
||||
str(chunk.chunk_id),
|
||||
type_node_id,
|
||||
"contains_entity_type",
|
||||
dict(
|
||||
relationship_name = "contains_entity_type",
|
||||
source_node_id = str(chunk.chunk_id),
|
||||
target_node_id = type_node_id,
|
||||
),
|
||||
))
|
||||
|
||||
existing_edges_map[edge_key] = True
|
||||
|
||||
# Add relationship that came from graphs.
|
||||
for edge in graph.edges:
|
||||
source_node_id = generate_node_id(edge.source_node_id)
|
||||
target_node_id = generate_node_id(edge.target_node_id)
|
||||
relationship_name = generate_name(edge.relationship_name)
|
||||
edge_key = source_node_id + target_node_id + relationship_name
|
||||
|
||||
if edge_key not in existing_edges_map:
|
||||
graph_edges.append((
|
||||
generate_node_id(edge.source_node_id),
|
||||
generate_node_id(edge.target_node_id),
|
||||
edge.relationship_name,
|
||||
dict(
|
||||
relationship_name = generate_name(edge.relationship_name),
|
||||
source_node_id = generate_node_id(edge.source_node_id),
|
||||
target_node_id = generate_node_id(edge.target_node_id),
|
||||
properties = json.dumps(edge.properties),
|
||||
),
|
||||
))
|
||||
existing_edges_map[edge_key] = True
|
||||
|
||||
if len(data_points) > 0:
|
||||
await vector_engine.create_data_points(collection_name, data_points)
|
||||
|
||||
if len(graph_nodes) > 0:
|
||||
await graph_engine.add_nodes(graph_nodes)
|
||||
|
||||
if len(graph_edges) > 0:
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
||||
def generate_name(name: str) -> str:
|
||||
return name.lower().replace(" ", "_").replace("'", "")
|
||||
|
||||
def generate_node_id(node_id: str) -> str:
|
||||
return node_id.lower().replace(" ", "_").replace("'", "")
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.data.processing.document_types.Document import Document
|
||||
|
||||
|
||||
async def source_documents_to_chunks(documents: list[Document], parent_node_id: str = None, user:str=None, user_permissions:str=None):
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
if parent_node_id and await graph_engine.extract_node(parent_node_id) is None:
|
||||
nodes.append((parent_node_id, {}))
|
||||
|
||||
document_nodes = await graph_engine.extract_nodes([str(document.id) for document in documents])
|
||||
|
||||
for (document_index, document) in enumerate(documents):
|
||||
document_node = document_nodes[document_index] if document_index in document_nodes else None
|
||||
|
||||
if document_node is None:
|
||||
document_dict = document.to_dict()
|
||||
document_dict["user"] = user
|
||||
document_dict["user_permissions"] = user_permissions
|
||||
nodes.append((str(document.id), document.to_dict()))
|
||||
|
||||
if parent_node_id:
|
||||
edges.append((
|
||||
parent_node_id,
|
||||
str(document.id),
|
||||
"has_document",
|
||||
dict(
|
||||
relationship_name = "has_document",
|
||||
source_node_id = parent_node_id,
|
||||
target_node_id = str(document.id),
|
||||
),
|
||||
))
|
||||
|
||||
if len(nodes) > 0:
|
||||
await graph_engine.add_nodes(nodes)
|
||||
await graph_engine.add_edges(edges)
|
||||
|
||||
for document in documents:
|
||||
document_reader = document.get_reader()
|
||||
|
||||
for document_chunk in document_reader.read(max_chunk_size = 1024):
|
||||
yield document_chunk
|
||||
Loading…
Add table
Reference in a new issue