add test for linter
This commit is contained in:
parent
630588bd46
commit
59feaa3e4e
4 changed files with 60 additions and 68 deletions
|
|
@ -0,0 +1,37 @@
|
|||
""" This module contains the configuration for the graph database. """
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.shared.data_models import DefaultGraphModel
|
||||
|
||||
base_config = get_base_config()
|
||||
|
||||
class GraphConfig(BaseSettings):
|
||||
graph_filename: str = "cognee_graph.pkl"
|
||||
graph_database_provider: str = "NETWORKX"
|
||||
graph_topology: str = DefaultGraphModel
|
||||
graph_database_url: str = ""
|
||||
graph_database_username: str = ""
|
||||
graph_database_password: str = ""
|
||||
graph_database_port: int = ""
|
||||
graph_file_path = os.path.join(base_config.database_directory_path,graph_filename)
|
||||
|
||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"graph_filename": self.graph_filename,
|
||||
"graph_database_provider": self.graph_database_provider,
|
||||
"graph_topology": self.graph_topology,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_database_url": self.graph_database_url,
|
||||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
|
||||
}
|
||||
|
||||
@lru_cache
|
||||
def get_graph_config():
|
||||
return GraphConfig()
|
||||
|
|
@ -1,18 +1,14 @@
|
|||
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||
|
||||
from cognee.config import Config
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
from .config import get_graph_config
|
||||
from .graph_db_interface import GraphDBInterface
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
config = get_graph_config()
|
||||
|
||||
|
||||
async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface :
|
||||
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||
graph_file_path = f"{infrastructure_config.get_config('database_directory_path')}/{graph_file_name if graph_file_name else config.graph_filename}"
|
||||
|
||||
if graph_type == GraphDBType.NEO4J:
|
||||
try:
|
||||
|
|
@ -25,10 +21,20 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None)
|
|||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
graph_client = NetworkXAdapter(filename = graph_file_path)
|
||||
|
||||
elif graph_type == GraphDBType.FALKORDB:
|
||||
try:
|
||||
from .falkordb.adapter import FalcorDBAdapter
|
||||
|
||||
return FalcorDBAdapter(
|
||||
graph_database_url = config.graph_database_url,
|
||||
graph_database_username = config.graph_database_username,
|
||||
graph_database_password = config.graph_database_password,
|
||||
graph_database_port = config.graph_database_port
|
||||
)
|
||||
except:
|
||||
pass
|
||||
graph_client = NetworkXAdapter(filename = config.graph_file_path)
|
||||
if (graph_client.graph is None):
|
||||
await graph_client.load_graph_from_file()
|
||||
|
||||
|
|
|
|||
|
|
@ -56,12 +56,6 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
if "name" not in serialized_properties:
|
||||
serialized_properties["name"] = node_id
|
||||
|
||||
# serialized_properties["created_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
# serialized_properties["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# properties = ", ".join(f"{property_name}: ${property_name}" for property_name in serialized_properties.keys())
|
||||
|
||||
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
|
||||
ON CREATE SET node += $properties
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId"""
|
||||
|
|
@ -85,30 +79,6 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
node_properties = node_properties,
|
||||
)
|
||||
|
||||
|
||||
# serialized_properties = self.serialize_properties(node_properties)
|
||||
|
||||
# if "name" not in serialized_properties:
|
||||
# serialized_properties["name"] = node_id
|
||||
|
||||
# nodes_data.append({
|
||||
# "node_id": node_id,
|
||||
# "properties": serialized_properties,
|
||||
# })
|
||||
|
||||
# query = """UNWIND $nodes_data AS node_data
|
||||
# MERGE (node:{id: node_data.node_id})
|
||||
# ON CREATE SET node += node_data.properties
|
||||
# RETURN ID(node) AS internal_id, node.id AS id"""
|
||||
|
||||
# params = {"nodes_data": nodes_data}
|
||||
|
||||
# result = await self.query(query, params)
|
||||
|
||||
# await self.close()
|
||||
|
||||
# return result
|
||||
|
||||
async def extract_node_description(self, node_id: str):
|
||||
query = """MATCH (n)-[r]->(m)
|
||||
WHERE n.id = $node_id
|
||||
|
|
@ -138,7 +108,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
|
||||
RETURN node"""
|
||||
|
||||
return [result['node'] for result in (await self.query(query))]
|
||||
return [result["node"] for result in (await self.query(query))]
|
||||
|
||||
async def extract_node(self, node_id: str):
|
||||
query= """
|
||||
|
|
@ -146,7 +116,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
RETURN node
|
||||
"""
|
||||
|
||||
results = [node['node'] for node in (await self.query(query, dict(node_id = node_id)))]
|
||||
results = [node["node"] for node in (await self.query(query, dict(node_id = node_id)))]
|
||||
|
||||
return results[0] if len(results) > 0 else None
|
||||
|
||||
|
|
@ -163,10 +133,12 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
from_node = from_node.replace(":", "_")
|
||||
to_node = to_node.replace(":", "_")
|
||||
|
||||
query = f"""MATCH (from_node:`{from_node}` {{id: $from_node}}), (to_node:`{to_node}` {{id: $to_node}})
|
||||
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
|
||||
SET r += $properties
|
||||
RETURN r"""
|
||||
query = f"""MATCH (from_node:`{from_node}`
|
||||
{{id: $from_node}}),
|
||||
(to_node:`{to_node}` {{id: $to_node}})
|
||||
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
|
||||
SET r += $properties
|
||||
RETURN r"""
|
||||
|
||||
params = {
|
||||
"from_node": from_node,
|
||||
|
|
@ -192,30 +164,6 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
edge_properties = edge_properties
|
||||
)
|
||||
|
||||
# Filter out None values and do not serialize; Neo4j can handle complex types like arrays directly
|
||||
# serialized_properties = self.serialize_properties(edge_properties)
|
||||
|
||||
# edges_data.append({
|
||||
# "from_node": from_node,
|
||||
# "to_node": to_node,
|
||||
# "relationship_name": relationship_name,
|
||||
# "properties": serialized_properties
|
||||
# })
|
||||
|
||||
# query = """UNWIND $edges_data AS edge_data
|
||||
# MATCH (from_node:{id: edge_data.from_node}), (to_node:{id: edge_data.to_node})
|
||||
# MERGE (from_node)-[r:{edge_data.relationship_name}]->(to_node)
|
||||
# ON CREATE SET r += edge_data.properties
|
||||
# RETURN r"""
|
||||
|
||||
# params = {"edges_data": edges_data}
|
||||
|
||||
# result = await self.query(query, params)
|
||||
|
||||
# await self.close()
|
||||
|
||||
# return result
|
||||
|
||||
|
||||
async def filter_nodes(self, search_criteria):
|
||||
query = f"""MATCH (node)
|
||||
|
|
|
|||
|
|
@ -206,6 +206,7 @@ class DefaultCognitiveLayer(BaseModel):
|
|||
class GraphDBType(Enum):
|
||||
NETWORKX = auto()
|
||||
NEO4J = auto()
|
||||
FALKORDB = auto()
|
||||
|
||||
|
||||
# Models for representing different entities
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue