mypy: fix PGVectorAdapter mypy errors

This commit is contained in:
Daulet Amirkhanov 2025-09-03 19:29:30 +01:00
parent eebca89855
commit deaf3debbf
2 changed files with 47 additions and 38 deletions

View file

@ -1,9 +1,9 @@
import asyncio
from typing import List, Optional, get_type_hints
from typing import List, Optional, get_type_hints, Dict, Any
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy import JSON, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@ -12,6 +12,7 @@ from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationE
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData as DataPointMetaData
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine
@ -42,7 +43,7 @@ class IndexSchema(DataPoint):
text: str
metadata: dict = {"index_fields": ["text"]}
metadata: DataPointMetaData = {"index_fields": ["text"], "type": "IndexSchema"}
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
@ -122,8 +123,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
vector_size = self.embedding_engine.get_vector_size()
async with self.VECTOR_DB_LOCK:
@ -147,19 +147,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
id: Mapped[str] = mapped_column(primary_key=True)
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
self.id = id
self.payload = payload
self.vector = vector
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
from sqlalchemy import Table
table: Table = PGVectorDataPoint.__table__ # type: ignore
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
Base.metadata.create_all, tables=[table]
)
@retry(
@ -167,9 +169,8 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
@override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
data_point_types = get_type_hints(DataPoint)
@override_distributed(queued_add_data_points) # type: ignore
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name=collection_name,
@ -196,11 +197,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
id: Mapped[str] = mapped_column(primary_key=True)
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
self.id = id
self.payload = payload
self.vector = vector
@ -225,13 +226,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# else:
pgvector_data_points.append(
PGVectorDataPoint(
id=data_point.id,
id=str(data_point.id),
vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()),
)
)
def to_dict(obj):
def to_dict(obj: Any) -> Dict[str, Any]:
return {
column.key: getattr(obj, column.key)
for column in inspect(obj).mapper.column_attrs
@ -245,12 +246,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await session.execute(insert_statement)
await session.commit()
async def create_vector_index(self, index_name: str, index_property_name: str):
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
) -> None:
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
@ -262,11 +263,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
],
)
async def get_table(self, collection_name: str) -> Table:
async def get_table(self, table_name: str, schema_name: Optional[str] = None) -> Table:
"""
Dynamically loads a table using the given collection name
with an async engine.
Dynamically loads a table using the given table name
with an async engine. Schema parameter is ignored for vector collections.
"""
collection_name = table_name
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
@ -279,15 +281,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
f"Collection '{collection_name}' not found!",
)
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute(
query_result = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
results = query_result.all()
return [
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
@ -312,7 +314,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
PGVectorDataPoint = await self.get_table(collection_name)
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
closest_items: List[ScoredResult] = []
# Use async session to connect to the database
async with self.get_async_session() as session:
@ -325,12 +327,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query = query.limit(limit)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
query_results = await session.execute(query)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items.all():
for vector in query_results.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
@ -349,7 +351,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
ScoredResult(
id=row["id"],
payload=row["payload"] or {},
score=row["score"]
)
for row in vector_list
]
@ -357,9 +363,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
limit: Optional[int] = None,
with_vectors: bool = False,
):
) -> List[List[ScoredResult]]:
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
@ -367,14 +373,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
limit=limit or 15,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> Any:
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
@ -384,6 +390,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await session.commit()
return results
async def prune(self):
async def prune(self) -> None:
# Clean up the database if it was set up as temporary
await self.delete_database()

View file

@ -34,7 +34,10 @@ ignore_missing_imports=true
[mypy-lancedb.*]
ignore_missing_imports=true
[mypy-psycopg2.*]
[mypy-asyncpg.*]
ignore_missing_imports=true
[mypy-pgvector.*]
ignore_missing_imports=true
[mypy-docs.*]