diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 4dfd9792f..afe60dc64 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -1,9 +1,9 @@ import asyncio -from typing import List, Optional, get_type_hints +from typing import List, Optional, get_type_hints, Dict, Any from sqlalchemy.inspection import inspect from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.dialects.postgresql import insert -from sqlalchemy import JSON, Column, Table, select, delete, MetaData +from sqlalchemy import JSON, Table, select, delete, MetaData from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.exc import ProgrammingError from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential @@ -12,6 +12,7 @@ from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationE from cognee.shared.logging_utils import get_logger from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine.models.DataPoint import MetaData as DataPointMetaData from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.databases.relational import get_relational_engine @@ -42,7 +43,7 @@ class IndexSchema(DataPoint): text: str - metadata: dict = {"index_fields": ["text"]} + metadata: DataPointMetaData = {"index_fields": ["text"], "type": "IndexSchema"} class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): @@ -122,8 +123,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): stop=stop_after_attempt(5), wait=wait_exponential(multiplier=2, min=1, max=6), ) - async def create_collection(self, collection_name: str, payload_schema=None): - data_point_types = get_type_hints(DataPoint) + async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None: vector_size = self.embedding_engine.get_vector_size() async with self.VECTOR_DB_LOCK: @@ -147,19 +147,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): __tablename__ = collection_name __table_args__ = {"extend_existing": True} # PGVector requires one column to be the primary key - id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True) - payload = Column(JSON) - vector = Column(self.Vector(vector_size)) + id: Mapped[str] = mapped_column(primary_key=True) + payload: Mapped[Dict[str, Any]] = mapped_column(JSON) + vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size)) - def __init__(self, id, payload, vector): + def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None: self.id = id self.payload = payload self.vector = vector async with self.engine.begin() as connection: if len(Base.metadata.tables.keys()) > 0: + from sqlalchemy import Table + table: Table = PGVectorDataPoint.__table__ # type: ignore await connection.run_sync( - Base.metadata.create_all, tables=[PGVectorDataPoint.__table__] + Base.metadata.create_all, tables=[table] ) @retry( @@ -167,9 +169,8 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=1, max=6), ) - @override_distributed(queued_add_data_points) - async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): - data_point_types = get_type_hints(DataPoint) + @override_distributed(queued_add_data_points) # type: ignore + async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None: if not await self.has_collection(collection_name): await self.create_collection( collection_name=collection_name, @@ -196,11 +197,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): __tablename__ = collection_name __table_args__ = {"extend_existing": True} # PGVector requires one column to be the primary key - id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True) - payload = Column(JSON) - vector = Column(self.Vector(vector_size)) + id: Mapped[str] = mapped_column(primary_key=True) + payload: Mapped[Dict[str, Any]] = mapped_column(JSON) + vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size)) - def __init__(self, id, payload, vector): + def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None: self.id = id self.payload = payload self.vector = vector @@ -225,13 +226,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # else: pgvector_data_points.append( PGVectorDataPoint( - id=data_point.id, + id=str(data_point.id), vector=data_vectors[data_index], payload=serialize_data(data_point.model_dump()), ) ) - def to_dict(obj): + def to_dict(obj: Any) -> Dict[str, Any]: return { column.key: getattr(obj, column.key) for column in inspect(obj).mapper.column_attrs @@ -245,12 +246,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): await session.execute(insert_statement) await session.commit() - async def create_vector_index(self, index_name: str, index_property_name: str): + async def create_vector_index(self, index_name: str, index_property_name: str) -> None: await self.create_collection(f"{index_name}_{index_property_name}") async def index_data_points( - self, index_name: str, index_property_name: str, data_points: list[DataPoint] - ): + self, index_name: str, index_property_name: str, data_points: List[DataPoint] + ) -> None: await self.create_data_points( f"{index_name}_{index_property_name}", [ @@ -262,11 +263,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): ], ) - async def get_table(self, collection_name: str) -> Table: + async def get_table(self, table_name: str, schema_name: Optional[str] = None) -> Table: """ - Dynamically loads a table using the given collection name - with an async engine. + Dynamically loads a table using the given table name + with an async engine. Schema parameter is ignored for vector collections. """ + collection_name = table_name async with self.engine.begin() as connection: # Create a MetaData instance to load table information metadata = MetaData() @@ -279,15 +281,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): f"Collection '{collection_name}' not found!", ) - async def retrieve(self, collection_name: str, data_point_ids: List[str]): + async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]: # Get PGVectorDataPoint Table from database PGVectorDataPoint = await self.get_table(collection_name) async with self.get_async_session() as session: - results = await session.execute( + query_result = await session.execute( select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids)) ) - results = results.all() + results = query_result.all() return [ ScoredResult(id=parse_id(result.id), payload=result.payload, score=0) @@ -312,7 +314,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): PGVectorDataPoint = await self.get_table(collection_name) # NOTE: This needs to be initialized in case search doesn't return a value - closest_items = [] + closest_items: List[ScoredResult] = [] # Use async session to connect to the database async with self.get_async_session() as session: @@ -325,12 +327,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): query = query.limit(limit) # Find closest vectors to query_vector - closest_items = await session.execute(query) + query_results = await session.execute(query) vector_list = [] # Extract distances and find min/max for normalization - for vector in closest_items.all(): + for vector in query_results.all(): vector_list.append( { "id": parse_id(str(vector.id)), @@ -349,7 +351,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # Create and return ScoredResult objects return [ - ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score")) + ScoredResult( + id=row["id"], + payload=row["payload"] or {}, + score=row["score"] + ) for row in vector_list ] @@ -357,9 +363,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): self, collection_name: str, query_texts: List[str], - limit: int = None, + limit: Optional[int] = None, with_vectors: bool = False, - ): + ) -> List[List[ScoredResult]]: query_vectors = await self.embedding_engine.embed_text(query_texts) return await asyncio.gather( @@ -367,14 +373,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): self.search( collection_name=collection_name, query_vector=query_vector, - limit=limit, + limit=limit or 15, with_vector=with_vectors, ) for query_vector in query_vectors ] ) - async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> Any: async with self.get_async_session() as session: # Get PGVectorDataPoint Table from database PGVectorDataPoint = await self.get_table(collection_name) @@ -384,6 +390,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): await session.commit() return results - async def prune(self): + async def prune(self) -> None: # Clean up the database if it was set up as temporary await self.delete_database() diff --git a/mypy.ini b/mypy.ini index 3af746437..7a6018e1b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -34,7 +34,10 @@ ignore_missing_imports=true [mypy-lancedb.*] ignore_missing_imports=true -[mypy-psycopg2.*] +[mypy-asyncpg.*] +ignore_missing_imports=true + +[mypy-pgvector.*] ignore_missing_imports=true [mypy-docs.*]