add test for linter

This commit is contained in:
Vasilije 2024-05-26 08:37:23 +02:00
parent 630588bd46
commit 59feaa3e4e
4 changed files with 60 additions and 68 deletions

View file

@ -0,0 +1,37 @@
""" This module contains the configuration for the graph database. """
import os
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.base_config import get_base_config
from cognee.shared.data_models import DefaultGraphModel
base_config = get_base_config()
class GraphConfig(BaseSettings):
graph_filename: str = "cognee_graph.pkl"
graph_database_provider: str = "NETWORKX"
graph_topology: str = DefaultGraphModel
graph_database_url: str = ""
graph_database_username: str = ""
graph_database_password: str = ""
graph_database_port: int = ""
graph_file_path = os.path.join(base_config.database_directory_path,graph_filename)
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
def to_dict(self) -> dict:
return {
"graph_filename": self.graph_filename,
"graph_database_provider": self.graph_database_provider,
"graph_topology": self.graph_topology,
"graph_file_path": self.graph_file_path,
"graph_database_url": self.graph_database_url,
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
}
@lru_cache
def get_graph_config():
return GraphConfig()

View file

@ -1,18 +1,14 @@
"""Factory function to get the appropriate graph client based on the graph type.""" """Factory function to get the appropriate graph client based on the graph type."""
from cognee.config import Config
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
from cognee.infrastructure import infrastructure_config from .config import get_graph_config
from .graph_db_interface import GraphDBInterface from .graph_db_interface import GraphDBInterface
from .networkx.adapter import NetworkXAdapter from .networkx.adapter import NetworkXAdapter
config = get_graph_config()
config = Config()
config.load()
async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface : async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface :
"""Factory function to get the appropriate graph client based on the graph type.""" """Factory function to get the appropriate graph client based on the graph type."""
graph_file_path = f"{infrastructure_config.get_config('database_directory_path')}/{graph_file_name if graph_file_name else config.graph_filename}"
if graph_type == GraphDBType.NEO4J: if graph_type == GraphDBType.NEO4J:
try: try:
@ -25,10 +21,20 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None)
) )
except: except:
pass pass
graph_client = NetworkXAdapter(filename = graph_file_path)
elif graph_type == GraphDBType.FALKORDB:
try:
from .falkordb.adapter import FalcorDBAdapter
return FalcorDBAdapter(
graph_database_url = config.graph_database_url,
graph_database_username = config.graph_database_username,
graph_database_password = config.graph_database_password,
graph_database_port = config.graph_database_port
)
except:
pass
graph_client = NetworkXAdapter(filename = config.graph_file_path)
if (graph_client.graph is None): if (graph_client.graph is None):
await graph_client.load_graph_from_file() await graph_client.load_graph_from_file()

View file

@ -56,12 +56,6 @@ class Neo4jAdapter(GraphDBInterface):
if "name" not in serialized_properties: if "name" not in serialized_properties:
serialized_properties["name"] = node_id serialized_properties["name"] = node_id
# serialized_properties["created_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# serialized_properties["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# properties = ", ".join(f"{property_name}: ${property_name}" for property_name in serialized_properties.keys())
query = f"""MERGE (node:`{node_id}` {{id: $node_id}}) query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
ON CREATE SET node += $properties ON CREATE SET node += $properties
RETURN ID(node) AS internal_id, node.id AS nodeId""" RETURN ID(node) AS internal_id, node.id AS nodeId"""
@ -85,30 +79,6 @@ class Neo4jAdapter(GraphDBInterface):
node_properties = node_properties, node_properties = node_properties,
) )
# serialized_properties = self.serialize_properties(node_properties)
# if "name" not in serialized_properties:
# serialized_properties["name"] = node_id
# nodes_data.append({
# "node_id": node_id,
# "properties": serialized_properties,
# })
# query = """UNWIND $nodes_data AS node_data
# MERGE (node:{id: node_data.node_id})
# ON CREATE SET node += node_data.properties
# RETURN ID(node) AS internal_id, node.id AS id"""
# params = {"nodes_data": nodes_data}
# result = await self.query(query, params)
# await self.close()
# return result
async def extract_node_description(self, node_id: str): async def extract_node_description(self, node_id: str):
query = """MATCH (n)-[r]->(m) query = """MATCH (n)-[r]->(m)
WHERE n.id = $node_id WHERE n.id = $node_id
@ -138,7 +108,7 @@ class Neo4jAdapter(GraphDBInterface):
query = """MATCH (node) WHERE node.layer_id IS NOT NULL query = """MATCH (node) WHERE node.layer_id IS NOT NULL
RETURN node""" RETURN node"""
return [result['node'] for result in (await self.query(query))] return [result["node"] for result in (await self.query(query))]
async def extract_node(self, node_id: str): async def extract_node(self, node_id: str):
query= """ query= """
@ -146,7 +116,7 @@ class Neo4jAdapter(GraphDBInterface):
RETURN node RETURN node
""" """
results = [node['node'] for node in (await self.query(query, dict(node_id = node_id)))] results = [node["node"] for node in (await self.query(query, dict(node_id = node_id)))]
return results[0] if len(results) > 0 else None return results[0] if len(results) > 0 else None
@ -163,10 +133,12 @@ class Neo4jAdapter(GraphDBInterface):
from_node = from_node.replace(":", "_") from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_") to_node = to_node.replace(":", "_")
query = f"""MATCH (from_node:`{from_node}` {{id: $from_node}}), (to_node:`{to_node}` {{id: $to_node}}) query = f"""MATCH (from_node:`{from_node}`
MERGE (from_node)-[r:`{relationship_name}`]->(to_node) {{id: $from_node}}),
SET r += $properties (to_node:`{to_node}` {{id: $to_node}})
RETURN r""" MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
SET r += $properties
RETURN r"""
params = { params = {
"from_node": from_node, "from_node": from_node,
@ -192,30 +164,6 @@ class Neo4jAdapter(GraphDBInterface):
edge_properties = edge_properties edge_properties = edge_properties
) )
# Filter out None values and do not serialize; Neo4j can handle complex types like arrays directly
# serialized_properties = self.serialize_properties(edge_properties)
# edges_data.append({
# "from_node": from_node,
# "to_node": to_node,
# "relationship_name": relationship_name,
# "properties": serialized_properties
# })
# query = """UNWIND $edges_data AS edge_data
# MATCH (from_node:{id: edge_data.from_node}), (to_node:{id: edge_data.to_node})
# MERGE (from_node)-[r:{edge_data.relationship_name}]->(to_node)
# ON CREATE SET r += edge_data.properties
# RETURN r"""
# params = {"edges_data": edges_data}
# result = await self.query(query, params)
# await self.close()
# return result
async def filter_nodes(self, search_criteria): async def filter_nodes(self, search_criteria):
query = f"""MATCH (node) query = f"""MATCH (node)

View file

@ -206,6 +206,7 @@ class DefaultCognitiveLayer(BaseModel):
class GraphDBType(Enum): class GraphDBType(Enum):
NETWORKX = auto() NETWORKX = auto()
NEO4J = auto() NEO4J = auto()
FALKORDB = auto()
# Models for representing different entities # Models for representing different entities