diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index c85ad1a9c..cb438f9e3 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -122,6 +122,62 @@ def create_graph_engine( username=graph_database_username, password=graph_database_password, ) + elif graph_database_provider == "neptune": + try: + from langchain_aws import NeptuneAnalyticsGraph + except ImportError: + raise ImportError( + "langchain_aws is not installed. Please install it with 'pip install langchain_aws'" + ) + + if not graph_database_url: + raise EnvironmentError("Missing Neptune endpoint.") + + from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL + + if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL): + raise ValueError( + f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}" + ) + + graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "") + + return NeptuneGraphDB( + graph_id=graph_identifier, + ) + + elif graph_database_provider == "neptune_analytics": + """ + Creates a graph DB from config + We want to use a hybrid (graph & vector) DB and we should update this + to make a single instance of the hybrid configuration (with embedder) + instead of creating the hybrid object twice. + """ + try: + from langchain_aws import NeptuneAnalyticsGraph + except ImportError: + raise ImportError( + "langchain_aws is not installed. Please install it with 'pip install langchain_aws'" + ) + + if not graph_database_url: + raise EnvironmentError("Missing Neptune endpoint.") + + from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import ( + NeptuneAnalyticsAdapter, + NEPTUNE_ANALYTICS_ENDPOINT_URL, + ) + + if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL): + raise ValueError( + f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}'" + ) + + graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "") + + return NeptuneAnalyticsAdapter( + graph_id=graph_identifier, + ) from .networkx.adapter import NetworkXAdapter diff --git a/cognee/modules/search/operations/select_search_type.py b/cognee/modules/search/operations/select_search_type.py index d08074d0d..a3013fec1 100644 --- a/cognee/modules/search/operations/select_search_type.py +++ b/cognee/modules/search/operations/select_search_type.py @@ -1,7 +1,7 @@ -from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.modules.search.types import SearchType from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter logger = get_logger("SearchTypeSelector") @@ -22,10 +22,9 @@ async def select_search_type( """ default_search_type = SearchType.RAG_COMPLETION system_prompt = read_query_prompt(system_prompt_path) - llm_client = get_llm_client() try: - response = await llm_client.acreate_structured_output( + response = await LLMAdapter.acreate_structured_output( text_input=query, system_prompt=system_prompt, response_model=str,