async update
This commit is contained in:
parent
42bbd93c38
commit
e0066ff235
7 changed files with 117 additions and 89 deletions
|
|
@ -25,6 +25,7 @@ from pydantic import BaseModel, Field
|
||||||
from transcript_parser import parse_podcast_messages
|
from transcript_parser import parse_podcast_messages
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||||
from graphiti_core.nodes import EpisodeType
|
from graphiti_core.nodes import EpisodeType
|
||||||
from graphiti_core.utils.bulk_utils import RawEpisode
|
from graphiti_core.utils.bulk_utils import RawEpisode
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||||
|
|
@ -34,6 +35,8 @@ load_dotenv()
|
||||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||||
|
aoss_host = os.environ.get('AOSS_HOST') or None
|
||||||
|
aoss_port = os.environ.get('AOSS_PORT') or None
|
||||||
|
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
|
|
@ -77,12 +80,25 @@ class IsPresidentOf(BaseModel):
|
||||||
|
|
||||||
async def main(use_bulk: bool = False):
|
async def main(use_bulk: bool = False):
|
||||||
setup_logging()
|
setup_logging()
|
||||||
client = Graphiti(
|
graph_driver = Neo4jDriver(
|
||||||
neo4j_uri,
|
neo4j_uri,
|
||||||
neo4j_user,
|
neo4j_user,
|
||||||
neo4j_password,
|
neo4j_password,
|
||||||
|
aoss_host=aoss_host,
|
||||||
|
aoss_port=int(aoss_port),
|
||||||
|
aws_profile_name='zep-development',
|
||||||
|
aws_region='us-west-2',
|
||||||
|
aws_service='es',
|
||||||
)
|
)
|
||||||
|
# client = Graphiti(
|
||||||
|
# neo4j_uri,
|
||||||
|
# neo4j_user,
|
||||||
|
# neo4j_password,
|
||||||
|
# )
|
||||||
|
client = Graphiti(graph_driver=graph_driver)
|
||||||
await clear_data(client.driver)
|
await clear_data(client.driver)
|
||||||
|
await client.driver.delete_aoss_indices()
|
||||||
|
await client.driver.create_aoss_indices()
|
||||||
await client.build_indices_and_constraints()
|
await client.build_indices_and_constraints()
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
group_id = str(uuid4())
|
group_id = str(uuid4())
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,12 @@ from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from graphiti_core.embedder.client import EMBEDDING_DIM
|
from graphiti_core.embedder.client import EMBEDDING_DIM
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from opensearchpy import OpenSearch, helpers
|
from opensearchpy import AsyncOpenSearch, helpers
|
||||||
|
|
||||||
_HAS_OPENSEARCH = True
|
_HAS_OPENSEARCH = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
@ -39,6 +41,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_SIZE = 10
|
DEFAULT_SIZE = 10
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
||||||
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
||||||
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
||||||
|
|
@ -62,7 +66,7 @@ aoss_indices = [
|
||||||
'uuid': {'type': 'keyword'},
|
'uuid': {'type': 'keyword'},
|
||||||
'name': {'type': 'text'},
|
'name': {'type': 'text'},
|
||||||
'summary': {'type': 'text'},
|
'summary': {'type': 'text'},
|
||||||
'group_id': {'type': 'text'},
|
'group_id': {'type': 'keyword'},
|
||||||
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
'name_embedding': {
|
'name_embedding': {
|
||||||
'type': 'knn_vector',
|
'type': 'knn_vector',
|
||||||
|
|
@ -85,7 +89,7 @@ aoss_indices = [
|
||||||
'properties': {
|
'properties': {
|
||||||
'uuid': {'type': 'keyword'},
|
'uuid': {'type': 'keyword'},
|
||||||
'name': {'type': 'text'},
|
'name': {'type': 'text'},
|
||||||
'group_id': {'type': 'text'},
|
'group_id': {'type': 'keyword'},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -99,7 +103,7 @@ aoss_indices = [
|
||||||
'content': {'type': 'text'},
|
'content': {'type': 'text'},
|
||||||
'source': {'type': 'text'},
|
'source': {'type': 'text'},
|
||||||
'source_description': {'type': 'text'},
|
'source_description': {'type': 'text'},
|
||||||
'group_id': {'type': 'text'},
|
'group_id': {'type': 'keyword'},
|
||||||
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
}
|
}
|
||||||
|
|
@ -115,7 +119,7 @@ aoss_indices = [
|
||||||
'uuid': {'type': 'keyword'},
|
'uuid': {'type': 'keyword'},
|
||||||
'name': {'type': 'text'},
|
'name': {'type': 'text'},
|
||||||
'fact': {'type': 'text'},
|
'fact': {'type': 'text'},
|
||||||
'group_id': {'type': 'text'},
|
'group_id': {'type': 'keyword'},
|
||||||
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
|
|
@ -167,7 +171,7 @@ class GraphDriver(ABC):
|
||||||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||||
)
|
)
|
||||||
_database: str
|
_database: str
|
||||||
aoss_client: OpenSearch | None # type: ignore
|
aoss_client: AsyncOpenSearch | None # type: ignore
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||||
|
|
@ -209,7 +213,7 @@ class GraphDriver(ABC):
|
||||||
alias_name = index['index_name']
|
alias_name = index['index_name']
|
||||||
|
|
||||||
# If alias already exists, skip (idempotent behavior)
|
# If alias already exists, skip (idempotent behavior)
|
||||||
if client.indices.exists_alias(name=alias_name):
|
if await client.indices.exists_alias(name=alias_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Build a physical index name with timestamp
|
# Build a physical index name with timestamp
|
||||||
|
|
@ -217,10 +221,10 @@ class GraphDriver(ABC):
|
||||||
physical_index_name = f'{alias_name}_{ts_suffix}'
|
physical_index_name = f'{alias_name}_{ts_suffix}'
|
||||||
|
|
||||||
# Create the index
|
# Create the index
|
||||||
client.indices.create(index=physical_index_name, body=index['body'])
|
await client.indices.create(index=physical_index_name, body=index['body'])
|
||||||
|
|
||||||
# Point alias to it
|
# Point alias to it
|
||||||
client.indices.put_alias(index=physical_index_name, name=alias_name)
|
await client.indices.put_alias(index=physical_index_name, name=alias_name)
|
||||||
|
|
||||||
# Allow some time for index creation
|
# Allow some time for index creation
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
@ -237,7 +241,7 @@ class GraphDriver(ABC):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Resolve alias → indices
|
# Resolve alias → indices
|
||||||
alias_info = client.indices.get_alias(name=alias_name)
|
alias_info = await client.indices.get_alias(name=alias_name)
|
||||||
indices = list(alias_info.keys())
|
indices = list(alias_info.keys())
|
||||||
|
|
||||||
if not indices:
|
if not indices:
|
||||||
|
|
@ -245,8 +249,8 @@ class GraphDriver(ABC):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index in indices:
|
for index in indices:
|
||||||
if client.indices.exists(index=index):
|
if await client.indices.exists(index=index):
|
||||||
client.indices.delete(index=index)
|
await client.indices.delete(index=index)
|
||||||
logger.info(f"Deleted index '{index}' (alias: {alias_name})")
|
logger.info(f"Deleted index '{index}' (alias: {alias_name})")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
|
logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
|
||||||
|
|
@ -264,14 +268,16 @@ class GraphDriver(ABC):
|
||||||
for index in aoss_indices:
|
for index in aoss_indices:
|
||||||
index_name = index['index_name']
|
index_name = index['index_name']
|
||||||
|
|
||||||
if client.indices.exists(index=index_name):
|
if await client.indices.exists(index=index_name):
|
||||||
try:
|
try:
|
||||||
# Delete all documents but keep the index
|
# Delete all documents but keep the index
|
||||||
response = client.delete_by_query(
|
response = await client.delete_by_query(
|
||||||
index=index_name,
|
index=index_name,
|
||||||
body={'query': {'match_all': {}}},
|
body={'query': {'match_all': {}}},
|
||||||
refresh=True,
|
refresh=True,
|
||||||
conflicts='proceed',
|
conflicts='proceed',
|
||||||
|
wait_for_completion=True,
|
||||||
|
slices='auto', # improves coverage/concurrency
|
||||||
)
|
)
|
||||||
logger.info(f"Cleared index '{index_name}': {response}")
|
logger.info(f"Cleared index '{index_name}': {response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -281,7 +287,7 @@ class GraphDriver(ABC):
|
||||||
|
|
||||||
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||||
client = self.aoss_client
|
client = self.aoss_client
|
||||||
if not client or not helpers:
|
if not client:
|
||||||
logger.warning('No OpenSearch client found')
|
logger.warning('No OpenSearch client found')
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
@ -289,16 +295,20 @@ class GraphDriver(ABC):
|
||||||
if name.lower() == index['index_name']:
|
if name.lower() == index['index_name']:
|
||||||
to_index = []
|
to_index = []
|
||||||
for d in data:
|
for d in data:
|
||||||
item = {
|
doc = {}
|
||||||
'_index': name,
|
|
||||||
'_routing': d.get('group_id'), # shard routing
|
|
||||||
}
|
|
||||||
for p in index['body']['mappings']['properties']:
|
for p in index['body']['mappings']['properties']:
|
||||||
if p in d: # protect against missing fields
|
if p in d: # protect against missing fields
|
||||||
item[p] = d[p]
|
doc[p] = d[p]
|
||||||
|
|
||||||
|
item = {
|
||||||
|
'_index': name,
|
||||||
|
'_id': d['uuid'],
|
||||||
|
'_routing': d.get('group_id'),
|
||||||
|
'_source': doc,
|
||||||
|
}
|
||||||
to_index.append(item)
|
to_index.append(item)
|
||||||
|
|
||||||
success, failed = helpers.bulk(
|
success, failed = await helpers.async_bulk(
|
||||||
client, to_index, stats_only=True, request_timeout=60
|
client, to_index, stats_only=True, request_timeout=60
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,8 +29,9 @@ logger = logging.getLogger(__name__)
|
||||||
try:
|
try:
|
||||||
import boto3
|
import boto3
|
||||||
from opensearchpy import (
|
from opensearchpy import (
|
||||||
|
AIOHttpConnection,
|
||||||
|
AsyncOpenSearch,
|
||||||
AWSV4SignerAuth,
|
AWSV4SignerAuth,
|
||||||
OpenSearch,
|
|
||||||
RequestsHttpConnection,
|
RequestsHttpConnection,
|
||||||
Urllib3AWSV4SignerAuth,
|
Urllib3AWSV4SignerAuth,
|
||||||
Urllib3HttpConnection,
|
Urllib3HttpConnection,
|
||||||
|
|
@ -75,12 +76,12 @@ class Neo4jDriver(GraphDriver):
|
||||||
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
|
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
|
||||||
auth = AWSV4SignerAuth(credentials, region or '', service or '')
|
auth = AWSV4SignerAuth(credentials, region or '', service or '')
|
||||||
|
|
||||||
self.aoss_client = OpenSearch(
|
self.aoss_client = AsyncOpenSearch(
|
||||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||||
http_auth=auth,
|
auth=auth,
|
||||||
use_ssl=True,
|
use_ssl=True,
|
||||||
verify_certs=True,
|
verify_certs=True,
|
||||||
connection_class=RequestsHttpConnection,
|
connection_class=AIOHttpConnection,
|
||||||
pool_maxsize=20,
|
pool_maxsize=20,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -267,7 +267,7 @@ class EntityEdge(Edge):
|
||||||
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
||||||
"""
|
"""
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
resp = driver.aoss_client.search(
|
resp = await driver.aoss_client.search(
|
||||||
body={
|
body={
|
||||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||||
'size': 1,
|
'size': 1,
|
||||||
|
|
|
||||||
|
|
@ -513,7 +513,7 @@ class EntityNode(Node):
|
||||||
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
||||||
"""
|
"""
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
resp = driver.aoss_client.search(
|
resp = await driver.aoss_client.search(
|
||||||
body={
|
body={
|
||||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||||
'size': 1,
|
'size': 1,
|
||||||
|
|
|
||||||
|
|
@ -215,11 +215,11 @@ async def edge_fulltext_search(
|
||||||
# Match the edge ids and return the values
|
# Match the edge ids and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as id
|
UNWIND $ids as id
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
|
|
@ -254,7 +254,7 @@ async def edge_fulltext_search(
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
route = group_ids[0] if group_ids else None
|
route = group_ids[0] if group_ids else None
|
||||||
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = await driver.aoss_client.search(
|
||||||
index=ENTITY_EDGE_INDEX_NAME,
|
index=ENTITY_EDGE_INDEX_NAME,
|
||||||
routing=route,
|
routing=route,
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
|
|
@ -353,8 +353,8 @@ async def edge_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||||
|
|
@ -415,7 +415,7 @@ async def edge_similarity_search(
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
route = group_ids[0] if group_ids else None
|
route = group_ids[0] if group_ids else None
|
||||||
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = await driver.aoss_client.search(
|
||||||
index=ENTITY_EDGE_INDEX_NAME,
|
index=ENTITY_EDGE_INDEX_NAME,
|
||||||
routing=route,
|
routing=route,
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
|
|
@ -637,11 +637,11 @@ async def node_fulltext_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE n.uuid=i.id
|
WHERE n.uuid=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -661,7 +661,7 @@ async def node_fulltext_search(
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
route = group_ids[0] if group_ids else None
|
route = group_ids[0] if group_ids else None
|
||||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = await driver.aoss_client.search(
|
||||||
index=ENTITY_INDEX_NAME,
|
index=ENTITY_INDEX_NAME,
|
||||||
routing=route,
|
routing=route,
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
|
|
@ -674,7 +674,7 @@ async def node_fulltext_search(
|
||||||
{
|
{
|
||||||
'multi_match': {
|
'multi_match': {
|
||||||
'query': query,
|
'query': query,
|
||||||
'fields': ['name', 'summary'], # ✅ fixed key
|
'fields': ['name', 'summary'],
|
||||||
'operator': 'or',
|
'operator': 'or',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -751,8 +751,8 @@ async def node_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||||
|
|
@ -781,11 +781,11 @@ async def node_similarity_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE id(n)=i.id
|
WHERE id(n)=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -806,7 +806,7 @@ async def node_similarity_search(
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
route = group_ids[0] if group_ids else None
|
route = group_ids[0] if group_ids else None
|
||||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = await driver.aoss_client.search(
|
||||||
index=ENTITY_INDEX_NAME,
|
index=ENTITY_INDEX_NAME,
|
||||||
routing=route,
|
routing=route,
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
|
|
@ -837,8 +837,8 @@ async def node_similarity_search(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH n, """
|
WITH n, """
|
||||||
|
|
@ -1011,7 +1011,7 @@ async def episode_fulltext_search(
|
||||||
return []
|
return []
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
route = group_ids[0] if group_ids else None
|
route = group_ids[0] if group_ids else None
|
||||||
res = driver.aoss_client.search(
|
res = await driver.aoss_client.search(
|
||||||
EPISODE_INDEX_NAME,
|
EPISODE_INDEX_NAME,
|
||||||
routing=route,
|
routing=route,
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
|
|
@ -1170,8 +1170,8 @@ async def community_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)
|
MATCH (n:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||||
|
|
@ -1230,8 +1230,8 @@ async def community_similarity_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)
|
MATCH (c:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH c,
|
WITH c,
|
||||||
|
|
@ -1373,9 +1373,9 @@ async def get_relevant_nodes(
|
||||||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, """
|
WITH node, n, """
|
||||||
|
|
@ -1420,9 +1420,9 @@ async def get_relevant_nodes(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, """
|
WITH node, n, """
|
||||||
|
|
@ -1511,9 +1511,9 @@ async def get_relevant_edges(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge
|
WITH e, edge
|
||||||
|
|
@ -1583,9 +1583,9 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, n, m, """
|
WITH e, edge, n, m, """
|
||||||
|
|
@ -1621,9 +1621,9 @@ async def get_relevant_edges(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, """
|
WITH e, edge, """
|
||||||
|
|
@ -1696,10 +1696,10 @@ async def get_edge_invalidation_candidates(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge
|
WITH e, edge
|
||||||
|
|
@ -1769,10 +1769,10 @@ async def get_edge_invalidation_candidates(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, n, m, """
|
WITH edge, e, n, m, """
|
||||||
|
|
@ -1808,10 +1808,10 @@ async def get_edge_invalidation_candidates(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, """
|
WITH edge, e, """
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
from xml.dom.minidom import Entity
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
@ -260,7 +261,7 @@ async def resolve_extracted_edges(
|
||||||
embedder = clients.embedder
|
embedder = clients.embedder
|
||||||
await create_entity_edge_embeddings(embedder, extracted_edges)
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
||||||
|
|
||||||
valid_uuids_list: list[list[str]] = await semaphore_gather(
|
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
||||||
for edge in extracted_edges
|
for edge in extracted_edges
|
||||||
|
|
@ -274,9 +275,9 @@ async def resolve_extracted_edges(
|
||||||
extracted_edge.fact,
|
extracted_edge.fact,
|
||||||
group_ids=[extracted_edge.group_id],
|
group_ids=[extracted_edge.group_id],
|
||||||
config=EDGE_HYBRID_SEARCH_RRF,
|
config=EDGE_HYBRID_SEARCH_RRF,
|
||||||
search_filter=SearchFilters(edge_uuids=valid_uuids),
|
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
||||||
)
|
)
|
||||||
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True)
|
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue