Update Pmll.py

Resolving review comments from dev
This commit is contained in:
Dr. Q and Company 2025-07-29 17:52:15 -04:00 committed by GitHub
parent 5c68d40750
commit b4444c97e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

80
Pmll.py
View file

@ -1,11 +1,23 @@
# Copyright 2025, Zep Software, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
PMLL memory wrapper around Graphiti version 3.0 PMLL memory wrapper around Graphiti version 3.2
------------------------------------------------- -------------------------------------------------
* Registers a minimalist Spatial ontology (SpatialNode + IsNear) * Registers SpatialNode / IsNear ontology
* Adds rich episodes; if `spatial_origin` is supplied it will * Adds episodes with optional spatial anchors & distance-chaining
create / dedupe a SpatialNode * Hybrid RRF search helper
attach an IsNear edge to the previous waypoint (with distance_m)
* Simple hybrid RRF search helper (`query`)
""" """
from __future__ import annotations from __future__ import annotations
@ -14,11 +26,10 @@ import asyncio
import json import json
import logging import logging
import math import math
import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from pydantic import Field, BaseModel from pydantic import BaseModel, Field
from graphiti_core import Graphiti from graphiti_core import Graphiti
from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver
@ -28,17 +39,17 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
# --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- #
# 1. Custom domain ontology # # 1. Custom ontology #
# --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- #
class SpatialNode(BaseModel): class SpatialNode(BaseModel):
"""A concrete point in 3-D Euclidean or geo space.""" """Cartesian/geo point in 3-D space."""
x: float = Field(..., description="X / longitude") x: float = Field(..., description="X / longitude")
y: float = Field(..., description="Y / latitude") y: float = Field(..., description="Y / latitude")
z: float = Field(..., description="Z / altitude (m)") z: float = Field(..., description="Z / altitude (m)")
class IsNear(BaseModel): class IsNear(BaseModel):
"""Spatial proximity relationship between two SpatialNodes.""" """Proximity relation between two SpatialNodes."""
distance_m: float = Field(..., description="Euclidean distance in metres") distance_m: float = Field(..., description="Euclidean distance in metres")
@ -52,16 +63,15 @@ EDGE_TYPE_MAP = {("SpatialNode", "SpatialNode"): ["IsNear"]}
class PMLL: class PMLL:
"""Thin convenience layer that marries PMLL ideas to Graphiti.""" """Thin convenience layer that marries PMLL ideas to Graphiti."""
_last_spatial: Optional[Tuple[str, float, float, float]] = None # (uuid,x,y,z)
def __init__(self, *, neo4j_uri: str, user: str, pwd: str): def __init__(self, *, neo4j_uri: str, user: str, pwd: str):
driver = Neo4jDriver(uri=neo4j_uri, user=user, password=pwd) driver = Neo4jDriver(uri=neo4j_uri, user=user, password=pwd)
self.graph = Graphiti(graph_driver=driver) self.graph = Graphiti(graph_driver=driver)
self._last_spatial: Optional[Tuple[str, float, float, float]] = None # uuid,x,y,z
# --------------------------- initialisation --------------------------- # # --------------------------- initialisation --------------------------- #
async def init(self) -> None: async def init(self) -> None:
"""Create indices/constraints once per DB.""" """Create indices/constraints once per DB."""
await self.graph.build_indices_and_constraints() # Graphiti quick-start ✔ await self.graph.build_indices_and_constraints()
# ------------------------------- ingest ------------------------------- # # ------------------------------- ingest ------------------------------- #
async def add_episode( async def add_episode(
@ -73,11 +83,9 @@ class PMLL:
group_id: str | None = None, group_id: str | None = None,
) -> None: ) -> None:
"""Persist raw experience (+ optional spatial anchor).""" """Persist raw experience (+ optional spatial anchor)."""
ep_type = EpisodeType.text if isinstance(content, str) else EpisodeType.json ep_type = EpisodeType.text if isinstance(content, str) else EpisodeType.json
body = content if isinstance(content, str) else json.dumps(content) body = content if isinstance(content, str) else json.dumps(content)
# 1⃣ Ingest the episode (Graphiti will extract regular entities)
await self.graph.add_episode( await self.graph.add_episode(
name=f"ep@{datetime.now(timezone.utc).isoformat()}", name=f"ep@{datetime.now(timezone.utc).isoformat()}",
episode_body=body, episode_body=body,
@ -85,45 +93,39 @@ class PMLL:
source_description=description, source_description=description,
reference_time=datetime.now(timezone.utc), reference_time=datetime.now(timezone.utc),
group_id=group_id or "", group_id=group_id or "",
entity_types=ENTITY_TYPES, # ← custom ontology [oai_citation:2‡Zep Documentation](https://help.getzep.com/graphiti/core-concepts/custom-entity-and-edge-types) entity_types=ENTITY_TYPES,
edge_types=EDGE_TYPES, edge_types=EDGE_TYPES,
edge_type_map=EDGE_TYPE_MAP, edge_type_map=EDGE_TYPE_MAP,
) )
# 2⃣ Spatial hook
if spatial_origin is None: if spatial_origin is None:
return return
x, y, z = spatial_origin x, y, z = spatial_origin
# Simple in-memory dedupe of *consecutive* identical coords # Re-use node if identical to previous coords
if self._last_spatial and self._last_spatial[1:] == spatial_origin: if self._last_spatial and self._last_spatial[1:] == spatial_origin:
spatial_uuid = self._last_spatial[0] spatial_uuid = self._last_spatial[0]
else: else:
# Create a new SpatialNode spatial_uuid = await self.graph.add_node(SpatialNode(x=x, y=y, z=z))
spatial_node = SpatialNode(x=x, y=y, z=z)
spatial_uuid = await self.graph.add_node(spatial_node)
# Connect to previous waypoint if it exists # Connect to previous waypoint
if self._last_spatial: if self._last_spatial:
_, px, py, pz = self._last_spatial _, px, py, pz = self._last_spatial
dist = math.dist((x, y, z), (px, py, pz))
await self.graph.add_edge( await self.graph.add_edge(
IsNear(distance_m=dist), IsNear(distance_m=math.dist((x, y, z), (px, py, pz))),
source_uuid=self._last_spatial[0], source_uuid=self._last_spatial[0],
target_uuid=spatial_uuid, target_uuid=spatial_uuid,
) )
# Update cache
self._last_spatial = (spatial_uuid, x, y, z) self._last_spatial = (spatial_uuid, x, y, z)
# ------------------------------- query -------------------------------- # # ------------------------------- query -------------------------------- #
async def query( async def query(
self, question: str, centre_uuid: str | None = None, k: int = 5 self, question: str, center_uuid: str | None = None, k: int = 5
): ):
"""Hybrid RRF search with optional centre-node re-ranking.""" """Hybrid RRF search with optional centre-node re-ranking."""
return await self.graph.search(question, center_node_uuid=centre_uuid, limit=k) return await self.graph.search(question, center_node_uuid=center_uuid, limit=k)
# ------------------------------ cleanup ------------------------------- # # ------------------------------ cleanup ------------------------------- #
async def close(self) -> None: async def close(self) -> None:
@ -131,27 +133,17 @@ class PMLL:
# --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- #
# 3. Demo (run: python -m pmll) # # 3. Demo #
# --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- #
async def _demo() -> None: async def _demo() -> None:
pmll = PMLL(neo4j_uri="bolt://localhost:7687", user="neo4j", pwd="password") pmll = PMLL(neo4j_uri="bolt://localhost:7687", user="neo4j", pwd="password")
await pmll.init() await pmll.init()
await pmll.add_episode( await pmll.add_episode("Robot entered Room A.", spatial_origin=(0, 0, 0))
"The robot entered Room A.", await pmll.add_episode({"cmd": "move", "to": "Room B"}, spatial_origin=(3, 4, 0))
spatial_origin=(0, 0, 0),
description="telemetry",
)
await pmll.add_episode(
{"cmd": "move", "to": "Room B"},
spatial_origin=(3, 4, 0),
description="control instruction",
)
results = await pmll.query("Where is the robot now?") for hit in await pmll.query("Where is the robot now?"):
print("\nTop answers:") print("", hit.fact)
for r in results:
print("", r.fact)
await pmll.close() await pmll.close()