add test for linter

This commit is contained in:
Vasilije 2024-05-26 08:37:23 +02:00
parent 630588bd46
commit 59feaa3e4e
4 changed files with 60 additions and 68 deletions

View file

@ -0,0 +1,37 @@
""" This module contains the configuration for the graph database. """
import os
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.base_config import get_base_config
from cognee.shared.data_models import DefaultGraphModel
base_config = get_base_config()
class GraphConfig(BaseSettings):
graph_filename: str = "cognee_graph.pkl"
graph_database_provider: str = "NETWORKX"
graph_topology: str = DefaultGraphModel
graph_database_url: str = ""
graph_database_username: str = ""
graph_database_password: str = ""
graph_database_port: int = ""
graph_file_path = os.path.join(base_config.database_directory_path,graph_filename)
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
def to_dict(self) -> dict:
return {
"graph_filename": self.graph_filename,
"graph_database_provider": self.graph_database_provider,
"graph_topology": self.graph_topology,
"graph_file_path": self.graph_file_path,
"graph_database_url": self.graph_database_url,
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
}
@lru_cache
def get_graph_config():
return GraphConfig()

View file

@ -1,18 +1,14 @@
"""Factory function to get the appropriate graph client based on the graph type."""
from cognee.config import Config
from cognee.shared.data_models import GraphDBType
from cognee.infrastructure import infrastructure_config
from .config import get_graph_config
from .graph_db_interface import GraphDBInterface
from .networkx.adapter import NetworkXAdapter
config = Config()
config.load()
config = get_graph_config()
async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface :
"""Factory function to get the appropriate graph client based on the graph type."""
graph_file_path = f"{infrastructure_config.get_config('database_directory_path')}/{graph_file_name if graph_file_name else config.graph_filename}"
if graph_type == GraphDBType.NEO4J:
try:
@ -25,10 +21,20 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None)
)
except:
pass
graph_client = NetworkXAdapter(filename = graph_file_path)
elif graph_type == GraphDBType.FALKORDB:
try:
from .falkordb.adapter import FalcorDBAdapter
return FalcorDBAdapter(
graph_database_url = config.graph_database_url,
graph_database_username = config.graph_database_username,
graph_database_password = config.graph_database_password,
graph_database_port = config.graph_database_port
)
except:
pass
graph_client = NetworkXAdapter(filename = config.graph_file_path)
if (graph_client.graph is None):
await graph_client.load_graph_from_file()

View file

@ -56,12 +56,6 @@ class Neo4jAdapter(GraphDBInterface):
if "name" not in serialized_properties:
serialized_properties["name"] = node_id
# serialized_properties["created_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# serialized_properties["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# properties = ", ".join(f"{property_name}: ${property_name}" for property_name in serialized_properties.keys())
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
ON CREATE SET node += $properties
RETURN ID(node) AS internal_id, node.id AS nodeId"""
@ -85,30 +79,6 @@ class Neo4jAdapter(GraphDBInterface):
node_properties = node_properties,
)
# serialized_properties = self.serialize_properties(node_properties)
# if "name" not in serialized_properties:
# serialized_properties["name"] = node_id
# nodes_data.append({
# "node_id": node_id,
# "properties": serialized_properties,
# })
# query = """UNWIND $nodes_data AS node_data
# MERGE (node:{id: node_data.node_id})
# ON CREATE SET node += node_data.properties
# RETURN ID(node) AS internal_id, node.id AS id"""
# params = {"nodes_data": nodes_data}
# result = await self.query(query, params)
# await self.close()
# return result
async def extract_node_description(self, node_id: str):
query = """MATCH (n)-[r]->(m)
WHERE n.id = $node_id
@ -138,7 +108,7 @@ class Neo4jAdapter(GraphDBInterface):
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
RETURN node"""
return [result['node'] for result in (await self.query(query))]
return [result["node"] for result in (await self.query(query))]
async def extract_node(self, node_id: str):
query= """
@ -146,7 +116,7 @@ class Neo4jAdapter(GraphDBInterface):
RETURN node
"""
results = [node['node'] for node in (await self.query(query, dict(node_id = node_id)))]
results = [node["node"] for node in (await self.query(query, dict(node_id = node_id)))]
return results[0] if len(results) > 0 else None
@ -163,10 +133,12 @@ class Neo4jAdapter(GraphDBInterface):
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")
query = f"""MATCH (from_node:`{from_node}` {{id: $from_node}}), (to_node:`{to_node}` {{id: $to_node}})
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
SET r += $properties
RETURN r"""
query = f"""MATCH (from_node:`{from_node}`
{{id: $from_node}}),
(to_node:`{to_node}` {{id: $to_node}})
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
SET r += $properties
RETURN r"""
params = {
"from_node": from_node,
@ -192,30 +164,6 @@ class Neo4jAdapter(GraphDBInterface):
edge_properties = edge_properties
)
# Filter out None values and do not serialize; Neo4j can handle complex types like arrays directly
# serialized_properties = self.serialize_properties(edge_properties)
# edges_data.append({
# "from_node": from_node,
# "to_node": to_node,
# "relationship_name": relationship_name,
# "properties": serialized_properties
# })
# query = """UNWIND $edges_data AS edge_data
# MATCH (from_node:{id: edge_data.from_node}), (to_node:{id: edge_data.to_node})
# MERGE (from_node)-[r:{edge_data.relationship_name}]->(to_node)
# ON CREATE SET r += edge_data.properties
# RETURN r"""
# params = {"edges_data": edges_data}
# result = await self.query(query, params)
# await self.close()
# return result
async def filter_nodes(self, search_criteria):
query = f"""MATCH (node)

View file

@ -206,6 +206,7 @@ class DefaultCognitiveLayer(BaseModel):
class GraphDBType(Enum):
NETWORKX = auto()
NEO4J = auto()
FALKORDB = auto()
# Models for representing different entities