refactor: Return neptune analytics to get graph engine
This commit is contained in:
parent
a09d2d0b3c
commit
1fd4c4fa8b
2 changed files with 58 additions and 3 deletions
|
|
@ -122,6 +122,62 @@ def create_graph_engine(
|
|||
username=graph_database_username,
|
||||
password=graph_database_password,
|
||||
)
|
||||
elif graph_database_provider == "neptune":
|
||||
try:
|
||||
from langchain_aws import NeptuneAnalyticsGraph
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"langchain_aws is not installed. Please install it with 'pip install langchain_aws'"
|
||||
)
|
||||
|
||||
if not graph_database_url:
|
||||
raise EnvironmentError("Missing Neptune endpoint.")
|
||||
|
||||
from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL
|
||||
|
||||
if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL):
|
||||
raise ValueError(
|
||||
f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>"
|
||||
)
|
||||
|
||||
graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "")
|
||||
|
||||
return NeptuneGraphDB(
|
||||
graph_id=graph_identifier,
|
||||
)
|
||||
|
||||
elif graph_database_provider == "neptune_analytics":
|
||||
"""
|
||||
Creates a graph DB from config
|
||||
We want to use a hybrid (graph & vector) DB and we should update this
|
||||
to make a single instance of the hybrid configuration (with embedder)
|
||||
instead of creating the hybrid object twice.
|
||||
"""
|
||||
try:
|
||||
from langchain_aws import NeptuneAnalyticsGraph
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"langchain_aws is not installed. Please install it with 'pip install langchain_aws'"
|
||||
)
|
||||
|
||||
if not graph_database_url:
|
||||
raise EnvironmentError("Missing Neptune endpoint.")
|
||||
|
||||
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
||||
NeptuneAnalyticsAdapter,
|
||||
NEPTUNE_ANALYTICS_ENDPOINT_URL,
|
||||
)
|
||||
|
||||
if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
|
||||
raise ValueError(
|
||||
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
|
||||
)
|
||||
|
||||
graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
|
||||
|
||||
return NeptuneAnalyticsAdapter(
|
||||
graph_id=graph_identifier,
|
||||
)
|
||||
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
|
||||
|
||||
logger = get_logger("SearchTypeSelector")
|
||||
|
||||
|
|
@ -22,10 +22,9 @@ async def select_search_type(
|
|||
"""
|
||||
default_search_type = SearchType.RAG_COMPLETION
|
||||
system_prompt = read_query_prompt(system_prompt_path)
|
||||
llm_client = get_llm_client()
|
||||
|
||||
try:
|
||||
response = await llm_client.acreate_structured_output(
|
||||
response = await LLMAdapter.acreate_structured_output(
|
||||
text_input=query,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue