mypy: fix PGVectorAdapter mypy errors
This commit is contained in:
parent
eebca89855
commit
deaf3debbf
2 changed files with 47 additions and 38 deletions
|
|
@ -1,9 +1,9 @@
|
|||
import asyncio
|
||||
from typing import List, Optional, get_type_hints
|
||||
from typing import List, Optional, get_type_hints, Dict, Any
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||
from sqlalchemy import JSON, Table, select, delete, MetaData
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
|
@ -12,6 +12,7 @@ from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationE
|
|||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.DataPoint import MetaData as DataPointMetaData
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
|
|
@ -42,7 +43,7 @@ class IndexSchema(DataPoint):
|
|||
|
||||
text: str
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
metadata: DataPointMetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
|
||||
|
||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||
|
|
@ -122,8 +123,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
||||
data_point_types = get_type_hints(DataPoint)
|
||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
|
|
@ -147,19 +147,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
id: Mapped[str] = mapped_column(primary_key=True)
|
||||
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
||||
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
|
||||
async with self.engine.begin() as connection:
|
||||
if len(Base.metadata.tables.keys()) > 0:
|
||||
from sqlalchemy import Table
|
||||
table: Table = PGVectorDataPoint.__table__ # type: ignore
|
||||
await connection.run_sync(
|
||||
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
||||
Base.metadata.create_all, tables=[table]
|
||||
)
|
||||
|
||||
@retry(
|
||||
|
|
@ -167,9 +169,8 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
@override_distributed(queued_add_data_points)
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
data_point_types = get_type_hints(DataPoint)
|
||||
@override_distributed(queued_add_data_points) # type: ignore
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name=collection_name,
|
||||
|
|
@ -196,11 +197,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
id: Mapped[str] = mapped_column(primary_key=True)
|
||||
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
||||
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
|
|
@ -225,13 +226,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# else:
|
||||
pgvector_data_points.append(
|
||||
PGVectorDataPoint(
|
||||
id=data_point.id,
|
||||
id=str(data_point.id),
|
||||
vector=data_vectors[data_index],
|
||||
payload=serialize_data(data_point.model_dump()),
|
||||
)
|
||||
)
|
||||
|
||||
def to_dict(obj):
|
||||
def to_dict(obj: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
column.key: getattr(obj, column.key)
|
||||
for column in inspect(obj).mapper.column_attrs
|
||||
|
|
@ -245,12 +246,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
await session.execute(insert_statement)
|
||||
await session.commit()
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
|
||||
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
||||
) -> None:
|
||||
await self.create_data_points(
|
||||
f"{index_name}_{index_property_name}",
|
||||
[
|
||||
|
|
@ -262,11 +263,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
],
|
||||
)
|
||||
|
||||
async def get_table(self, collection_name: str) -> Table:
|
||||
async def get_table(self, table_name: str, schema_name: Optional[str] = None) -> Table:
|
||||
"""
|
||||
Dynamically loads a table using the given collection name
|
||||
with an async engine.
|
||||
Dynamically loads a table using the given table name
|
||||
with an async engine. Schema parameter is ignored for vector collections.
|
||||
"""
|
||||
collection_name = table_name
|
||||
async with self.engine.begin() as connection:
|
||||
# Create a MetaData instance to load table information
|
||||
metadata = MetaData()
|
||||
|
|
@ -279,15 +281,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
f"Collection '{collection_name}' not found!",
|
||||
)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
async with self.get_async_session() as session:
|
||||
results = await session.execute(
|
||||
query_result = await session.execute(
|
||||
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
||||
)
|
||||
results = results.all()
|
||||
results = query_result.all()
|
||||
|
||||
return [
|
||||
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
|
||||
|
|
@ -312,7 +314,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||
closest_items = []
|
||||
closest_items: List[ScoredResult] = []
|
||||
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
|
|
@ -325,12 +327,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
query = query.limit(limit)
|
||||
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(query)
|
||||
query_results = await session.execute(query)
|
||||
|
||||
vector_list = []
|
||||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items.all():
|
||||
for vector in query_results.all():
|
||||
vector_list.append(
|
||||
{
|
||||
"id": parse_id(str(vector.id)),
|
||||
|
|
@ -349,7 +351,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||
ScoredResult(
|
||||
id=row["id"],
|
||||
payload=row["payload"] or {},
|
||||
score=row["score"]
|
||||
)
|
||||
for row in vector_list
|
||||
]
|
||||
|
||||
|
|
@ -357,9 +363,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
self,
|
||||
collection_name: str,
|
||||
query_texts: List[str],
|
||||
limit: int = None,
|
||||
limit: Optional[int] = None,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
) -> List[List[ScoredResult]]:
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
||||
return await asyncio.gather(
|
||||
|
|
@ -367,14 +373,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
self.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
limit=limit or 15,
|
||||
with_vector=with_vectors,
|
||||
)
|
||||
for query_vector in query_vectors
|
||||
]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> Any:
|
||||
async with self.get_async_session() as session:
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
|
@ -384,6 +390,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
await session.commit()
|
||||
return results
|
||||
|
||||
async def prune(self):
|
||||
async def prune(self) -> None:
|
||||
# Clean up the database if it was set up as temporary
|
||||
await self.delete_database()
|
||||
|
|
|
|||
5
mypy.ini
5
mypy.ini
|
|
@ -34,7 +34,10 @@ ignore_missing_imports=true
|
|||
[mypy-lancedb.*]
|
||||
ignore_missing_imports=true
|
||||
|
||||
[mypy-psycopg2.*]
|
||||
[mypy-asyncpg.*]
|
||||
ignore_missing_imports=true
|
||||
|
||||
[mypy-pgvector.*]
|
||||
ignore_missing_imports=true
|
||||
|
||||
[mypy-docs.*]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue