cognee/cognee/tasks/graph/infer_data_ontology.py
2024-11-12 15:37:03 +01:00

178 lines
8 KiB
Python

""" This module contains the OntologyEngine class which is responsible for adding graph ontology from a JSON or CSV file. """
import csv
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Union, Type
import aiofiles
import pandas as pd
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.data.chunking.config import get_chunk_config
from cognee.infrastructure.data.chunking.get_chunking_engine import get_chunk_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.files.utils.extract_text_from_file import extract_text_from_file
from cognee.infrastructure.files.utils.guess_file_type import guess_file_type, FileTypeException
from cognee.modules.data.extraction.knowledge_graph.add_model_class_to_graph import add_model_class_to_graph
from cognee.tasks.graph.models import NodeModel, GraphOntology
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.engine.utils import generate_node_id, generate_node_name
logger = logging.getLogger("task:infer_data_ontology")
async def extract_ontology(content: str, response_model: Type[BaseModel]):
llm_client = get_llm_client()
system_prompt = read_query_prompt("extract_ontology.txt")
ontology = await llm_client.acreate_structured_output(content, system_prompt, response_model)
return ontology
class OntologyEngine:
async def flatten_model(self, model: NodeModel, parent_id: Optional[str] = None) -> Dict[str, Any]:
"""Flatten the model to a dictionary."""
result = model.dict()
result["parent_id"] = parent_id
if model.default_relationship:
result.update({
"relationship_type": model.default_relationship.type,
"relationship_source": model.default_relationship.source,
"relationship_target": model.default_relationship.target
})
return result
async def recursive_flatten(self, items: Union[List[Dict[str, Any]], Dict[str, Any]], parent_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""Recursively flatten the items. """
flat_list = []
if isinstance(items, list):
for item in items:
flat_list.extend(await self.recursive_flatten(item, parent_id))
elif isinstance(items, dict):
model = NodeModel.model_validate(items)
flat_list.append(await self.flatten_model(model, parent_id))
for child in model.children:
flat_list.extend(await self.recursive_flatten(child, model.node_id))
return flat_list
async def load_data(self, file_path: str) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
"""Load data from a JSON or CSV file."""
try:
if file_path.endswith(".json"):
async with aiofiles.open(file_path, mode="r") as f:
data = await f.read()
return json.loads(data)
elif file_path.endswith(".csv"):
async with aiofiles.open(file_path, mode="r") as f:
content = await f.read()
reader = csv.DictReader(content.splitlines())
return list(reader)
else:
raise ValueError("Unsupported file format")
except Exception as e:
raise RuntimeError(f"Failed to load data from {file_path}: {e}")
async def add_graph_ontology(self, file_path: str = None, documents: list = None):
"""Add graph ontology from a JSON or CSV file or infer from documents content."""
if file_path is None:
initial_chunks_and_ids = []
chunk_config = get_chunk_config()
chunk_engine = get_chunk_engine()
chunk_strategy = chunk_config.chunk_strategy
for base_file in documents:
with open(base_file.raw_data_location, "rb") as file:
try:
file_type = guess_file_type(file)
text = extract_text_from_file(file, file_type)
subchunks, chunks_with_ids = chunk_engine.chunk_data(
chunk_strategy,
text,
chunk_config.chunk_size,
chunk_config.chunk_overlap,
)
if chunks_with_ids[0][0] == 1:
initial_chunks_and_ids.append({base_file.id: chunks_with_ids})
except FileTypeException:
logger.warning("File (%s) has an unknown file type. We are skipping it.", file["id"])
ontology = await extract_ontology(str(initial_chunks_and_ids), GraphOntology)
graph_client = await get_graph_engine()
await graph_client.add_nodes([(node.id, dict(
uuid = generate_node_id(node.id),
name = generate_node_name(node.name),
type = generate_node_id(node.id),
description = node.description,
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
)) for node in ontology.nodes])
await graph_client.add_edges((
generate_node_id(edge.source_id),
generate_node_id(edge.target_id),
edge.relationship_type,
dict(
source_node_id = generate_node_id(edge.source_id),
target_node_id = generate_node_id(edge.target_id),
relationship_name = edge.relationship_type,
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
),
) for edge in ontology.edges)
else:
dataset_level_information = documents[0][1]
# Extract the list of valid IDs from the explanations
valid_ids = {item["id"] for item in dataset_level_information}
try:
data = await self.load_data(file_path)
flt_ontology = await self.recursive_flatten(data)
df = pd.DataFrame(flt_ontology)
graph_client = await get_graph_engine()
for _, row in df.iterrows():
node_data = row.to_dict()
node_id = node_data.pop("node_id", None)
if node_id in valid_ids:
await graph_client.add_node(node_id, node_data)
if node_id not in valid_ids:
raise ValueError(f"Node ID {node_id} not found in the dataset")
if pd.notna(row.get("relationship_source")) and pd.notna(row.get("relationship_target")):
await graph_client.add_edge(
row["relationship_source"],
row["relationship_target"],
relationship_name=row["relationship_type"],
edge_properties = {
"source_node_id": row["relationship_source"],
"target_node_id": row["relationship_target"],
"relationship_name": row["relationship_type"],
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
},
)
return
except Exception as e:
raise RuntimeError(f"Failed to add graph ontology from {file_path}: {e}") from e
async def infer_data_ontology(documents, ontology_model = KnowledgeGraph, root_node_id = None):
if ontology_model == KnowledgeGraph:
ontology_engine = OntologyEngine()
root_node_id = await ontology_engine.add_graph_ontology(documents = documents)
else:
graph_engine = await get_graph_engine()
await add_model_class_to_graph(ontology_model, graph_engine)
yield (documents, root_node_id)