improve structure, readability
This commit is contained in:
parent
ae56b09434
commit
c924846b77
1 changed files with 212 additions and 79 deletions
|
|
@ -1,8 +1,14 @@
|
|||
import os
|
||||
import json
|
||||
"""Cognee demo with simplified structure."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
from typing import List, Any
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, List, Mapping
|
||||
|
||||
from cognee import config, prune, search, SearchType, visualize_graph
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.pipelines import run_tasks, Task
|
||||
|
|
@ -13,120 +19,247 @@ from cognee.modules.data.methods import load_or_create_datasets
|
|||
|
||||
|
||||
class Person(DataPoint):
|
||||
"""Represent a person."""
|
||||
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class Department(DataPoint):
|
||||
"""Represent a department."""
|
||||
|
||||
name: str
|
||||
employees: list[Person]
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class CompanyType(DataPoint):
|
||||
"""Represent a company type."""
|
||||
|
||||
name: str = "Company"
|
||||
|
||||
|
||||
class Company(DataPoint):
|
||||
"""Represent a company."""
|
||||
|
||||
name: str
|
||||
departments: list[Department]
|
||||
is_type: CompanyType
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
def ingest_files(data: List[Any]):
|
||||
if not data or data == [None]:
|
||||
companies_file_path = os.path.join(os.path.dirname(__file__), "../data/companies.json")
|
||||
companies = json.loads(open(companies_file_path, "r").read())
|
||||
ROOT = Path(__file__).resolve().parent
|
||||
DATA_DIR = ROOT.parent / "data"
|
||||
COGNEE_DIR = ROOT / ".cognee_system"
|
||||
ARTIFACTS_DIR = ROOT / ".artifacts"
|
||||
GRAPH_HTML = ARTIFACTS_DIR / "graph_visualization.html"
|
||||
COMPANIES_JSON = DATA_DIR / "companies.json"
|
||||
PEOPLE_JSON = DATA_DIR / "people.json"
|
||||
|
||||
people_file_path = os.path.join(os.path.dirname(__file__), "../data/people.json")
|
||||
people = json.loads(open(people_file_path, "r").read())
|
||||
|
||||
data = [{"companies": companies, "people": people}]
|
||||
def load_json_file(path: Path) -> Any:
|
||||
"""Load a JSON file."""
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Missing required file: {path}")
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
people_data_points = {}
|
||||
departments_data_points = {}
|
||||
companies_data_points = {}
|
||||
|
||||
for data_item in data:
|
||||
people = data_item["people"]
|
||||
companies = data_item["companies"]
|
||||
def remove_duplicates_preserve_order(seq: Iterable[Any]) -> list[Any]:
|
||||
"""Return list with duplicates removed while preserving order."""
|
||||
seen = set()
|
||||
out = []
|
||||
for x in seq:
|
||||
if x in seen:
|
||||
continue
|
||||
seen.add(x)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
def collect_people(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
|
||||
"""Collect people from payloads."""
|
||||
people = [person for payload in payloads for person in payload.get("people", [])]
|
||||
return people
|
||||
|
||||
|
||||
def collect_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
|
||||
"""Collect companies from payloads."""
|
||||
companies = [company for payload in payloads for company in payload.get("companies", [])]
|
||||
return companies
|
||||
|
||||
|
||||
def build_people_nodes(people: Iterable[Mapping[str, Any]]) -> dict:
|
||||
"""Build person nodes keyed by name."""
|
||||
nodes = {p["name"]: Person(name=p["name"]) for p in people if p.get("name")}
|
||||
return nodes
|
||||
|
||||
|
||||
def group_people_by_department(people: Iterable[Mapping[str, Any]]) -> dict:
|
||||
"""Group person names by department."""
|
||||
groups = defaultdict(list)
|
||||
for person in people:
|
||||
new_person = Person(name=person["name"])
|
||||
people_data_points[person["name"]] = new_person
|
||||
name = person.get("name")
|
||||
if not name:
|
||||
continue
|
||||
dept = person.get("department", "Unknown")
|
||||
groups[dept].append(name)
|
||||
return groups
|
||||
|
||||
if person["department"] not in departments_data_points:
|
||||
departments_data_points[person["department"]] = Department(
|
||||
name=person["department"], employees=[new_person]
|
||||
)
|
||||
else:
|
||||
departments_data_points[person["department"]].employees.append(new_person)
|
||||
|
||||
# Create a single CompanyType node, so we connect all companies to it.
|
||||
companyType = CompanyType()
|
||||
|
||||
def collect_declared_departments(
|
||||
groups: Mapping[str, list[str]], companies: Iterable[Mapping[str, Any]]
|
||||
) -> set:
|
||||
"""Collect department names referenced anywhere."""
|
||||
names = set(groups)
|
||||
for company in companies:
|
||||
new_company = Company(name=company["name"], departments=[], is_type=companyType)
|
||||
companies_data_points[company["name"]] = new_company
|
||||
|
||||
for department_name in company["departments"]:
|
||||
if department_name not in departments_data_points:
|
||||
departments_data_points[department_name] = Department(
|
||||
name=department_name, employees=[]
|
||||
)
|
||||
|
||||
new_company.departments.append(departments_data_points[department_name])
|
||||
|
||||
return list(companies_data_points.values())
|
||||
for dept in company.get("departments", []):
|
||||
names.add(dept)
|
||||
return names
|
||||
|
||||
|
||||
async def main():
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system")).resolve()
|
||||
)
|
||||
# Set up the Cognee system directory. Cognee will store system files and databases here.
|
||||
config.system_root_directory(cognee_directory_path)
|
||||
def build_department_nodes(dept_names: Iterable[str]) -> dict:
|
||||
"""Build department nodes keyed by name."""
|
||||
nodes = {name: Department(name=name, employees=[]) for name in dept_names}
|
||||
return nodes
|
||||
|
||||
# Prune system metadata before running, only if we want "fresh" state.
|
||||
|
||||
def build_company_nodes(companies: Iterable[Mapping[str, Any]], company_type: CompanyType) -> dict:
|
||||
"""Build company nodes keyed by name."""
|
||||
nodes = {
|
||||
c["name"]: Company(name=c["name"], departments=[], is_type=company_type)
|
||||
for c in companies
|
||||
if c.get("name")
|
||||
}
|
||||
return nodes
|
||||
|
||||
|
||||
def iterate_company_department_pairs(companies: Iterable[Mapping[str, Any]]):
|
||||
"""Yield (company_name, department_name) pairs."""
|
||||
for company in companies:
|
||||
comp_name = company.get("name")
|
||||
if not comp_name:
|
||||
continue
|
||||
for dept in company.get("departments", []):
|
||||
yield comp_name, dept
|
||||
|
||||
|
||||
def attach_departments_to_companies(
|
||||
companies: Iterable[Mapping[str, Any]],
|
||||
dept_nodes: Mapping[str, Department],
|
||||
company_nodes: Mapping[str, Company],
|
||||
) -> None:
|
||||
"""Attach department nodes to companies."""
|
||||
for comp_name in company_nodes:
|
||||
company_nodes[comp_name].departments = []
|
||||
for comp_name, dept_name in iterate_company_department_pairs(companies):
|
||||
dept = dept_nodes.get(dept_name)
|
||||
company = company_nodes.get(comp_name)
|
||||
if not dept or not company:
|
||||
continue
|
||||
company.departments.append(dept)
|
||||
|
||||
|
||||
def attach_employees_to_departments(
|
||||
groups: Mapping[str, list[str]],
|
||||
people_nodes: Mapping[str, Person],
|
||||
dept_nodes: Mapping[str, Department],
|
||||
) -> None:
|
||||
"""Attach employees to departments."""
|
||||
for dept in dept_nodes.values():
|
||||
dept.employees = []
|
||||
for dept_name, names in groups.items():
|
||||
unique_names = remove_duplicates_preserve_order(names)
|
||||
target = dept_nodes.get(dept_name)
|
||||
if not target:
|
||||
continue
|
||||
employees = [people_nodes[n] for n in unique_names if n in people_nodes]
|
||||
target.employees = employees
|
||||
|
||||
|
||||
def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]:
|
||||
"""Build company nodes from payloads."""
|
||||
people = collect_people(payloads)
|
||||
companies = collect_companies(payloads)
|
||||
people_nodes = build_people_nodes(people)
|
||||
groups = group_people_by_department(people)
|
||||
dept_names = collect_declared_departments(groups, companies)
|
||||
dept_nodes = build_department_nodes(dept_names)
|
||||
company_type = CompanyType()
|
||||
company_nodes = build_company_nodes(companies, company_type)
|
||||
attach_departments_to_companies(companies, dept_nodes, company_nodes)
|
||||
attach_employees_to_departments(groups, people_nodes, dept_nodes)
|
||||
result = list(company_nodes.values())
|
||||
return result
|
||||
|
||||
|
||||
def load_default_payload() -> list[Mapping[str, Any]]:
|
||||
"""Load the default payload from data files."""
|
||||
companies = load_json_file(COMPANIES_JSON)
|
||||
people = load_json_file(PEOPLE_JSON)
|
||||
payload = [{"companies": companies, "people": people}]
|
||||
return payload
|
||||
|
||||
|
||||
def ingest_payloads(data: List[Any] | None) -> list[Company]:
|
||||
"""Ingest payloads and build company nodes."""
|
||||
if not data or data == [None]:
|
||||
data = load_default_payload()
|
||||
companies = build_companies(data)
|
||||
return companies
|
||||
|
||||
|
||||
async def execute_pipeline() -> None:
|
||||
"""Execute Cognee pipeline."""
|
||||
|
||||
# Configure system paths
|
||||
logging.info("Configuring Cognee directories at %s", COGNEE_DIR)
|
||||
config.system_root_directory(str(COGNEE_DIR))
|
||||
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Reset state and initialize
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
await setup()
|
||||
|
||||
# Get default user
|
||||
# Get user and dataset
|
||||
user = await get_default_user()
|
||||
|
||||
datasets = await load_or_create_datasets(["demo_dataset"], [], user)
|
||||
dataset_id = datasets[0].id
|
||||
|
||||
pipeline = run_tasks(
|
||||
[
|
||||
Task(ingest_files),
|
||||
Task(add_data_points),
|
||||
],
|
||||
datasets[0].id,
|
||||
None,
|
||||
user,
|
||||
"demo_pipeline",
|
||||
)
|
||||
|
||||
# Build and run pipeline
|
||||
tasks = [Task(ingest_payloads), Task(add_data_points)]
|
||||
pipeline = run_tasks(tasks, dataset_id, None, user, "demo_pipeline")
|
||||
async for status in pipeline:
|
||||
print(status)
|
||||
logging.info("Pipeline status: %s", status)
|
||||
|
||||
# Post-process: index graph edges and visualize
|
||||
await index_graph_edges()
|
||||
await visualize_graph(str(GRAPH_HTML))
|
||||
|
||||
# Or use our simple graph preview
|
||||
graph_file_path = str(
|
||||
os.path.join(os.path.dirname(__file__), ".artifacts/graph_visualization.html")
|
||||
)
|
||||
await visualize_graph(graph_file_path)
|
||||
|
||||
# Completion query that uses graph data to form context.
|
||||
# Run query against graph
|
||||
completion = await search(
|
||||
query_text="Who works for GreenFuture Solutions?",
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
)
|
||||
print("Graph completion result is:")
|
||||
print(completion)
|
||||
result = completion
|
||||
logging.info("Graph completion result: %s", result)
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
"""Configure logging."""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)s | %(message)s",
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run main function."""
|
||||
configure_logging()
|
||||
try:
|
||||
await execute_pipeline()
|
||||
except Exception:
|
||||
logging.exception("Run failed")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue