refactor: Return neptune analytics to get graph engine

This commit is contained in:
Igor Ilic 2025-08-05 19:31:30 +02:00
parent a09d2d0b3c
commit 1fd4c4fa8b
2 changed files with 58 additions and 3 deletions

View file

@ -122,6 +122,62 @@ def create_graph_engine(
username=graph_database_username,
password=graph_database_password,
)
elif graph_database_provider == "neptune":
try:
from langchain_aws import NeptuneAnalyticsGraph
except ImportError:
raise ImportError(
"langchain_aws is not installed. Please install it with 'pip install langchain_aws'"
)
if not graph_database_url:
raise EnvironmentError("Missing Neptune endpoint.")
from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL
if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL):
raise ValueError(
f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>"
)
graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "")
return NeptuneGraphDB(
graph_id=graph_identifier,
)
elif graph_database_provider == "neptune_analytics":
"""
Creates a graph DB from config
We want to use a hybrid (graph & vector) DB and we should update this
to make a single instance of the hybrid configuration (with embedder)
instead of creating the hybrid object twice.
"""
try:
from langchain_aws import NeptuneAnalyticsGraph
except ImportError:
raise ImportError(
"langchain_aws is not installed. Please install it with 'pip install langchain_aws'"
)
if not graph_database_url:
raise EnvironmentError("Missing Neptune endpoint.")
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
NeptuneAnalyticsAdapter,
NEPTUNE_ANALYTICS_ENDPOINT_URL,
)
if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
raise ValueError(
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
)
graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
return NeptuneAnalyticsAdapter(
graph_id=graph_identifier,
)
from .networkx.adapter import NetworkXAdapter

View file

@ -1,7 +1,7 @@
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.search.types import SearchType
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
logger = get_logger("SearchTypeSelector")
@ -22,10 +22,9 @@ async def select_search_type(
"""
default_search_type = SearchType.RAG_COMPLETION
system_prompt = read_query_prompt(system_prompt_path)
llm_client = get_llm_client()
try:
response = await llm_client.acreate_structured_output(
response = await LLMAdapter.acreate_structured_output(
text_input=query,
system_prompt=system_prompt,
response_model=str,