Update Pmll.py
Resolving review comments from dev
This commit is contained in:
parent
5c68d40750
commit
b4444c97e4
1 changed files with 36 additions and 44 deletions
80
Pmll.py
80
Pmll.py
|
|
@ -1,11 +1,23 @@
|
||||||
|
# Copyright 2025, Zep Software, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
PMLL memory wrapper around Graphiti – version 3.0
|
PMLL memory wrapper around Graphiti – version 3.2
|
||||||
-------------------------------------------------
|
-------------------------------------------------
|
||||||
* Registers a minimalist Spatial ontology (SpatialNode + IsNear)
|
* Registers SpatialNode / IsNear ontology
|
||||||
* Adds rich episodes; if `spatial_origin` is supplied it will
|
* Adds episodes with optional spatial anchors & distance-chaining
|
||||||
• create / dedupe a SpatialNode
|
* Hybrid RRF search helper
|
||||||
• attach an IsNear edge to the previous waypoint (with distance_m)
|
|
||||||
* Simple hybrid RRF search helper (`query`)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -14,11 +26,10 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import Field, BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||||
|
|
@ -28,17 +39,17 @@ logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
# --------------------------------------------------------------------------- #
|
||||||
# 1. Custom domain ontology #
|
# 1. Custom ontology #
|
||||||
# --------------------------------------------------------------------------- #
|
# --------------------------------------------------------------------------- #
|
||||||
class SpatialNode(BaseModel):
|
class SpatialNode(BaseModel):
|
||||||
"""A concrete point in 3-D Euclidean or geo space."""
|
"""Cartesian/geo point in 3-D space."""
|
||||||
x: float = Field(..., description="X / longitude")
|
x: float = Field(..., description="X / longitude")
|
||||||
y: float = Field(..., description="Y / latitude")
|
y: float = Field(..., description="Y / latitude")
|
||||||
z: float = Field(..., description="Z / altitude (m)")
|
z: float = Field(..., description="Z / altitude (m)")
|
||||||
|
|
||||||
|
|
||||||
class IsNear(BaseModel):
|
class IsNear(BaseModel):
|
||||||
"""Spatial proximity relationship between two SpatialNodes."""
|
"""Proximity relation between two SpatialNodes."""
|
||||||
distance_m: float = Field(..., description="Euclidean distance in metres")
|
distance_m: float = Field(..., description="Euclidean distance in metres")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -52,16 +63,15 @@ EDGE_TYPE_MAP = {("SpatialNode", "SpatialNode"): ["IsNear"]}
|
||||||
class PMLL:
|
class PMLL:
|
||||||
"""Thin convenience layer that marries PMLL ideas to Graphiti."""
|
"""Thin convenience layer that marries PMLL ideas to Graphiti."""
|
||||||
|
|
||||||
_last_spatial: Optional[Tuple[str, float, float, float]] = None # (uuid,x,y,z)
|
|
||||||
|
|
||||||
def __init__(self, *, neo4j_uri: str, user: str, pwd: str):
|
def __init__(self, *, neo4j_uri: str, user: str, pwd: str):
|
||||||
driver = Neo4jDriver(uri=neo4j_uri, user=user, password=pwd)
|
driver = Neo4jDriver(uri=neo4j_uri, user=user, password=pwd)
|
||||||
self.graph = Graphiti(graph_driver=driver)
|
self.graph = Graphiti(graph_driver=driver)
|
||||||
|
self._last_spatial: Optional[Tuple[str, float, float, float]] = None # uuid,x,y,z
|
||||||
|
|
||||||
# --------------------------- initialisation --------------------------- #
|
# --------------------------- initialisation --------------------------- #
|
||||||
async def init(self) -> None:
|
async def init(self) -> None:
|
||||||
"""Create indices/constraints once per DB."""
|
"""Create indices/constraints once per DB."""
|
||||||
await self.graph.build_indices_and_constraints() # Graphiti quick-start ✔
|
await self.graph.build_indices_and_constraints()
|
||||||
|
|
||||||
# ------------------------------- ingest ------------------------------- #
|
# ------------------------------- ingest ------------------------------- #
|
||||||
async def add_episode(
|
async def add_episode(
|
||||||
|
|
@ -73,11 +83,9 @@ class PMLL:
|
||||||
group_id: str | None = None,
|
group_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Persist raw experience (+ optional spatial anchor)."""
|
"""Persist raw experience (+ optional spatial anchor)."""
|
||||||
|
|
||||||
ep_type = EpisodeType.text if isinstance(content, str) else EpisodeType.json
|
ep_type = EpisodeType.text if isinstance(content, str) else EpisodeType.json
|
||||||
body = content if isinstance(content, str) else json.dumps(content)
|
body = content if isinstance(content, str) else json.dumps(content)
|
||||||
|
|
||||||
# 1️⃣ Ingest the episode (Graphiti will extract regular entities)
|
|
||||||
await self.graph.add_episode(
|
await self.graph.add_episode(
|
||||||
name=f"ep@{datetime.now(timezone.utc).isoformat()}",
|
name=f"ep@{datetime.now(timezone.utc).isoformat()}",
|
||||||
episode_body=body,
|
episode_body=body,
|
||||||
|
|
@ -85,45 +93,39 @@ class PMLL:
|
||||||
source_description=description,
|
source_description=description,
|
||||||
reference_time=datetime.now(timezone.utc),
|
reference_time=datetime.now(timezone.utc),
|
||||||
group_id=group_id or "",
|
group_id=group_id or "",
|
||||||
entity_types=ENTITY_TYPES, # ← custom ontology [oai_citation:2‡Zep Documentation](https://help.getzep.com/graphiti/core-concepts/custom-entity-and-edge-types)
|
entity_types=ENTITY_TYPES,
|
||||||
edge_types=EDGE_TYPES,
|
edge_types=EDGE_TYPES,
|
||||||
edge_type_map=EDGE_TYPE_MAP,
|
edge_type_map=EDGE_TYPE_MAP,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2️⃣ Spatial hook
|
|
||||||
if spatial_origin is None:
|
if spatial_origin is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
x, y, z = spatial_origin
|
x, y, z = spatial_origin
|
||||||
|
|
||||||
# Simple in-memory dedupe of *consecutive* identical coords
|
# Re-use node if identical to previous coords
|
||||||
if self._last_spatial and self._last_spatial[1:] == spatial_origin:
|
if self._last_spatial and self._last_spatial[1:] == spatial_origin:
|
||||||
spatial_uuid = self._last_spatial[0]
|
spatial_uuid = self._last_spatial[0]
|
||||||
else:
|
else:
|
||||||
# Create a new SpatialNode
|
spatial_uuid = await self.graph.add_node(SpatialNode(x=x, y=y, z=z))
|
||||||
spatial_node = SpatialNode(x=x, y=y, z=z)
|
|
||||||
spatial_uuid = await self.graph.add_node(spatial_node)
|
|
||||||
|
|
||||||
# Connect to previous waypoint if it exists
|
# Connect to previous waypoint
|
||||||
if self._last_spatial:
|
if self._last_spatial:
|
||||||
_, px, py, pz = self._last_spatial
|
_, px, py, pz = self._last_spatial
|
||||||
dist = math.dist((x, y, z), (px, py, pz))
|
|
||||||
|
|
||||||
await self.graph.add_edge(
|
await self.graph.add_edge(
|
||||||
IsNear(distance_m=dist),
|
IsNear(distance_m=math.dist((x, y, z), (px, py, pz))),
|
||||||
source_uuid=self._last_spatial[0],
|
source_uuid=self._last_spatial[0],
|
||||||
target_uuid=spatial_uuid,
|
target_uuid=spatial_uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update cache
|
|
||||||
self._last_spatial = (spatial_uuid, x, y, z)
|
self._last_spatial = (spatial_uuid, x, y, z)
|
||||||
|
|
||||||
# ------------------------------- query -------------------------------- #
|
# ------------------------------- query -------------------------------- #
|
||||||
async def query(
|
async def query(
|
||||||
self, question: str, centre_uuid: str | None = None, k: int = 5
|
self, question: str, center_uuid: str | None = None, k: int = 5
|
||||||
):
|
):
|
||||||
"""Hybrid RRF search with optional centre-node re-ranking."""
|
"""Hybrid RRF search with optional centre-node re-ranking."""
|
||||||
return await self.graph.search(question, center_node_uuid=centre_uuid, limit=k)
|
return await self.graph.search(question, center_node_uuid=center_uuid, limit=k)
|
||||||
|
|
||||||
# ------------------------------ cleanup ------------------------------- #
|
# ------------------------------ cleanup ------------------------------- #
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
|
@ -131,27 +133,17 @@ class PMLL:
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
# --------------------------------------------------------------------------- #
|
||||||
# 3. Demo (run: python -m pmll) #
|
# 3. Demo #
|
||||||
# --------------------------------------------------------------------------- #
|
# --------------------------------------------------------------------------- #
|
||||||
async def _demo() -> None:
|
async def _demo() -> None:
|
||||||
pmll = PMLL(neo4j_uri="bolt://localhost:7687", user="neo4j", pwd="password")
|
pmll = PMLL(neo4j_uri="bolt://localhost:7687", user="neo4j", pwd="password")
|
||||||
await pmll.init()
|
await pmll.init()
|
||||||
|
|
||||||
await pmll.add_episode(
|
await pmll.add_episode("Robot entered Room A.", spatial_origin=(0, 0, 0))
|
||||||
"The robot entered Room A.",
|
await pmll.add_episode({"cmd": "move", "to": "Room B"}, spatial_origin=(3, 4, 0))
|
||||||
spatial_origin=(0, 0, 0),
|
|
||||||
description="telemetry",
|
|
||||||
)
|
|
||||||
await pmll.add_episode(
|
|
||||||
{"cmd": "move", "to": "Room B"},
|
|
||||||
spatial_origin=(3, 4, 0),
|
|
||||||
description="control instruction",
|
|
||||||
)
|
|
||||||
|
|
||||||
results = await pmll.query("Where is the robot now?")
|
for hit in await pmll.query("Where is the robot now?"):
|
||||||
print("\nTop answers:")
|
print("→", hit.fact)
|
||||||
for r in results:
|
|
||||||
print("•", r.fact)
|
|
||||||
|
|
||||||
await pmll.close()
|
await pmll.close()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue