ruff format

This commit is contained in:
Daulet Amirkhanov 2025-09-04 18:13:06 +01:00
parent 0e0bf9a00d
commit c7b0da7aa6
6 changed files with 66 additions and 44 deletions

View file

@ -145,7 +145,7 @@ class KuzuAdapter(GraphDBInterface):
except Exception as e:
logger.error(f"Failed to initialize Kuzu database: {e}")
raise e
def _get_connection(self) -> Connection:
"""Get the connection to the Kuzu database."""
if not self.connection:
@ -173,7 +173,9 @@ class KuzuAdapter(GraphDBInterface):
except FileNotFoundError:
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple[Any, ...]]:
async def query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
"""
Execute a Kuzu query asynchronously with automatic reconnection.
@ -221,7 +223,7 @@ class KuzuAdapter(GraphDBInterface):
val = val.as_py()
processed_rows.append(val)
rows.append(tuple(processed_rows))
return rows
except Exception as e:
logger.error(f"Query execution failed: {str(e)}")
@ -320,7 +322,9 @@ class KuzuAdapter(GraphDBInterface):
result = await self.query(query_str, {"id": node_id})
return result[0][0] if result else False
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
"""
Add a single node to the graph if it doesn't exist.
@ -343,8 +347,9 @@ class KuzuAdapter(GraphDBInterface):
"type": str(node_properties.get("type", "")),
}
# Use the passed properties, excluding core fields
other_properties = {k: v for k, v in node_properties.items()
if k not in ["id", "name", "type"]}
other_properties = {
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
else:
# Handle DataPoint object
node_properties = node.model_dump()
@ -354,8 +359,9 @@ class KuzuAdapter(GraphDBInterface):
"type": str(node_properties.get("type", "")),
}
# Remove core fields from other properties
other_properties = {k: v for k, v in node_properties.items()
if k not in ["id", "name", "type"]}
other_properties = {
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
@ -593,7 +599,9 @@ class KuzuAdapter(GraphDBInterface):
)
return result[0][0] if result else False
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[Tuple[str, str, str, Dict[str, Any]]]:
async def has_edges(
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
"""
Check if multiple edges exist in a batch operation.
@ -793,7 +801,7 @@ class KuzuAdapter(GraphDBInterface):
relationship_name = row[1]
target_node = self._parse_node_properties(row[2])
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
return edges
except Exception as e:
logger.error(f"Failed to get edges for node {node_id}: {e}")
@ -1362,8 +1370,8 @@ class KuzuAdapter(GraphDBInterface):
try:
# Get basic graph data
nodes, edges = await self.get_model_independent_graph_data()
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
# Calculate mandatory metrics
mandatory_metrics = {
@ -1571,10 +1579,10 @@ class KuzuAdapter(GraphDBInterface):
# Reinitialize the database
self._initialize_connection()
if not self._get_connection():
raise RuntimeError("Failed to establish database connection")
# Verify the database is empty
result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)")
if not isinstance(result, list):

View file

@ -73,13 +73,15 @@ class RemoteKuzuAdapter(KuzuAdapter):
status=response.status,
message=error_detail,
)
return await response.json() # type: ignore
return await response.json() # type: ignore
except aiohttp.ClientError as e:
logger.error(f"API request failed: {str(e)}")
logger.error(f"Request data: {data}")
raise
async def query(self, query: str, params: Optional[dict[str, Any]] = None) -> List[Tuple[Any, ...]]:
async def query(
self, query: str, params: Optional[dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
"""Execute a Kuzu query via the REST API."""
try:
# Initialize schema if needed

View file

@ -142,7 +142,9 @@ class Neo4jAdapter(GraphDBInterface):
)
return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
"""
Add a new node to the database based on the provided DataPoint object or string ID.
@ -407,7 +409,7 @@ class Neo4jAdapter(GraphDBInterface):
)
params = {
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
"to_node": str(target_id),
"relationship_name": relationship_name,
"properties": serialized_properties,
@ -482,8 +484,8 @@ class Neo4jAdapter(GraphDBInterface):
edge_params = [
{
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
"relationship_name": edge[2],
"properties": self._flatten_edge_properties(
{

View file

@ -217,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface):
collections = await self.get_collection_names()
return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
"""
Create a new collection in ChromaDB if it does not already exist.
@ -313,10 +315,7 @@ class ChromaDBAdapter(VectorDBInterface):
[
IndexSchema(
id=data_point.id,
text=getattr(
data_point,
data_point.metadata["index_fields"][0]
),
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0

View file

@ -4,7 +4,18 @@ from uuid import UUID
import lancedb
from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints, Dict, Any
from typing import (
Generic,
List,
Optional,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
Dict,
Any,
)
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from cognee.infrastructure.engine import DataPoint
@ -46,7 +57,7 @@ class LanceDBAdapter(VectorDBInterface):
def __init__(
self,
url: Optional[str], # TODO: consider if we want to make this required and/or api_key
url: Optional[str], # TODO: consider if we want to make this required and/or api_key
api_key: Optional[str],
embedding_engine: EmbeddingEngine,
):
@ -109,7 +120,9 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names()
return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
vector_size = self.embedding_engine.get_vector_size()
class LanceDataPoint(LanceModel):
@ -123,7 +136,9 @@ class LanceDBAdapter(VectorDBInterface):
"""
id: UUID
vector: Vector[vector_size] # TODO: double check and consider raising this later in Pydantic
vector: Vector[
vector_size
] # TODO: double check and consider raising this later in Pydantic
payload: Dict[str, Any]
if not await self.has_collection(collection_name):
@ -300,10 +315,7 @@ class LanceDBAdapter(VectorDBInterface):
[
IndexSchema(
id=data_point.id,
text=getattr(
data_point,
data_point.metadata["index_fields"][0]
),
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
@ -327,7 +339,7 @@ class LanceDBAdapter(VectorDBInterface):
def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
if model_type is None:
return DataPoint
related_models_fields = []
for field_name, field_config in model_type.model_fields.items():

View file

@ -123,7 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
vector_size = self.embedding_engine.get_vector_size()
async with self.VECTOR_DB_LOCK:
@ -151,7 +153,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
def __init__(
self, id: str, payload: Dict[str, Any], vector: List[float]
) -> None:
self.id = id
self.payload = payload
self.vector = vector
@ -159,10 +163,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
from sqlalchemy import Table
table: Table = PGVectorDataPoint.__table__ # type: ignore
await connection.run_sync(
Base.metadata.create_all, tables=[table]
)
await connection.run_sync(Base.metadata.create_all, tables=[table])
@retry(
retry=retry_if_exception_type(DeadlockDetectedError),
@ -351,11 +354,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Create and return ScoredResult objects
return [
ScoredResult(
id=row["id"],
payload=row["payload"] or {},
score=row["score"]
)
ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"])
for row in vector_list
]