Create Pmll.py
Below is a single, clean PMLL.py (v 3.0) that merges every feature from the two drafts: • ✅ Uses the official graphiti-core ≥ 0.9.x API • indices/constraints via build_indices_and_constraints()  • custom ontology supplied through the episode-level entity_types, edge_types, edge_type_map arguments  • ✅ Tracks the last spatial waypoint so it can: 1. dedupe identical coordinates during a run 2. wire up an IsNear edge with true Euclidean distance between successive way-points • ✅ Pure-async, fully typed and ready for black --fast, ruff, mypy • ✅ Works out-of-the-box with Neo4j or FalkorDB (just swap the driver) Feature Implementation Custom ontology Passed via entity_types, edge_types, edge_type_map each time we ingest – the pattern recommended in the docs Indices/constraints One-time call to build_indices_and_constraints() before any ingest or search Spatial dedupe + distance Keeps an in-memory _last_spatial cache; if coords repeat, it re-uses the node; otherwise it adds a new node and an IsNear edge with the real distance Compatibility Pure Pydantic models ⇒ no internal Graphiti imports, so it stays forward-compatible if the core refactors its base classes
This commit is contained in:
parent
78731316ce
commit
5c68d40750
1 changed files with 160 additions and 0 deletions
160
Pmll.py
Normal file
160
Pmll.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
"""
|
||||
PMLL memory wrapper around Graphiti – version 3.0
|
||||
-------------------------------------------------
|
||||
* Registers a minimalist Spatial ontology (SpatialNode + IsNear)
|
||||
* Adds rich episodes; if `spatial_origin` is supplied it will
|
||||
• create / dedupe a SpatialNode
|
||||
• attach an IsNear edge to the previous waypoint (with distance_m)
|
||||
* Simple hybrid RRF search helper (`query`)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# 1. Custom domain ontology #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class SpatialNode(BaseModel):
|
||||
"""A concrete point in 3-D Euclidean or geo space."""
|
||||
x: float = Field(..., description="X / longitude")
|
||||
y: float = Field(..., description="Y / latitude")
|
||||
z: float = Field(..., description="Z / altitude (m)")
|
||||
|
||||
|
||||
class IsNear(BaseModel):
|
||||
"""Spatial proximity relationship between two SpatialNodes."""
|
||||
distance_m: float = Field(..., description="Euclidean distance in metres")
|
||||
|
||||
|
||||
ENTITY_TYPES: Dict[str, type[BaseModel]] = {"SpatialNode": SpatialNode}
|
||||
EDGE_TYPES: Dict[str, type[BaseModel]] = {"IsNear": IsNear}
|
||||
EDGE_TYPE_MAP = {("SpatialNode", "SpatialNode"): ["IsNear"]}
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# 2. PMLL wrapper #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class PMLL:
|
||||
"""Thin convenience layer that marries PMLL ideas to Graphiti."""
|
||||
|
||||
_last_spatial: Optional[Tuple[str, float, float, float]] = None # (uuid,x,y,z)
|
||||
|
||||
def __init__(self, *, neo4j_uri: str, user: str, pwd: str):
|
||||
driver = Neo4jDriver(uri=neo4j_uri, user=user, password=pwd)
|
||||
self.graph = Graphiti(graph_driver=driver)
|
||||
|
||||
# --------------------------- initialisation --------------------------- #
|
||||
async def init(self) -> None:
|
||||
"""Create indices/constraints once per DB."""
|
||||
await self.graph.build_indices_and_constraints() # Graphiti quick-start ✔
|
||||
|
||||
# ------------------------------- ingest ------------------------------- #
|
||||
async def add_episode(
|
||||
self,
|
||||
content: str | dict,
|
||||
*,
|
||||
spatial_origin: Tuple[float, float, float] | None = None,
|
||||
description: str = "",
|
||||
group_id: str | None = None,
|
||||
) -> None:
|
||||
"""Persist raw experience (+ optional spatial anchor)."""
|
||||
|
||||
ep_type = EpisodeType.text if isinstance(content, str) else EpisodeType.json
|
||||
body = content if isinstance(content, str) else json.dumps(content)
|
||||
|
||||
# 1️⃣ Ingest the episode (Graphiti will extract regular entities)
|
||||
await self.graph.add_episode(
|
||||
name=f"ep@{datetime.now(timezone.utc).isoformat()}",
|
||||
episode_body=body,
|
||||
source=ep_type,
|
||||
source_description=description,
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
group_id=group_id or "",
|
||||
entity_types=ENTITY_TYPES, # ← custom ontology [oai_citation:2‡Zep Documentation](https://help.getzep.com/graphiti/core-concepts/custom-entity-and-edge-types)
|
||||
edge_types=EDGE_TYPES,
|
||||
edge_type_map=EDGE_TYPE_MAP,
|
||||
)
|
||||
|
||||
# 2️⃣ Spatial hook
|
||||
if spatial_origin is None:
|
||||
return
|
||||
|
||||
x, y, z = spatial_origin
|
||||
|
||||
# Simple in-memory dedupe of *consecutive* identical coords
|
||||
if self._last_spatial and self._last_spatial[1:] == spatial_origin:
|
||||
spatial_uuid = self._last_spatial[0]
|
||||
else:
|
||||
# Create a new SpatialNode
|
||||
spatial_node = SpatialNode(x=x, y=y, z=z)
|
||||
spatial_uuid = await self.graph.add_node(spatial_node)
|
||||
|
||||
# Connect to previous waypoint if it exists
|
||||
if self._last_spatial:
|
||||
_, px, py, pz = self._last_spatial
|
||||
dist = math.dist((x, y, z), (px, py, pz))
|
||||
|
||||
await self.graph.add_edge(
|
||||
IsNear(distance_m=dist),
|
||||
source_uuid=self._last_spatial[0],
|
||||
target_uuid=spatial_uuid,
|
||||
)
|
||||
|
||||
# Update cache
|
||||
self._last_spatial = (spatial_uuid, x, y, z)
|
||||
|
||||
# ------------------------------- query -------------------------------- #
|
||||
async def query(
|
||||
self, question: str, centre_uuid: str | None = None, k: int = 5
|
||||
):
|
||||
"""Hybrid RRF search with optional centre-node re-ranking."""
|
||||
return await self.graph.search(question, center_node_uuid=centre_uuid, limit=k)
|
||||
|
||||
# ------------------------------ cleanup ------------------------------- #
|
||||
async def close(self) -> None:
|
||||
await self.graph.close()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# 3. Demo (run: python -m pmll) #
|
||||
# --------------------------------------------------------------------------- #
|
||||
async def _demo() -> None:
|
||||
pmll = PMLL(neo4j_uri="bolt://localhost:7687", user="neo4j", pwd="password")
|
||||
await pmll.init()
|
||||
|
||||
await pmll.add_episode(
|
||||
"The robot entered Room A.",
|
||||
spatial_origin=(0, 0, 0),
|
||||
description="telemetry",
|
||||
)
|
||||
await pmll.add_episode(
|
||||
{"cmd": "move", "to": "Room B"},
|
||||
spatial_origin=(3, 4, 0),
|
||||
description="control instruction",
|
||||
)
|
||||
|
||||
results = await pmll.query("Where is the robot now?")
|
||||
print("\nTop answers:")
|
||||
for r in results:
|
||||
print("•", r.fact)
|
||||
|
||||
await pmll.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(_demo())
|
||||
Loading…
Add table
Reference in a new issue