Update Pmll.py

Resolving review comments from dev
This commit is contained in:
Dr. Q and Company 2025-07-29 17:52:15 -04:00 committed by GitHub
parent 5c68d40750
commit b4444c97e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

80
Pmll.py
View file

@ -1,11 +1,23 @@
# Copyright 2025, Zep Software, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PMLL memory wrapper around Graphiti version 3.0
PMLL memory wrapper around Graphiti version 3.2
-------------------------------------------------
* Registers a minimalist Spatial ontology (SpatialNode + IsNear)
* Adds rich episodes; if `spatial_origin` is supplied it will
create / dedupe a SpatialNode
attach an IsNear edge to the previous waypoint (with distance_m)
* Simple hybrid RRF search helper (`query`)
* Registers SpatialNode / IsNear ontology
* Adds episodes with optional spatial anchors & distance-chaining
* Hybrid RRF search helper
"""
from __future__ import annotations
@ -14,11 +26,10 @@ import asyncio
import json
import logging
import math
import uuid
from datetime import datetime, timezone
from typing import Dict, Optional, Tuple
from pydantic import Field, BaseModel
from pydantic import BaseModel, Field
from graphiti_core import Graphiti
from graphiti_core.driver.neo4j_driver import Neo4jDriver
@ -28,17 +39,17 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# --------------------------------------------------------------------------- #
# 1. Custom domain ontology #
# 1. Custom ontology #
# --------------------------------------------------------------------------- #
class SpatialNode(BaseModel):
"""A concrete point in 3-D Euclidean or geo space."""
"""Cartesian/geo point in 3-D space."""
x: float = Field(..., description="X / longitude")
y: float = Field(..., description="Y / latitude")
z: float = Field(..., description="Z / altitude (m)")
class IsNear(BaseModel):
"""Spatial proximity relationship between two SpatialNodes."""
"""Proximity relation between two SpatialNodes."""
distance_m: float = Field(..., description="Euclidean distance in metres")
@ -52,16 +63,15 @@ EDGE_TYPE_MAP = {("SpatialNode", "SpatialNode"): ["IsNear"]}
class PMLL:
"""Thin convenience layer that marries PMLL ideas to Graphiti."""
_last_spatial: Optional[Tuple[str, float, float, float]] = None # (uuid,x,y,z)
def __init__(self, *, neo4j_uri: str, user: str, pwd: str):
driver = Neo4jDriver(uri=neo4j_uri, user=user, password=pwd)
self.graph = Graphiti(graph_driver=driver)
self._last_spatial: Optional[Tuple[str, float, float, float]] = None # uuid,x,y,z
# --------------------------- initialisation --------------------------- #
async def init(self) -> None:
"""Create indices/constraints once per DB."""
await self.graph.build_indices_and_constraints() # Graphiti quick-start ✔
await self.graph.build_indices_and_constraints()
# ------------------------------- ingest ------------------------------- #
async def add_episode(
@ -73,11 +83,9 @@ class PMLL:
group_id: str | None = None,
) -> None:
"""Persist raw experience (+ optional spatial anchor)."""
ep_type = EpisodeType.text if isinstance(content, str) else EpisodeType.json
body = content if isinstance(content, str) else json.dumps(content)
# 1⃣ Ingest the episode (Graphiti will extract regular entities)
await self.graph.add_episode(
name=f"ep@{datetime.now(timezone.utc).isoformat()}",
episode_body=body,
@ -85,45 +93,39 @@ class PMLL:
source_description=description,
reference_time=datetime.now(timezone.utc),
group_id=group_id or "",
entity_types=ENTITY_TYPES, # ← custom ontology [oai_citation:2‡Zep Documentation](https://help.getzep.com/graphiti/core-concepts/custom-entity-and-edge-types)
entity_types=ENTITY_TYPES,
edge_types=EDGE_TYPES,
edge_type_map=EDGE_TYPE_MAP,
)
# 2⃣ Spatial hook
if spatial_origin is None:
return
x, y, z = spatial_origin
# Simple in-memory dedupe of *consecutive* identical coords
# Re-use node if identical to previous coords
if self._last_spatial and self._last_spatial[1:] == spatial_origin:
spatial_uuid = self._last_spatial[0]
else:
# Create a new SpatialNode
spatial_node = SpatialNode(x=x, y=y, z=z)
spatial_uuid = await self.graph.add_node(spatial_node)
spatial_uuid = await self.graph.add_node(SpatialNode(x=x, y=y, z=z))
# Connect to previous waypoint if it exists
# Connect to previous waypoint
if self._last_spatial:
_, px, py, pz = self._last_spatial
dist = math.dist((x, y, z), (px, py, pz))
await self.graph.add_edge(
IsNear(distance_m=dist),
IsNear(distance_m=math.dist((x, y, z), (px, py, pz))),
source_uuid=self._last_spatial[0],
target_uuid=spatial_uuid,
)
# Update cache
self._last_spatial = (spatial_uuid, x, y, z)
# ------------------------------- query -------------------------------- #
async def query(
self, question: str, centre_uuid: str | None = None, k: int = 5
self, question: str, center_uuid: str | None = None, k: int = 5
):
"""Hybrid RRF search with optional centre-node re-ranking."""
return await self.graph.search(question, center_node_uuid=centre_uuid, limit=k)
return await self.graph.search(question, center_node_uuid=center_uuid, limit=k)
# ------------------------------ cleanup ------------------------------- #
async def close(self) -> None:
@ -131,27 +133,17 @@ class PMLL:
# --------------------------------------------------------------------------- #
# 3. Demo (run: python -m pmll) #
# 3. Demo #
# --------------------------------------------------------------------------- #
async def _demo() -> None:
pmll = PMLL(neo4j_uri="bolt://localhost:7687", user="neo4j", pwd="password")
await pmll.init()
await pmll.add_episode(
"The robot entered Room A.",
spatial_origin=(0, 0, 0),
description="telemetry",
)
await pmll.add_episode(
{"cmd": "move", "to": "Room B"},
spatial_origin=(3, 4, 0),
description="control instruction",
)
await pmll.add_episode("Robot entered Room A.", spatial_origin=(0, 0, 0))
await pmll.add_episode({"cmd": "move", "to": "Room B"}, spatial_origin=(3, 4, 0))
results = await pmll.query("Where is the robot now?")
print("\nTop answers:")
for r in results:
print("", r.fact)
for hit in await pmll.query("Where is the robot now?"):
print("", hit.fact)
await pmll.close()