diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 578a601fc..2399eac09 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,6 +1,6 @@ from typing import Dict -from .config import get_relational_config +from ..relational.config import get_relational_config class VectorConfig(Dict): vector_db_url: str @@ -30,7 +30,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine): embedding_engine = embedding_engine ) elif config["vector_db_provider"] == "pgvector": - from .pgvector import PGVectorAdapter + from .pgvector.PGVectorAdapter import PGVectorAdapter # Get configuration for postgres database relational_config = get_relational_config() @@ -43,7 +43,10 @@ def create_vector_engine(config: VectorConfig, embedding_engine): db_name = config["vector_db_name"] connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" - return PGVectorAdapter(connection_string) + return PGVectorAdapter(connection_string, + config["vector_db_key"], + embedding_engine + ) else: from .lancedb.LanceDBAdapter import LanceDBAdapter diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index db558721d..8b79fb9d3 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -6,12 +6,21 @@ from ..vector_db_interface import VectorDBInterface, DataPoint from ..embeddings.EmbeddingEngine import EmbeddingEngine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy.orm import DeclarativeBase, mapped_column +from pgvector.sqlalchemy import Vector + +from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter + # Define the models class Base(DeclarativeBase): pass class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): + async def create_vector_extension(self): + async with self.get_async_session() as session: + await session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) + def __init__(self, connection_string: str, api_key: Optional[str], embedding_engine: EmbeddingEngine @@ -22,19 +31,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): self.engine = create_async_engine(connection_string) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) - - # Create pgvector extension in postgres - with engine.begin() as connection: - connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) - + self.create_vector_extension() async def embed_data(self, data: list[str]) -> list[list[float]]: return await self.embedding_engine.embed_text(data) async def has_collection(self, collection_name: str) -> bool: - connection = await self.get_async_session() - collection_names = await connection.table_names() - return collection_name in collection_names + async with self.engine.begin() as connection: + collection_names = await connection.table_names() + return collection_name in collection_names async def create_collection(self, collection_name: str, payload_schema = None): data_point_types = get_type_hints(DataPoint) @@ -46,61 +51,60 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): payload: mapped_column(payload_schema) if not await self.has_collection(collection_name): - connection = await self.get_async_session() - return await connection.create_table( - name = collection_name, - schema = PGVectorDataPoint, - exist_ok = True, - ) + async with self.engine.begin() as connection: + return await connection.create_table( + name = collection_name, + schema = PGVectorDataPoint, + exist_ok = True, + ) async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): - connection = await self.get_async_session() + async with self.engine.begin() as connection: + if not await self.has_collection(collection_name): + await self.create_collection( + collection_name, + payload_schema = type(data_points[0].payload), + ) - if not await self.has_collection(collection_name): - await self.create_collection( - collection_name, - payload_schema = type(data_points[0].payload), + collection = await connection.open_table(collection_name) + + data_vectors = await self.embed_data( + [data_point.get_embeddable_data() for data_point in data_points] ) - collection = await connection.open_table(collection_name) + IdType = TypeVar("IdType") + PayloadSchema = TypeVar("PayloadSchema") + vector_size = self.embedding_engine.get_vector_size() - data_vectors = await self.embed_data( - [data_point.get_embeddable_data() for data_point in data_points] - ) + class PGVectorDataPoint(Base, Generic[IdType, PayloadSchema]): + id: Mapped[int] = mapped_column(IdType, primary_key=True) + vector = mapped_column(Vector(vector_size)) + payload: mapped_column(PayloadSchema) - IdType = TypeVar("IdType") - PayloadSchema = TypeVar("PayloadSchema") - vector_size = self.embedding_engine.get_vector_size() + pgvector_data_points = [ + PGVectorDataPoint[type(data_point.id), type(data_point.payload)]( + id = data_point.id, + vector = data_vectors[data_index], + payload = data_point.payload, + ) for (data_index, data_point) in enumerate(data_points) + ] - class PGVectorDataPoint(Base, Generic[IdType, PayloadSchema]): - id: Mapped[int] = mapped_column(IdType, primary_key=True) - vector = mapped_column(Vector(vector_size)) - payload: mapped_column(PayloadSchema) - - pgvector_data_points = [ - PGVectorDataPoint[type(data_point.id), type(data_point.payload)]( - id = data_point.id, - vector = data_vectors[data_index], - payload = data_point.payload, - ) for (data_index, data_point) in enumerate(data_points) - ] - - await collection.add(pgvector_data_points) + await collection.add(pgvector_data_points) async def retrieve(self, collection_name: str, data_point_ids: list[str]): - connection = await self.get_async_session() - collection = await connection.open_table(collection_name) + async with self.engine.begin() as connection: + collection = await connection.open_table(collection_name) - if len(data_point_ids) == 1: - results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas() - else: - results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas() + if len(data_point_ids) == 1: + results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas() + else: + results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas() - return [ScoredResult( - id = result["id"], - payload = result["payload"], - score = 0, - ) for result in results.to_dict("index").values()] + return [ScoredResult( + id = result["id"], + payload = result["payload"], + score = 0, + ) for result in results.to_dict("index").values()] async def search( self, @@ -116,30 +120,30 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): if query_text and not query_vector: query_vector = (await self.embedding_engine.embed_text([query_text]))[0] - connection = await self.get_async_session() - collection = await connection.open_table(collection_name) + async with self.engine.begin() as connection: + collection = await connection.open_table(collection_name) - results = await collection.vector_search(query_vector).limit(limit).to_pandas() + results = await collection.vector_search(query_vector).limit(limit).to_pandas() - result_values = list(results.to_dict("index").values()) + result_values = list(results.to_dict("index").values()) - min_value = 100 - max_value = 0 + min_value = 100 + max_value = 0 - for result in result_values: - value = float(result["_distance"]) - if value > max_value: - max_value = value - if value < min_value: - min_value = value + for result in result_values: + value = float(result["_distance"]) + if value > max_value: + max_value = value + if value < min_value: + min_value = value - normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values] + normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values] - return [ScoredResult( - id = str(result["id"]), - payload = result["payload"], - score = normalized_values[value_index], - ) for value_index, result in enumerate(result_values)] + return [ScoredResult( + id = str(result["id"]), + payload = result["payload"], + score = normalized_values[value_index], + ) for value_index, result in enumerate(result_values)] async def batch_search( self, @@ -160,10 +164,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): ) async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): - connection = await self.get_async_session() - collection = await connection.open_table(collection_name) - results = await collection.delete(f"id IN {tuple(data_point_ids)}") - return results + async with self.engine.begin() as connection: + collection = await connection.open_table(collection_name) + results = await collection.delete(f"id IN {tuple(data_point_ids)}") + return results async def prune(self): # Clean up the database if it was set up as temporary diff --git a/cognee/infrastructure/databases/vector/pgvector/__init__.py b/cognee/infrastructure/databases/vector/pgvector/__init__.py index e69de29bb..84dc89113 100644 --- a/cognee/infrastructure/databases/vector/pgvector/__init__.py +++ b/cognee/infrastructure/databases/vector/pgvector/__init__.py @@ -0,0 +1 @@ +from .PGVectorAdapter import PGVectorAdapter \ No newline at end of file