edits
This commit is contained in:
parent
ef8bd41e1a
commit
13fc9cf1e4
9 changed files with 114 additions and 89 deletions
|
|
@ -217,15 +217,7 @@ class GraphDriver(ABC):
|
|||
client.indices.delete(index=index_name)
|
||||
|
||||
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
index['query']['query']['multi_match']['query'] = query_text
|
||||
query = {'size': limit, 'query': index['query']}
|
||||
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
||||
return resp
|
||||
return {}
|
||||
|
||||
from opensearchpy import helpers
|
||||
pass
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
for index in aoss_indices:
|
||||
|
|
|
|||
|
|
@ -18,15 +18,25 @@ import logging
|
|||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from neo4j import AsyncGraphDatabase, EagerResult
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
||||
|
||||
_HAS_OPENSEARCH = True
|
||||
except ImportError:
|
||||
boto3 = None
|
||||
OpenSearch = None
|
||||
Urllib3AWSV4SignerAuth = None
|
||||
Urllib3HttpConnection = None
|
||||
_HAS_OPENSEARCH = False
|
||||
|
||||
|
||||
class Neo4jDriver(GraphDriver):
|
||||
provider = GraphProvider.NEO4J
|
||||
|
|
@ -49,17 +59,21 @@ class Neo4jDriver(GraphDriver):
|
|||
|
||||
self.aoss_client = None
|
||||
if aoss_host and aoss_port:
|
||||
session = boto3.Session()
|
||||
self.aoss_client = OpenSearch(
|
||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||
http_auth=Urllib3AWSV4SignerAuth(
|
||||
session.get_credentials(), session.region_name, 'aoss'
|
||||
),
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=Urllib3HttpConnection,
|
||||
pool_maxsize=20,
|
||||
)
|
||||
try:
|
||||
session = boto3.Session()
|
||||
self.aoss_client = OpenSearch(
|
||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||
http_auth=Urllib3AWSV4SignerAuth(
|
||||
session.get_credentials(), session.region_name, 'aoss'
|
||||
),
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=Urllib3HttpConnection,
|
||||
pool_maxsize=20,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to initialize OpenSearch client: {e}')
|
||||
self.aoss_client = None
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||
# Check if database_ is provided in kwargs.
|
||||
|
|
@ -86,7 +100,7 @@ class Neo4jDriver(GraphDriver):
|
|||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
|
||||
if self.aoss_client:
|
||||
self.delete_all_indexes_impl()
|
||||
return self.delete_aoss_indices()
|
||||
return self.client.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
)
|
||||
|
|
|
|||
|
|
@ -307,7 +307,7 @@ class EntityEdge(Edge):
|
|||
if driver.provider == GraphProvider.KUZU:
|
||||
edge_data['attributes'] = json.dumps(self.attributes)
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
||||
**edge_data,
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -482,7 +482,7 @@ class EntityNode(Node):
|
|||
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels),
|
||||
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
||||
entity_data=entity_data,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ def edge_search_filter_query_constructor(
|
|||
|
||||
|
||||
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
||||
filters = [{'term': {'group_id': group_ids}}]
|
||||
filters = [{'terms': {'group_id': group_ids}}]
|
||||
|
||||
if search_filters.node_labels:
|
||||
filters.append({'terms': {'node_labels': search_filters.node_labels}})
|
||||
|
|
@ -246,7 +246,7 @@ def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters)
|
|||
|
||||
|
||||
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
||||
filters = [{'term': {'group_id': group_ids}}]
|
||||
filters = [{'terms': {'group_id': group_ids}}]
|
||||
|
||||
if search_filters.edge_types:
|
||||
filters.append({'terms': {'edge_types': search_filters.edge_types}})
|
||||
|
|
|
|||
|
|
@ -209,11 +209,11 @@ async def edge_fulltext_search(
|
|||
# Match the edge ids and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
AND id(e)=id
|
||||
|
|
@ -265,7 +265,8 @@ async def edge_fulltext_search(
|
|||
|
||||
# Get edges
|
||||
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
|
||||
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_edges
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
|
|
@ -342,8 +343,8 @@ async def edge_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||
|
|
@ -423,7 +424,8 @@ async def edge_similarity_search(
|
|||
|
||||
# Get edges
|
||||
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
|
||||
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_edges
|
||||
|
||||
else:
|
||||
query = (
|
||||
|
|
@ -618,11 +620,11 @@ async def node_fulltext_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -669,7 +671,8 @@ async def node_fulltext_search(
|
|||
|
||||
# Get nodes
|
||||
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
return entities.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
|
||||
entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entities
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
|
|
@ -728,8 +731,8 @@ async def node_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -758,11 +761,11 @@ async def node_similarity_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -801,13 +804,14 @@ async def node_similarity_search(
|
|||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get edges
|
||||
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
|
||||
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_nodes
|
||||
else:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, """
|
||||
|
|
@ -985,7 +989,7 @@ async def episode_fulltext_search(
|
|||
_source=['uuid'],
|
||||
query={
|
||||
'bool': {
|
||||
'filter': [{'term': {'group_id': group_ids}}],
|
||||
'filter': [{'terms': {'group_id': group_ids}}],
|
||||
'must': [
|
||||
{
|
||||
'multi_match': {
|
||||
|
|
@ -1161,8 +1165,8 @@ async def community_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -1221,8 +1225,8 @@ async def community_similarity_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH c,
|
||||
|
|
@ -1364,9 +1368,9 @@ async def get_relevant_nodes(
|
|||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1411,9 +1415,9 @@ async def get_relevant_nodes(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1502,9 +1506,9 @@ async def get_relevant_edges(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1574,9 +1578,9 @@ async def get_relevant_edges(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, n, m, """
|
||||
|
|
@ -1612,9 +1616,9 @@ async def get_relevant_edges(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, """
|
||||
|
|
@ -1687,10 +1691,10 @@ async def get_edge_invalidation_candidates(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1760,10 +1764,10 @@ async def get_edge_invalidation_candidates(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, n, m, """
|
||||
|
|
@ -1799,10 +1803,10 @@ async def get_edge_invalidation_candidates(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, """
|
||||
|
|
|
|||
|
|
@ -187,12 +187,20 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
await tx.run(episodic_edge_query, **edge.model_dump())
|
||||
else:
|
||||
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||||
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
|
||||
await tx.run(
|
||||
get_entity_node_save_bulk_query(driver.provider, nodes),
|
||||
nodes=nodes,
|
||||
has_aoss=bool(driver.aoss_client),
|
||||
)
|
||||
await tx.run(
|
||||
get_episodic_edge_save_bulk_query(driver.provider),
|
||||
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
||||
)
|
||||
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
|
||||
await tx.run(
|
||||
get_entity_edge_save_bulk_query(driver.provider),
|
||||
entity_edges=edges,
|
||||
has_aoss=bool(driver.aoss_client),
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss('episodes', episodes)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"]
|
|||
kuzu = ["kuzu>=0.11.2"]
|
||||
falkordb = ["falkordb>=1.1.2,<2.0.0"]
|
||||
voyageai = ["voyageai>=0.2.3"]
|
||||
neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"]
|
||||
sentence-transformers = ["sentence-transformers>=3.2.1"]
|
||||
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
|
||||
dev = [
|
||||
|
|
|
|||
10
uv.lock
generated
10
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.21.0"
|
||||
version = "0.21.0rc1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
@ -835,6 +835,10 @@ groq = [
|
|||
kuzu = [
|
||||
{ name = "kuzu" },
|
||||
]
|
||||
neo4j-opensearch = [
|
||||
{ name = "boto3" },
|
||||
{ name = "opensearch-py" },
|
||||
]
|
||||
neptune = [
|
||||
{ name = "boto3" },
|
||||
{ name = "langchain-aws" },
|
||||
|
|
@ -851,6 +855,7 @@ voyageai = [
|
|||
requires-dist = [
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
|
||||
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
|
||||
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
|
||||
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
|
||||
{ name = "diskcache", specifier = ">=5.6.3" },
|
||||
{ name = "diskcache-stubs", marker = "extra == 'dev'", specifier = ">=5.6.3.6.20240818" },
|
||||
|
|
@ -872,6 +877,7 @@ requires-dist = [
|
|||
{ name = "neo4j", specifier = ">=5.26.0" },
|
||||
{ name = "numpy", specifier = ">=1.0.0" },
|
||||
{ name = "openai", specifier = ">=1.91.0" },
|
||||
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
|
||||
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
|
||||
{ name = "posthog", specifier = ">=3.0.0" },
|
||||
{ name = "pydantic", specifier = ">=2.11.5" },
|
||||
|
|
@ -888,7 +894,7 @@ requires-dist = [
|
|||
{ name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" },
|
||||
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" },
|
||||
]
|
||||
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
|
||||
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "neo4j-opensearch", "sentence-transformers", "neptune", "dev"]
|
||||
|
||||
[[package]]
|
||||
name = "groq"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue