* Update node_db_queries.py * Update node_db_queries.py * graph-per-graphid * fix-groupid-usage * ruff-fix * rev-driver-changes * rm-un-changes * fix lint --------- Co-authored-by: Naseem Ali <34807727+Naseem77@users.noreply.github.com>
124 lines
3.7 KiB
Python
124 lines
3.7 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import copy
|
|
import logging
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Coroutine
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface
|
|
from graphiti_core.driver.search_interface.search_interface import SearchInterface
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_SIZE = 10
|
|
|
|
load_dotenv()
|
|
|
|
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
|
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
|
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
|
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
|
|
|
|
|
class GraphProvider(Enum):
|
|
NEO4J = 'neo4j'
|
|
FALKORDB = 'falkordb'
|
|
KUZU = 'kuzu'
|
|
NEPTUNE = 'neptune'
|
|
|
|
|
|
class GraphDriverSession(ABC):
|
|
provider: GraphProvider
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
@abstractmethod
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
# No cleanup needed for Falkor, but method must exist
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def run(self, query: str, **kwargs: Any) -> Any:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
async def close(self):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
async def execute_write(self, func, *args, **kwargs):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class GraphDriver(ABC):
|
|
provider: GraphProvider
|
|
fulltext_syntax: str = (
|
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
|
)
|
|
_database: str
|
|
default_group_id: str = ''
|
|
search_interface: SearchInterface | None = None
|
|
graph_operations_interface: GraphOperationsInterface | None = None
|
|
|
|
@abstractmethod
|
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def session(self, database: str | None = None) -> GraphDriverSession:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def close(self):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def delete_all_indexes(self) -> Coroutine:
|
|
raise NotImplementedError()
|
|
|
|
def with_database(self, database: str) -> 'GraphDriver':
|
|
"""
|
|
Returns a shallow copy of this driver with a different default database.
|
|
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
|
"""
|
|
cloned = copy.copy(self)
|
|
cloned._database = database
|
|
|
|
return cloned
|
|
|
|
@abstractmethod
|
|
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
|
raise NotImplementedError()
|
|
|
|
def clone(self, database: str) -> 'GraphDriver':
|
|
"""Clone the driver with a different database or graph name."""
|
|
return self
|
|
|
|
def build_fulltext_query(
|
|
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
|
) -> str:
|
|
"""
|
|
Specific fulltext query builder for database providers.
|
|
Only implemented by providers that need custom fulltext query building.
|
|
"""
|
|
raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
|