update driver (#583)

* update driver

* mypy updates

* mypy updates

* mypy updates

* Update graphiti_core/graph_queries.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* mypy updates

* mypy

* mypy updates

* mypy updates

* mypy updates

* mypy updates

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-06-13 14:12:09 -04:00 committed by GitHub
parent 12b90633a4
commit 19fde653a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 70 additions and 77 deletions

View file

@ -25,10 +25,26 @@ logger = logging.getLogger(__name__)
class GraphDriverSession(ABC): class GraphDriverSession(ABC):
async def __aenter__(self):
return self
@abstractmethod
async def __aexit__(self, exc_type, exc, tb):
# No cleanup needed for Falkor, but method must exist
pass
@abstractmethod @abstractmethod
async def run(self, query: str, **kwargs: Any) -> Any: async def run(self, query: str, **kwargs: Any) -> Any:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod
async def close(self):
raise NotImplementedError()
@abstractmethod
async def execute_write(self, func, *args, **kwargs):
raise NotImplementedError()
class GraphDriver(ABC): class GraphDriver(ABC):
provider: str provider: str
@ -42,40 +58,9 @@ class GraphDriver(ABC):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def close(self) -> None: def close(self):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine: def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
raise NotImplementedError() raise NotImplementedError()
# class GraphDriver:
# _driver: GraphClient
#
# def __init__(
# self,
# uri: str,
# user: str,
# password: str,
# ):
# if uri.startswith('falkor'):
# # FalkorDB
# self._driver = FalkorClient(uri, user, password)
# self.provider = 'falkordb'
# else:
# # Neo4j
# self._driver = Neo4jClient(uri, user, password)
# self.provider = 'neo4j'
#
# def execute_query(self, cypher_query_, **kwargs: Any) -> Coroutine:
# return self._driver.execute_query(cypher_query_, **kwargs)
#
# async def close(self):
# return await self._driver.close()
#
# def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
# return self._driver.delete_all_indexes(database_)
#
# def session(self, database: str) -> GraphClientSession:
# return self._driver.session(database)

View file

@ -19,8 +19,8 @@ from collections.abc import Coroutine
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from falkordb import Graph as FalkorGraph from falkordb import Graph as FalkorGraph # type: ignore
from falkordb.asyncio import FalkorDB from falkordb.asyncio import FalkorDB # type: ignore
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.helpers import DEFAULT_DATABASE from graphiti_core.helpers import DEFAULT_DATABASE
@ -28,7 +28,7 @@ from graphiti_core.helpers import DEFAULT_DATABASE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FalkorClientSession(GraphDriverSession): class FalkorDriverSession(GraphDriverSession):
def __init__(self, graph: FalkorGraph): def __init__(self, graph: FalkorGraph):
self.graph = graph self.graph = graph
@ -47,16 +47,16 @@ class FalkorClientSession(GraphDriverSession):
# Directly await the provided async function with `self` as the transaction/session # Directly await the provided async function with `self` as the transaction/session
return await func(self, *args, **kwargs) return await func(self, *args, **kwargs)
async def run(self, cypher_query_: str | list, **kwargs: Any) -> Any: async def run(self, query: str | list, **kwargs: Any) -> Any:
# FalkorDB does not support argument for Label Set, so it's converted into an array of queries # FalkorDB does not support argument for Label Set, so it's converted into an array of queries
if isinstance(cypher_query_, list): if isinstance(query, list):
for cypher, params in cypher_query_: for cypher, params in query:
params = convert_datetimes_to_strings(params) params = convert_datetimes_to_strings(params)
await self.graph.query(str(cypher), params) await self.graph.query(str(cypher), params)
else: else:
params = dict(kwargs) params = dict(kwargs)
params = convert_datetimes_to_strings(params) params = convert_datetimes_to_strings(params)
await self.graph.query(str(cypher_query_), params) await self.graph.query(str(query), params)
# Assuming `graph.query` is async (ideal); otherwise, wrap in executor # Assuming `graph.query` is async (ideal); otherwise, wrap in executor
return None return None
@ -79,7 +79,7 @@ class FalkorDriver(GraphDriver):
url=uri, url=uri,
) )
def _get_graph(self, graph_name: str) -> FalkorGraph: def _get_graph(self, graph_name: str | None) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE" # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
if graph_name is None: if graph_name is None:
graph_name = 'DEFAULT_DATABASE' graph_name = 'DEFAULT_DATABASE'
@ -106,13 +106,13 @@ class FalkorDriver(GraphDriver):
header = [h[1].decode('utf-8') for h in result.header] header = [h[1].decode('utf-8') for h in result.header]
return result.result_set, header, None return result.result_set, header, None
def session(self, database: str) -> GraphDriverSession: def session(self, database: str | None) -> GraphDriverSession:
return FalkorClientSession(self._get_graph(database)) return FalkorDriverSession(self._get_graph(database))
async def close(self) -> None: async def close(self) -> None:
await self.client.connection.close() await self.client.connection.close()
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine: async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
return self.execute_query( return self.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name', 'CALL db.indexes() YIELD name DROP INDEX name',
database_=database_, database_=database_,

View file

@ -33,13 +33,13 @@ class Neo4jDriver(GraphDriver):
def __init__( def __init__(
self, self,
uri: str, uri: str,
user: str, user: str | None,
password: str, password: str | None,
): ):
super().__init__() super().__init__()
self.client = AsyncGraphDatabase.driver( self.client = AsyncGraphDatabase.driver(
uri=uri, uri=uri,
auth=(user, password), auth=(user or '', password or ''),
) )
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine: async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:

View file

@ -345,8 +345,8 @@ class EntityEdge(Edge):
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str): async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
query: LiteralString = ( query: LiteralString = (
""" """
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
""" """
+ ENTITY_EDGE_RETURN + ENTITY_EDGE_RETURN
) )
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -463,7 +463,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
group_id=record['group_id'], group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'], source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'], target_node_uuid=record['target_node_uuid'],
created_at=parse_db_date(record['created_at']), created_at=parse_db_date(record['created_at']), # type: ignore
) )
@ -476,7 +476,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
name=record['name'], name=record['name'],
group_id=record['group_id'], group_id=record['group_id'],
episodes=record['episodes'], episodes=record['episodes'],
created_at=parse_db_date(record['created_at']), created_at=parse_db_date(record['created_at']), # type: ignore
expired_at=parse_db_date(record['expired_at']), expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']), valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']), invalid_at=parse_db_date(record['invalid_at']),
@ -504,7 +504,7 @@ def get_community_edge_from_record(record: Any):
group_id=record['group_id'], group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'], source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'], target_node_uuid=record['target_node_uuid'],
created_at=parse_db_date(record['created_at']), created_at=parse_db_date(record['created_at']), # type: ignore
) )

View file

@ -5,6 +5,8 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB,
supporting index creation, fulltext search, and bulk operations. supporting index creation, fulltext search, and bulk operations.
""" """
from typing import Any
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.models.edges.edge_db_queries import ( from graphiti_core.models.edges.edge_db_queries import (
@ -84,7 +86,7 @@ def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
] ]
def get_nodes_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str: def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
if db_type == 'falkordb': if db_type == 'falkordb':
label = NEO4J_TO_FALKORDB_MAPPING[name] label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})" return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
@ -100,7 +102,7 @@ def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
return f'vector.similarity.cosine({vec1}, {vec2})' return f'vector.similarity.cosine({vec1}, {vec2})'
def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str: def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
if db_type == 'falkordb': if db_type == 'falkordb':
label = NEO4J_TO_FALKORDB_MAPPING[name] label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)" return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
@ -108,7 +110,7 @@ def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str: def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
if db_type == 'falkordb': if db_type == 'falkordb':
queries = [] queries = []
for node in nodes: for node in nodes:

View file

@ -95,13 +95,13 @@ class Graphiti:
def __init__( def __init__(
self, self,
uri: str, uri: str,
user: str = None, user: str | None = None,
password: str = None, password: str | None = None,
llm_client: LLMClient | None = None, llm_client: LLMClient | None = None,
embedder: EmbedderClient | None = None, embedder: EmbedderClient | None = None,
cross_encoder: CrossEncoderClient | None = None, cross_encoder: CrossEncoderClient | None = None,
store_raw_episode_content: bool = True, store_raw_episode_content: bool = True,
graph_driver: GraphDriver = None, graph_driver: GraphDriver | None = None,
): ):
""" """
Initialize a Graphiti instance. Initialize a Graphiti instance.

View file

@ -27,7 +27,7 @@ from typing_extensions import LiteralString
load_dotenv() load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None) DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False)) USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20)) SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0)) MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))

View file

@ -345,8 +345,8 @@ class EntityNode(Node):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
query = ( query = (
""" """
MATCH (n:Entity {uuid: $uuid}) MATCH (n:Entity {uuid: $uuid})
""" """
+ ENTITY_NODE_RETURN + ENTITY_NODE_RETURN
) )
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -542,8 +542,8 @@ class CommunityNode(Node):
def get_episodic_node_from_record(record: Any) -> EpisodicNode: def get_episodic_node_from_record(record: Any) -> EpisodicNode:
return EpisodicNode( return EpisodicNode(
content=record['content'], content=record['content'],
created_at=parse_db_date(record['created_at']).timestamp(), created_at=parse_db_date(record['created_at']), # type: ignore
valid_at=(parse_db_date(record['valid_at'])), valid_at=parse_db_date(record['valid_at']), # type: ignore
uuid=record['uuid'], uuid=record['uuid'],
group_id=record['group_id'], group_id=record['group_id'],
source=EpisodeType.from_str(record['source']), source=EpisodeType.from_str(record['source']),
@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
name=record['name'], name=record['name'],
group_id=record['group_id'], group_id=record['group_id'],
labels=record['labels'], labels=record['labels'],
created_at=parse_db_date(record['created_at']), created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'], summary=record['summary'],
attributes=record['attributes'], attributes=record['attributes'],
) )
@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
name=record['name'], name=record['name'],
group_id=record['group_id'], group_id=record['group_id'],
name_embedding=record['name_embedding'], name_embedding=record['name_embedding'],
created_at=parse_db_date(record['created_at']), created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'], summary=record['summary'],
) )

View file

@ -167,7 +167,7 @@ async def edge_fulltext_search(
filter_query, filter_params = edge_search_filter_query_constructor(search_filter) filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query = ( query = (
get_relationships_query(driver.provider, 'edge_name_and_fact', '$query') get_relationships_query('edge_name_and_fact', db_type=driver.provider)
+ """ + """
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
@ -301,12 +301,12 @@ async def edge_bfs_search(
query = ( query = (
""" """
UNWIND $bfs_origin_node_uuids AS origin_uuid UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid WHERE r.uuid = rel.uuid
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT RETURN DISTINCT
@ -455,10 +455,10 @@ async def node_bfs_search(
query = ( query = (
""" """
UNWIND $bfs_origin_node_uuids AS origin_uuid UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id WHERE n.group_id = origin.group_id
""" """
+ filter_query + filter_query
+ ENTITY_NODE_RETURN + ENTITY_NODE_RETURN
+ """ + """

View file

@ -232,7 +232,7 @@ async def determine_entity_community(
driver: GraphDriver, entity: EntityNode driver: GraphDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]: ) -> tuple[CommunityNode | None, bool]:
# Check if the node is already part of a community # Check if the node is already part of a community
records, _, _ = driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid}) MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
RETURN RETURN
@ -250,7 +250,7 @@ async def determine_entity_community(
return get_community_node_from_record(records[0]), False return get_community_node_from_record(records[0]), False
# If the node has no community, add it to the mode community of surrounding entities # If the node has no community, add it to the mode community of surrounding entities
records, _, _ = driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid}) MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
RETURN RETURN

View file

@ -85,3 +85,9 @@ ignore = ["E501"]
quote-style = "single" quote-style = "single"
indent-style = "space" indent-style = "space"
docstring-code-format = true docstring-code-format = true
[mypy-falkordb]
ignore_missing_imports = true
[mypy-falkordb.asyncio]
ignore_missing_imports = true