diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py index e69de29bb..7b48e65c4 100644 --- a/cognee/infrastructure/databases/graph/config.py +++ b/cognee/infrastructure/databases/graph/config.py @@ -0,0 +1,37 @@ +""" This module contains the configuration for the graph database. """ +import os +from functools import lru_cache +from pydantic_settings import BaseSettings, SettingsConfigDict +from cognee.base_config import get_base_config +from cognee.shared.data_models import DefaultGraphModel + +base_config = get_base_config() + +class GraphConfig(BaseSettings): + graph_filename: str = "cognee_graph.pkl" + graph_database_provider: str = "NETWORKX" + graph_topology: str = DefaultGraphModel + graph_database_url: str = "" + graph_database_username: str = "" + graph_database_password: str = "" + graph_database_port: int = "" + graph_file_path = os.path.join(base_config.database_directory_path,graph_filename) + + model_config = SettingsConfigDict(env_file = ".env", extra = "allow") + + def to_dict(self) -> dict: + return { + "graph_filename": self.graph_filename, + "graph_database_provider": self.graph_database_provider, + "graph_topology": self.graph_topology, + "graph_file_path": self.graph_file_path, + "graph_database_url": self.graph_database_url, + "graph_database_username": self.graph_database_username, + "graph_database_password": self.graph_database_password, + "graph_database_port": self.graph_database_port, + + } + +@lru_cache +def get_graph_config(): + return GraphConfig() diff --git a/cognee/infrastructure/databases/graph/get_graph_client.py b/cognee/infrastructure/databases/graph/get_graph_client.py index 1a906927e..84165c61f 100644 --- a/cognee/infrastructure/databases/graph/get_graph_client.py +++ b/cognee/infrastructure/databases/graph/get_graph_client.py @@ -1,18 +1,14 @@ """Factory function to get the appropriate graph client based on the graph type.""" -from cognee.config import Config from cognee.shared.data_models import GraphDBType -from cognee.infrastructure import infrastructure_config +from .config import get_graph_config from .graph_db_interface import GraphDBInterface from .networkx.adapter import NetworkXAdapter - -config = Config() -config.load() +config = get_graph_config() async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface : """Factory function to get the appropriate graph client based on the graph type.""" - graph_file_path = f"{infrastructure_config.get_config('database_directory_path')}/{graph_file_name if graph_file_name else config.graph_filename}" if graph_type == GraphDBType.NEO4J: try: @@ -25,10 +21,20 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) ) except: pass - - graph_client = NetworkXAdapter(filename = graph_file_path) + elif graph_type == GraphDBType.FALKORDB: + try: + from .falkordb.adapter import FalcorDBAdapter + return FalcorDBAdapter( + graph_database_url = config.graph_database_url, + graph_database_username = config.graph_database_username, + graph_database_password = config.graph_database_password, + graph_database_port = config.graph_database_port + ) + except: + pass + graph_client = NetworkXAdapter(filename = config.graph_file_path) if (graph_client.graph is None): await graph_client.load_graph_from_file() diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index a2f621399..f687d29fc 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -56,12 +56,6 @@ class Neo4jAdapter(GraphDBInterface): if "name" not in serialized_properties: serialized_properties["name"] = node_id - - # serialized_properties["created_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - # serialized_properties["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - # properties = ", ".join(f"{property_name}: ${property_name}" for property_name in serialized_properties.keys()) - query = f"""MERGE (node:`{node_id}` {{id: $node_id}}) ON CREATE SET node += $properties RETURN ID(node) AS internal_id, node.id AS nodeId""" @@ -85,30 +79,6 @@ class Neo4jAdapter(GraphDBInterface): node_properties = node_properties, ) - - # serialized_properties = self.serialize_properties(node_properties) - - # if "name" not in serialized_properties: - # serialized_properties["name"] = node_id - - # nodes_data.append({ - # "node_id": node_id, - # "properties": serialized_properties, - # }) - - # query = """UNWIND $nodes_data AS node_data - # MERGE (node:{id: node_data.node_id}) - # ON CREATE SET node += node_data.properties - # RETURN ID(node) AS internal_id, node.id AS id""" - - # params = {"nodes_data": nodes_data} - - # result = await self.query(query, params) - - # await self.close() - - # return result - async def extract_node_description(self, node_id: str): query = """MATCH (n)-[r]->(m) WHERE n.id = $node_id @@ -138,7 +108,7 @@ class Neo4jAdapter(GraphDBInterface): query = """MATCH (node) WHERE node.layer_id IS NOT NULL RETURN node""" - return [result['node'] for result in (await self.query(query))] + return [result["node"] for result in (await self.query(query))] async def extract_node(self, node_id: str): query= """ @@ -146,7 +116,7 @@ class Neo4jAdapter(GraphDBInterface): RETURN node """ - results = [node['node'] for node in (await self.query(query, dict(node_id = node_id)))] + results = [node["node"] for node in (await self.query(query, dict(node_id = node_id)))] return results[0] if len(results) > 0 else None @@ -163,10 +133,12 @@ class Neo4jAdapter(GraphDBInterface): from_node = from_node.replace(":", "_") to_node = to_node.replace(":", "_") - query = f"""MATCH (from_node:`{from_node}` {{id: $from_node}}), (to_node:`{to_node}` {{id: $to_node}}) - MERGE (from_node)-[r:`{relationship_name}`]->(to_node) - SET r += $properties - RETURN r""" + query = f"""MATCH (from_node:`{from_node}` + {{id: $from_node}}), + (to_node:`{to_node}` {{id: $to_node}}) + MERGE (from_node)-[r:`{relationship_name}`]->(to_node) + SET r += $properties + RETURN r""" params = { "from_node": from_node, @@ -192,30 +164,6 @@ class Neo4jAdapter(GraphDBInterface): edge_properties = edge_properties ) - # Filter out None values and do not serialize; Neo4j can handle complex types like arrays directly - # serialized_properties = self.serialize_properties(edge_properties) - - # edges_data.append({ - # "from_node": from_node, - # "to_node": to_node, - # "relationship_name": relationship_name, - # "properties": serialized_properties - # }) - - # query = """UNWIND $edges_data AS edge_data - # MATCH (from_node:{id: edge_data.from_node}), (to_node:{id: edge_data.to_node}) - # MERGE (from_node)-[r:{edge_data.relationship_name}]->(to_node) - # ON CREATE SET r += edge_data.properties - # RETURN r""" - - # params = {"edges_data": edges_data} - - # result = await self.query(query, params) - - # await self.close() - - # return result - async def filter_nodes(self, search_criteria): query = f"""MATCH (node) diff --git a/cognee/shared/data_models.py b/cognee/shared/data_models.py index 51b0124ba..7e228bf8b 100644 --- a/cognee/shared/data_models.py +++ b/cognee/shared/data_models.py @@ -206,6 +206,7 @@ class DefaultCognitiveLayer(BaseModel): class GraphDBType(Enum): NETWORKX = auto() NEO4J = auto() + FALKORDB = auto() # Models for representing different entities