OpenSearch updates (#906)
* updates * add uuid filter functionality * update * updates * bump-version * update * fix typo * use async function * update unit tests * update delete * update deletion * async update * update * update * update * update
This commit is contained in:
parent
4dab259217
commit
3efe085a92
13 changed files with 479 additions and 191 deletions
|
|
@ -17,16 +17,19 @@ limitations under the License.
|
|||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.embedder.client import EMBEDDING_DIM
|
||||
|
||||
try:
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy import AsyncOpenSearch, helpers
|
||||
|
||||
_HAS_OPENSEARCH = True
|
||||
except ImportError:
|
||||
|
|
@ -38,6 +41,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
DEFAULT_SIZE = 10
|
||||
|
||||
load_dotenv()
|
||||
|
||||
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
||||
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
||||
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
||||
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
||||
|
||||
|
||||
class GraphProvider(Enum):
|
||||
NEO4J = 'neo4j'
|
||||
|
|
@ -48,20 +58,19 @@ class GraphProvider(Enum):
|
|||
|
||||
aoss_indices = [
|
||||
{
|
||||
'index_name': 'entities',
|
||||
'index_name': ENTITY_INDEX_NAME,
|
||||
'body': {
|
||||
'settings': {'index': {'knn': True}},
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'summary': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'group_id': {'type': 'keyword'},
|
||||
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'name_embedding': {
|
||||
'type': 'knn_vector',
|
||||
'dims': EMBEDDING_DIM,
|
||||
'index': True,
|
||||
'similarity': 'cosine',
|
||||
'dimension': EMBEDDING_DIM,
|
||||
'method': {
|
||||
'engine': 'faiss',
|
||||
'space_type': 'cosinesimil',
|
||||
|
|
@ -70,23 +79,23 @@ aoss_indices = [
|
|||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'communities',
|
||||
'index_name': COMMUNITY_INDEX_NAME,
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
'group_id': {'type': 'keyword'},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'episodes',
|
||||
'index_name': EPISODE_INDEX_NAME,
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -94,31 +103,30 @@ aoss_indices = [
|
|||
'content': {'type': 'text'},
|
||||
'source': {'type': 'text'},
|
||||
'source_description': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'group_id': {'type': 'keyword'},
|
||||
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'entity_edges',
|
||||
'index_name': ENTITY_EDGE_INDEX_NAME,
|
||||
'body': {
|
||||
'settings': {'index': {'knn': True}},
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'fact': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'group_id': {'type': 'keyword'},
|
||||
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'fact_embedding': {
|
||||
'type': 'knn_vector',
|
||||
'dims': EMBEDDING_DIM,
|
||||
'index': True,
|
||||
'similarity': 'cosine',
|
||||
'dimension': EMBEDDING_DIM,
|
||||
'method': {
|
||||
'engine': 'faiss',
|
||||
'space_type': 'cosinesimil',
|
||||
|
|
@ -127,7 +135,7 @@ aoss_indices = [
|
|||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
|
@ -163,7 +171,7 @@ class GraphDriver(ABC):
|
|||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||
)
|
||||
_database: str
|
||||
aoss_client: OpenSearch | None # type: ignore
|
||||
aoss_client: AsyncOpenSearch | None # type: ignore
|
||||
|
||||
@abstractmethod
|
||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||
|
|
@ -205,7 +213,7 @@ class GraphDriver(ABC):
|
|||
alias_name = index['index_name']
|
||||
|
||||
# If alias already exists, skip (idempotent behavior)
|
||||
if client.indices.exists_alias(name=alias_name):
|
||||
if await client.indices.exists_alias(name=alias_name):
|
||||
continue
|
||||
|
||||
# Build a physical index name with timestamp
|
||||
|
|
@ -213,27 +221,67 @@ class GraphDriver(ABC):
|
|||
physical_index_name = f'{alias_name}_{ts_suffix}'
|
||||
|
||||
# Create the index
|
||||
client.indices.create(index=physical_index_name, body=index['body'])
|
||||
await client.indices.create(index=physical_index_name, body=index['body'])
|
||||
|
||||
# Point alias to it
|
||||
client.indices.put_alias(index=physical_index_name, name=alias_name)
|
||||
await client.indices.put_alias(index=physical_index_name, name=alias_name)
|
||||
|
||||
# Allow some time for index creation
|
||||
await asyncio.sleep(60)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def delete_aoss_indices(self):
|
||||
client = self.aoss_client
|
||||
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return
|
||||
|
||||
for entry in aoss_indices:
|
||||
alias_name = entry['index_name']
|
||||
|
||||
try:
|
||||
# Resolve alias → indices
|
||||
alias_info = await client.indices.get_alias(name=alias_name)
|
||||
indices = list(alias_info.keys())
|
||||
|
||||
if not indices:
|
||||
logger.info(f"No indices found for alias '{alias_name}'")
|
||||
continue
|
||||
|
||||
for index in indices:
|
||||
if await client.indices.exists(index=index):
|
||||
await client.indices.delete(index=index)
|
||||
logger.info(f"Deleted index '{index}' (alias: {alias_name})")
|
||||
else:
|
||||
logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
|
||||
|
||||
async def clear_aoss_indices(self):
|
||||
client = self.aoss_client
|
||||
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return
|
||||
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return
|
||||
if await client.indices.exists(index=index_name):
|
||||
try:
|
||||
# Delete all documents but keep the index
|
||||
response = await client.delete_by_query(
|
||||
index=index_name,
|
||||
body={'query': {'match_all': {}}},
|
||||
)
|
||||
logger.info(f"Cleared index '{index_name}': {response}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing index '{index_name}': {e}")
|
||||
else:
|
||||
logger.warning(f"Index '{index_name}' does not exist")
|
||||
|
||||
if client.indices.exists(index=index_name):
|
||||
client.indices.delete(index=index_name)
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
client = self.aoss_client
|
||||
if not client or not helpers:
|
||||
logger.warning('No OpenSearch client found')
|
||||
|
|
@ -243,16 +291,22 @@ class GraphDriver(ABC):
|
|||
if name.lower() == index['index_name']:
|
||||
to_index = []
|
||||
for d in data:
|
||||
item = {
|
||||
'_index': name,
|
||||
'_routing': d.get('group_id'), # shard routing
|
||||
}
|
||||
doc = {}
|
||||
for p in index['body']['mappings']['properties']:
|
||||
if p in d: # protect against missing fields
|
||||
item[p] = d[p]
|
||||
doc[p] = d[p]
|
||||
|
||||
item = {
|
||||
'_index': name,
|
||||
'_id': d['uuid'],
|
||||
'_routing': d.get('group_id'),
|
||||
'_source': doc,
|
||||
}
|
||||
to_index.append(item)
|
||||
|
||||
success, failed = helpers.bulk(client, to_index, stats_only=True)
|
||||
success, failed = await helpers.async_bulk(
|
||||
client, to_index, stats_only=True, request_timeout=60
|
||||
)
|
||||
|
||||
return success if failed == 0 else success
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
try:
|
||||
import boto3
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
||||
from opensearchpy import (
|
||||
AIOHttpConnection,
|
||||
AsyncOpenSearch,
|
||||
AWSV4SignerAuth,
|
||||
Urllib3AWSV4SignerAuth,
|
||||
Urllib3HttpConnection,
|
||||
)
|
||||
|
||||
_HAS_OPENSEARCH = True
|
||||
except ImportError:
|
||||
|
|
@ -50,6 +56,9 @@ class Neo4jDriver(GraphDriver):
|
|||
database: str = 'neo4j',
|
||||
aoss_host: str | None = None,
|
||||
aoss_port: int | None = None,
|
||||
aws_profile_name: str | None = None,
|
||||
aws_region: str | None = None,
|
||||
aws_service: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.client = AsyncGraphDatabase.driver(
|
||||
|
|
@ -61,15 +70,17 @@ class Neo4jDriver(GraphDriver):
|
|||
self.aoss_client = None
|
||||
if aoss_host and aoss_port and boto3 is not None:
|
||||
try:
|
||||
session = boto3.Session()
|
||||
self.aoss_client = OpenSearch( # type: ignore
|
||||
region = aws_region
|
||||
service = aws_service
|
||||
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
|
||||
auth = AWSV4SignerAuth(credentials, region or '', service or '')
|
||||
|
||||
self.aoss_client = AsyncOpenSearch(
|
||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||
http_auth=Urllib3AWSV4SignerAuth( # type: ignore
|
||||
session.get_credentials(), session.region_name, 'aoss'
|
||||
),
|
||||
auth=auth,
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=Urllib3HttpConnection,
|
||||
connection_class=AIOHttpConnection,
|
||||
pool_maxsize=20,
|
||||
) # type: ignore
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -237,12 +237,12 @@ class NeptuneDriver(GraphDriver):
|
|||
'You must provide an AOSS endpoint to create an OpenSearch driver.'
|
||||
)
|
||||
if not client.indices.exists(index=index_name):
|
||||
client.indices.create(index=index_name, body=index['body'])
|
||||
await client.indices.create(index=index_name, body=index['body'])
|
||||
|
||||
alias_name = index.get('alias_name', index_name)
|
||||
|
||||
if not client.indices.exists_alias(name=alias_name, index=index_name):
|
||||
client.indices.put_alias(index=index_name, name=alias_name)
|
||||
await client.indices.put_alias(index=index_name, name=alias_name)
|
||||
|
||||
# Sleep for 1 minute to let the index creation complete
|
||||
await asyncio.sleep(60)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from uuid import uuid4
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
|
|
@ -77,6 +77,13 @@ class Edge(BaseModel, ABC):
|
|||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
await driver.aoss_client.delete(
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
id=self.uuid,
|
||||
params={'routing': self.group_id},
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||
|
||||
@classmethod
|
||||
|
|
@ -108,6 +115,12 @@ class Edge(BaseModel, ABC):
|
|||
uuids=uuids,
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
body={'query': {'terms': {'uuid': uuids}}},
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Edges: {uuids}')
|
||||
|
||||
def __hash__(self):
|
||||
|
|
@ -256,13 +269,13 @@ class EntityEdge(Edge):
|
|||
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
||||
"""
|
||||
elif driver.aoss_client:
|
||||
resp = driver.aoss_client.search(
|
||||
resp = await driver.aoss_client.search(
|
||||
body={
|
||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||
'size': 1,
|
||||
},
|
||||
index='entity_edges',
|
||||
routing=self.group_id,
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
params={'routing': self.group_id},
|
||||
)
|
||||
|
||||
if resp['hits']['hits']:
|
||||
|
|
@ -314,7 +327,7 @@ class EntityEdge(Edge):
|
|||
edge_data.update(self.attributes or {})
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
|
|
@ -351,6 +364,35 @@ class EntityEdge(Edge):
|
|||
raise EdgeNotFoundError(uuid)
|
||||
return edges[0]
|
||||
|
||||
@classmethod
|
||||
async def get_between_nodes(
|
||||
cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
|
||||
):
|
||||
match_query = """
|
||||
MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
||||
"""
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
match_query = """
|
||||
MATCH (n:Entity {uuid: $source_node_uuid})
|
||||
-[:RELATES_TO]->(e:RelatesToNode_)
|
||||
-[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
match_query
|
||||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_edge_return_query(driver.provider),
|
||||
source_node_uuid=source_node_uuid,
|
||||
target_node_uuid=target_node_uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
||||
|
||||
return edges
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
if len(uuids) == 0:
|
||||
|
|
|
|||
|
|
@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
|
|||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import (
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
get_edge_invalidation_candidates,
|
||||
get_mentioned_nodes,
|
||||
get_relevant_edges,
|
||||
)
|
||||
from graphiti_core.telemetry import capture_event
|
||||
from graphiti_core.utils.bulk_utils import (
|
||||
|
|
@ -1037,10 +1035,28 @@ class Graphiti:
|
|||
|
||||
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
||||
|
||||
related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
|
||||
valid_edges = await EntityEdge.get_between_nodes(
|
||||
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
||||
)
|
||||
|
||||
related_edges = (
|
||||
await search(
|
||||
self.clients,
|
||||
updated_edge.fact,
|
||||
group_ids=[updated_edge.group_id],
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
||||
)
|
||||
).edges
|
||||
existing_edges = (
|
||||
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
||||
)[0]
|
||||
await search(
|
||||
self.clients,
|
||||
updated_edge.fact,
|
||||
group_ids=[updated_edge.group_id],
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
search_filter=SearchFilters(),
|
||||
)
|
||||
).edges
|
||||
|
||||
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
||||
self.llm_client,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,14 @@ from uuid import uuid4
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.driver.driver import (
|
||||
COMMUNITY_INDEX_NAME,
|
||||
ENTITY_EDGE_INDEX_NAME,
|
||||
ENTITY_INDEX_NAME,
|
||||
EPISODE_INDEX_NAME,
|
||||
GraphDriver,
|
||||
GraphProvider,
|
||||
)
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import NodeNotFoundError
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
|
|
@ -94,13 +101,39 @@ class Node(BaseModel, ABC):
|
|||
async def delete(self, driver: GraphDriver):
|
||||
match driver.provider:
|
||||
case GraphProvider.NEO4J:
|
||||
await driver.execute_query(
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
||||
MATCH (n {uuid: $uuid})
|
||||
WHERE n:Entity OR n:Episodic OR n:Community
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
WITH collect(r.uuid) AS edge_uuids, n
|
||||
DETACH DELETE n
|
||||
RETURN edge_uuids
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
|
||||
|
||||
if driver.aoss_client:
|
||||
# Delete the node from OpenSearch indices
|
||||
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
||||
await driver.aoss_client.delete(
|
||||
index=index,
|
||||
id=self.uuid,
|
||||
params={'routing': self.group_id},
|
||||
)
|
||||
|
||||
# Bulk delete the detached edges
|
||||
if edge_uuids:
|
||||
actions = []
|
||||
for eid in edge_uuids:
|
||||
actions.append(
|
||||
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
||||
)
|
||||
|
||||
await driver.aoss_client.bulk(body=actions)
|
||||
|
||||
case GraphProvider.KUZU:
|
||||
for label in ['Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
|
|
@ -162,6 +195,32 @@ class Node(BaseModel, ABC):
|
|||
group_id=group_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=EPISODE_INDEX_NAME,
|
||||
body={'query': {'term': {'group_id': group_id}}},
|
||||
params={'routing': group_id},
|
||||
)
|
||||
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=ENTITY_INDEX_NAME,
|
||||
body={'query': {'term': {'group_id': group_id}}},
|
||||
params={'routing': group_id},
|
||||
)
|
||||
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=COMMUNITY_INDEX_NAME,
|
||||
body={'query': {'term': {'group_id': group_id}}},
|
||||
params={'routing': group_id},
|
||||
)
|
||||
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
body={'query': {'term': {'group_id': group_id}}},
|
||||
params={'routing': group_id},
|
||||
)
|
||||
|
||||
case GraphProvider.KUZU:
|
||||
for label in ['Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
|
|
@ -240,6 +299,23 @@ class Node(BaseModel, ABC):
|
|||
)
|
||||
case _: # Neo4J, Neptune
|
||||
async with driver.session() as session:
|
||||
# Collect all edge UUIDs before deleting nodes
|
||||
result = await session.run(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community)
|
||||
WHERE n.uuid IN $uuids
|
||||
MATCH (n)-[r]-()
|
||||
RETURN collect(r.uuid) AS edge_uuids
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
|
||||
record = await result.single()
|
||||
edge_uuids: list[str] = (
|
||||
record['edge_uuids'] if record and record['edge_uuids'] else []
|
||||
)
|
||||
|
||||
# Now delete the nodes in batches
|
||||
await session.run(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community)
|
||||
|
|
@ -253,6 +329,20 @@ class Node(BaseModel, ABC):
|
|||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=index,
|
||||
body={'query': {'terms': {'uuid': uuids}}},
|
||||
)
|
||||
|
||||
if edge_uuids:
|
||||
actions = [
|
||||
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
||||
for eid in edge_uuids
|
||||
]
|
||||
await driver.aoss_client.bulk(body=actions)
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||
|
||||
|
|
@ -286,7 +376,7 @@ class EpisodicNode(Node):
|
|||
}
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'episodes',
|
||||
[episode_args],
|
||||
)
|
||||
|
|
@ -426,13 +516,13 @@ class EntityNode(Node):
|
|||
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
||||
"""
|
||||
elif driver.aoss_client:
|
||||
resp = driver.aoss_client.search(
|
||||
resp = await driver.aoss_client.search(
|
||||
body={
|
||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||
'size': 1,
|
||||
},
|
||||
index='entities',
|
||||
routing=self.group_id,
|
||||
index=ENTITY_INDEX_NAME,
|
||||
params={'routing': self.group_id},
|
||||
)
|
||||
|
||||
if resp['hits']['hits']:
|
||||
|
|
@ -479,7 +569,7 @@ class EntityNode(Node):
|
|||
labels = ':'.join(self.labels + ['Entity'])
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
||||
|
|
@ -577,7 +667,7 @@ class CommunityNode(Node):
|
|||
|
||||
async def save(self, driver: GraphDriver):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'communities',
|
||||
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class SearchFilters(BaseModel):
|
|||
invalid_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
created_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
expired_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
edge_uuids: list[str] | None = Field(default=None)
|
||||
|
||||
|
||||
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
|
||||
|
|
@ -108,6 +109,10 @@ def edge_search_filter_query_constructor(
|
|||
filter_queries.append('e.name in $edge_types')
|
||||
filter_params['edge_types'] = edge_types
|
||||
|
||||
if filters.edge_uuids is not None:
|
||||
filter_queries.append('e.uuid in $edge_uuids')
|
||||
filter_params['edge_uuids'] = filters.edge_uuids
|
||||
|
||||
if filters.node_labels is not None:
|
||||
if provider == GraphProvider.KUZU:
|
||||
node_label_filter = (
|
||||
|
|
@ -261,6 +266,9 @@ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters)
|
|||
if search_filters.edge_types:
|
||||
filters.append({'terms': {'edge_types': search_filters.edge_types}})
|
||||
|
||||
if search_filters.edge_uuids:
|
||||
filters.append({'terms': {'uuid': search_filters.edge_uuids}})
|
||||
|
||||
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
|
||||
ranges = getattr(search_filters, field)
|
||||
if ranges:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,13 @@ import numpy as np
|
|||
from numpy._typing import NDArray
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.driver.driver import (
|
||||
ENTITY_EDGE_INDEX_NAME,
|
||||
ENTITY_INDEX_NAME,
|
||||
EPISODE_INDEX_NAME,
|
||||
GraphDriver,
|
||||
GraphProvider,
|
||||
)
|
||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||
from graphiti_core.graph_queries import (
|
||||
get_nodes_query,
|
||||
|
|
@ -209,11 +215,11 @@ async def edge_fulltext_search(
|
|||
# Match the edge ids and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
AND id(e)=id
|
||||
|
|
@ -248,17 +254,21 @@ async def edge_fulltext_search(
|
|||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
||||
res = driver.aoss_client.search(
|
||||
index='entity_edges',
|
||||
routing=route,
|
||||
_source=['uuid'],
|
||||
query={
|
||||
'bool': {
|
||||
'filter': filters,
|
||||
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
|
||||
}
|
||||
res = await driver.aoss_client.search(
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
params={'routing': route},
|
||||
body={
|
||||
'size': limit,
|
||||
'_source': ['uuid'],
|
||||
'query': {
|
||||
'bool': {
|
||||
'filter': filters,
|
||||
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
|
|
@ -344,8 +354,8 @@ async def edge_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||
|
|
@ -406,17 +416,22 @@ async def edge_similarity_search(
|
|||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
||||
res = driver.aoss_client.search(
|
||||
index='entity_edges',
|
||||
routing=route,
|
||||
_source=['uuid'],
|
||||
knn={
|
||||
'field': 'fact_embedding',
|
||||
'query_vector': search_vector,
|
||||
'k': limit,
|
||||
'num_candidates': 1000,
|
||||
res = await driver.aoss_client.search(
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
params={'routing': route},
|
||||
body={
|
||||
'size': limit,
|
||||
'_source': ['uuid'],
|
||||
'query': {
|
||||
'knn': {
|
||||
'fact_embedding': {
|
||||
'vector': list(map(float, search_vector)),
|
||||
'k': limit,
|
||||
'filter': {'bool': {'filter': filters}},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
query={'bool': {'filter': filters}},
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
|
|
@ -428,6 +443,7 @@ async def edge_similarity_search(
|
|||
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_edges
|
||||
return []
|
||||
|
||||
else:
|
||||
query = (
|
||||
|
|
@ -622,11 +638,11 @@ async def node_fulltext_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -646,25 +662,27 @@ async def node_fulltext_search(
|
|||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||
res = driver.aoss_client.search(
|
||||
'entities',
|
||||
routing=route,
|
||||
_source=['uuid'],
|
||||
query={
|
||||
'bool': {
|
||||
'filter': filters,
|
||||
'must': [
|
||||
{
|
||||
'multi_match': {
|
||||
'query': query,
|
||||
'field': ['name', 'summary'],
|
||||
'operator': 'or',
|
||||
res = await driver.aoss_client.search(
|
||||
index=ENTITY_INDEX_NAME,
|
||||
params={'routing': route},
|
||||
body={
|
||||
'_source': ['uuid'],
|
||||
'size': limit,
|
||||
'query': {
|
||||
'bool': {
|
||||
'filter': filters,
|
||||
'must': [
|
||||
{
|
||||
'multi_match': {
|
||||
'query': query,
|
||||
'fields': ['name', 'summary'],
|
||||
'operator': 'or',
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
},
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
|
|
@ -734,8 +752,8 @@ async def node_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -764,11 +782,11 @@ async def node_similarity_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -789,17 +807,22 @@ async def node_similarity_search(
|
|||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||
res = driver.aoss_client.search(
|
||||
index='entities',
|
||||
routing=route,
|
||||
_source=['uuid'],
|
||||
knn={
|
||||
'field': 'fact_embedding',
|
||||
'query_vector': search_vector,
|
||||
'k': limit,
|
||||
'num_candidates': 1000,
|
||||
res = await driver.aoss_client.search(
|
||||
index=ENTITY_INDEX_NAME,
|
||||
params={'routing': route},
|
||||
body={
|
||||
'size': limit,
|
||||
'_source': ['uuid'],
|
||||
'query': {
|
||||
'knn': {
|
||||
'name_embedding': {
|
||||
'vector': list(map(float, search_vector)),
|
||||
'k': limit,
|
||||
'filter': {'bool': {'filter': filters}},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
query={'bool': {'filter': filters}},
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
|
|
@ -811,11 +834,12 @@ async def node_similarity_search(
|
|||
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_nodes
|
||||
return []
|
||||
else:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, """
|
||||
|
|
@ -988,11 +1012,12 @@ async def episode_fulltext_search(
|
|||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
res = driver.aoss_client.search(
|
||||
'episodes',
|
||||
routing=route,
|
||||
_source=['uuid'],
|
||||
query={
|
||||
res = await driver.aoss_client.search(
|
||||
index=EPISODE_INDEX_NAME,
|
||||
params={'routing': route},
|
||||
body={
|
||||
'size': limit,
|
||||
'_source': ['uuid'],
|
||||
'bool': {
|
||||
'filter': {'terms': group_ids},
|
||||
'must': [
|
||||
|
|
@ -1004,9 +1029,8 @@ async def episode_fulltext_search(
|
|||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
},
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
|
|
@ -1147,8 +1171,8 @@ async def community_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -1207,8 +1231,8 @@ async def community_similarity_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH c,
|
||||
|
|
@ -1350,9 +1374,9 @@ async def get_relevant_nodes(
|
|||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1397,9 +1421,9 @@ async def get_relevant_nodes(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1488,9 +1512,9 @@ async def get_relevant_edges(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1560,9 +1584,9 @@ async def get_relevant_edges(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, n, m, """
|
||||
|
|
@ -1598,9 +1622,9 @@ async def get_relevant_edges(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, """
|
||||
|
|
@ -1673,10 +1697,10 @@ async def get_edge_invalidation_candidates(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1746,10 +1770,10 @@ async def get_edge_invalidation_candidates(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, n, m, """
|
||||
|
|
@ -1785,10 +1809,10 @@ async def get_edge_invalidation_candidates(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, """
|
||||
|
|
|
|||
|
|
@ -23,7 +23,14 @@ import numpy as np
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Any
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.driver.driver import (
|
||||
ENTITY_EDGE_INDEX_NAME,
|
||||
ENTITY_INDEX_NAME,
|
||||
EPISODE_INDEX_NAME,
|
||||
GraphDriver,
|
||||
GraphDriverSession,
|
||||
GraphProvider,
|
||||
)
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
|
|
@ -203,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss('episodes', episodes)
|
||||
driver.save_to_aoss('entities', nodes)
|
||||
driver.save_to_aoss('entity_edges', edges)
|
||||
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
||||
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
||||
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
|
|
|
|||
|
|
@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
||||
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
||||
from graphiti_core.search.search import search
|
||||
from graphiti_core.search.search_config import SearchResults
|
||||
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
||||
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -258,12 +260,44 @@ async def resolve_extracted_edges(
|
|||
embedder = clients.embedder
|
||||
await create_entity_edge_embeddings(embedder, extracted_edges)
|
||||
|
||||
search_results = await semaphore_gather(
|
||||
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
||||
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
|
||||
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
|
||||
*[
|
||||
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
||||
for edge in extracted_edges
|
||||
]
|
||||
)
|
||||
|
||||
related_edges_lists, edge_invalidation_candidates = search_results
|
||||
related_edges_results: list[SearchResults] = await semaphore_gather(
|
||||
*[
|
||||
search(
|
||||
clients,
|
||||
extracted_edge.fact,
|
||||
group_ids=[extracted_edge.group_id],
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
||||
)
|
||||
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
|
||||
]
|
||||
)
|
||||
|
||||
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
|
||||
|
||||
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
|
||||
*[
|
||||
search(
|
||||
clients,
|
||||
extracted_edge.fact,
|
||||
group_ids=[extracted_edge.group_id],
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
search_filter=SearchFilters(),
|
||||
)
|
||||
for extracted_edge in extracted_edges
|
||||
]
|
||||
)
|
||||
|
||||
edge_invalidation_candidates: list[list[EntityEdge]] = [
|
||||
result.edges for result in edge_invalidation_candidate_results
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
||||
|
|
|
|||
|
|
@ -95,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
|||
|
||||
async def delete_all(tx):
|
||||
await tx.run('MATCH (n) DETACH DELETE n')
|
||||
if driver.aoss_client:
|
||||
await driver.clear_aoss_indices()
|
||||
|
||||
async def delete_group_ids(tx):
|
||||
labels = ['Entity', 'Episodic', 'Community']
|
||||
|
|
@ -151,9 +153,9 @@ async def retrieve_episodes(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ query_filter
|
||||
+ """
|
||||
RETURN
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.21.0pre1"
|
||||
version = "0.21.0pre2"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.21.0rc1"
|
||||
version = "0.21.0rc2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue