Add Missing Node and edge CRUD (#51)

* add CRUD operations and fix search limit bugs

* format

* update tests

* å

* update tests to double limit call

* add default field

* format

* import correct field
This commit is contained in:
Preston Rasmussen 2024-08-27 16:18:01 -04:00 committed by GitHub
parent 3f3fb60a55
commit 06d8d9359f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 251 additions and 39 deletions

View file

@ -23,6 +23,7 @@ from uuid import uuid4
from neo4j import AsyncDriver from neo4j import AsyncDriver
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from graphiti_core.helpers import parse_db_date
from graphiti_core.llm_client.config import EMBEDDING_DIM from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.nodes import Node from graphiti_core.nodes import Node
@ -38,6 +39,9 @@ class Edge(BaseModel, ABC):
@abstractmethod @abstractmethod
async def save(self, driver: AsyncDriver): ... async def save(self, driver: AsyncDriver): ...
@abstractmethod
async def delete(self, driver: AsyncDriver): ...
def __hash__(self): def __hash__(self):
return hash(self.uuid) return hash(self.uuid)
@ -46,6 +50,9 @@ class Edge(BaseModel, ABC):
return self.uuid == other.uuid return self.uuid == other.uuid
return False return False
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
class EpisodicEdge(Edge): class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver): async def save(self, driver: AsyncDriver):
@ -66,9 +73,48 @@ class EpisodicEdge(Edge):
return result return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
DELETE e
""",
uuid=self.uuid,
)
# TODO: Neo4j doesn't support variables for edge types and labels. logger.info(f'Deleted Edge: {self.uuid}')
# Right now we have all edge nodes as type RELATES_TO
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN
e.uuid As uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuid=uuid,
)
edges: list[EpisodicEdge] = []
for record in records:
edges.append(
EpisodicEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)
)
logger.info(f'Found Edge: {uuid}')
return edges[0]
class EntityEdge(Edge): class EntityEdge(Edge):
@ -97,7 +143,7 @@ class EntityEdge(Edge):
self.fact_embedding = embedding[:EMBEDDING_DIM] self.fact_embedding = embedding[:EMBEDDING_DIM]
end = time() end = time()
logger.info(f'embedded {text} in {end-start} ms') logger.info(f'embedded {text} in {end - start} ms')
return embedding return embedding
@ -127,3 +173,60 @@ class EntityEdge(Edge):
logger.info(f'Saved edge to neo4j: {self.uuid}') logger.info(f'Saved edge to neo4j: {self.uuid}')
return result return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
DELETE e
""",
uuid=self.uuid,
)
logger.info(f'Deleted Edge: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
""",
uuid=uuid,
)
edges: list[EntityEdge] = []
for record in records:
edges.append(
EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
)
logger.info(f'Found Edge: {uuid}')
return edges[0]

7
graphiti_core/helpers.py Normal file
View file

@ -0,0 +1,7 @@
from datetime import datetime
from neo4j import time as neo4j_time
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None

View file

@ -75,6 +75,9 @@ class Node(BaseModel, ABC):
@abstractmethod @abstractmethod
async def save(self, driver: AsyncDriver): ... async def save(self, driver: AsyncDriver): ...
@abstractmethod
async def delete(self, driver: AsyncDriver): ...
def __hash__(self): def __hash__(self):
return hash(self.uuid) return hash(self.uuid)
@ -83,6 +86,9 @@ class Node(BaseModel, ABC):
return self.uuid == other.uuid return self.uuid == other.uuid
return False return False
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
class EpisodicNode(Node): class EpisodicNode(Node):
source: EpisodeType = Field(description='source type') source: EpisodeType = Field(description='source type')
@ -111,13 +117,58 @@ class EpisodicNode(Node):
created_at=self.created_at, created_at=self.created_at,
valid_at=self.valid_at, valid_at=self.valid_at,
source=self.source.value, source=self.source.value,
_database='neo4j',
) )
logger.info(f'Saved Node to neo4j: {self.uuid}') logger.info(f'Saved Node to neo4j: {self.uuid}')
return result return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Episodic {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.info(f'Deleted Node: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic {uuid: $uuid})
RETURN e.content as content,
e.created_at as created_at,
e.valid_at as valid_at,
e.uuid as uuid,
e.name as name,
e.source_description as source_description,
e.source as source
""",
uuid=uuid,
)
episodes = [
EpisodicNode(
content=record['content'],
created_at=record['created_at'].to_native().timestamp(),
valid_at=(record['valid_at'].to_native()),
uuid=record['uuid'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],
)
for record in records
]
logger.info(f'Found Node: {uuid}')
return episodes[0]
class EntityNode(Node): class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name') name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
@ -153,3 +204,47 @@ class EntityNode(Node):
logger.info(f'Saved Node to neo4j: {self.uuid}') logger.info(f'Saved Node to neo4j: {self.uuid}')
return result return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.info(f'Deleted Node: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
RETURN
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
""",
uuid=uuid,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
logger.info(f'Found Node: {uuid}')
return nodes[0]

View file

@ -20,7 +20,7 @@ from enum import Enum
from time import time from time import time
from neo4j import AsyncDriver from neo4j import AsyncDriver
from pydantic import BaseModel from pydantic import BaseModel, Field
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client.config import EMBEDDING_DIM from graphiti_core.llm_client.config import EMBEDDING_DIM
@ -49,8 +49,8 @@ class Reranker(Enum):
class SearchConfig(BaseModel): class SearchConfig(BaseModel):
num_edges: int = 10 num_edges: int = Field(default=10)
num_nodes: int = 10 num_nodes: int = Field(default=10)
num_episodes: int = EPISODE_WINDOW_LEN num_episodes: int = EPISODE_WINDOW_LEN
search_methods: list[SearchMethod] search_methods: list[SearchMethod]
reranker: Reranker | None reranker: Reranker | None
@ -63,12 +63,12 @@ class SearchResults(BaseModel):
async def hybrid_search( async def hybrid_search(
driver: AsyncDriver, driver: AsyncDriver,
embedder, embedder,
query: str, query: str,
timestamp: datetime, timestamp: datetime,
config: SearchConfig, config: SearchConfig,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
) -> SearchResults: ) -> SearchResults:
start = time() start = time()
@ -79,11 +79,11 @@ async def hybrid_search(
search_results = [] search_results = []
if config.num_episodes > 0: if config.num_episodes > 0:
episodes.extend(await retrieve_episodes(driver, timestamp)) episodes.extend(await retrieve_episodes(driver, timestamp, config.num_episodes))
nodes.extend(await get_mentioned_nodes(driver, episodes)) nodes.extend(await get_mentioned_nodes(driver, episodes))
if SearchMethod.bm25 in config.search_methods: if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(query, driver) text_search = await edge_fulltext_search(query, driver, 2 * config.num_edges)
search_results.append(text_search) search_results.append(text_search)
if SearchMethod.cosine_similarity in config.search_methods: if SearchMethod.cosine_similarity in config.search_methods:
@ -94,7 +94,9 @@ async def hybrid_search(
.embedding[:EMBEDDING_DIM] .embedding[:EMBEDDING_DIM]
) )
similarity_search = await edge_similarity_search(search_vector, driver) similarity_search = await edge_similarity_search(
search_vector, driver, 2 * config.num_edges
)
search_results.append(similarity_search) search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None: if len(search_results) > 1 and config.reranker is None:

View file

@ -3,13 +3,12 @@ import logging
import re import re
import typing import typing
from collections import defaultdict from collections import defaultdict
from datetime import datetime
from time import time from time import time
from neo4j import AsyncDriver from neo4j import AsyncDriver
from neo4j import time as neo4j_time
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.helpers import parse_db_date
from graphiti_core.nodes import EntityNode, EpisodicNode from graphiti_core.nodes import EntityNode, EpisodicNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,10 +16,6 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3 RELEVANT_SCHEMA_LIMIT = 3
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]): async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
episode_uuids = [episode.uuid for episode in episodes] episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -106,7 +101,7 @@ async def edge_similarity_search(
# vector similarity search over embedded facts # vector similarity search over embedded facts
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS r, score YIELD relationship AS r, score
MATCH (n)-[r:RELATES_TO]->(m) MATCH (n)-[r:RELATES_TO]->(m)
RETURN RETURN
@ -121,7 +116,7 @@ async def edge_similarity_search(
r.expired_at AS expired_at, r.expired_at AS expired_at,
r.valid_at AS valid_at, r.valid_at AS valid_at,
r.invalid_at AS invalid_at r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit ORDER BY score DESC
""", """,
search_vector=search_vector, search_vector=search_vector,
limit=limit, limit=limit,
@ -316,8 +311,11 @@ async def hybrid_node_search(
relevant_node_uuids = set() relevant_node_uuids = set()
results = await asyncio.gather( results = await asyncio.gather(
*[entity_fulltext_search(q, driver, limit or RELEVANT_SCHEMA_LIMIT) for q in queries], *[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries],
*[entity_similarity_search(e, driver, limit or RELEVANT_SCHEMA_LIMIT) for e in embeddings], *[
entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT))
for e in embeddings
],
) )
for result in results: for result in results:

View file

@ -22,8 +22,6 @@ from datetime import datetime
import pytest import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from openai import OpenAI
from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti from graphiti_core.graphiti import Graphiti
@ -74,7 +72,7 @@ def format_context(facts):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graphiti_init(): async def test_graphiti_init():
logger = setup_logging() logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
edges = await graphiti.search('Freakenomics guest') edges = await graphiti.search('Freakenomics guest')
@ -92,11 +90,9 @@ async def test_graphiti_init():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_integration(): async def test_graph_integration():
driver = AsyncGraphDatabase.driver( client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
NEO4J_URI, embedder = client.llm_client.get_embedder()
auth=(NEO4j_USER, NEO4j_PASSWORD), driver = client.driver
)
embedder = OpenAI().embeddings
now = datetime.now() now = datetime.now()
episode = EpisodicNode( episode = EpisodicNode(
@ -139,10 +135,21 @@ async def test_graph_integration():
invalid_at=now, invalid_at=now,
) )
entity_edge.generate_embedding(embedder) await entity_edge.generate_embedding(embedder)
nodes = [episode, alice_node, bob_node] nodes = [episode, alice_node, bob_node]
edges = [episodic_edge_1, episodic_edge_2, entity_edge] edges = [episodic_edge_1, episodic_edge_2, entity_edge]
# test save
await asyncio.gather(*[node.save(driver) for node in nodes]) await asyncio.gather(*[node.save(driver) for node in nodes])
await asyncio.gather(*[edge.save(driver) for edge in edges]) await asyncio.gather(*[edge.save(driver) for edge in edges])
# test get
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
# test delete
await asyncio.gather(*[node.delete(driver) for node in nodes])
await asyncio.gather(*[edge.delete(driver) for edge in edges])

View file

@ -113,8 +113,8 @@ async def test_hybrid_node_search_with_limit():
assert mock_fulltext_search.call_count == 1 assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1 assert mock_similarity_search.call_count == 1
# Verify that the limit was passed to the search functions # Verify that the limit was passed to the search functions
mock_fulltext_search.assert_called_with('Test', mock_driver, 1) mock_fulltext_search.assert_called_with('Test', mock_driver, 2)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 1) mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -148,5 +148,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'} assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
assert mock_fulltext_search.call_count == 1 assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1 assert mock_similarity_search.call_count == 1
mock_fulltext_search.assert_called_with('Test', mock_driver, 2) mock_fulltext_search.assert_called_with('Test', mock_driver, 4)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2) mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 4)