ruff format

This commit is contained in:
Daulet Amirkhanov 2025-09-04 18:13:06 +01:00
parent 0e0bf9a00d
commit c7b0da7aa6
6 changed files with 66 additions and 44 deletions

View file

@ -173,7 +173,9 @@ class KuzuAdapter(GraphDBInterface):
except FileNotFoundError: except FileNotFoundError:
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}") logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple[Any, ...]]: async def query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
""" """
Execute a Kuzu query asynchronously with automatic reconnection. Execute a Kuzu query asynchronously with automatic reconnection.
@ -320,7 +322,9 @@ class KuzuAdapter(GraphDBInterface):
result = await self.query(query_str, {"id": node_id}) result = await self.query(query_str, {"id": node_id})
return result[0][0] if result else False return result[0][0] if result else False
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None: async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
""" """
Add a single node to the graph if it doesn't exist. Add a single node to the graph if it doesn't exist.
@ -343,8 +347,9 @@ class KuzuAdapter(GraphDBInterface):
"type": str(node_properties.get("type", "")), "type": str(node_properties.get("type", "")),
} }
# Use the passed properties, excluding core fields # Use the passed properties, excluding core fields
other_properties = {k: v for k, v in node_properties.items() other_properties = {
if k not in ["id", "name", "type"]} k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
else: else:
# Handle DataPoint object # Handle DataPoint object
node_properties = node.model_dump() node_properties = node.model_dump()
@ -354,8 +359,9 @@ class KuzuAdapter(GraphDBInterface):
"type": str(node_properties.get("type", "")), "type": str(node_properties.get("type", "")),
} }
# Remove core fields from other properties # Remove core fields from other properties
other_properties = {k: v for k, v in node_properties.items() other_properties = {
if k not in ["id", "name", "type"]} k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder) core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
@ -593,7 +599,9 @@ class KuzuAdapter(GraphDBInterface):
) )
return result[0][0] if result else False return result[0][0] if result else False
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[Tuple[str, str, str, Dict[str, Any]]]: async def has_edges(
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
""" """
Check if multiple edges exist in a batch operation. Check if multiple edges exist in a batch operation.
@ -793,7 +801,7 @@ class KuzuAdapter(GraphDBInterface):
relationship_name = row[1] relationship_name = row[1]
target_node = self._parse_node_properties(row[2]) target_node = self._parse_node_properties(row[2])
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason # TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
return edges return edges
except Exception as e: except Exception as e:
logger.error(f"Failed to get edges for node {node_id}: {e}") logger.error(f"Failed to get edges for node {node_id}: {e}")
@ -1362,8 +1370,8 @@ class KuzuAdapter(GraphDBInterface):
try: try:
# Get basic graph data # Get basic graph data
nodes, edges = await self.get_model_independent_graph_data() nodes, edges = await self.get_model_independent_graph_data()
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string? num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string? num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
# Calculate mandatory metrics # Calculate mandatory metrics
mandatory_metrics = { mandatory_metrics = {

View file

@ -73,13 +73,15 @@ class RemoteKuzuAdapter(KuzuAdapter):
status=response.status, status=response.status,
message=error_detail, message=error_detail,
) )
return await response.json() # type: ignore return await response.json() # type: ignore
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.error(f"API request failed: {str(e)}") logger.error(f"API request failed: {str(e)}")
logger.error(f"Request data: {data}") logger.error(f"Request data: {data}")
raise raise
async def query(self, query: str, params: Optional[dict[str, Any]] = None) -> List[Tuple[Any, ...]]: async def query(
self, query: str, params: Optional[dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
"""Execute a Kuzu query via the REST API.""" """Execute a Kuzu query via the REST API."""
try: try:
# Initialize schema if needed # Initialize schema if needed

View file

@ -142,7 +142,9 @@ class Neo4jAdapter(GraphDBInterface):
) )
return results[0]["node_exists"] if len(results) > 0 else False return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None: async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
""" """
Add a new node to the database based on the provided DataPoint object or string ID. Add a new node to the database based on the provided DataPoint object or string ID.
@ -407,7 +409,7 @@ class Neo4jAdapter(GraphDBInterface):
) )
params = { params = {
"from_node": str(source_id), # Adding str as callsites may still be passing UUID "from_node": str(source_id), # Adding str as callsites may still be passing UUID
"to_node": str(target_id), "to_node": str(target_id),
"relationship_name": relationship_name, "relationship_name": relationship_name,
"properties": serialized_properties, "properties": serialized_properties,
@ -482,8 +484,8 @@ class Neo4jAdapter(GraphDBInterface):
edge_params = [ edge_params = [
{ {
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID "from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID "to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
"relationship_name": edge[2], "relationship_name": edge[2],
"properties": self._flatten_edge_properties( "properties": self._flatten_edge_properties(
{ {

View file

@ -217,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface):
collections = await self.get_collection_names() collections = await self.get_collection_names()
return collection_name in collections return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None: async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
""" """
Create a new collection in ChromaDB if it does not already exist. Create a new collection in ChromaDB if it does not already exist.
@ -313,10 +315,7 @@ class ChromaDBAdapter(VectorDBInterface):
[ [
IndexSchema( IndexSchema(
id=data_point.id, id=data_point.id,
text=getattr( text=getattr(data_point, data_point.metadata["index_fields"][0]),
data_point,
data_point.metadata["index_fields"][0]
),
) )
for data_point in data_points for data_point in data_points
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0 if data_point.metadata and len(data_point.metadata["index_fields"]) > 0

View file

@ -4,7 +4,18 @@ from uuid import UUID
import lancedb import lancedb
from pydantic import BaseModel from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints, Dict, Any from typing import (
Generic,
List,
Optional,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
Dict,
Any,
)
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -46,7 +57,7 @@ class LanceDBAdapter(VectorDBInterface):
def __init__( def __init__(
self, self,
url: Optional[str], # TODO: consider if we want to make this required and/or api_key url: Optional[str], # TODO: consider if we want to make this required and/or api_key
api_key: Optional[str], api_key: Optional[str],
embedding_engine: EmbeddingEngine, embedding_engine: EmbeddingEngine,
): ):
@ -109,7 +120,9 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names() collection_names = await connection.table_names()
return collection_name in collection_names return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None: async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
class LanceDataPoint(LanceModel): class LanceDataPoint(LanceModel):
@ -123,7 +136,9 @@ class LanceDBAdapter(VectorDBInterface):
""" """
id: UUID id: UUID
vector: Vector[vector_size] # TODO: double check and consider raising this later in Pydantic vector: Vector[
vector_size
] # TODO: double check and consider raising this later in Pydantic
payload: Dict[str, Any] payload: Dict[str, Any]
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
@ -300,10 +315,7 @@ class LanceDBAdapter(VectorDBInterface):
[ [
IndexSchema( IndexSchema(
id=data_point.id, id=data_point.id,
text=getattr( text=getattr(data_point, data_point.metadata["index_fields"][0]),
data_point,
data_point.metadata["index_fields"][0]
),
) )
for data_point in data_points for data_point in data_points
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0 if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0

View file

@ -123,7 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
stop=stop_after_attempt(5), stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=1, max=6), wait=wait_exponential(multiplier=2, min=1, max=6),
) )
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None: async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
async with self.VECTOR_DB_LOCK: async with self.VECTOR_DB_LOCK:
@ -151,7 +153,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
payload: Mapped[Dict[str, Any]] = mapped_column(JSON) payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size)) vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None: def __init__(
self, id: str, payload: Dict[str, Any], vector: List[float]
) -> None:
self.id = id self.id = id
self.payload = payload self.payload = payload
self.vector = vector self.vector = vector
@ -159,10 +163,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0: if len(Base.metadata.tables.keys()) > 0:
from sqlalchemy import Table from sqlalchemy import Table
table: Table = PGVectorDataPoint.__table__ # type: ignore table: Table = PGVectorDataPoint.__table__ # type: ignore
await connection.run_sync( await connection.run_sync(Base.metadata.create_all, tables=[table])
Base.metadata.create_all, tables=[table]
)
@retry( @retry(
retry=retry_if_exception_type(DeadlockDetectedError), retry=retry_if_exception_type(DeadlockDetectedError),
@ -351,11 +354,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Create and return ScoredResult objects # Create and return ScoredResult objects
return [ return [
ScoredResult( ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"])
id=row["id"],
payload=row["payload"] or {},
score=row["score"]
)
for row in vector_list for row in vector_list
] ]