cognee/cognee/modules/cognify/graph/add_cognitive_layer_graphs.py
2024-06-01 18:43:04 +02:00

155 lines
5.1 KiB
Python

from datetime import datetime
from uuid import uuid4
from typing import List, Tuple, TypedDict
from pydantic import BaseModel
from cognee.infrastructure.databases.vector import DataPoint
# from cognee.utils import extract_pos_tags, extract_named_entities, extract_sentiment_vader
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config
class GraphLike(TypedDict):
nodes: List
edges: List
async def add_cognitive_layer_graphs(
graph_client,
chunk_collection: str,
chunk_id: str,
layer_graphs: List[Tuple[str, GraphLike]],
):
vectordb_config = get_vectordb_config()
vector_client = vectordb_config.vector_engine
graph_config = get_graph_config()
graph_model = graph_config.graph_model
for (layer_id, layer_graph) in layer_graphs:
graph_nodes = []
graph_edges = []
if not isinstance(layer_graph, graph_model):
layer_graph = graph_model.parse_obj(layer_graph)
for node in layer_graph.nodes:
node_id = generate_node_id(node.id)
type_node_id = generate_node_id(node.type)
type_node = await graph_client.extract_node(type_node_id)
if not type_node:
node_name = node.type.lower().capitalize()
type_node = (
type_node_id,
dict(
id = type_node_id,
name = node_name,
type = node_name,
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
)
graph_nodes.append(type_node)
# Add relationship between document and entity type: "Document contains Person"
graph_edges.append((
layer_id,
type_node_id,
"contains",
dict(relationship_name = "contains"),
))
# pos_tags = extract_pos_tags(node.description)
# named_entities = extract_named_entities(node.description)
# sentiment = extract_sentiment_vader(node.description)
id, type, name, description, *node_properties = node
print("Node properties: ", node_properties)
node_properties = dict(node_properties)
graph_nodes.append((
node_id,
dict(
id = node_id,
layer_id = layer_id,
chunk_id = chunk_id,
chunk_collection = chunk_collection,
name = node.name,
type = node.type.lower().capitalize(),
description = node.description,
# pos_tags = pos_tags,
# sentiment = sentiment,
# named_entities = named_entities,
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
**node_properties,
)
))
# Add relationship between entity type and entity itself: "Jake is Person"
graph_edges.append((
node_id,
type_node_id,
"is",
dict(relationship_name = "is"),
))
graph_edges.append((
layer_id,
node_id,
"contains",
dict(relationship_name = "contains"),
))
# Add relationship that came from graphs.
for edge in layer_graph.edges:
graph_edges.append((
generate_node_id(edge.source_node_id),
generate_node_id(edge.target_node_id),
edge.relationship_name,
dict(relationship_name = edge.relationship_name),
))
await graph_client.add_nodes(graph_nodes)
await graph_client.add_edges(graph_edges)
class References(BaseModel):
node_id: str
cognitive_layer: str
class PayloadSchema(BaseModel):
value: str
references: References
try:
await vector_client.create_collection(layer_id, payload_schema = PayloadSchema)
except Exception:
# It's ok if the collection already exists.
pass
data_points = [
DataPoint[PayloadSchema](
id = str(uuid4()),
payload = dict(
value = node_data["name"],
references = dict(
node_id = node_id,
cognitive_layer = layer_id,
),
),
embed_field = "value"
) for (node_id, node_data) in graph_nodes
]
await vector_client.create_data_points(layer_id, data_points)
def generate_node_id(node_id: str) -> str:
return node_id.upper().replace(' ', '_').replace("'", "")