claude suggestions

This commit is contained in:
prestonrasmussen 2025-09-07 23:46:47 -04:00
parent 13fc9cf1e4
commit 8e442d4634
4 changed files with 67 additions and 64 deletions

View file

@ -23,8 +23,15 @@ from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any from typing import Any
try:
from opensearchpy import OpenSearch, helpers from opensearchpy import OpenSearch, helpers
_HAS_OPENSEARCH = True
except ImportError:
OpenSearch = None
helpers = None
_HAS_OPENSEARCH = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10 DEFAULT_SIZE = 10
@ -216,9 +223,6 @@ class GraphDriver(ABC):
if client.indices.exists(index=index_name): if client.indices.exists(index=index_name):
client.indices.delete(index=index_name) client.indices.delete(index=index_name)
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
pass
def save_to_aoss(self, name: str, data: list[dict]) -> int: def save_to_aoss(self, name: str, data: list[dict]) -> int:
for index in aoss_indices: for index in aoss_indices:
if name.lower() == index['index_name']: if name.lower() == index['index_name']:

View file

@ -58,7 +58,7 @@ class Neo4jDriver(GraphDriver):
self._database = database self._database = database
self.aoss_client = None self.aoss_client = None
if aoss_host and aoss_port: if aoss_host and aoss_port and boto3 is not None:
try: try:
session = boto3.Session() session = boto3.Session()
self.aoss_client = OpenSearch( self.aoss_client = OpenSearch(

View file

@ -22,14 +22,13 @@ from typing import Any
import boto3 import boto3
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
from graphiti_core.driver.driver import ( from graphiti_core.driver.driver import (
DEFAULT_SIZE, DEFAULT_SIZE,
GraphDriver, GraphDriver,
GraphDriverSession, GraphDriverSession,
GraphProvider, GraphProvider,
aoss_indices,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -249,7 +249,7 @@ async def edge_fulltext_search(
filters = build_aoss_edge_filters(group_ids, search_filter) filters = build_aoss_edge_filters(group_ids, search_filter)
res = driver.aoss_client.search( res = driver.aoss_client.search(
index='entity_edges', index='entity_edges',
routing=group_ids, routing=group_ids[0],
_source=['uuid'], _source=['uuid'],
query={ query={
'bool': { 'bool': {
@ -406,7 +406,7 @@ async def edge_similarity_search(
filters = build_aoss_edge_filters(group_ids, search_filter) filters = build_aoss_edge_filters(group_ids, search_filter)
res = driver.aoss_client.search( res = driver.aoss_client.search(
index='entity_edges', index='entity_edges',
routing=group_ids, routing=group_ids[0],
_source=['uuid'], _source=['uuid'],
knn={ knn={
'field': 'fact_embedding', 'field': 'fact_embedding',
@ -645,7 +645,7 @@ async def node_fulltext_search(
filters = build_aoss_node_filters(group_ids, search_filter) filters = build_aoss_node_filters(group_ids, search_filter)
res = driver.aoss_client.search( res = driver.aoss_client.search(
'entities', 'entities',
routing=group_ids, routing=group_ids[0],
_source=['uuid'], _source=['uuid'],
query={ query={
'bool': { 'bool': {
@ -787,7 +787,7 @@ async def node_similarity_search(
filters = build_aoss_node_filters(group_ids, search_filter) filters = build_aoss_node_filters(group_ids, search_filter)
res = driver.aoss_client.search( res = driver.aoss_client.search(
index='entities', index='entities',
routing=group_ids, routing=group_ids[0],
_source=['uuid'], _source=['uuid'],
knn={ knn={
'field': 'fact_embedding', 'field': 'fact_embedding',
@ -985,7 +985,7 @@ async def episode_fulltext_search(
elif driver.aoss_client: elif driver.aoss_client:
res = driver.aoss_client.search( res = driver.aoss_client.search(
'episodes', 'episodes',
routing=group_ids, routing=group_ids[0],
_source=['uuid'], _source=['uuid'],
query={ query={
'bool': { 'bool': {