fix: update low level examples (#1331)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-09-04 21:25:34 +02:00 committed by GitHub
commit ce130f358a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 209 additions and 68 deletions

View file

@ -1,124 +1,265 @@
import os
import uuid
import json
"""Cognee demo with simplified structure."""
from __future__ import annotations
import asyncio
import pathlib
import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, Iterable, List, Mapping
from cognee import config, prune, search, SearchType, visualize_graph
from cognee.low_level import setup, DataPoint
from cognee.pipelines import run_tasks, Task
from cognee.tasks.storage import add_data_points
from cognee.tasks.storage.index_graph_edges import index_graph_edges
from cognee.modules.users.methods import get_default_user
from cognee.modules.data.methods import load_or_create_datasets
class Person(DataPoint):
"""Represent a person."""
name: str
metadata: dict = {"index_fields": ["name"]}
class Department(DataPoint):
"""Represent a department."""
name: str
employees: list[Person]
metadata: dict = {"index_fields": ["name"]}
class CompanyType(DataPoint):
"""Represent a company type."""
name: str = "Company"
class Company(DataPoint):
"""Represent a company."""
name: str
departments: list[Department]
is_type: CompanyType
metadata: dict = {"index_fields": ["name"]}
def ingest_files():
companies_file_path = os.path.join(os.path.dirname(__file__), "../data/companies.json")
companies = json.loads(open(companies_file_path, "r").read())
ROOT = Path(__file__).resolve().parent
DATA_DIR = ROOT.parent / "data"
COGNEE_DIR = ROOT / ".cognee_system"
ARTIFACTS_DIR = ROOT / ".artifacts"
GRAPH_HTML = ARTIFACTS_DIR / "graph_visualization.html"
COMPANIES_JSON = DATA_DIR / "companies.json"
PEOPLE_JSON = DATA_DIR / "people.json"
people_file_path = os.path.join(os.path.dirname(__file__), "../data/people.json")
people = json.loads(open(people_file_path, "r").read())
people_data_points = {}
departments_data_points = {}
def load_json_file(path: Path) -> Any:
"""Load a JSON file."""
if not path.exists():
raise FileNotFoundError(f"Missing required file: {path}")
return json.loads(path.read_text(encoding="utf-8"))
def remove_duplicates_preserve_order(seq: Iterable[Any]) -> list[Any]:
"""Return list with duplicates removed while preserving order."""
seen = set()
out = []
for x in seq:
if x in seen:
continue
seen.add(x)
out.append(x)
return out
def collect_people(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
"""Collect people from payloads."""
people = [person for payload in payloads for person in payload.get("people", [])]
return people
def collect_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
"""Collect companies from payloads."""
companies = [company for payload in payloads for company in payload.get("companies", [])]
return companies
def build_people_nodes(people: Iterable[Mapping[str, Any]]) -> dict:
"""Build person nodes keyed by name."""
nodes = {p["name"]: Person(name=p["name"]) for p in people if p.get("name")}
return nodes
def group_people_by_department(people: Iterable[Mapping[str, Any]]) -> dict:
"""Group person names by department."""
groups = defaultdict(list)
for person in people:
new_person = Person(name=person["name"])
people_data_points[person["name"]] = new_person
name = person.get("name")
if not name:
continue
dept = person.get("department", "Unknown")
groups[dept].append(name)
return groups
if person["department"] not in departments_data_points:
departments_data_points[person["department"]] = Department(
name=person["department"], employees=[new_person]
)
else:
departments_data_points[person["department"]].employees.append(new_person)
companies_data_points = {}
# Create a single CompanyType node, so we connect all companies to it.
companyType = CompanyType()
def collect_declared_departments(
groups: Mapping[str, list[str]], companies: Iterable[Mapping[str, Any]]
) -> set:
"""Collect department names referenced anywhere."""
names = set(groups)
for company in companies:
new_company = Company(name=company["name"], departments=[], is_type=companyType)
companies_data_points[company["name"]] = new_company
for department_name in company["departments"]:
if department_name not in departments_data_points:
departments_data_points[department_name] = Department(
name=department_name, employees=[]
)
new_company.departments.append(departments_data_points[department_name])
return companies_data_points.values()
for dept in company.get("departments", []):
names.add(dept)
return names
async def main():
cognee_directory_path = str(
pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system")).resolve()
)
# Set up the Cognee system directory. Cognee will store system files and databases here.
config.system_root_directory(cognee_directory_path)
def build_department_nodes(dept_names: Iterable[str]) -> dict:
"""Build department nodes keyed by name."""
nodes = {name: Department(name=name, employees=[]) for name in dept_names}
return nodes
# Prune system metadata before running, only if we want "fresh" state.
def build_company_nodes(companies: Iterable[Mapping[str, Any]], company_type: CompanyType) -> dict:
"""Build company nodes keyed by name."""
nodes = {
c["name"]: Company(name=c["name"], departments=[], is_type=company_type)
for c in companies
if c.get("name")
}
return nodes
def iterate_company_department_pairs(companies: Iterable[Mapping[str, Any]]):
"""Yield (company_name, department_name) pairs."""
for company in companies:
comp_name = company.get("name")
if not comp_name:
continue
for dept in company.get("departments", []):
yield comp_name, dept
def attach_departments_to_companies(
companies: Iterable[Mapping[str, Any]],
dept_nodes: Mapping[str, Department],
company_nodes: Mapping[str, Company],
) -> None:
"""Attach department nodes to companies."""
for comp_name in company_nodes:
company_nodes[comp_name].departments = []
for comp_name, dept_name in iterate_company_department_pairs(companies):
dept = dept_nodes.get(dept_name)
company = company_nodes.get(comp_name)
if not dept or not company:
continue
company.departments.append(dept)
def attach_employees_to_departments(
groups: Mapping[str, list[str]],
people_nodes: Mapping[str, Person],
dept_nodes: Mapping[str, Department],
) -> None:
"""Attach employees to departments."""
for dept in dept_nodes.values():
dept.employees = []
for dept_name, names in groups.items():
unique_names = remove_duplicates_preserve_order(names)
target = dept_nodes.get(dept_name)
if not target:
continue
employees = [people_nodes[n] for n in unique_names if n in people_nodes]
target.employees = employees
def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]:
"""Build company nodes from payloads."""
people = collect_people(payloads)
companies = collect_companies(payloads)
people_nodes = build_people_nodes(people)
groups = group_people_by_department(people)
dept_names = collect_declared_departments(groups, companies)
dept_nodes = build_department_nodes(dept_names)
company_type = CompanyType()
company_nodes = build_company_nodes(companies, company_type)
attach_departments_to_companies(companies, dept_nodes, company_nodes)
attach_employees_to_departments(groups, people_nodes, dept_nodes)
result = list(company_nodes.values())
return result
def load_default_payload() -> list[Mapping[str, Any]]:
"""Load the default payload from data files."""
companies = load_json_file(COMPANIES_JSON)
people = load_json_file(PEOPLE_JSON)
payload = [{"companies": companies, "people": people}]
return payload
def ingest_payloads(data: List[Any] | None) -> list[Company]:
"""Ingest payloads and build company nodes."""
if not data or data == [None]:
data = load_default_payload()
companies = build_companies(data)
return companies
async def execute_pipeline() -> None:
"""Execute Cognee pipeline."""
# Configure system paths
logging.info("Configuring Cognee directories at %s", COGNEE_DIR)
config.system_root_directory(str(COGNEE_DIR))
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
# Reset state and initialize
await prune.prune_system(metadata=True)
await setup()
# Generate a random dataset_id
dataset_id = uuid.uuid4()
# Get user and dataset
user = await get_default_user()
datasets = await load_or_create_datasets(["demo_dataset"], [], user)
dataset_id = datasets[0].id
pipeline = run_tasks(
[
Task(ingest_files),
Task(add_data_points),
],
dataset_id,
None,
user,
"demo_pipeline",
)
# Build and run pipeline
tasks = [Task(ingest_payloads), Task(add_data_points)]
pipeline = run_tasks(tasks, dataset_id, None, user, "demo_pipeline")
async for status in pipeline:
print(status)
logging.info("Pipeline status: %s", status)
# Post-process: index graph edges and visualize
await index_graph_edges()
await visualize_graph(str(GRAPH_HTML))
# Or use our simple graph preview
graph_file_path = str(
os.path.join(os.path.dirname(__file__), ".artifacts/graph_visualization.html")
)
await visualize_graph(graph_file_path)
# Completion query that uses graph data to form context.
# Run query against graph
completion = await search(
query_text="Who works for GreenFuture Solutions?",
query_type=SearchType.GRAPH_COMPLETION,
)
print("Graph completion result is:")
print(completion)
result = completion
logging.info("Graph completion result: %s", result)
def configure_logging() -> None:
"""Configure logging."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
)
async def main() -> None:
"""Run main function."""
configure_logging()
try:
await execute_pipeline()
except Exception:
logging.exception("Run failed")
raise
if __name__ == "__main__":

View file

@ -73,7 +73,7 @@ def ingest_files(data: List[Any]):
new_company.departments.append(departments_data_points[department_name])
return companies_data_points.values()
return list(companies_data_points.values())
async def main():