diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 7999bc25..74e80d10 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -23,6 +23,8 @@ from datetime import datetime from enum import Enum from typing import TYPE_CHECKING, Any +from graphiti_core.embedder.client import EMBEDDING_DIM + try: from opensearchpy import OpenSearch, helpers @@ -59,8 +61,8 @@ aoss_indices = [ 'group_id': {'type': 'text'}, 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, 'name_embedding': { - 'type': 'dense_vector', - 'dims': 1024, + 'type': 'knn_vector', + 'dims': EMBEDDING_DIM, 'index': True, 'similarity': 'cosine', 'method': { @@ -116,13 +118,13 @@ aoss_indices = [ 'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, 'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, 'fact_embedding': { - 'type': 'dense_vector', - 'dims': 1024, + 'type': 'knn_vector', + 'dims': EMBEDDING_DIM, 'index': True, 'similarity': 'cosine', 'method': { 'engine': 'faiss', - 'space_type': 'cosinesimil', + 'space_type': 'innerproduct', 'name': 'hnsw', 'parameters': {'ef_construction': 128, 'm': 16}, }, @@ -236,7 +238,7 @@ class GraphDriver(ABC): def save_to_aoss(self, name: str, data: list[dict]) -> int: client = self.aoss_client - if not client: + if not client or not helpers: logger.warning('No OpenSearch client found') return 0 diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index 95605cd2..7c619897 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -39,7 +39,6 @@ logger = logging.getLogger(__name__) class FalkorDriverSession(GraphDriverSession): provider = GraphProvider.FALKORDB - aoss_client: None = None def __init__(self, graph: FalkorGraph): self.graph = graph @@ -75,6 +74,7 @@ class FalkorDriverSession(GraphDriverSession): class FalkorDriver(GraphDriver): provider = GraphProvider.FALKORDB + aoss_client: None = None def __init__( self, diff --git a/graphiti_core/embedder/client.py b/graphiti_core/embedder/client.py index 9ffc0653..dd405572 100644 --- a/graphiti_core/embedder/client.py +++ b/graphiti_core/embedder/client.py @@ -14,12 +14,13 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os from abc import ABC, abstractmethod from collections.abc import Iterable from pydantic import BaseModel, Field -EMBEDDING_DIM = 1024 +EMBEDDING_DIM = os.getenv('EMBEDDING_DIM', 1024) class EmbedderConfig(BaseModel):