ruff format
This commit is contained in:
parent
0e0bf9a00d
commit
c7b0da7aa6
6 changed files with 66 additions and 44 deletions
|
|
@ -145,7 +145,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Kuzu database: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def _get_connection(self) -> Connection:
|
||||
"""Get the connection to the Kuzu database."""
|
||||
if not self.connection:
|
||||
|
|
@ -173,7 +173,9 @@ class KuzuAdapter(GraphDBInterface):
|
|||
except FileNotFoundError:
|
||||
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
||||
|
||||
async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple[Any, ...]]:
|
||||
async def query(
|
||||
self, query: str, params: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
"""
|
||||
Execute a Kuzu query asynchronously with automatic reconnection.
|
||||
|
||||
|
|
@ -221,7 +223,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
val = val.as_py()
|
||||
processed_rows.append(val)
|
||||
rows.append(tuple(processed_rows))
|
||||
|
||||
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error(f"Query execution failed: {str(e)}")
|
||||
|
|
@ -320,7 +322,9 @@ class KuzuAdapter(GraphDBInterface):
|
|||
result = await self.query(query_str, {"id": node_id})
|
||||
return result[0][0] if result else False
|
||||
|
||||
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
|
||||
async def add_node(
|
||||
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Add a single node to the graph if it doesn't exist.
|
||||
|
||||
|
|
@ -343,8 +347,9 @@ class KuzuAdapter(GraphDBInterface):
|
|||
"type": str(node_properties.get("type", "")),
|
||||
}
|
||||
# Use the passed properties, excluding core fields
|
||||
other_properties = {k: v for k, v in node_properties.items()
|
||||
if k not in ["id", "name", "type"]}
|
||||
other_properties = {
|
||||
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
|
||||
}
|
||||
else:
|
||||
# Handle DataPoint object
|
||||
node_properties = node.model_dump()
|
||||
|
|
@ -354,8 +359,9 @@ class KuzuAdapter(GraphDBInterface):
|
|||
"type": str(node_properties.get("type", "")),
|
||||
}
|
||||
# Remove core fields from other properties
|
||||
other_properties = {k: v for k, v in node_properties.items()
|
||||
if k not in ["id", "name", "type"]}
|
||||
other_properties = {
|
||||
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
|
||||
}
|
||||
|
||||
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
|
||||
|
||||
|
|
@ -593,7 +599,9 @@ class KuzuAdapter(GraphDBInterface):
|
|||
)
|
||||
return result[0][0] if result else False
|
||||
|
||||
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||
async def has_edges(
|
||||
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
|
||||
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||
"""
|
||||
Check if multiple edges exist in a batch operation.
|
||||
|
||||
|
|
@ -793,7 +801,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
relationship_name = row[1]
|
||||
target_node = self._parse_node_properties(row[2])
|
||||
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
|
||||
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
|
||||
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
|
||||
return edges
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get edges for node {node_id}: {e}")
|
||||
|
|
@ -1362,8 +1370,8 @@ class KuzuAdapter(GraphDBInterface):
|
|||
try:
|
||||
# Get basic graph data
|
||||
nodes, edges = await self.get_model_independent_graph_data()
|
||||
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
|
||||
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
|
||||
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
|
||||
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
|
||||
|
||||
# Calculate mandatory metrics
|
||||
mandatory_metrics = {
|
||||
|
|
@ -1571,10 +1579,10 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
# Reinitialize the database
|
||||
self._initialize_connection()
|
||||
|
||||
|
||||
if not self._get_connection():
|
||||
raise RuntimeError("Failed to establish database connection")
|
||||
|
||||
|
||||
# Verify the database is empty
|
||||
result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)")
|
||||
if not isinstance(result, list):
|
||||
|
|
|
|||
|
|
@ -73,13 +73,15 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
|||
status=response.status,
|
||||
message=error_detail,
|
||||
)
|
||||
return await response.json() # type: ignore
|
||||
return await response.json() # type: ignore
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"API request failed: {str(e)}")
|
||||
logger.error(f"Request data: {data}")
|
||||
raise
|
||||
|
||||
async def query(self, query: str, params: Optional[dict[str, Any]] = None) -> List[Tuple[Any, ...]]:
|
||||
async def query(
|
||||
self, query: str, params: Optional[dict[str, Any]] = None
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
"""Execute a Kuzu query via the REST API."""
|
||||
try:
|
||||
# Initialize schema if needed
|
||||
|
|
|
|||
|
|
@ -142,7 +142,9 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
return results[0]["node_exists"] if len(results) > 0 else False
|
||||
|
||||
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
|
||||
async def add_node(
|
||||
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Add a new node to the database based on the provided DataPoint object or string ID.
|
||||
|
||||
|
|
@ -407,7 +409,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
|
||||
params = {
|
||||
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
|
||||
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
|
||||
"to_node": str(target_id),
|
||||
"relationship_name": relationship_name,
|
||||
"properties": serialized_properties,
|
||||
|
|
@ -482,8 +484,8 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
edge_params = [
|
||||
{
|
||||
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
|
||||
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
|
||||
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
|
||||
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
|
||||
"relationship_name": edge[2],
|
||||
"properties": self._flatten_edge_properties(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -217,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
collections = await self.get_collection_names()
|
||||
return collection_name in collections
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
||||
async def create_collection(
|
||||
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||
) -> None:
|
||||
"""
|
||||
Create a new collection in ChromaDB if it does not already exist.
|
||||
|
||||
|
|
@ -313,10 +315,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
[
|
||||
IndexSchema(
|
||||
id=data_point.id,
|
||||
text=getattr(
|
||||
data_point,
|
||||
data_point.metadata["index_fields"][0]
|
||||
),
|
||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
)
|
||||
for data_point in data_points
|
||||
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
|
||||
|
|
|
|||
|
|
@ -4,7 +4,18 @@ from uuid import UUID
|
|||
import lancedb
|
||||
from pydantic import BaseModel
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints, Dict, Any
|
||||
from typing import (
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
Dict,
|
||||
Any,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -46,7 +57,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
url: Optional[str], # TODO: consider if we want to make this required and/or api_key
|
||||
url: Optional[str], # TODO: consider if we want to make this required and/or api_key
|
||||
api_key: Optional[str],
|
||||
embedding_engine: EmbeddingEngine,
|
||||
):
|
||||
|
|
@ -109,7 +120,9 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
collection_names = await connection.table_names()
|
||||
return collection_name in collection_names
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
||||
async def create_collection(
|
||||
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||
) -> None:
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
class LanceDataPoint(LanceModel):
|
||||
|
|
@ -123,7 +136,9 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
"""
|
||||
|
||||
id: UUID
|
||||
vector: Vector[vector_size] # TODO: double check and consider raising this later in Pydantic
|
||||
vector: Vector[
|
||||
vector_size
|
||||
] # TODO: double check and consider raising this later in Pydantic
|
||||
payload: Dict[str, Any]
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
|
|
@ -300,10 +315,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
[
|
||||
IndexSchema(
|
||||
id=data_point.id,
|
||||
text=getattr(
|
||||
data_point,
|
||||
data_point.metadata["index_fields"][0]
|
||||
),
|
||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
)
|
||||
for data_point in data_points
|
||||
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
|
||||
|
|
@ -327,7 +339,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
|
||||
if model_type is None:
|
||||
return DataPoint
|
||||
|
||||
|
||||
related_models_fields = []
|
||||
|
||||
for field_name, field_config in model_type.model_fields.items():
|
||||
|
|
|
|||
|
|
@ -123,7 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
||||
async def create_collection(
|
||||
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||
) -> None:
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
|
|
@ -151,7 +153,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
||||
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
|
||||
def __init__(
|
||||
self, id: str, payload: Dict[str, Any], vector: List[float]
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
|
|
@ -159,10 +163,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
async with self.engine.begin() as connection:
|
||||
if len(Base.metadata.tables.keys()) > 0:
|
||||
from sqlalchemy import Table
|
||||
|
||||
table: Table = PGVectorDataPoint.__table__ # type: ignore
|
||||
await connection.run_sync(
|
||||
Base.metadata.create_all, tables=[table]
|
||||
)
|
||||
await connection.run_sync(Base.metadata.create_all, tables=[table])
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||
|
|
@ -351,11 +354,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(
|
||||
id=row["id"],
|
||||
payload=row["payload"] or {},
|
||||
score=row["score"]
|
||||
)
|
||||
ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"])
|
||||
for row in vector_list
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue