update
This commit is contained in:
parent
e0066ff235
commit
fd1c360e8c
7 changed files with 60 additions and 80 deletions
|
|
@ -25,7 +25,6 @@ from pydantic import BaseModel, Field
|
||||||
from transcript_parser import parse_podcast_messages
|
from transcript_parser import parse_podcast_messages
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|
||||||
from graphiti_core.nodes import EpisodeType
|
from graphiti_core.nodes import EpisodeType
|
||||||
from graphiti_core.utils.bulk_utils import RawEpisode
|
from graphiti_core.utils.bulk_utils import RawEpisode
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||||
|
|
@ -35,8 +34,6 @@ load_dotenv()
|
||||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||||
aoss_host = os.environ.get('AOSS_HOST') or None
|
|
||||||
aoss_port = os.environ.get('AOSS_PORT') or None
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
|
|
@ -80,25 +77,12 @@ class IsPresidentOf(BaseModel):
|
||||||
|
|
||||||
async def main(use_bulk: bool = False):
|
async def main(use_bulk: bool = False):
|
||||||
setup_logging()
|
setup_logging()
|
||||||
graph_driver = Neo4jDriver(
|
client = Graphiti(
|
||||||
neo4j_uri,
|
neo4j_uri,
|
||||||
neo4j_user,
|
neo4j_user,
|
||||||
neo4j_password,
|
neo4j_password,
|
||||||
aoss_host=aoss_host,
|
|
||||||
aoss_port=int(aoss_port),
|
|
||||||
aws_profile_name='zep-development',
|
|
||||||
aws_region='us-west-2',
|
|
||||||
aws_service='es',
|
|
||||||
)
|
)
|
||||||
# client = Graphiti(
|
|
||||||
# neo4j_uri,
|
|
||||||
# neo4j_user,
|
|
||||||
# neo4j_password,
|
|
||||||
# )
|
|
||||||
client = Graphiti(graph_driver=graph_driver)
|
|
||||||
await clear_data(client.driver)
|
await clear_data(client.driver)
|
||||||
await client.driver.delete_aoss_indices()
|
|
||||||
await client.driver.create_aoss_indices()
|
|
||||||
await client.build_indices_and_constraints()
|
await client.build_indices_and_constraints()
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
group_id = str(uuid4())
|
group_id = str(uuid4())
|
||||||
|
|
|
||||||
|
|
@ -274,10 +274,7 @@ class GraphDriver(ABC):
|
||||||
response = await client.delete_by_query(
|
response = await client.delete_by_query(
|
||||||
index=index_name,
|
index=index_name,
|
||||||
body={'query': {'match_all': {}}},
|
body={'query': {'match_all': {}}},
|
||||||
refresh=True,
|
slices='auto',
|
||||||
conflicts='proceed',
|
|
||||||
wait_for_completion=True,
|
|
||||||
slices='auto', # improves coverage/concurrency
|
|
||||||
)
|
)
|
||||||
logger.info(f"Cleared index '{index_name}': {response}")
|
logger.info(f"Cleared index '{index_name}': {response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -287,7 +284,7 @@ class GraphDriver(ABC):
|
||||||
|
|
||||||
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||||
client = self.aoss_client
|
client = self.aoss_client
|
||||||
if not client:
|
if not client or not helpers:
|
||||||
logger.warning('No OpenSearch client found')
|
logger.warning('No OpenSearch client found')
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,6 @@ try:
|
||||||
AIOHttpConnection,
|
AIOHttpConnection,
|
||||||
AsyncOpenSearch,
|
AsyncOpenSearch,
|
||||||
AWSV4SignerAuth,
|
AWSV4SignerAuth,
|
||||||
RequestsHttpConnection,
|
|
||||||
Urllib3AWSV4SignerAuth,
|
Urllib3AWSV4SignerAuth,
|
||||||
Urllib3HttpConnection,
|
Urllib3HttpConnection,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -237,12 +237,12 @@ class NeptuneDriver(GraphDriver):
|
||||||
'You must provide an AOSS endpoint to create an OpenSearch driver.'
|
'You must provide an AOSS endpoint to create an OpenSearch driver.'
|
||||||
)
|
)
|
||||||
if not client.indices.exists(index=index_name):
|
if not client.indices.exists(index=index_name):
|
||||||
client.indices.create(index=index_name, body=index['body'])
|
await client.indices.create(index=index_name, body=index['body'])
|
||||||
|
|
||||||
alias_name = index.get('alias_name', index_name)
|
alias_name = index.get('alias_name', index_name)
|
||||||
|
|
||||||
if not client.indices.exists_alias(name=alias_name, index=index_name):
|
if not client.indices.exists_alias(name=alias_name, index=index_name):
|
||||||
client.indices.put_alias(index=index_name, name=alias_name)
|
await client.indices.put_alias(index=index_name, name=alias_name)
|
||||||
|
|
||||||
# Sleep for 1 minute to let the index creation complete
|
# Sleep for 1 minute to let the index creation complete
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
|
|
|
||||||
|
|
@ -103,8 +103,9 @@ class Node(BaseModel, ABC):
|
||||||
case GraphProvider.NEO4J:
|
case GraphProvider.NEO4J:
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n {uuid: $uuid})-[r]-()
|
MATCH (n {uuid: $uuid})
|
||||||
WHERE n:Entity OR n:Episodic OR n:Community
|
WHERE n:Entity OR n:Episodic OR n:Community
|
||||||
|
OPTIONAL MATCH (n)-[r]-()
|
||||||
WITH collect(r.uuid) AS edge_uuids, n
|
WITH collect(r.uuid) AS edge_uuids, n
|
||||||
DETACH DELETE n
|
DETACH DELETE n
|
||||||
RETURN edge_uuids
|
RETURN edge_uuids
|
||||||
|
|
|
||||||
|
|
@ -808,7 +808,7 @@ async def node_similarity_search(
|
||||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||||
res = await driver.aoss_client.search(
|
res = await driver.aoss_client.search(
|
||||||
index=ENTITY_INDEX_NAME,
|
index=ENTITY_INDEX_NAME,
|
||||||
routing=route,
|
params={'routing': route},
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
size=limit,
|
size=limit,
|
||||||
body={
|
body={
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
from xml.dom.minidom import Entity
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue