diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index e550889de..a3430d370 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -145,7 +145,7 @@ class KuzuAdapter(GraphDBInterface): except Exception as e: logger.error(f"Failed to initialize Kuzu database: {e}") raise e - + def _get_connection(self) -> Connection: """Get the connection to the Kuzu database.""" if not self.connection: @@ -173,7 +173,9 @@ class KuzuAdapter(GraphDBInterface): except FileNotFoundError: logger.warning(f"Kuzu S3 storage file not found: {self.db_path}") - async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple[Any, ...]]: + async def query( + self, query: str, params: Optional[Dict[str, Any]] = None + ) -> List[Tuple[Any, ...]]: """ Execute a Kuzu query asynchronously with automatic reconnection. @@ -221,7 +223,7 @@ class KuzuAdapter(GraphDBInterface): val = val.as_py() processed_rows.append(val) rows.append(tuple(processed_rows)) - + return rows except Exception as e: logger.error(f"Query execution failed: {str(e)}") @@ -320,7 +322,9 @@ class KuzuAdapter(GraphDBInterface): result = await self.query(query_str, {"id": node_id}) return result[0][0] if result else False - async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None: + async def add_node( + self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None + ) -> None: """ Add a single node to the graph if it doesn't exist. @@ -343,8 +347,9 @@ class KuzuAdapter(GraphDBInterface): "type": str(node_properties.get("type", "")), } # Use the passed properties, excluding core fields - other_properties = {k: v for k, v in node_properties.items() - if k not in ["id", "name", "type"]} + other_properties = { + k: v for k, v in node_properties.items() if k not in ["id", "name", "type"] + } else: # Handle DataPoint object node_properties = node.model_dump() @@ -354,8 +359,9 @@ class KuzuAdapter(GraphDBInterface): "type": str(node_properties.get("type", "")), } # Remove core fields from other properties - other_properties = {k: v for k, v in node_properties.items() - if k not in ["id", "name", "type"]} + other_properties = { + k: v for k, v in node_properties.items() if k not in ["id", "name", "type"] + } core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder) @@ -593,7 +599,9 @@ class KuzuAdapter(GraphDBInterface): ) return result[0][0] if result else False - async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[Tuple[str, str, str, Dict[str, Any]]]: + async def has_edges( + self, edges: List[Tuple[str, str, str, Dict[str, Any]]] + ) -> List[Tuple[str, str, str, Dict[str, Any]]]: """ Check if multiple edges exist in a batch operation. @@ -793,7 +801,7 @@ class KuzuAdapter(GraphDBInterface): relationship_name = row[1] target_node = self._parse_node_properties(row[2]) # TODO: any edge properties we can add? Adding empty to avoid modifying query without reason - edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings + edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings return edges except Exception as e: logger.error(f"Failed to get edges for node {node_id}: {e}") @@ -1362,8 +1370,8 @@ class KuzuAdapter(GraphDBInterface): try: # Get basic graph data nodes, edges = await self.get_model_independent_graph_data() - num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string? - num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string? + num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string? + num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string? # Calculate mandatory metrics mandatory_metrics = { @@ -1571,10 +1579,10 @@ class KuzuAdapter(GraphDBInterface): # Reinitialize the database self._initialize_connection() - + if not self._get_connection(): raise RuntimeError("Failed to establish database connection") - + # Verify the database is empty result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)") if not isinstance(result, list): diff --git a/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py b/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py index 60b19dc30..7dcb5e2a6 100644 --- a/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py @@ -73,13 +73,15 @@ class RemoteKuzuAdapter(KuzuAdapter): status=response.status, message=error_detail, ) - return await response.json() # type: ignore + return await response.json() # type: ignore except aiohttp.ClientError as e: logger.error(f"API request failed: {str(e)}") logger.error(f"Request data: {data}") raise - async def query(self, query: str, params: Optional[dict[str, Any]] = None) -> List[Tuple[Any, ...]]: + async def query( + self, query: str, params: Optional[dict[str, Any]] = None + ) -> List[Tuple[Any, ...]]: """Execute a Kuzu query via the REST API.""" try: # Initialize schema if needed diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 9284ebbbb..58f859576 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -142,7 +142,9 @@ class Neo4jAdapter(GraphDBInterface): ) return results[0]["node_exists"] if len(results) > 0 else False - async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None: + async def add_node( + self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None + ) -> None: """ Add a new node to the database based on the provided DataPoint object or string ID. @@ -407,7 +409,7 @@ class Neo4jAdapter(GraphDBInterface): ) params = { - "from_node": str(source_id), # Adding str as callsites may still be passing UUID + "from_node": str(source_id), # Adding str as callsites may still be passing UUID "to_node": str(target_id), "relationship_name": relationship_name, "properties": serialized_properties, @@ -482,8 +484,8 @@ class Neo4jAdapter(GraphDBInterface): edge_params = [ { - "from_node": str(edge[0]), # Adding str as callsites may still be passing UUID - "to_node": str(edge[1]), # Adding str as callsites may still be passing UUID + "from_node": str(edge[0]), # Adding str as callsites may still be passing UUID + "to_node": str(edge[1]), # Adding str as callsites may still be passing UUID "relationship_name": edge[2], "properties": self._flatten_edge_properties( { diff --git a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py index e89ac3193..7ac30d53e 100644 --- a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +++ b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py @@ -217,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface): collections = await self.get_collection_names() return collection_name in collections - async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None: + async def create_collection( + self, collection_name: str, payload_schema: Optional[Any] = None + ) -> None: """ Create a new collection in ChromaDB if it does not already exist. @@ -313,10 +315,7 @@ class ChromaDBAdapter(VectorDBInterface): [ IndexSchema( id=data_point.id, - text=getattr( - data_point, - data_point.metadata["index_fields"][0] - ), + text=getattr(data_point, data_point.metadata["index_fields"][0]), ) for data_point in data_points if data_point.metadata and len(data_point.metadata["index_fields"]) > 0 diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index c5f0bf0b6..4b9059bb9 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -4,7 +4,18 @@ from uuid import UUID import lancedb from pydantic import BaseModel from lancedb.pydantic import LanceModel, Vector -from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints, Dict, Any +from typing import ( + Generic, + List, + Optional, + TypeVar, + Union, + get_args, + get_origin, + get_type_hints, + Dict, + Any, +) from cognee.infrastructure.databases.exceptions import MissingQueryParameterError from cognee.infrastructure.engine import DataPoint @@ -46,7 +57,7 @@ class LanceDBAdapter(VectorDBInterface): def __init__( self, - url: Optional[str], # TODO: consider if we want to make this required and/or api_key + url: Optional[str], # TODO: consider if we want to make this required and/or api_key api_key: Optional[str], embedding_engine: EmbeddingEngine, ): @@ -109,7 +120,9 @@ class LanceDBAdapter(VectorDBInterface): collection_names = await connection.table_names() return collection_name in collection_names - async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None: + async def create_collection( + self, collection_name: str, payload_schema: Optional[Any] = None + ) -> None: vector_size = self.embedding_engine.get_vector_size() class LanceDataPoint(LanceModel): @@ -123,7 +136,9 @@ class LanceDBAdapter(VectorDBInterface): """ id: UUID - vector: Vector[vector_size] # TODO: double check and consider raising this later in Pydantic + vector: Vector[ + vector_size + ] # TODO: double check and consider raising this later in Pydantic payload: Dict[str, Any] if not await self.has_collection(collection_name): @@ -300,10 +315,7 @@ class LanceDBAdapter(VectorDBInterface): [ IndexSchema( id=data_point.id, - text=getattr( - data_point, - data_point.metadata["index_fields"][0] - ), + text=getattr(data_point, data_point.metadata["index_fields"][0]), ) for data_point in data_points if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0 @@ -327,7 +339,7 @@ class LanceDBAdapter(VectorDBInterface): def get_data_point_schema(self, model_type: Optional[Any]) -> Any: if model_type is None: return DataPoint - + related_models_fields = [] for field_name, field_config in model_type.model_fields.items(): diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index afe60dc64..03318992e 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -123,7 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): stop=stop_after_attempt(5), wait=wait_exponential(multiplier=2, min=1, max=6), ) - async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None: + async def create_collection( + self, collection_name: str, payload_schema: Optional[Any] = None + ) -> None: vector_size = self.embedding_engine.get_vector_size() async with self.VECTOR_DB_LOCK: @@ -151,7 +153,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): payload: Mapped[Dict[str, Any]] = mapped_column(JSON) vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size)) - def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None: + def __init__( + self, id: str, payload: Dict[str, Any], vector: List[float] + ) -> None: self.id = id self.payload = payload self.vector = vector @@ -159,10 +163,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async with self.engine.begin() as connection: if len(Base.metadata.tables.keys()) > 0: from sqlalchemy import Table + table: Table = PGVectorDataPoint.__table__ # type: ignore - await connection.run_sync( - Base.metadata.create_all, tables=[table] - ) + await connection.run_sync(Base.metadata.create_all, tables=[table]) @retry( retry=retry_if_exception_type(DeadlockDetectedError), @@ -351,11 +354,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # Create and return ScoredResult objects return [ - ScoredResult( - id=row["id"], - payload=row["payload"] or {}, - score=row["score"] - ) + ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"]) for row in vector_list ]