improve structure, readability

This commit is contained in:
Hande 2025-09-04 16:16:28 +02:00
parent ae56b09434
commit c924846b77

View file

@ -1,8 +1,14 @@
import os
import json
"""Cognee demo with simplified structure."""
from __future__ import annotations
import asyncio
import pathlib
from typing import List, Any
import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, Iterable, List, Mapping
from cognee import config, prune, search, SearchType, visualize_graph
from cognee.low_level import setup, DataPoint
from cognee.pipelines import run_tasks, Task
@ -13,120 +19,247 @@ from cognee.modules.data.methods import load_or_create_datasets
class Person(DataPoint):
"""Represent a person."""
name: str
metadata: dict = {"index_fields": ["name"]}
class Department(DataPoint):
"""Represent a department."""
name: str
employees: list[Person]
metadata: dict = {"index_fields": ["name"]}
class CompanyType(DataPoint):
"""Represent a company type."""
name: str = "Company"
class Company(DataPoint):
"""Represent a company."""
name: str
departments: list[Department]
is_type: CompanyType
metadata: dict = {"index_fields": ["name"]}
def ingest_files(data: List[Any]):
if not data or data == [None]:
companies_file_path = os.path.join(os.path.dirname(__file__), "../data/companies.json")
companies = json.loads(open(companies_file_path, "r").read())
ROOT = Path(__file__).resolve().parent
DATA_DIR = ROOT.parent / "data"
COGNEE_DIR = ROOT / ".cognee_system"
ARTIFACTS_DIR = ROOT / ".artifacts"
GRAPH_HTML = ARTIFACTS_DIR / "graph_visualization.html"
COMPANIES_JSON = DATA_DIR / "companies.json"
PEOPLE_JSON = DATA_DIR / "people.json"
people_file_path = os.path.join(os.path.dirname(__file__), "../data/people.json")
people = json.loads(open(people_file_path, "r").read())
data = [{"companies": companies, "people": people}]
def load_json_file(path: Path) -> Any:
"""Load a JSON file."""
if not path.exists():
raise FileNotFoundError(f"Missing required file: {path}")
return json.loads(path.read_text(encoding="utf-8"))
people_data_points = {}
departments_data_points = {}
companies_data_points = {}
for data_item in data:
people = data_item["people"]
companies = data_item["companies"]
def remove_duplicates_preserve_order(seq: Iterable[Any]) -> list[Any]:
"""Return list with duplicates removed while preserving order."""
seen = set()
out = []
for x in seq:
if x in seen:
continue
seen.add(x)
out.append(x)
return out
def collect_people(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
"""Collect people from payloads."""
people = [person for payload in payloads for person in payload.get("people", [])]
return people
def collect_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
"""Collect companies from payloads."""
companies = [company for payload in payloads for company in payload.get("companies", [])]
return companies
def build_people_nodes(people: Iterable[Mapping[str, Any]]) -> dict:
"""Build person nodes keyed by name."""
nodes = {p["name"]: Person(name=p["name"]) for p in people if p.get("name")}
return nodes
def group_people_by_department(people: Iterable[Mapping[str, Any]]) -> dict:
"""Group person names by department."""
groups = defaultdict(list)
for person in people:
new_person = Person(name=person["name"])
people_data_points[person["name"]] = new_person
name = person.get("name")
if not name:
continue
dept = person.get("department", "Unknown")
groups[dept].append(name)
return groups
if person["department"] not in departments_data_points:
departments_data_points[person["department"]] = Department(
name=person["department"], employees=[new_person]
)
else:
departments_data_points[person["department"]].employees.append(new_person)
# Create a single CompanyType node, so we connect all companies to it.
companyType = CompanyType()
def collect_declared_departments(
groups: Mapping[str, list[str]], companies: Iterable[Mapping[str, Any]]
) -> set:
"""Collect department names referenced anywhere."""
names = set(groups)
for company in companies:
new_company = Company(name=company["name"], departments=[], is_type=companyType)
companies_data_points[company["name"]] = new_company
for department_name in company["departments"]:
if department_name not in departments_data_points:
departments_data_points[department_name] = Department(
name=department_name, employees=[]
)
new_company.departments.append(departments_data_points[department_name])
return list(companies_data_points.values())
for dept in company.get("departments", []):
names.add(dept)
return names
async def main():
cognee_directory_path = str(
pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system")).resolve()
)
# Set up the Cognee system directory. Cognee will store system files and databases here.
config.system_root_directory(cognee_directory_path)
def build_department_nodes(dept_names: Iterable[str]) -> dict:
"""Build department nodes keyed by name."""
nodes = {name: Department(name=name, employees=[]) for name in dept_names}
return nodes
# Prune system metadata before running, only if we want "fresh" state.
def build_company_nodes(companies: Iterable[Mapping[str, Any]], company_type: CompanyType) -> dict:
"""Build company nodes keyed by name."""
nodes = {
c["name"]: Company(name=c["name"], departments=[], is_type=company_type)
for c in companies
if c.get("name")
}
return nodes
def iterate_company_department_pairs(companies: Iterable[Mapping[str, Any]]):
"""Yield (company_name, department_name) pairs."""
for company in companies:
comp_name = company.get("name")
if not comp_name:
continue
for dept in company.get("departments", []):
yield comp_name, dept
def attach_departments_to_companies(
companies: Iterable[Mapping[str, Any]],
dept_nodes: Mapping[str, Department],
company_nodes: Mapping[str, Company],
) -> None:
"""Attach department nodes to companies."""
for comp_name in company_nodes:
company_nodes[comp_name].departments = []
for comp_name, dept_name in iterate_company_department_pairs(companies):
dept = dept_nodes.get(dept_name)
company = company_nodes.get(comp_name)
if not dept or not company:
continue
company.departments.append(dept)
def attach_employees_to_departments(
groups: Mapping[str, list[str]],
people_nodes: Mapping[str, Person],
dept_nodes: Mapping[str, Department],
) -> None:
"""Attach employees to departments."""
for dept in dept_nodes.values():
dept.employees = []
for dept_name, names in groups.items():
unique_names = remove_duplicates_preserve_order(names)
target = dept_nodes.get(dept_name)
if not target:
continue
employees = [people_nodes[n] for n in unique_names if n in people_nodes]
target.employees = employees
def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]:
"""Build company nodes from payloads."""
people = collect_people(payloads)
companies = collect_companies(payloads)
people_nodes = build_people_nodes(people)
groups = group_people_by_department(people)
dept_names = collect_declared_departments(groups, companies)
dept_nodes = build_department_nodes(dept_names)
company_type = CompanyType()
company_nodes = build_company_nodes(companies, company_type)
attach_departments_to_companies(companies, dept_nodes, company_nodes)
attach_employees_to_departments(groups, people_nodes, dept_nodes)
result = list(company_nodes.values())
return result
def load_default_payload() -> list[Mapping[str, Any]]:
"""Load the default payload from data files."""
companies = load_json_file(COMPANIES_JSON)
people = load_json_file(PEOPLE_JSON)
payload = [{"companies": companies, "people": people}]
return payload
def ingest_payloads(data: List[Any] | None) -> list[Company]:
"""Ingest payloads and build company nodes."""
if not data or data == [None]:
data = load_default_payload()
companies = build_companies(data)
return companies
async def execute_pipeline() -> None:
"""Execute Cognee pipeline."""
# Configure system paths
logging.info("Configuring Cognee directories at %s", COGNEE_DIR)
config.system_root_directory(str(COGNEE_DIR))
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
# Reset state and initialize
await prune.prune_system(metadata=True)
await setup()
# Get default user
# Get user and dataset
user = await get_default_user()
datasets = await load_or_create_datasets(["demo_dataset"], [], user)
dataset_id = datasets[0].id
pipeline = run_tasks(
[
Task(ingest_files),
Task(add_data_points),
],
datasets[0].id,
None,
user,
"demo_pipeline",
)
# Build and run pipeline
tasks = [Task(ingest_payloads), Task(add_data_points)]
pipeline = run_tasks(tasks, dataset_id, None, user, "demo_pipeline")
async for status in pipeline:
print(status)
logging.info("Pipeline status: %s", status)
# Post-process: index graph edges and visualize
await index_graph_edges()
await visualize_graph(str(GRAPH_HTML))
# Or use our simple graph preview
graph_file_path = str(
os.path.join(os.path.dirname(__file__), ".artifacts/graph_visualization.html")
)
await visualize_graph(graph_file_path)
# Completion query that uses graph data to form context.
# Run query against graph
completion = await search(
query_text="Who works for GreenFuture Solutions?",
query_type=SearchType.GRAPH_COMPLETION,
)
print("Graph completion result is:")
print(completion)
result = completion
logging.info("Graph completion result: %s", result)
def configure_logging() -> None:
"""Configure logging."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
)
async def main() -> None:
"""Run main function."""
configure_logging()
try:
await execute_pipeline()
except Exception:
logging.exception("Run failed")
raise
if __name__ == "__main__":