diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 17e3329a..35d01775 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -25,10 +25,26 @@ logger = logging.getLogger(__name__) class GraphDriverSession(ABC): + async def __aenter__(self): + return self + + @abstractmethod + async def __aexit__(self, exc_type, exc, tb): + # No cleanup needed for Falkor, but method must exist + pass + @abstractmethod async def run(self, query: str, **kwargs: Any) -> Any: raise NotImplementedError() + @abstractmethod + async def close(self): + raise NotImplementedError() + + @abstractmethod + async def execute_write(self, func, *args, **kwargs): + raise NotImplementedError() + class GraphDriver(ABC): provider: str @@ -42,40 +58,9 @@ class GraphDriver(ABC): raise NotImplementedError() @abstractmethod - def close(self) -> None: + def close(self): raise NotImplementedError() @abstractmethod def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine: raise NotImplementedError() - - -# class GraphDriver: -# _driver: GraphClient -# -# def __init__( -# self, -# uri: str, -# user: str, -# password: str, -# ): -# if uri.startswith('falkor'): -# # FalkorDB -# self._driver = FalkorClient(uri, user, password) -# self.provider = 'falkordb' -# else: -# # Neo4j -# self._driver = Neo4jClient(uri, user, password) -# self.provider = 'neo4j' -# -# def execute_query(self, cypher_query_, **kwargs: Any) -> Coroutine: -# return self._driver.execute_query(cypher_query_, **kwargs) -# -# async def close(self): -# return await self._driver.close() -# -# def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine: -# return self._driver.delete_all_indexes(database_) -# -# def session(self, database: str) -> GraphClientSession: -# return self._driver.session(database) diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index 6dbd1a0a..17369660 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -19,8 +19,8 @@ from collections.abc import Coroutine from datetime import datetime from typing import Any -from falkordb import Graph as FalkorGraph -from falkordb.asyncio import FalkorDB +from falkordb import Graph as FalkorGraph # type: ignore +from falkordb.asyncio import FalkorDB # type: ignore from graphiti_core.driver.driver import GraphDriver, GraphDriverSession from graphiti_core.helpers import DEFAULT_DATABASE @@ -28,7 +28,7 @@ from graphiti_core.helpers import DEFAULT_DATABASE logger = logging.getLogger(__name__) -class FalkorClientSession(GraphDriverSession): +class FalkorDriverSession(GraphDriverSession): def __init__(self, graph: FalkorGraph): self.graph = graph @@ -47,16 +47,16 @@ class FalkorClientSession(GraphDriverSession): # Directly await the provided async function with `self` as the transaction/session return await func(self, *args, **kwargs) - async def run(self, cypher_query_: str | list, **kwargs: Any) -> Any: + async def run(self, query: str | list, **kwargs: Any) -> Any: # FalkorDB does not support argument for Label Set, so it's converted into an array of queries - if isinstance(cypher_query_, list): - for cypher, params in cypher_query_: + if isinstance(query, list): + for cypher, params in query: params = convert_datetimes_to_strings(params) await self.graph.query(str(cypher), params) else: params = dict(kwargs) params = convert_datetimes_to_strings(params) - await self.graph.query(str(cypher_query_), params) + await self.graph.query(str(query), params) # Assuming `graph.query` is async (ideal); otherwise, wrap in executor return None @@ -79,7 +79,7 @@ class FalkorDriver(GraphDriver): url=uri, ) - def _get_graph(self, graph_name: str) -> FalkorGraph: + def _get_graph(self, graph_name: str | None) -> FalkorGraph: # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE" if graph_name is None: graph_name = 'DEFAULT_DATABASE' @@ -106,13 +106,13 @@ class FalkorDriver(GraphDriver): header = [h[1].decode('utf-8') for h in result.header] return result.result_set, header, None - def session(self, database: str) -> GraphDriverSession: - return FalkorClientSession(self._get_graph(database)) + def session(self, database: str | None) -> GraphDriverSession: + return FalkorDriverSession(self._get_graph(database)) async def close(self) -> None: await self.client.connection.close() - def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine: + async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine: return self.execute_query( 'CALL db.indexes() YIELD name DROP INDEX name', database_=database_, diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 8b9058f0..9dd48fbf 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -33,13 +33,13 @@ class Neo4jDriver(GraphDriver): def __init__( self, uri: str, - user: str, - password: str, + user: str | None, + password: str | None, ): super().__init__() self.client = AsyncGraphDatabase.driver( uri=uri, - auth=(user, password), + auth=(user or '', password or ''), ) async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine: diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index f2491e99..4a7fdaa1 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -345,8 +345,8 @@ class EntityEdge(Edge): async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str): query: LiteralString = ( """ - MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) - """ + MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) + """ + ENTITY_EDGE_RETURN ) records, _, _ = await driver.execute_query( @@ -463,7 +463,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge: group_id=record['group_id'], source_node_uuid=record['source_node_uuid'], target_node_uuid=record['target_node_uuid'], - created_at=parse_db_date(record['created_at']), + created_at=parse_db_date(record['created_at']), # type: ignore ) @@ -476,7 +476,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge: name=record['name'], group_id=record['group_id'], episodes=record['episodes'], - created_at=parse_db_date(record['created_at']), + created_at=parse_db_date(record['created_at']), # type: ignore expired_at=parse_db_date(record['expired_at']), valid_at=parse_db_date(record['valid_at']), invalid_at=parse_db_date(record['invalid_at']), @@ -504,7 +504,7 @@ def get_community_edge_from_record(record: Any): group_id=record['group_id'], source_node_uuid=record['source_node_uuid'], target_node_uuid=record['target_node_uuid'], - created_at=parse_db_date(record['created_at']), + created_at=parse_db_date(record['created_at']), # type: ignore ) diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index 885ef669..10e396c0 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -5,6 +5,8 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB, supporting index creation, fulltext search, and bulk operations. """ +from typing import Any + from typing_extensions import LiteralString from graphiti_core.models.edges.edge_db_queries import ( @@ -84,7 +86,7 @@ def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]: ] -def get_nodes_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str: +def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str: if db_type == 'falkordb': label = NEO4J_TO_FALKORDB_MAPPING[name] return f"CALL db.idx.fulltext.queryNodes('{label}', {query})" @@ -100,7 +102,7 @@ def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str: return f'vector.similarity.cosine({vec1}, {vec2})' -def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str: +def get_relationships_query(name: str, db_type: str = 'neo4j') -> str: if db_type == 'falkordb': label = NEO4J_TO_FALKORDB_MAPPING[name] return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)" @@ -108,7 +110,7 @@ def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' -def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str: +def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any: if db_type == 'falkordb': queries = [] for node in nodes: diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 5410283e..371c520f 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -95,13 +95,13 @@ class Graphiti: def __init__( self, uri: str, - user: str = None, - password: str = None, + user: str | None = None, + password: str | None = None, llm_client: LLMClient | None = None, embedder: EmbedderClient | None = None, cross_encoder: CrossEncoderClient | None = None, store_raw_episode_content: bool = True, - graph_driver: GraphDriver = None, + graph_driver: GraphDriver | None = None, ): """ Initialize a Graphiti instance. diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index c7d8cd40..de5020a9 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -27,7 +27,7 @@ from typing_extensions import LiteralString load_dotenv() -DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None) +DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j') USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False)) SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20)) MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0)) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 945a07b2..fd15499c 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -345,8 +345,8 @@ class EntityNode(Node): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): query = ( """ - MATCH (n:Entity {uuid: $uuid}) - """ + MATCH (n:Entity {uuid: $uuid}) + """ + ENTITY_NODE_RETURN ) records, _, _ = await driver.execute_query( @@ -542,8 +542,8 @@ class CommunityNode(Node): def get_episodic_node_from_record(record: Any) -> EpisodicNode: return EpisodicNode( content=record['content'], - created_at=parse_db_date(record['created_at']).timestamp(), - valid_at=(parse_db_date(record['valid_at'])), + created_at=parse_db_date(record['created_at']), # type: ignore + valid_at=parse_db_date(record['valid_at']), # type: ignore uuid=record['uuid'], group_id=record['group_id'], source=EpisodeType.from_str(record['source']), @@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode: name=record['name'], group_id=record['group_id'], labels=record['labels'], - created_at=parse_db_date(record['created_at']), + created_at=parse_db_date(record['created_at']), # type: ignore summary=record['summary'], attributes=record['attributes'], ) @@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode: name=record['name'], group_id=record['group_id'], name_embedding=record['name_embedding'], - created_at=parse_db_date(record['created_at']), + created_at=parse_db_date(record['created_at']), # type: ignore summary=record['summary'], ) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 6ef2cd32..6505e8b9 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -167,7 +167,7 @@ async def edge_fulltext_search( filter_query, filter_params = edge_search_filter_query_constructor(search_filter) query = ( - get_relationships_query(driver.provider, 'edge_name_and_fact', '$query') + get_relationships_query('edge_name_and_fact', db_type=driver.provider) + """ YIELD relationship AS rel, score MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) @@ -301,12 +301,12 @@ async def edge_bfs_search( query = ( """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - UNWIND relationships(path) AS rel - MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) - WHERE r.uuid = rel.uuid - """ + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) + UNWIND relationships(path) AS rel + MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) + WHERE r.uuid = rel.uuid + """ + filter_query + """ RETURN DISTINCT @@ -455,10 +455,10 @@ async def node_bfs_search( query = ( """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - WHERE n.group_id = origin.group_id - """ + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) + WHERE n.group_id = origin.group_id + """ + filter_query + ENTITY_NODE_RETURN + """ diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index cca3fe00..049444f0 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -232,7 +232,7 @@ async def determine_entity_community( driver: GraphDriver, entity: EntityNode ) -> tuple[CommunityNode | None, bool]: # Check if the node is already part of a community - records, _, _ = driver.execute_query( + records, _, _ = await driver.execute_query( """ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid}) RETURN @@ -250,7 +250,7 @@ async def determine_entity_community( return get_community_node_from_record(records[0]), False # If the node has no community, add it to the mode community of surrounding entities - records, _, _ = driver.execute_query( + records, _, _ = await driver.execute_query( """ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid}) RETURN diff --git a/pyproject.toml b/pyproject.toml index 925e8a0b..146f59d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,3 +85,9 @@ ignore = ["E501"] quote-style = "single" indent-style = "space" docstring-code-format = true + +[mypy-falkordb] +ignore_missing_imports = true + +[mypy-falkordb.asyncio] +ignore_missing_imports = true