claude suggestions
This commit is contained in:
parent
13fc9cf1e4
commit
8e442d4634
4 changed files with 67 additions and 64 deletions
|
|
@ -23,8 +23,15 @@ from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
try:
|
||||||
from opensearchpy import OpenSearch, helpers
|
from opensearchpy import OpenSearch, helpers
|
||||||
|
|
||||||
|
_HAS_OPENSEARCH = True
|
||||||
|
except ImportError:
|
||||||
|
OpenSearch = None
|
||||||
|
helpers = None
|
||||||
|
_HAS_OPENSEARCH = False
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_SIZE = 10
|
DEFAULT_SIZE = 10
|
||||||
|
|
@ -216,9 +223,6 @@ class GraphDriver(ABC):
|
||||||
if client.indices.exists(index=index_name):
|
if client.indices.exists(index=index_name):
|
||||||
client.indices.delete(index=index_name)
|
client.indices.delete(index=index_name)
|
||||||
|
|
||||||
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||||
for index in aoss_indices:
|
for index in aoss_indices:
|
||||||
if name.lower() == index['index_name']:
|
if name.lower() == index['index_name']:
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ class Neo4jDriver(GraphDriver):
|
||||||
self._database = database
|
self._database = database
|
||||||
|
|
||||||
self.aoss_client = None
|
self.aoss_client = None
|
||||||
if aoss_host and aoss_port:
|
if aoss_host and aoss_port and boto3 is not None:
|
||||||
try:
|
try:
|
||||||
session = boto3.Session()
|
session = boto3.Session()
|
||||||
self.aoss_client = OpenSearch(
|
self.aoss_client = OpenSearch(
|
||||||
|
|
|
||||||
|
|
@ -22,14 +22,13 @@ from typing import Any
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
||||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
||||||
|
|
||||||
from graphiti_core.driver.driver import (
|
from graphiti_core.driver.driver import (
|
||||||
DEFAULT_SIZE,
|
DEFAULT_SIZE,
|
||||||
GraphDriver,
|
GraphDriver,
|
||||||
GraphDriverSession,
|
GraphDriverSession,
|
||||||
GraphProvider,
|
GraphProvider,
|
||||||
aoss_indices,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -249,7 +249,7 @@ async def edge_fulltext_search(
|
||||||
filters = build_aoss_edge_filters(group_ids, search_filter)
|
filters = build_aoss_edge_filters(group_ids, search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = driver.aoss_client.search(
|
||||||
index='entity_edges',
|
index='entity_edges',
|
||||||
routing=group_ids,
|
routing=group_ids[0],
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
query={
|
query={
|
||||||
'bool': {
|
'bool': {
|
||||||
|
|
@ -406,7 +406,7 @@ async def edge_similarity_search(
|
||||||
filters = build_aoss_edge_filters(group_ids, search_filter)
|
filters = build_aoss_edge_filters(group_ids, search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = driver.aoss_client.search(
|
||||||
index='entity_edges',
|
index='entity_edges',
|
||||||
routing=group_ids,
|
routing=group_ids[0],
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
knn={
|
knn={
|
||||||
'field': 'fact_embedding',
|
'field': 'fact_embedding',
|
||||||
|
|
@ -645,7 +645,7 @@ async def node_fulltext_search(
|
||||||
filters = build_aoss_node_filters(group_ids, search_filter)
|
filters = build_aoss_node_filters(group_ids, search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = driver.aoss_client.search(
|
||||||
'entities',
|
'entities',
|
||||||
routing=group_ids,
|
routing=group_ids[0],
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
query={
|
query={
|
||||||
'bool': {
|
'bool': {
|
||||||
|
|
@ -787,7 +787,7 @@ async def node_similarity_search(
|
||||||
filters = build_aoss_node_filters(group_ids, search_filter)
|
filters = build_aoss_node_filters(group_ids, search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = driver.aoss_client.search(
|
||||||
index='entities',
|
index='entities',
|
||||||
routing=group_ids,
|
routing=group_ids[0],
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
knn={
|
knn={
|
||||||
'field': 'fact_embedding',
|
'field': 'fact_embedding',
|
||||||
|
|
@ -985,7 +985,7 @@ async def episode_fulltext_search(
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
res = driver.aoss_client.search(
|
res = driver.aoss_client.search(
|
||||||
'episodes',
|
'episodes',
|
||||||
routing=group_ids,
|
routing=group_ids[0],
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
query={
|
query={
|
||||||
'bool': {
|
'bool': {
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue