update driver (#583)
* update driver * mypy updates * mypy updates * mypy updates * Update graphiti_core/graph_queries.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mypy updates * mypy * mypy updates * mypy updates * mypy updates * mypy updates --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
12b90633a4
commit
19fde653a6
11 changed files with 70 additions and 77 deletions
|
|
@ -25,10 +25,26 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class GraphDriverSession(ABC):
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
# No cleanup needed for Falkor, but method must exist
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, query: str, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def execute_write(self, func, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class GraphDriver(ABC):
|
||||
provider: str
|
||||
|
|
@ -42,40 +58,9 @@ class GraphDriver(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
def close(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# class GraphDriver:
|
||||
# _driver: GraphClient
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# uri: str,
|
||||
# user: str,
|
||||
# password: str,
|
||||
# ):
|
||||
# if uri.startswith('falkor'):
|
||||
# # FalkorDB
|
||||
# self._driver = FalkorClient(uri, user, password)
|
||||
# self.provider = 'falkordb'
|
||||
# else:
|
||||
# # Neo4j
|
||||
# self._driver = Neo4jClient(uri, user, password)
|
||||
# self.provider = 'neo4j'
|
||||
#
|
||||
# def execute_query(self, cypher_query_, **kwargs: Any) -> Coroutine:
|
||||
# return self._driver.execute_query(cypher_query_, **kwargs)
|
||||
#
|
||||
# async def close(self):
|
||||
# return await self._driver.close()
|
||||
#
|
||||
# def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||
# return self._driver.delete_all_indexes(database_)
|
||||
#
|
||||
# def session(self, database: str) -> GraphClientSession:
|
||||
# return self._driver.session(database)
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ from collections.abc import Coroutine
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from falkordb import Graph as FalkorGraph
|
||||
from falkordb.asyncio import FalkorDB
|
||||
from falkordb import Graph as FalkorGraph # type: ignore
|
||||
from falkordb.asyncio import FalkorDB # type: ignore
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
|
|
@ -28,7 +28,7 @@ from graphiti_core.helpers import DEFAULT_DATABASE
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FalkorClientSession(GraphDriverSession):
|
||||
class FalkorDriverSession(GraphDriverSession):
|
||||
def __init__(self, graph: FalkorGraph):
|
||||
self.graph = graph
|
||||
|
||||
|
|
@ -47,16 +47,16 @@ class FalkorClientSession(GraphDriverSession):
|
|||
# Directly await the provided async function with `self` as the transaction/session
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
async def run(self, cypher_query_: str | list, **kwargs: Any) -> Any:
|
||||
async def run(self, query: str | list, **kwargs: Any) -> Any:
|
||||
# FalkorDB does not support argument for Label Set, so it's converted into an array of queries
|
||||
if isinstance(cypher_query_, list):
|
||||
for cypher, params in cypher_query_:
|
||||
if isinstance(query, list):
|
||||
for cypher, params in query:
|
||||
params = convert_datetimes_to_strings(params)
|
||||
await self.graph.query(str(cypher), params)
|
||||
else:
|
||||
params = dict(kwargs)
|
||||
params = convert_datetimes_to_strings(params)
|
||||
await self.graph.query(str(cypher_query_), params)
|
||||
await self.graph.query(str(query), params)
|
||||
# Assuming `graph.query` is async (ideal); otherwise, wrap in executor
|
||||
return None
|
||||
|
||||
|
|
@ -79,7 +79,7 @@ class FalkorDriver(GraphDriver):
|
|||
url=uri,
|
||||
)
|
||||
|
||||
def _get_graph(self, graph_name: str) -> FalkorGraph:
|
||||
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
|
||||
if graph_name is None:
|
||||
graph_name = 'DEFAULT_DATABASE'
|
||||
|
|
@ -106,13 +106,13 @@ class FalkorDriver(GraphDriver):
|
|||
header = [h[1].decode('utf-8') for h in result.header]
|
||||
return result.result_set, header, None
|
||||
|
||||
def session(self, database: str) -> GraphDriverSession:
|
||||
return FalkorClientSession(self._get_graph(database))
|
||||
def session(self, database: str | None) -> GraphDriverSession:
|
||||
return FalkorDriverSession(self._get_graph(database))
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.client.connection.close()
|
||||
|
||||
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||
return self.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
database_=database_,
|
||||
|
|
|
|||
|
|
@ -33,13 +33,13 @@ class Neo4jDriver(GraphDriver):
|
|||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
user: str,
|
||||
password: str,
|
||||
user: str | None,
|
||||
password: str | None,
|
||||
):
|
||||
super().__init__()
|
||||
self.client = AsyncGraphDatabase.driver(
|
||||
uri=uri,
|
||||
auth=(user, password),
|
||||
auth=(user or '', password or ''),
|
||||
)
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:
|
||||
|
|
|
|||
|
|
@ -345,8 +345,8 @@ class EntityEdge(Edge):
|
|||
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -463,7 +463,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|||
group_id=record['group_id'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
created_at=parse_db_date(record['created_at']),
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -476,7 +476,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
episodes=record['episodes'],
|
||||
created_at=parse_db_date(record['created_at']),
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
|
|
@ -504,7 +504,7 @@ def get_community_edge_from_record(record: Any):
|
|||
group_id=record['group_id'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
created_at=parse_db_date(record['created_at']),
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB,
|
|||
supporting index creation, fulltext search, and bulk operations.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
|
|
@ -84,7 +86,7 @@ def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
|||
]
|
||||
|
||||
|
||||
def get_nodes_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str:
|
||||
def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
|
||||
if db_type == 'falkordb':
|
||||
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
||||
|
|
@ -100,7 +102,7 @@ def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
|
|||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||
|
||||
|
||||
def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str:
|
||||
def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
|
||||
if db_type == 'falkordb':
|
||||
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
||||
|
|
@ -108,7 +110,7 @@ def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str
|
|||
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||
|
||||
|
||||
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str:
|
||||
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
|
||||
if db_type == 'falkordb':
|
||||
queries = []
|
||||
for node in nodes:
|
||||
|
|
|
|||
|
|
@ -95,13 +95,13 @@ class Graphiti:
|
|||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
user: str = None,
|
||||
password: str = None,
|
||||
user: str | None = None,
|
||||
password: str | None = None,
|
||||
llm_client: LLMClient | None = None,
|
||||
embedder: EmbedderClient | None = None,
|
||||
cross_encoder: CrossEncoderClient | None = None,
|
||||
store_raw_episode_content: bool = True,
|
||||
graph_driver: GraphDriver = None,
|
||||
graph_driver: GraphDriver | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize a Graphiti instance.
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from typing_extensions import LiteralString
|
|||
|
||||
load_dotenv()
|
||||
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
|
||||
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
||||
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
||||
|
|
|
|||
|
|
@ -345,8 +345,8 @@ class EntityNode(Node):
|
|||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -542,8 +542,8 @@ class CommunityNode(Node):
|
|||
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
||||
return EpisodicNode(
|
||||
content=record['content'],
|
||||
created_at=parse_db_date(record['created_at']).timestamp(),
|
||||
valid_at=(parse_db_date(record['valid_at'])),
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
valid_at=parse_db_date(record['valid_at']), # type: ignore
|
||||
uuid=record['uuid'],
|
||||
group_id=record['group_id'],
|
||||
source=EpisodeType.from_str(record['source']),
|
||||
|
|
@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
labels=record['labels'],
|
||||
created_at=parse_db_date(record['created_at']),
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
summary=record['summary'],
|
||||
attributes=record['attributes'],
|
||||
)
|
||||
|
|
@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
name_embedding=record['name_embedding'],
|
||||
created_at=parse_db_date(record['created_at']),
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
summary=record['summary'],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ async def edge_fulltext_search(
|
|||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
|
||||
query = (
|
||||
get_relationships_query(driver.provider, 'edge_name_and_fact', '$query')
|
||||
get_relationships_query('edge_name_and_fact', db_type=driver.provider)
|
||||
+ """
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
|
|
@ -301,12 +301,12 @@ async def edge_bfs_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
WHERE r.uuid = rel.uuid
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
WHERE r.uuid = rel.uuid
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT
|
||||
|
|
@ -455,10 +455,10 @@ async def node_bfs_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
WHERE n.group_id = origin.group_id
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
WHERE n.group_id = origin.group_id
|
||||
"""
|
||||
+ filter_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ async def determine_entity_community(
|
|||
driver: GraphDriver, entity: EntityNode
|
||||
) -> tuple[CommunityNode | None, bool]:
|
||||
# Check if the node is already part of a community
|
||||
records, _, _ = driver.execute_query(
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
|
|
@ -250,7 +250,7 @@ async def determine_entity_community(
|
|||
return get_community_node_from_record(records[0]), False
|
||||
|
||||
# If the node has no community, add it to the mode community of surrounding entities
|
||||
records, _, _ = driver.execute_query(
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
|
|
|
|||
|
|
@ -85,3 +85,9 @@ ignore = ["E501"]
|
|||
quote-style = "single"
|
||||
indent-style = "space"
|
||||
docstring-code-format = true
|
||||
|
||||
[mypy-falkordb]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[mypy-falkordb.asyncio]
|
||||
ignore_missing_imports = true
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue