update driver (#583)

* update driver

* mypy updates

* mypy updates

* mypy updates

* Update graphiti_core/graph_queries.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* mypy updates

* mypy

* mypy updates

* mypy updates

* mypy updates

* mypy updates

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-06-13 14:12:09 -04:00 committed by GitHub
parent 12b90633a4
commit 19fde653a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 70 additions and 77 deletions

View file

@ -25,10 +25,26 @@ logger = logging.getLogger(__name__)
class GraphDriverSession(ABC):
async def __aenter__(self):
return self
@abstractmethod
async def __aexit__(self, exc_type, exc, tb):
# No cleanup needed for Falkor, but method must exist
pass
@abstractmethod
async def run(self, query: str, **kwargs: Any) -> Any:
raise NotImplementedError()
@abstractmethod
async def close(self):
raise NotImplementedError()
@abstractmethod
async def execute_write(self, func, *args, **kwargs):
raise NotImplementedError()
class GraphDriver(ABC):
provider: str
@ -42,40 +58,9 @@ class GraphDriver(ABC):
raise NotImplementedError()
@abstractmethod
def close(self) -> None:
def close(self):
raise NotImplementedError()
@abstractmethod
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
raise NotImplementedError()
# class GraphDriver:
# _driver: GraphClient
#
# def __init__(
# self,
# uri: str,
# user: str,
# password: str,
# ):
# if uri.startswith('falkor'):
# # FalkorDB
# self._driver = FalkorClient(uri, user, password)
# self.provider = 'falkordb'
# else:
# # Neo4j
# self._driver = Neo4jClient(uri, user, password)
# self.provider = 'neo4j'
#
# def execute_query(self, cypher_query_, **kwargs: Any) -> Coroutine:
# return self._driver.execute_query(cypher_query_, **kwargs)
#
# async def close(self):
# return await self._driver.close()
#
# def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
# return self._driver.delete_all_indexes(database_)
#
# def session(self, database: str) -> GraphClientSession:
# return self._driver.session(database)

View file

@ -19,8 +19,8 @@ from collections.abc import Coroutine
from datetime import datetime
from typing import Any
from falkordb import Graph as FalkorGraph
from falkordb.asyncio import FalkorDB
from falkordb import Graph as FalkorGraph # type: ignore
from falkordb.asyncio import FalkorDB # type: ignore
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.helpers import DEFAULT_DATABASE
@ -28,7 +28,7 @@ from graphiti_core.helpers import DEFAULT_DATABASE
logger = logging.getLogger(__name__)
class FalkorClientSession(GraphDriverSession):
class FalkorDriverSession(GraphDriverSession):
def __init__(self, graph: FalkorGraph):
self.graph = graph
@ -47,16 +47,16 @@ class FalkorClientSession(GraphDriverSession):
# Directly await the provided async function with `self` as the transaction/session
return await func(self, *args, **kwargs)
async def run(self, cypher_query_: str | list, **kwargs: Any) -> Any:
async def run(self, query: str | list, **kwargs: Any) -> Any:
# FalkorDB does not support argument for Label Set, so it's converted into an array of queries
if isinstance(cypher_query_, list):
for cypher, params in cypher_query_:
if isinstance(query, list):
for cypher, params in query:
params = convert_datetimes_to_strings(params)
await self.graph.query(str(cypher), params)
else:
params = dict(kwargs)
params = convert_datetimes_to_strings(params)
await self.graph.query(str(cypher_query_), params)
await self.graph.query(str(query), params)
# Assuming `graph.query` is async (ideal); otherwise, wrap in executor
return None
@ -79,7 +79,7 @@ class FalkorDriver(GraphDriver):
url=uri,
)
def _get_graph(self, graph_name: str) -> FalkorGraph:
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
if graph_name is None:
graph_name = 'DEFAULT_DATABASE'
@ -106,13 +106,13 @@ class FalkorDriver(GraphDriver):
header = [h[1].decode('utf-8') for h in result.header]
return result.result_set, header, None
def session(self, database: str) -> GraphDriverSession:
return FalkorClientSession(self._get_graph(database))
def session(self, database: str | None) -> GraphDriverSession:
return FalkorDriverSession(self._get_graph(database))
async def close(self) -> None:
await self.client.connection.close()
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
return self.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
database_=database_,

View file

@ -33,13 +33,13 @@ class Neo4jDriver(GraphDriver):
def __init__(
self,
uri: str,
user: str,
password: str,
user: str | None,
password: str | None,
):
super().__init__()
self.client = AsyncGraphDatabase.driver(
uri=uri,
auth=(user, password),
auth=(user or '', password or ''),
)
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:

View file

@ -345,8 +345,8 @@ class EntityEdge(Edge):
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
query: LiteralString = (
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
+ ENTITY_EDGE_RETURN
)
records, _, _ = await driver.execute_query(
@ -463,7 +463,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=parse_db_date(record['created_at']),
created_at=parse_db_date(record['created_at']), # type: ignore
)
@ -476,7 +476,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
name=record['name'],
group_id=record['group_id'],
episodes=record['episodes'],
created_at=parse_db_date(record['created_at']),
created_at=parse_db_date(record['created_at']), # type: ignore
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
@ -504,7 +504,7 @@ def get_community_edge_from_record(record: Any):
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=parse_db_date(record['created_at']),
created_at=parse_db_date(record['created_at']), # type: ignore
)

View file

@ -5,6 +5,8 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB,
supporting index creation, fulltext search, and bulk operations.
"""
from typing import Any
from typing_extensions import LiteralString
from graphiti_core.models.edges.edge_db_queries import (
@ -84,7 +86,7 @@ def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
]
def get_nodes_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str:
def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
if db_type == 'falkordb':
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
@ -100,7 +102,7 @@ def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
return f'vector.similarity.cosine({vec1}, {vec2})'
def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str:
def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
if db_type == 'falkordb':
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
@ -108,7 +110,7 @@ def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str:
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
if db_type == 'falkordb':
queries = []
for node in nodes:

View file

@ -95,13 +95,13 @@ class Graphiti:
def __init__(
self,
uri: str,
user: str = None,
password: str = None,
user: str | None = None,
password: str | None = None,
llm_client: LLMClient | None = None,
embedder: EmbedderClient | None = None,
cross_encoder: CrossEncoderClient | None = None,
store_raw_episode_content: bool = True,
graph_driver: GraphDriver = None,
graph_driver: GraphDriver | None = None,
):
"""
Initialize a Graphiti instance.

View file

@ -27,7 +27,7 @@ from typing_extensions import LiteralString
load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))

View file

@ -345,8 +345,8 @@ class EntityNode(Node):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
query = (
"""
MATCH (n:Entity {uuid: $uuid})
"""
MATCH (n:Entity {uuid: $uuid})
"""
+ ENTITY_NODE_RETURN
)
records, _, _ = await driver.execute_query(
@ -542,8 +542,8 @@ class CommunityNode(Node):
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
return EpisodicNode(
content=record['content'],
created_at=parse_db_date(record['created_at']).timestamp(),
valid_at=(parse_db_date(record['valid_at'])),
created_at=parse_db_date(record['created_at']), # type: ignore
valid_at=parse_db_date(record['valid_at']), # type: ignore
uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']),
@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
name=record['name'],
group_id=record['group_id'],
labels=record['labels'],
created_at=parse_db_date(record['created_at']),
created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'],
attributes=record['attributes'],
)
@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
name=record['name'],
group_id=record['group_id'],
name_embedding=record['name_embedding'],
created_at=parse_db_date(record['created_at']),
created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'],
)

View file

@ -167,7 +167,7 @@ async def edge_fulltext_search(
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query = (
get_relationships_query(driver.provider, 'edge_name_and_fact', '$query')
get_relationships_query('edge_name_and_fact', db_type=driver.provider)
+ """
YIELD relationship AS rel, score
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
@ -301,12 +301,12 @@ async def edge_bfs_search(
query = (
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid
"""
+ filter_query
+ """
RETURN DISTINCT
@ -455,10 +455,10 @@ async def node_bfs_search(
query = (
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """

View file

@ -232,7 +232,7 @@ async def determine_entity_community(
driver: GraphDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]:
# Check if the node is already part of a community
records, _, _ = driver.execute_query(
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
RETURN
@ -250,7 +250,7 @@ async def determine_entity_community(
return get_community_node_from_record(records[0]), False
# If the node has no community, add it to the mode community of surrounding entities
records, _, _ = driver.execute_query(
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
RETURN

View file

@ -85,3 +85,9 @@ ignore = ["E501"]
quote-style = "single"
indent-style = "space"
docstring-code-format = true
[mypy-falkordb]
ignore_missing_imports = true
[mypy-falkordb.asyncio]
ignore_missing_imports = true