cognee/cognee/infrastructure/databases/graph/get_graph_engine.py
Igor Ilic 343d990fcc
Merge main vol 4 (#1200)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Signed-off-by: Andrew Carbonetto <andrew.carbonetto@improving.com>
Signed-off-by: Andy Kwok <andy.kwok@improving.com>
Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
Co-authored-by: vasilije <vas.markovic@gmail.com>
Co-authored-by: Andrew Carbonetto <andrew.carbonetto@improving.com>
Co-authored-by: Andy Kwok <andy.kwok@improving.com>
2025-08-05 12:48:24 +02:00

199 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Factory function to get the appropriate graph client based on the graph type."""
from functools import lru_cache
from .config import get_graph_context_config
from .graph_db_interface import GraphDBInterface
from .supported_databases import supported_databases
async def get_graph_engine() -> GraphDBInterface:
"""Factory function to get the appropriate graph client based on the graph type."""
# Get appropriate graph configuration based on current async context
config = get_graph_context_config()
graph_client = create_graph_engine(**config)
# Async functions can't be cached. After creating and caching the graph engine
# handle all necessary async operations for different graph types bellow.
# Run any adapterspecific async initialization
if hasattr(graph_client, "initialize"):
await graph_client.initialize()
# Handle loading of graph for NetworkX
if config["graph_database_provider"].lower() == "networkx" and graph_client.graph is None:
await graph_client.load_graph_from_file()
return graph_client
@lru_cache
def create_graph_engine(
graph_database_provider,
graph_file_path,
graph_database_url="",
graph_database_username="",
graph_database_password="",
graph_database_port="",
):
"""
Create a graph engine based on the specified provider type.
This factory function initializes and returns the appropriate graph client depending on
the database provider specified. It validates required parameters and raises an
EnvironmentError if any are missing for the respective provider implementations.
Parameters:
-----------
- graph_database_provider: The type of graph database provider to use (e.g., neo4j,
falkordb, kuzu, memgraph).
- graph_database_url: The URL for the graph database instance. Required for neo4j,
falkordb, and memgraph providers.
- graph_database_username: The username for authentication with the graph database.
Required for neo4j and memgraph providers.
- graph_database_password: The password for authentication with the graph database.
Required for neo4j and memgraph providers.
- graph_database_port: The port number for the graph database connection. Required
for the falkordb provider.
- graph_file_path: The filesystem path to the graph file. Required for the kuzu
provider.
Returns:
--------
Returns an instance of the appropriate graph adapter depending on the provider type
specified.
"""
if graph_database_provider in supported_databases:
adapter = supported_databases[graph_database_provider]
return adapter(
graph_database_url=graph_database_url,
graph_database_username=graph_database_username,
graph_database_password=graph_database_password,
)
if graph_database_provider == "neo4j":
if not graph_database_url:
raise EnvironmentError("Missing required Neo4j URL.")
from .neo4j_driver.adapter import Neo4jAdapter
return Neo4jAdapter(
graph_database_url=graph_database_url,
graph_database_username=graph_database_username or None,
graph_database_password=graph_database_password or None,
)
elif graph_database_provider == "falkordb":
if not (graph_database_url and graph_database_port):
raise EnvironmentError("Missing required FalkorDB credentials.")
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
embedding_engine = get_embedding_engine()
return FalkorDBAdapter(
database_url=graph_database_url,
database_port=graph_database_port,
embedding_engine=embedding_engine,
)
elif graph_database_provider == "kuzu":
if not graph_file_path:
raise EnvironmentError("Missing required Kuzu database path.")
from .kuzu.adapter import KuzuAdapter
return KuzuAdapter(db_path=graph_file_path)
elif graph_database_provider == "kuzu-remote":
if not graph_database_url:
raise EnvironmentError("Missing required Kuzu remote URL.")
from .kuzu.remote_kuzu_adapter import RemoteKuzuAdapter
return RemoteKuzuAdapter(
api_url=graph_database_url,
username=graph_database_username,
password=graph_database_password,
)
elif graph_database_provider == "memgraph":
if not graph_database_url:
raise EnvironmentError("Missing required Memgraph URL.")
from .memgraph.memgraph_adapter import MemgraphAdapter
return MemgraphAdapter(
graph_database_url=graph_database_url,
graph_database_username=graph_database_username or None,
graph_database_password=graph_database_password or None,
)
elif graph_database_provider == "neptune":
try:
from langchain_aws import NeptuneAnalyticsGraph
except ImportError:
raise ImportError(
"langchain_aws is not installed. Please install it with 'pip install langchain_aws'"
)
if not graph_database_url:
raise EnvironmentError("Missing Neptune endpoint.")
from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL
if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL):
raise ValueError(
f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>"
)
graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "")
return NeptuneGraphDB(
graph_id=graph_identifier,
)
elif graph_database_provider == "neptune_analytics":
"""
Creates a graph DB from config
We want to use a hybrid (graph & vector) DB and we should update this
to make a single instance of the hybrid configuration (with embedder)
instead of creating the hybrid object twice.
"""
try:
from langchain_aws import NeptuneAnalyticsGraph
except ImportError:
raise ImportError(
"langchain_aws is not installed. Please install it with 'pip install langchain_aws'"
)
if not graph_database_url:
raise EnvironmentError("Missing Neptune endpoint.")
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
NeptuneAnalyticsAdapter,
NEPTUNE_ANALYTICS_ENDPOINT_URL,
)
if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
raise ValueError(
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
)
graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
return NeptuneAnalyticsAdapter(
graph_id=graph_identifier,
)
from .networkx.adapter import NetworkXAdapter
graph_client = NetworkXAdapter(filename=graph_file_path)
return graph_client