chore/prepare kuzu integration (#762)

* Prepare code

* Fix tests

* As -> AS, remove trailing spaces

* Enable more tests for FalkorDB

* Fix more cypher queries

* Return all created nodes and edges

* Add Neo4j service to unit tests workflow

- Introduced Neo4j as a service in the GitHub Actions workflow for unit tests.
- Configured Neo4j with appropriate ports, authentication, and health checks.
- Updated test steps to include waiting for Neo4j and running integration tests against it.
- Set environment variables for Neo4j connection in both non-integration and integration test steps.

* Update Neo4j authentication in unit tests workflow

- Changed Neo4j authentication password from 'test' to 'testpass' in the GitHub Actions workflow.
- Updated health check command to reflect the new password.
- Ensured consistency across all test steps that utilize Neo4j credentials.

* fix health check

* Fix Neo4j integration tests in CI workflow

Remove reference to non-existent test_neo4j_driver.py file from test command.
Integration tests now run via parametrized tests using the drivers list.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Add OPENAI_API_KEY to Neo4j integration tests

Neo4j integration tests require OpenAI API access for LLM functionality.
Add the secret environment variable to enable these tests to run properly.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix Neo4j Cypher syntax error in BFS search queries

Replace parameter substitution in relationship pattern ranges (*1..$depth)
with direct string interpolation (*1..{bfs_max_depth}). Neo4j doesn't allow
parameter maps in MATCH pattern ranges - they must be literal values.

Fixed in both node_bfs_search and edge_bfs_search functions.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix variable name mismatch in edge_bfs_search query

Change relationship variable from 'r' to 'e' to match ENTITY_EDGE_RETURN
constant expectations. The ENTITY_EDGE_RETURN constant references variable
'e' for relationships, but the query was using 'r', causing "Variable e
not defined" errors.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Isolate database tests in CI workflow

- FalkorDB tests: Add DISABLE_NEO4J=1 and remove Neo4j env vars
- Neo4j tests: Keep current setup without DISABLE_NEO4J flag

This ensures proper test isolation where each test suite only runs
against its intended database backend.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Siddhartha Sahu <sid@kuzudb.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Daniel Chalef 2025-07-29 06:07:34 -07:00 committed by GitHub
parent 9ceeb54186
commit dcc9da3f68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1339 additions and 1068 deletions

View file

@ -20,6 +20,15 @@ jobs:
ports:
- 6379:6379
options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5
neo4j:
image: neo4j:5.26-community
ports:
- 7687:7687
- 7474:7474
env:
NEO4J_AUTH: neo4j/testpass
NEO4J_PLUGINS: '["apoc"]'
options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10
steps:
- uses: actions/checkout@v4
- name: Set up Python
@ -37,15 +46,33 @@ jobs:
- name: Run non-integration tests
env:
PYTHONPATH: ${{ github.workspace }}
NEO4J_URI: bolt://localhost:7687
NEO4J_USER: neo4j
NEO4J_PASSWORD: testpass
run: |
uv run pytest -m "not integration"
- name: Wait for FalkorDB
run: |
timeout 60 bash -c 'until redis-cli -h localhost -p 6379 ping; do sleep 1; done'
- name: Wait for Neo4j
run: |
timeout 60 bash -c 'until wget -O /dev/null http://localhost:7474 >/dev/null 2>&1; do sleep 1; done'
- name: Run FalkorDB integration tests
env:
PYTHONPATH: ${{ github.workspace }}
FALKORDB_HOST: localhost
FALKORDB_PORT: 6379
DISABLE_NEO4J: 1
run: |
uv run pytest tests/driver/test_falkordb_driver.py
- name: Run Neo4j integration tests
env:
PYTHONPATH: ${{ github.workspace }}
NEO4J_URI: bolt://localhost:7687
NEO4J_USER: neo4j
NEO4J_PASSWORD: testpass
FALKORDB_HOST: localhost
FALKORDB_PORT: 6379
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
uv run pytest tests/test_*_int.py -k "neo4j"

View file

@ -18,11 +18,17 @@ import copy
import logging
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from enum import Enum
from typing import Any
logger = logging.getLogger(__name__)
class GraphProvider(Enum):
NEO4J = 'neo4j'
FALKORDB = 'falkordb'
class GraphDriverSession(ABC):
async def __aenter__(self):
return self
@ -46,7 +52,7 @@ class GraphDriverSession(ABC):
class GraphDriver(ABC):
provider: str
provider: GraphProvider
fulltext_syntax: str = (
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
)

View file

@ -32,7 +32,7 @@ else:
'Install it with: pip install graphiti-core[falkordb]'
) from None
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
logger = logging.getLogger(__name__)
@ -71,7 +71,7 @@ class FalkorDriverSession(GraphDriverSession):
class FalkorDriver(GraphDriver):
provider: str = 'falkordb'
provider = GraphProvider.FALKORDB
def __init__(
self,
@ -119,7 +119,7 @@ class FalkorDriver(GraphDriver):
# check if index already exists
logger.info(f'Index already exists: {e}')
return None
logger.error(f'Error executing FalkorDB query: {e}')
logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
raise
# Convert the result header to a list of strings

View file

@ -21,13 +21,13 @@ from typing import Any
from neo4j import AsyncGraphDatabase, EagerResult
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
logger = logging.getLogger(__name__)
class Neo4jDriver(GraphDriver):
provider: str = 'neo4j'
provider = GraphProvider.NEO4J
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
super().__init__()
@ -45,7 +45,11 @@ class Neo4jDriver(GraphDriver):
params = {}
params.setdefault('database_', self._database)
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
try:
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
except Exception as e:
logger.error(f'Error executing Neo4j query: {e}\n{cypher_query_}\n{params}')
raise
return result

View file

@ -29,29 +29,17 @@ from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_SAVE,
ENTITY_EDGE_SAVE,
COMMUNITY_EDGE_RETURN,
ENTITY_EDGE_RETURN,
EPISODIC_EDGE_RETURN,
EPISODIC_EDGE_SAVE,
get_community_edge_save_query,
get_entity_edge_save_query,
)
from graphiti_core.nodes import Node
logger = logging.getLogger(__name__)
ENTITY_EDGE_RETURN: LiteralString = """
RETURN
e.uuid AS uuid,
startNode(e).uuid AS source_node_uuid,
endNode(e).uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes"""
class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
@ -66,9 +54,9 @@ class Edge(BaseModel, ABC):
async def delete(self, driver: GraphDriver):
result = await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
)
@ -107,14 +95,10 @@ class EpisodicEdge(Edge):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN
"""
+ EPISODIC_EDGE_RETURN,
uuid=uuid,
routing_='r',
)
@ -129,15 +113,11 @@ class EpisodicEdge(Edge):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
"""
+ EPISODIC_EDGE_RETURN,
uuids=uuids,
routing_='r',
)
@ -161,19 +141,17 @@ class EpisodicEdge(Edge):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
ORDER BY e.uuid DESC
"""
RETURN
"""
+ EPISODIC_EDGE_RETURN
+ """
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
@ -221,11 +199,14 @@ class EntityEdge(Edge):
return self.fact_embedding
async def load_fact_embedding(self, driver: GraphDriver):
query: LiteralString = """
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
"""
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
""",
uuid=self.uuid,
routing_='r',
)
if len(records) == 0:
raise EdgeNotFoundError(self.uuid)
@ -251,7 +232,7 @@ class EntityEdge(Edge):
edge_data.update(self.attributes or {})
result = await driver.execute_query(
ENTITY_EDGE_SAVE,
get_entity_edge_save_query(driver.provider),
edge_data=edge_data,
)
@ -263,8 +244,9 @@ class EntityEdge(Edge):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN
"""
+ ENTITY_EDGE_RETURN,
uuid=uuid,
routing_='r',
@ -283,9 +265,10 @@ class EntityEdge(Edge):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.uuid IN $uuids
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
"""
+ ENTITY_EDGE_RETURN,
uuids=uuids,
routing_='r',
@ -314,22 +297,21 @@ class EntityEdge(Edge):
else ''
)
query: LiteralString = (
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ ENTITY_EDGE_RETURN
+ with_embeddings_query
+ """
ORDER BY e.uuid DESC
"""
+ limit_query
)
records, _, _ = await driver.execute_query(
query,
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
@ -344,13 +326,15 @@ class EntityEdge(Edge):
@classmethod
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
query: LiteralString = (
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
+ ENTITY_EDGE_RETURN
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
RETURN
"""
+ ENTITY_EDGE_RETURN,
node_uuid=node_uuid,
routing_='r',
)
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
edges = [get_entity_edge_from_record(record) for record in records]
@ -360,7 +344,7 @@ class EntityEdge(Edge):
class CommunityEdge(Edge):
async def save(self, driver: GraphDriver):
result = await driver.execute_query(
COMMUNITY_EDGE_SAVE,
get_community_edge_save_query(driver.provider),
community_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
@ -376,14 +360,10 @@ class CommunityEdge(Edge):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
RETURN
"""
+ COMMUNITY_EDGE_RETURN,
uuid=uuid,
routing_='r',
)
@ -396,15 +376,11 @@ class CommunityEdge(Edge):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
WHERE e.uuid IN $uuids
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
RETURN
"""
+ COMMUNITY_EDGE_RETURN,
uuids=uuids,
routing_='r',
)
@ -426,19 +402,17 @@ class CommunityEdge(Edge):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
WHERE e.group_id IN $group_ids
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
ORDER BY e.uuid DESC
"""
RETURN
"""
+ COMMUNITY_EDGE_RETURN
+ """
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,

View file

@ -5,16 +5,9 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB,
supporting index creation, fulltext search, and bulk operations.
"""
from typing import Any
from typing_extensions import LiteralString
from graphiti_core.models.edges.edge_db_queries import (
ENTITY_EDGE_SAVE_BULK,
)
from graphiti_core.models.nodes.node_db_queries import (
ENTITY_NODE_SAVE_BULK,
)
from graphiti_core.driver.driver import GraphProvider
# Mapping from Neo4j fulltext index names to FalkorDB node labels
NEO4J_TO_FALKORDB_MAPPING = {
@ -25,8 +18,8 @@ NEO4J_TO_FALKORDB_MAPPING = {
}
def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
if db_type == 'falkordb':
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
if provider == GraphProvider.FALKORDB:
return [
# Entity node
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
@ -41,109 +34,70 @@ def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
# HAS_MEMBER edge
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
]
else:
return [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
]
return [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
]
def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
if db_type == 'falkordb':
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
if provider == GraphProvider.FALKORDB:
return [
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
]
else:
return [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
FOR (n:Community) ON EACH [n.name, n.group_id]""",
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
]
return [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
FOR (n:Community) ON EACH [n.name, n.group_id]""",
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
]
def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
if db_type == 'falkordb':
def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
else:
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
if db_type == 'falkordb':
def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
else:
return f'vector.similarity.cosine({vec1}, {vec2})'
return f'vector.similarity.cosine({vec1}, {vec2})'
def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
if db_type == 'falkordb':
def get_relationships_query(name: str, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
else:
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
if db_type == 'falkordb':
queries = []
for node in nodes:
for label in node['labels']:
queries.append(
(
f"""
UNWIND $nodes AS node
MERGE (n:Entity {{uuid: node.uuid}})
SET n:{label}
SET n = node
WITH n, node
SET n.name_embedding = vecf32(node.name_embedding)
RETURN n.uuid AS uuid
""",
{'nodes': [node]},
)
)
return queries
else:
return ENTITY_NODE_SAVE_BULK
def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
if db_type == 'falkordb':
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
WITH r, edge
RETURN edge.uuid AS uuid"""
else:
return ENTITY_EDGE_SAVE_BULK
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'

View file

@ -26,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import (
@ -93,8 +93,11 @@ load_dotenv()
class AddEpisodeResults(BaseModel):
episode: EpisodicNode
episodic_edges: list[EpisodicEdge]
nodes: list[EntityNode]
edges: list[EntityEdge]
communities: list[CommunityNode]
community_edges: list[CommunityEdge]
class Graphiti:
@ -520,9 +523,12 @@ class Graphiti:
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
)
communities = []
community_edges = []
# Update any communities
if update_communities:
await semaphore_gather(
communities, community_edges = await semaphore_gather(
*[
update_community(self.driver, self.llm_client, self.embedder, node)
for node in nodes
@ -532,7 +538,14 @@ class Graphiti:
end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
return AddEpisodeResults(
episode=episode,
episodic_edges=episodic_edges,
nodes=hydrated_nodes,
edges=entity_edges,
communities=communities,
community_edges=community_edges,
)
except Exception as e:
raise e
@ -817,7 +830,9 @@ class Graphiti:
except Exception as e:
raise e
async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
async def build_communities(
self, group_ids: list[str] | None = None
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
"""
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
the content of these communities.
@ -846,7 +861,7 @@ class Graphiti:
max_coroutines=self.max_coroutines,
)
return community_nodes
return community_nodes, community_edges
async def search(
self,

View file

@ -28,6 +28,7 @@ from numpy._typing import NDArray
from pydantic import BaseModel
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphProvider
from graphiti_core.errors import GroupIdValidationError
load_dotenv()
@ -52,12 +53,12 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None
)
def get_default_group_id(db_type: str) -> str:
def get_default_group_id(provider: GraphProvider) -> str:
"""
This function differentiates the default group id based on the database type.
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
"""
if db_type == 'falkordb':
if provider == GraphProvider.FALKORDB:
return '_'
else:
return ''

View file

@ -14,43 +14,117 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from graphiti_core.driver.driver import GraphProvider
EPISODIC_EDGE_SAVE = """
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid"""
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN e.uuid AS uuid
"""
EPISODIC_EDGE_SAVE_BULK = """
UNWIND $episodic_edges AS edge
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
MATCH (node:Entity {uuid: edge.target_node_uuid})
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
RETURN r.uuid AS uuid
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
MATCH (node:Entity {uuid: edge.target_node_uuid})
MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
SET e = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
RETURN e.uuid AS uuid
"""
ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET r = $edge_data
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
RETURN r.uuid AS uuid"""
ENTITY_EDGE_SAVE_BULK = """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = edge
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
RETURN edge.uuid AS uuid
EPISODIC_EDGE_RETURN = """
e.uuid AS uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
"""
COMMUNITY_EDGE_SAVE = """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid"""
def get_entity_edge_save_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
RETURN e.uuid AS uuid
"""
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
RETURN e.uuid AS uuid
"""
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
WITH r, edge
RETURN edge.uuid AS uuid
"""
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
SET e = edge
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
RETURN edge.uuid AS uuid
"""
ENTITY_EDGE_RETURN = """
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.group_id AS group_id,
e.name AS name,
e.fact AS fact,
e.episodes AS episodes,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes
"""
def get_community_edge_save_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN e.uuid AS uuid
"""
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN e.uuid AS uuid
"""
COMMUNITY_EDGE_RETURN = """
e.uuid AS uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
"""

View file

@ -14,39 +14,120 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any
from graphiti_core.driver.driver import GraphProvider
EPISODIC_NODE_SAVE = """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid"""
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
EPISODIC_NODE_SAVE_BULK = """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
ENTITY_NODE_SAVE = """
MERGE (n:Entity {uuid: $entity_data.uuid})
SET n:$($labels)
SET n = $entity_data
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
RETURN n.uuid AS uuid"""
ENTITY_NODE_SAVE_BULK = """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n:$(node.labels)
SET n = node
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
EPISODIC_NODE_RETURN = """
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
e.entity_edges AS entity_edges
"""
COMMUNITY_NODE_SAVE = """
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
if provider == GraphProvider.FALKORDB:
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
RETURN n.uuid AS uuid
"""
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
RETURN n.uuid AS uuid
"""
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
if provider == GraphProvider.FALKORDB:
queries = []
for node in nodes:
for label in node['labels']:
queries.append(
(
f"""
UNWIND $nodes AS node
MERGE (n:Entity {{uuid: node.uuid}})
SET n:{label}
SET n = node
WITH n, node
SET n.name_embedding = vecf32(node.name_embedding)
RETURN n.uuid AS uuid
""",
{'nodes': [node]},
)
)
return queries
return """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n:$(node.labels)
SET n = node
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
"""
ENTITY_NODE_RETURN = """
n.uuid AS uuid,
n.name AS name,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
"""
def get_community_node_save_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: $name_embedding}
RETURN n.uuid AS uuid
"""
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""
RETURN n.uuid AS uuid
"""
COMMUNITY_NODE_RETURN = """
n.uuid AS uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.summary AS summary,
n.created_at AS created_at
"""

View file

@ -25,29 +25,22 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_SAVE,
ENTITY_NODE_SAVE,
COMMUNITY_NODE_RETURN,
ENTITY_NODE_RETURN,
EPISODIC_NODE_RETURN,
EPISODIC_NODE_SAVE,
get_community_node_save_query,
get_entity_node_save_query,
)
from graphiti_core.utils.datetime_utils import utc_now
logger = logging.getLogger(__name__)
ENTITY_NODE_RETURN: LiteralString = """
RETURN
n.uuid As uuid,
n.name AS name,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes"""
class EpisodeType(Enum):
"""
@ -96,18 +89,26 @@ class Node(BaseModel, ABC):
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver):
result = await driver.execute_query(
"""
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
if driver.provider == GraphProvider.FALKORDB:
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
else:
await driver.execute_query(
"""
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.debug(f'Deleted Node: {self.uuid}')
return result
def __hash__(self):
return hash(self.uuid)
@ -118,15 +119,23 @@ class Node(BaseModel, ABC):
@classmethod
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
await driver.execute_query(
"""
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
DETACH DELETE n
""",
group_id=group_id,
)
return 'SUCCESS'
if driver.provider == GraphProvider.FALKORDB:
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{group_id: $group_id}})
DETACH DELETE n
""",
group_id=group_id,
)
else:
await driver.execute_query(
"""
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
DETACH DELETE n
""",
group_id=group_id,
)
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@ -169,17 +178,10 @@ class EpisodicNode(Node):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic {uuid: $uuid})
RETURN e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
e.entity_edges AS entity_edges
""",
MATCH (e:Episodic {uuid: $uuid})
RETURN
"""
+ EPISODIC_NODE_RETURN,
uuid=uuid,
routing_='r',
)
@ -195,18 +197,11 @@ class EpisodicNode(Node):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.uuid IN $uuids
MATCH (e:Episodic)
WHERE e.uuid IN $uuids
RETURN DISTINCT
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
e.entity_edges AS entity_edges
""",
"""
+ EPISODIC_NODE_RETURN,
uuids=uuids,
routing_='r',
)
@ -228,22 +223,17 @@ class EpisodicNode(Node):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
"""
MATCH (e:Episodic)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN DISTINCT
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
e.entity_edges AS entity_edges
ORDER BY e.uuid DESC
"""
"""
+ EPISODIC_NODE_RETURN
+ """
ORDER BY uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
@ -259,18 +249,10 @@ class EpisodicNode(Node):
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
RETURN DISTINCT
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
e.entity_edges AS entity_edges
""",
"""
+ EPISODIC_NODE_RETURN,
entity_node_uuid=entity_node_uuid,
routing_='r',
)
@ -297,11 +279,14 @@ class EntityNode(Node):
return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver):
query: LiteralString = """
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
RETURN n.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
""",
uuid=self.uuid,
routing_='r',
)
if len(records) == 0:
raise NodeNotFoundError(self.uuid)
@ -317,12 +302,12 @@ class EntityNode(Node):
'summary': self.summary,
'created_at': self.created_at,
}
entity_data.update(self.attributes or {})
labels = ':'.join(self.labels + ['Entity'])
result = await driver.execute_query(
ENTITY_NODE_SAVE,
labels=self.labels + ['Entity'],
get_entity_node_save_query(driver.provider, labels),
entity_data=entity_data,
)
@ -332,14 +317,12 @@ class EntityNode(Node):
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
query = (
"""
MATCH (n:Entity {uuid: $uuid})
"""
+ ENTITY_NODE_RETURN
)
records, _, _ = await driver.execute_query(
query,
"""
MATCH (n:Entity {uuid: $uuid})
RETURN
"""
+ ENTITY_NODE_RETURN,
uuid=uuid,
routing_='r',
)
@ -355,8 +338,10 @@ class EntityNode(Node):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity) WHERE n.uuid IN $uuids
"""
MATCH (n:Entity)
WHERE n.uuid IN $uuids
RETURN
"""
+ ENTITY_NODE_RETURN,
uuids=uuids,
routing_='r',
@ -379,22 +364,26 @@ class EntityNode(Node):
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
with_embeddings_query: LiteralString = (
""",
n.name_embedding AS name_embedding
"""
n.name_embedding AS name_embedding
"""
if with_embeddings
else ''
)
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity) WHERE n.group_id IN $group_ids
"""
MATCH (n:Entity)
WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ ENTITY_NODE_RETURN
+ with_embeddings_query
+ """
ORDER BY n.uuid DESC
"""
ORDER BY n.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
@ -413,7 +402,7 @@ class CommunityNode(Node):
async def save(self, driver: GraphDriver):
result = await driver.execute_query(
COMMUNITY_NODE_SAVE,
get_community_node_save_query(driver.provider),
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
@ -436,11 +425,14 @@ class CommunityNode(Node):
return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver):
query: LiteralString = """
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community {uuid: $uuid})
RETURN c.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
""",
uuid=self.uuid,
routing_='r',
)
if len(records) == 0:
raise NodeNotFoundError(self.uuid)
@ -451,14 +443,10 @@ class CommunityNode(Node):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community {uuid: $uuid})
RETURN
n.uuid As uuid,
n.name AS name,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
""",
MATCH (n:Community {uuid: $uuid})
RETURN
"""
+ COMMUNITY_NODE_RETURN,
uuid=uuid,
routing_='r',
)
@ -474,14 +462,11 @@ class CommunityNode(Node):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community) WHERE n.uuid IN $uuids
RETURN
n.uuid As uuid,
n.name AS name,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
""",
MATCH (n:Community)
WHERE n.uuid IN $uuids
RETURN
"""
+ COMMUNITY_NODE_RETURN,
uuids=uuids,
routing_='r',
)
@ -503,18 +488,17 @@ class CommunityNode(Node):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community) WHERE n.group_id IN $group_ids
"""
MATCH (n:Community)
WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
n.uuid As uuid,
n.name AS name,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
ORDER BY n.uuid DESC
"""
RETURN
"""
+ COMMUNITY_NODE_RETURN
+ """
ORDER BY n.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
@ -586,6 +570,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
if not nodes: # Handle empty list case
return
name_embeddings = await embedder.create_batch([node.name for node in nodes])
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
node.name_embedding = name_embedding

View file

@ -72,7 +72,7 @@ def edge_search_filter_query_constructor(
if filters.edge_types is not None:
edge_types = filters.edge_types
edge_types_filter = '\nAND r.name in $edge_types'
edge_types_filter = '\nAND e.name in $edge_types'
filter_query += edge_types_filter
filter_params['edge_types'] = edge_types
@ -88,7 +88,7 @@ def edge_search_filter_query_constructor(
filter_params['valid_at_' + str(j)] = date_filter.date
and_filters = [
'(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
'(e.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
@ -113,7 +113,7 @@ def edge_search_filter_query_constructor(
filter_params['invalid_at_' + str(j)] = date_filter.date
and_filters = [
'(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
'(e.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
@ -138,7 +138,7 @@ def edge_search_filter_query_constructor(
filter_params['created_at_' + str(j)] = date_filter.date
and_filters = [
'(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
'(e.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
@ -163,7 +163,7 @@ def edge_search_filter_query_constructor(
filter_params['expired_at_' + str(j)] = date_filter.date
and_filters = [
'(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
'(e.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''

View file

@ -23,7 +23,7 @@ import numpy as np
from numpy._typing import NDArray
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.graph_queries import (
get_nodes_query,
@ -36,6 +36,8 @@ from graphiti_core.helpers import (
normalize_l2,
semaphore_gather,
)
from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN
from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN
from graphiti_core.nodes import (
ENTITY_NODE_RETURN,
CommunityNode,
@ -100,20 +102,13 @@ async def get_mentioned_nodes(
) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes]
query = """
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
"""
records, _, _ = await driver.execute_query(
query,
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
WHERE episode.uuid IN $uuids
RETURN DISTINCT
"""
+ ENTITY_NODE_RETURN,
uuids=episode_uuids,
routing_='r',
)
@ -128,18 +123,13 @@ async def get_communities_by_nodes(
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
query = """
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
RETURN DISTINCT
c.uuid As uuid,
c.group_id AS group_id,
c.name AS name,
c.created_at AS created_at,
c.summary AS summary
"""
records, _, _ = await driver.execute_query(
query,
"""
MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity)
WHERE m.uuid IN $uuids
RETURN DISTINCT
"""
+ COMMUNITY_NODE_RETURN,
uuids=node_uuids,
routing_='r',
)
@ -164,38 +154,30 @@ async def edge_fulltext_search(
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query = (
get_relationships_query('edge_name_and_fact', db_type=driver.provider)
get_relationships_query('edge_name_and_fact', provider=driver.provider)
+ """
YIELD relationship AS rel, score
MATCH (n:Entity)-[r:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
WHERE r.group_id IN $group_ids """
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
WHERE e.group_id IN $group_ids """
+ filter_query
+ """
WITH r, score, startNode(r) AS n, endNode(r) AS m
WITH e, score, n, m
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at,
properties(r) AS attributes
ORDER BY score DESC LIMIT $limit
"""
+ ENTITY_EDGE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
params=filter_params,
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
routing_='r',
**filter_params,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -219,58 +201,47 @@ async def edge_similarity_search(
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query_params.update(filter_params)
group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
if group_ids is not None:
group_filter_query += '\nAND r.group_id IN $group_ids'
group_filter_query += '\nAND e.group_id IN $group_ids'
query_params['group_ids'] = group_ids
query_params['source_node_uuid'] = source_node_uuid
query_params['target_node_uuid'] = target_node_uuid
if source_node_uuid is not None:
group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
query_params['source_uuid'] = source_node_uuid
group_filter_query += '\nAND (n.uuid = $source_uuid)'
if target_node_uuid is not None:
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
query_params['target_uuid'] = target_node_uuid
group_filter_query += '\nAND (m.uuid = $target_uuid)'
query = (
RUNTIME_QUERY
+ """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ filter_query
+ """
WITH DISTINCT r, """
+ get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
WITH DISTINCT e, n, m, """
+ get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
+ """ AS score
WHERE score > $min_score
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
startNode(r).uuid AS source_node_uuid,
endNode(r).uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at,
properties(r) AS attributes
"""
+ ENTITY_EDGE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, header, _ = await driver.execute_query(
records, _, _ = await driver.execute_query(
query,
params=query_params,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -293,41 +264,31 @@ async def edge_bfs_search(
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query = (
f"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
WHERE e.uuid = rel.uuid
AND e.group_id IN $group_ids
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid
AND r.group_id IN $group_ids
"""
+ filter_query
+ """
RETURN DISTINCT
r.uuid AS uuid,
r.group_id AS group_id,
startNode(r).uuid AS source_node_uuid,
endNode(r).uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at,
properties(r) AS attributes
LIMIT $limit
+ """
RETURN DISTINCT
"""
+ ENTITY_EDGE_RETURN
+ """
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
params=filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
group_ids=group_ids,
limit=limit,
routing_='r',
**filter_params,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -352,23 +313,27 @@ async def node_fulltext_search(
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
+ """
YIELD node AS n, score
WITH n, score
LIMIT $limit
WHERE n:Entity AND n.group_id IN $group_ids
WHERE n:Entity AND n.group_id IN $group_ids
WITH n, score
LIMIT $limit
"""
+ filter_query
+ """
RETURN
"""
+ ENTITY_NODE_RETURN
+ """
ORDER BY score DESC
"""
)
records, header, _ = await driver.execute_query(
records, _, _ = await driver.execute_query(
query,
params=filter_params,
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
routing_='r',
**filter_params,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -406,22 +371,23 @@ async def node_similarity_search(
WITH n, """
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
+ """ AS score
WHERE score > $min_score"""
WHERE score > $min_score
RETURN
"""
+ ENTITY_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
"""
)
records, header, _ = await driver.execute_query(
records, _, _ = await driver.execute_query(
query,
params=query_params,
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -444,26 +410,29 @@ async def node_bfs_search(
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
query = (
f"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
WHERE n.group_id = origin.group_id
AND origin.group_id IN $group_ids
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id
AND origin.group_id IN $group_ids
"""
+ filter_query
+ """
RETURN
"""
+ ENTITY_NODE_RETURN
+ """
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
params=filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
group_ids=group_ids,
limit=limit,
routing_='r',
**filter_params,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -489,16 +458,10 @@ async def episode_fulltext_search(
MATCH (e:Episodic)
WHERE e.uuid = episode.uuid
AND e.group_id IN $group_ids
RETURN
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
e.entity_edges AS entity_edges
RETURN
"""
+ EPISODIC_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
@ -530,15 +493,12 @@ async def community_fulltext_search(
query = (
get_nodes_query(driver.provider, 'community_name', '$query')
+ """
YIELD node AS comm, score
WHERE comm.group_id IN $group_ids
YIELD node AS n, score
WHERE n.group_id IN $group_ids
RETURN
comm.uuid AS uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.created_at AS created_at,
comm.summary AS summary,
comm.name_embedding AS name_embedding
"""
+ COMMUNITY_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
@ -568,39 +528,37 @@ async def community_similarity_search(
group_filter_query: LiteralString = ''
if group_ids is not None:
group_filter_query += 'WHERE comm.group_id IN $group_ids'
group_filter_query += 'WHERE n.group_id IN $group_ids'
query_params['group_ids'] = group_ids
query = (
RUNTIME_QUERY
+ """
MATCH (comm:Community)
"""
MATCH (n:Community)
"""
+ group_filter_query
+ """
WITH comm, """
+ get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
WITH n,
"""
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
+ """ AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.created_at AS created_at,
comm.summary AS summary,
comm.name_embedding AS name_embedding
ORDER BY score DESC
LIMIT $limit
WHERE score > $min_score
RETURN
"""
+ COMMUNITY_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
communities = [get_community_node_from_record(record) for record in records]
@ -719,8 +677,8 @@ async def get_relevant_nodes(
WHERE m.group_id = $group_id
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
WITH node,
top_vector_nodes,
WITH node,
top_vector_nodes,
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
@ -728,10 +686,10 @@ async def get_relevant_nodes(
UNWIND combined_nodes AS combined_node
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
RETURN
RETURN
node.uuid AS search_node_uuid,
[x IN deduped_nodes | {
uuid: x.uuid,
uuid: x.uuid,
name: x.name,
name_embedding: x.name_embedding,
group_id: x.group_id,
@ -755,12 +713,12 @@ async def get_relevant_nodes(
results, _, _ = await driver.execute_query(
query,
params=query_params,
nodes=query_nodes,
group_id=group_id,
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
relevant_nodes_dict: dict[str, list[EntityNode]] = {
@ -825,11 +783,11 @@ async def get_relevant_edges(
results, _, _ = await driver.execute_query(
query,
params=query_params,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
relevant_edges_dict: dict[str, list[EntityEdge]] = {
@ -895,11 +853,11 @@ async def get_edge_invalidation_candidates(
results, _, _ = await driver.execute_query(
query,
params=query_params,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
@ -943,18 +901,17 @@ async def node_distance_reranker(
scores: dict[str, float] = {center_node_uuid: 0.0}
# Find the shortest path to center node
query = """
results, header, _ = await driver.execute_query(
"""
UNWIND $node_uuids AS node_uuid
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
RETURN 1 AS score, node_uuid AS uuid
"""
results, header, _ = await driver.execute_query(
query,
""",
node_uuids=filtered_uuids,
center_uuid=center_node_uuid,
routing_='r',
)
if driver.provider == 'falkordb':
if driver.provider == GraphProvider.FALKORDB:
results = [dict(zip(header, row, strict=True)) for row in results]
for result in results:
@ -987,13 +944,12 @@ async def episode_mentions_reranker(
scores: dict[str, float] = {}
# Find the shortest path to center node
query = """
UNWIND $node_uuids AS node_uuid
results, _, _ = await driver.execute_query(
"""
UNWIND $node_uuids AS node_uuid
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
RETURN count(*) AS score, n.uuid AS uuid
"""
results, _, _ = await driver.execute_query(
query,
""",
node_uuids=sorted_uuids,
routing_='r',
)
@ -1053,15 +1009,16 @@ def maximal_marginal_relevance(
async def get_embeddings_for_nodes(
driver: GraphDriver, nodes: list[EntityNode]
) -> dict[str, list[float]]:
query: LiteralString = """MATCH (n:Entity)
WHERE n.uuid IN $node_uuids
RETURN DISTINCT
n.uuid AS uuid,
n.name_embedding AS name_embedding
"""
results, _, _ = await driver.execute_query(
query, node_uuids=[node.uuid for node in nodes], routing_='r'
"""
MATCH (n:Entity)
WHERE n.uuid IN $node_uuids
RETURN DISTINCT
n.uuid AS uuid,
n.name_embedding AS name_embedding
""",
node_uuids=[node.uuid for node in nodes],
routing_='r',
)
embeddings_dict: dict[str, list[float]] = {}
@ -1077,15 +1034,14 @@ async def get_embeddings_for_nodes(
async def get_embeddings_for_communities(
driver: GraphDriver, communities: list[CommunityNode]
) -> dict[str, list[float]]:
query: LiteralString = """MATCH (c:Community)
WHERE c.uuid IN $community_uuids
RETURN DISTINCT
c.uuid AS uuid,
c.name_embedding AS name_embedding
"""
results, _, _ = await driver.execute_query(
query,
"""
MATCH (c:Community)
WHERE c.uuid IN $community_uuids
RETURN DISTINCT
c.uuid AS uuid,
c.name_embedding AS name_embedding
""",
community_uuids=[community.uuid for community in communities],
routing_='r',
)
@ -1103,15 +1059,14 @@ async def get_embeddings_for_communities(
async def get_embeddings_for_edges(
driver: GraphDriver, edges: list[EntityEdge]
) -> dict[str, list[float]]:
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
WHERE e.uuid IN $edge_uuids
RETURN DISTINCT
e.uuid AS uuid,
e.fact_embedding AS fact_embedding
"""
results, _, _ = await driver.execute_query(
query,
"""
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
WHERE e.uuid IN $edge_uuids
RETURN DISTINCT
e.uuid AS uuid,
e.fact_embedding AS fact_embedding
""",
edge_uuids=[edge.uuid for edge in edges],
routing_='r',
)

View file

@ -25,17 +25,15 @@ from typing_extensions import Any
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graph_queries import (
get_entity_edge_save_bulk_query,
get_entity_node_save_bulk_query,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import normalize_l2, semaphore_gather
from graphiti_core.models.edges.edge_db_queries import (
EPISODIC_EDGE_SAVE_BULK,
get_entity_edge_save_bulk_query,
)
from graphiti_core.models.nodes.node_db_queries import (
EPISODIC_NODE_SAVE_BULK,
get_entity_node_save_bulk_query,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.utils.maintenance.edge_operations import (
@ -158,7 +156,7 @@ async def add_nodes_and_edges_bulk_tx(
edges.append(edge_data)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
await tx.run(entity_node_save_bulk, nodes=nodes)
await tx.run(
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]

View file

@ -34,7 +34,7 @@ async def get_community_clusters(
group_id_values, _, _ = await driver.execute_query(
"""
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
RETURN
RETURN
collect(DISTINCT n.group_id) AS group_ids
""",
)
@ -233,10 +233,10 @@ async def determine_entity_community(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
RETURN
c.uuid As uuid,
c.uuid AS uuid,
c.name AS name,
c.group_id AS group_id,
c.created_at AS created_at,
c.created_at AS created_at,
c.summary AS summary
""",
entity_uuid=entity.uuid,
@ -250,10 +250,10 @@ async def determine_entity_community(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
RETURN
c.uuid As uuid,
c.uuid AS uuid,
c.name AS name,
c.group_id AS group_id,
c.created_at AS created_at,
c.created_at AS created_at,
c.summary AS summary
""",
entity_uuid=entity.uuid,
@ -286,11 +286,11 @@ async def determine_entity_community(
async def update_community(
driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
):
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
community, is_new = await determine_entity_community(driver, entity)
if community is None:
return
return [], []
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
new_name = await generate_summary_description(llm_client, new_summary)
@ -298,10 +298,14 @@ async def update_community(
community.summary = new_summary
community.name = new_name
community_edges = []
if is_new:
community_edge = (build_community_edges([entity], community, utc_now()))[0]
await community_edge.save(driver)
community_edges.append(community_edge)
await community.generate_name_embedding(embedder)
await community.save(driver)
return [community], community_edges

View file

@ -15,14 +15,15 @@ limitations under the License.
"""
import logging
from datetime import datetime, timezone
from datetime import datetime
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import parse_db_date, semaphore_gather
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.helpers import semaphore_gather
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
EPISODE_WINDOW_LEN = 3
@ -33,8 +34,8 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
if delete_existing:
records, _, _ = await driver.execute_query(
"""
SHOW INDEXES YIELD name
""",
SHOW INDEXES YIELD name
""",
)
index_names = [record['name'] for record in records]
await semaphore_gather(
@ -108,19 +109,16 @@ async def retrieve_episodes(
query: LiteralString = (
"""
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
+ group_id_filter
+ source_filter
+ """
RETURN e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.group_id AS group_id,
e.name AS name,
e.source_description AS source_description,
e.source AS source
RETURN
"""
+ EPISODIC_NODE_RETURN
+ """
ORDER BY e.valid_at DESC
LIMIT $num_episodes
"""
@ -133,18 +131,5 @@ async def retrieve_episodes(
group_ids=group_ids,
)
episodes = [
EpisodicNode(
content=record['content'],
created_at=parse_db_date(record['created_at'])
or datetime.min.replace(tzinfo=timezone.utc),
valid_at=parse_db_date(record['valid_at']) or datetime.min.replace(tzinfo=timezone.utc),
uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],
)
for record in result
]
episodes = [get_episodic_node_from_record(record) for record in result]
return list(reversed(episodes)) # Return in chronological order

View file

@ -3,9 +3,9 @@ name = "graphiti-core"
description = "A temporal graph building library"
version = "0.18.0"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
{ "name" = "Daniel Chalef", "email" = "daniel@getzep.com" },
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
{ name = "Daniel Chalef", email = "daniel@getzep.com" },
]
readme = "README.md"
license = "Apache-2.0"

View file

@ -21,6 +21,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from graphiti_core.driver.driver import GraphProvider
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
@ -48,7 +50,7 @@ class TestFalkorDriver:
driver = FalkorDriver(
host='test-host', port='1234', username='test-user', password='test-pass'
)
assert driver.provider == 'falkordb'
assert driver.provider == GraphProvider.FALKORDB
mock_falkor_db.assert_called_once_with(
host='test-host', port='1234', username='test-user', password='test-pass'
)
@ -59,14 +61,14 @@ class TestFalkorDriver:
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
mock_falkor_db = MagicMock()
driver = FalkorDriver(falkor_db=mock_falkor_db)
assert driver.provider == 'falkordb'
assert driver.provider == GraphProvider.FALKORDB
assert driver.client is mock_falkor_db
mock_falkor_db_class.assert_not_called()
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_provider(self):
"""Test driver provider identification."""
assert self.driver.provider == 'falkordb'
assert self.driver.provider == GraphProvider.FALKORDB
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_get_graph_with_name(self):

View file

@ -14,10 +14,68 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import pytest
import os
import pytest
from dotenv import load_dotenv
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.helpers import lucene_sanitize
load_dotenv()
HAS_NEO4J = False
HAS_FALKORDB = False
if os.getenv('DISABLE_NEO4J') is None:
try:
from graphiti_core.driver.neo4j_driver import Neo4jDriver
HAS_NEO4J = True
except ImportError:
pass
if os.getenv('DISABLE_FALKORDB') is None:
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver
HAS_FALKORDB = True
except ImportError:
pass
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
def get_driver(driver_name: str) -> GraphDriver:
if driver_name == 'neo4j':
return Neo4jDriver(
uri=NEO4J_URI,
user=NEO4J_USER,
password=NEO4J_PASSWORD,
)
elif driver_name == 'falkordb':
return FalkorDriver(
host=FALKORDB_HOST,
port=int(FALKORDB_PORT),
username=FALKORDB_USER,
password=FALKORDB_PASSWORD,
)
else:
raise ValueError(f'Driver {driver_name} not available')
drivers: list[str] = []
if HAS_NEO4J:
drivers.append('neo4j')
if HAS_FALKORDB:
drivers.append('falkordb')
def test_lucene_sanitize():
# Call the function with test data

384
tests/test_edge_int.py Normal file
View file

@ -0,0 +1,384 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import sys
from datetime import datetime
from uuid import uuid4
import numpy as np
import pytest
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.embedder.openai import OpenAIEmbedder
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from tests.helpers_test import drivers, get_driver
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
group_id = f'test_group_{str(uuid4())}'
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_episodic_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
now = datetime.now()
episode_node = EpisodicNode(
name='test_episode',
labels=[],
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='conversation message',
content='Alice likes Bob',
entity_edges=[],
group_id=group_id,
)
node_count = await get_node_count(graph_driver, episode_node.uuid)
assert node_count == 0
await episode_node.save(graph_driver)
node_count = await get_node_count(graph_driver, episode_node.uuid)
assert node_count == 1
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
group_id=group_id,
)
await alice_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 1
episodic_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid,
target_node_uuid=alice_node.uuid,
created_at=now,
group_id=group_id,
)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
assert edge_count == 0
await episodic_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
assert edge_count == 1
retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
assert retrieved.uuid == episodic_edge.uuid
assert retrieved.source_node_uuid == episode_node.uuid
assert retrieved.target_node_uuid == alice_node.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
assert retrieved[0].source_node_uuid == episode_node.uuid
assert retrieved[0].target_node_uuid == alice_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
assert retrieved[0].source_node_uuid == episode_node.uuid
assert retrieved[0].target_node_uuid == alice_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
await episodic_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
assert edge_count == 0
await episode_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, episode_node.uuid)
assert node_count == 0
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_entity_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
now = datetime.now()
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
group_id=group_id,
)
await alice_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 1
bob_node = EntityNode(
name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
)
await bob_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, bob_node.uuid)
assert node_count == 0
await bob_node.save(graph_driver)
node_count = await get_node_count(graph_driver, bob_node.uuid)
assert node_count == 1
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
group_id=group_id,
)
edge_embedding = await entity_edge.generate_embedding(embedder)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
assert edge_count == 0
await entity_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
assert edge_count == 1
retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
assert retrieved.uuid == entity_edge.uuid
assert retrieved.source_node_uuid == alice_node.uuid
assert retrieved.target_node_uuid == bob_node.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
await entity_edge.load_fact_embedding(graph_driver)
assert np.allclose(entity_edge.fact_embedding, edge_embedding)
await entity_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
assert edge_count == 0
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await bob_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, bob_node.uuid)
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_community_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
now = datetime.now()
community_node_1 = CommunityNode(
name='Community A',
group_id=group_id,
summary='Community A summary',
)
await community_node_1.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, community_node_1.uuid)
assert node_count == 0
await community_node_1.save(graph_driver)
node_count = await get_node_count(graph_driver, community_node_1.uuid)
assert node_count == 1
community_node_2 = CommunityNode(
name='Community B',
group_id=group_id,
summary='Community B summary',
)
await community_node_2.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, community_node_2.uuid)
assert node_count == 0
await community_node_2.save(graph_driver)
node_count = await get_node_count(graph_driver, community_node_2.uuid)
assert node_count == 1
alice_node = EntityNode(
name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
)
await alice_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 1
community_edge = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=community_node_2.uuid,
created_at=now,
group_id=group_id,
)
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
assert edge_count == 0
await community_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
assert edge_count == 1
retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
assert retrieved.uuid == community_edge.uuid
assert retrieved.source_node_uuid == community_node_1.uuid
assert retrieved.target_node_uuid == community_node_2.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
assert retrieved[0].source_node_uuid == community_node_1.uuid
assert retrieved[0].target_node_uuid == community_node_2.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
assert retrieved[0].source_node_uuid == community_node_1.uuid
assert retrieved[0].target_node_uuid == community_node_2.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
await community_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
assert edge_count == 0
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await community_node_1.delete(graph_driver)
node_count = await get_node_count(graph_driver, community_node_1.uuid)
assert node_count == 0
await community_node_2.delete(graph_driver)
node_count = await get_node_count(graph_driver, community_node_2.uuid)
assert node_count == 0
await graph_driver.close()
async def get_node_count(driver: GraphDriver, uuid: str):
results, _, _ = await driver.execute_query(
"""
MATCH (n {uuid: $uuid})
RETURN COUNT(n) as count
""",
uuid=uuid,
)
return int(results[0]['count'])
async def get_edge_count(driver: GraphDriver, uuid: str):
results, _, _ = await driver.execute_query(
"""
MATCH (n)-[e {uuid: $uuid}]->(m)
RETURN COUNT(e) as count
UNION ALL
MATCH (n)-[e:RELATES_TO]->(m {uuid: $uuid})-[e2:RELATES_TO]->(m2)
RETURN COUNT(m) as count
""",
uuid=uuid,
)
return sum(int(result['count']) for result in results)

View file

@ -14,26 +14,18 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from datetime import datetime, timezone
import pytest
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from graphiti_core.graphiti import Graphiti
from graphiti_core.helpers import validate_excluded_entity_types
from tests.helpers_test import drivers, get_driver
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
load_dotenv()
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USER = os.getenv('NEO4J_USER')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
# Test entity type definitions
class Person(BaseModel):
@ -65,9 +57,14 @@ class Location(BaseModel):
@pytest.mark.asyncio
async def test_exclude_default_entity_type():
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_exclude_default_entity_type(driver):
"""Test excluding the default 'Entity' type while keeping custom types."""
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
graphiti = Graphiti(graph_driver=get_driver(driver))
try:
await graphiti.build_indices_and_constraints()
@ -118,9 +115,14 @@ async def test_exclude_default_entity_type():
@pytest.mark.asyncio
async def test_exclude_specific_custom_types():
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_exclude_specific_custom_types(driver):
"""Test excluding specific custom entity types while keeping others."""
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
graphiti = Graphiti(graph_driver=get_driver(driver))
try:
await graphiti.build_indices_and_constraints()
@ -177,9 +179,14 @@ async def test_exclude_specific_custom_types():
@pytest.mark.asyncio
async def test_exclude_all_types():
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_exclude_all_types(driver):
"""Test excluding all entity types (edge case)."""
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
graphiti = Graphiti(graph_driver=get_driver(driver))
try:
await graphiti.build_indices_and_constraints()
@ -221,9 +228,14 @@ async def test_exclude_all_types():
@pytest.mark.asyncio
async def test_exclude_no_types():
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_exclude_no_types(driver):
"""Test normal behavior when no types are excluded (baseline test)."""
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
graphiti = Graphiti(graph_driver=get_driver(driver))
try:
await graphiti.build_indices_and_constraints()
@ -299,9 +311,14 @@ def test_validation_invalid_excluded_types():
@pytest.mark.asyncio
async def test_excluded_types_parameter_validation_in_add_episode():
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_excluded_types_parameter_validation_in_add_episode(driver):
"""Test that add_episode validates excluded_entity_types parameter."""
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
graphiti = Graphiti(graph_driver=get_driver(driver))
try:
entity_types = {

View file

@ -1,164 +0,0 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import os
import sys
import unittest
from datetime import datetime, timezone
import pytest
from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_helpers import search_results_to_context_string
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver
HAS_FALKORDB = True
except ImportError:
FalkorDriver = None
HAS_FALKORDB = False
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
load_dotenv()
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_graphiti_falkordb_init():
logger = setup_logging()
falkor_driver = FalkorDriver(
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
)
graphiti = Graphiti(graph_driver=falkor_driver)
results = await graphiti.search_(query='Who is the user?')
pretty_results = search_results_to_context_string(results)
logger.info(pretty_results)
await graphiti.close()
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_graph_falkordb_integration():
falkor_driver = FalkorDriver(
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
)
client = Graphiti(graph_driver=falkor_driver)
embedder = client.embedder
driver = client.driver
now = datetime.now(timezone.utc)
episode = EpisodicNode(
name='test_episode',
labels=[],
created_at=now,
valid_at=now,
source='message',
source_description='conversation message',
content='Alice likes Bob',
entity_edges=[],
)
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
)
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
)
episodic_edge_2 = EpisodicEdge(
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
)
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
)
await entity_edge.generate_embedding(embedder)
nodes = [episode, alice_node, bob_node]
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
# test save
await semaphore_gather(*[node.save(driver) for node in nodes])
await semaphore_gather(*[edge.save(driver) for edge in edges])
# test get
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
# test delete
await semaphore_gather(*[node.delete(driver) for node in nodes])
await semaphore_gather(*[edge.delete(driver) for edge in edges])
await client.close()

View file

@ -15,31 +15,19 @@ limitations under the License.
"""
import logging
import os
import sys
from datetime import datetime, timezone
import pytest
from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
from graphiti_core.search.search_helpers import search_results_to_context_string
from graphiti_core.utils.datetime_utils import utc_now
from tests.helpers_test import drivers, get_driver
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
load_dotenv()
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4j_USER = os.getenv('NEO4J_USER')
NEO4j_PASSWORD = os.getenv('NEO4J_PASSWORD')
def setup_logging():
# Create a logger
@ -63,9 +51,18 @@ def setup_logging():
@pytest.mark.asyncio
async def test_graphiti_init():
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_graphiti_init(driver):
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
driver = get_driver(driver)
graphiti = Graphiti(graph_driver=driver)
await graphiti.build_indices_and_constraints()
search_filter = SearchFilters(
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]]
)
@ -76,74 +73,6 @@ async def test_graphiti_init():
)
pretty_results = search_results_to_context_string(results)
logger.info(pretty_results)
await graphiti.close()
@pytest.mark.asyncio
async def test_graph_integration():
client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
embedder = client.embedder
driver = client.driver
now = datetime.now(timezone.utc)
episode = EpisodicNode(
name='test_episode',
labels=[],
created_at=now,
valid_at=now,
source='message',
source_description='conversation message',
content='Alice likes Bob',
entity_edges=[],
)
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
)
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
)
episodic_edge_2 = EpisodicEdge(
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
)
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
)
await entity_edge.generate_embedding(embedder)
nodes = [episode, alice_node, bob_node]
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
# test save
await semaphore_gather(*[node.save(driver) for node in nodes])
await semaphore_gather(*[edge.save(driver) for edge in edges])
# test get
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
# test delete
await semaphore_gather(*[node.delete(driver) for node in nodes])
await semaphore_gather(*[edge.delete(driver) for edge in edges])

View file

@ -1,139 +0,0 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodeType,
EpisodicNode,
)
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver
HAS_FALKORDB = True
except ImportError:
FalkorDriver = None
HAS_FALKORDB = False
@pytest.fixture
def sample_entity_node():
return EntityNode(
uuid=str(uuid4()),
name='Test Entity',
group_id='test_group',
labels=['Entity'],
name_embedding=[0.5] * 1024,
summary='Entity Summary',
)
@pytest.fixture
def sample_episodic_node():
return EpisodicNode(
uuid=str(uuid4()),
name='Episode 1',
group_id='test_group',
source=EpisodeType.text,
source_description='Test source',
content='Some content here',
valid_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_community_node():
return CommunityNode(
uuid=str(uuid4()),
name='Community A',
name_embedding=[0.5] * 1024,
group_id='test_group',
summary='Community summary',
)
@pytest.mark.asyncio
@pytest.mark.integration
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_entity_node_save_get_and_delete(sample_entity_node):
falkor_driver = FalkorDriver(
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
)
await sample_entity_node.save(falkor_driver)
retrieved = await EntityNode.get_by_uuid(falkor_driver, sample_entity_node.uuid)
assert retrieved.uuid == sample_entity_node.uuid
assert retrieved.name == 'Test Entity'
assert retrieved.group_id == 'test_group'
await sample_entity_node.delete(falkor_driver)
await falkor_driver.close()
@pytest.mark.asyncio
@pytest.mark.integration
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_community_node_save_get_and_delete(sample_community_node):
falkor_driver = FalkorDriver(
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
)
await sample_community_node.save(falkor_driver)
retrieved = await CommunityNode.get_by_uuid(falkor_driver, sample_community_node.uuid)
assert retrieved.uuid == sample_community_node.uuid
assert retrieved.name == 'Community A'
assert retrieved.group_id == 'test_group'
assert retrieved.summary == 'Community summary'
await sample_community_node.delete(falkor_driver)
await falkor_driver.close()
@pytest.mark.asyncio
@pytest.mark.integration
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
falkor_driver = FalkorDriver(
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
)
await sample_episodic_node.save(falkor_driver)
retrieved = await EpisodicNode.get_by_uuid(falkor_driver, sample_episodic_node.uuid)
assert retrieved.uuid == sample_episodic_node.uuid
assert retrieved.name == 'Episode 1'
assert retrieved.group_id == 'test_group'
assert retrieved.source == EpisodeType.text
assert retrieved.source_description == 'Test source'
assert retrieved.content == 'Some content here'
await sample_episodic_node.delete(falkor_driver)
await falkor_driver.close()

View file

@ -14,23 +14,22 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from datetime import datetime, timezone
from datetime import datetime
from uuid import uuid4
import numpy as np
import pytest
from neo4j import AsyncGraphDatabase
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodeType,
EpisodicNode,
)
from tests.helpers_test import drivers, get_driver
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
group_id = f'test_group_{str(uuid4())}'
@pytest.fixture
@ -38,8 +37,8 @@ def sample_entity_node():
return EntityNode(
uuid=str(uuid4()),
name='Test Entity',
group_id='test_group',
labels=['Entity'],
group_id=group_id,
labels=[],
name_embedding=[0.5] * 1024,
summary='Entity Summary',
)
@ -50,11 +49,11 @@ def sample_episodic_node():
return EpisodicNode(
uuid=str(uuid4()),
name='Episode 1',
group_id='test_group',
group_id=group_id,
source=EpisodeType.text,
source_description='Test source',
content='Some content here',
valid_at=datetime.now(timezone.utc),
valid_at=datetime.now(),
)
@ -64,59 +63,181 @@ def sample_community_node():
uuid=str(uuid4()),
name='Community A',
name_embedding=[0.5] * 1024,
group_id='test_group',
group_id=group_id,
summary='Community summary',
)
@pytest.mark.asyncio
@pytest.mark.integration
async def test_entity_node_save_get_and_delete(sample_entity_node):
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
await sample_entity_node.save(neo4j_driver)
retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid)
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_entity_node(sample_entity_node, driver):
driver = get_driver(driver)
uuid = sample_entity_node.uuid
# Create node
node_count = await get_node_count(driver, uuid)
assert node_count == 0
await sample_entity_node.save(driver)
node_count = await get_node_count(driver, uuid)
assert node_count == 1
retrieved = await EntityNode.get_by_uuid(driver, sample_entity_node.uuid)
assert retrieved.uuid == sample_entity_node.uuid
assert retrieved.name == 'Test Entity'
assert retrieved.group_id == 'test_group'
assert retrieved.group_id == group_id
await sample_entity_node.delete(neo4j_driver)
retrieved = await EntityNode.get_by_uuids(driver, [sample_entity_node.uuid])
assert retrieved[0].uuid == sample_entity_node.uuid
assert retrieved[0].name == 'Test Entity'
assert retrieved[0].group_id == group_id
await neo4j_driver.close()
retrieved = await EntityNode.get_by_group_ids(driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == sample_entity_node.uuid
assert retrieved[0].name == 'Test Entity'
assert retrieved[0].group_id == group_id
await sample_entity_node.load_name_embedding(driver)
assert np.allclose(sample_entity_node.name_embedding, [0.5] * 1024)
# Delete node by uuid
await sample_entity_node.delete(driver)
node_count = await get_node_count(driver, uuid)
assert node_count == 0
# Delete node by group id
await sample_entity_node.save(driver)
node_count = await get_node_count(driver, uuid)
assert node_count == 1
await sample_entity_node.delete_by_group_id(driver, group_id)
node_count = await get_node_count(driver, uuid)
assert node_count == 0
await driver.close()
@pytest.mark.asyncio
@pytest.mark.integration
async def test_community_node_save_get_and_delete(sample_community_node):
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_community_node(sample_community_node, driver):
driver = get_driver(driver)
uuid = sample_community_node.uuid
await sample_community_node.save(neo4j_driver)
# Create node
node_count = await get_node_count(driver, uuid)
assert node_count == 0
await sample_community_node.save(driver)
node_count = await get_node_count(driver, uuid)
assert node_count == 1
retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid)
retrieved = await CommunityNode.get_by_uuid(driver, sample_community_node.uuid)
assert retrieved.uuid == sample_community_node.uuid
assert retrieved.name == 'Community A'
assert retrieved.group_id == 'test_group'
assert retrieved.group_id == group_id
assert retrieved.summary == 'Community summary'
await sample_community_node.delete(neo4j_driver)
retrieved = await CommunityNode.get_by_uuids(driver, [sample_community_node.uuid])
assert retrieved[0].uuid == sample_community_node.uuid
assert retrieved[0].name == 'Community A'
assert retrieved[0].group_id == group_id
assert retrieved[0].summary == 'Community summary'
await neo4j_driver.close()
retrieved = await CommunityNode.get_by_group_ids(driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == sample_community_node.uuid
assert retrieved[0].name == 'Community A'
assert retrieved[0].group_id == group_id
# Delete node by uuid
await sample_community_node.delete(driver)
node_count = await get_node_count(driver, uuid)
assert node_count == 0
# Delete node by group id
await sample_community_node.save(driver)
node_count = await get_node_count(driver, uuid)
assert node_count == 1
await sample_community_node.delete_by_group_id(driver, group_id)
node_count = await get_node_count(driver, uuid)
assert node_count == 0
await driver.close()
@pytest.mark.asyncio
@pytest.mark.integration
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_episodic_node(sample_episodic_node, driver):
driver = get_driver(driver)
uuid = sample_episodic_node.uuid
await sample_episodic_node.save(neo4j_driver)
# Create node
node_count = await get_node_count(driver, uuid)
assert node_count == 0
await sample_episodic_node.save(driver)
node_count = await get_node_count(driver, uuid)
assert node_count == 1
retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid)
retrieved = await EpisodicNode.get_by_uuid(driver, sample_episodic_node.uuid)
assert retrieved.uuid == sample_episodic_node.uuid
assert retrieved.name == 'Episode 1'
assert retrieved.group_id == 'test_group'
assert retrieved.group_id == group_id
assert retrieved.source == EpisodeType.text
assert retrieved.source_description == 'Test source'
assert retrieved.content == 'Some content here'
assert retrieved.valid_at == sample_episodic_node.valid_at
await sample_episodic_node.delete(neo4j_driver)
retrieved = await EpisodicNode.get_by_uuids(driver, [sample_episodic_node.uuid])
assert retrieved[0].uuid == sample_episodic_node.uuid
assert retrieved[0].name == 'Episode 1'
assert retrieved[0].group_id == group_id
assert retrieved[0].source == EpisodeType.text
assert retrieved[0].source_description == 'Test source'
assert retrieved[0].content == 'Some content here'
assert retrieved[0].valid_at == sample_episodic_node.valid_at
await neo4j_driver.close()
retrieved = await EpisodicNode.get_by_group_ids(driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == sample_episodic_node.uuid
assert retrieved[0].name == 'Episode 1'
assert retrieved[0].group_id == group_id
assert retrieved[0].source == EpisodeType.text
assert retrieved[0].source_description == 'Test source'
assert retrieved[0].content == 'Some content here'
assert retrieved[0].valid_at == sample_episodic_node.valid_at
# Delete node by uuid
await sample_episodic_node.delete(driver)
node_count = await get_node_count(driver, uuid)
assert node_count == 0
# Delete node by group id
await sample_episodic_node.save(driver)
node_count = await get_node_count(driver, uuid)
assert node_count == 1
await sample_episodic_node.delete_by_group_id(driver, group_id)
node_count = await get_node_count(driver, uuid)
assert node_count == 0
await driver.close()
async def get_node_count(driver: GraphDriver, uuid: str):
result, _, _ = await driver.execute_query(
"""
MATCH (n {uuid: $uuid})
RETURN COUNT(n) as count
""",
uuid=uuid,
)
return int(result[0]['count'])