From c62dfdda9b6f8a1ab8230fe2e8964538db030915 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Fri, 11 Oct 2024 15:00:28 +0200 Subject: [PATCH] feat: Add PGVectorAdapter Added PGVectorAdapter Feature #COG-170 --- .env.template | 10 +- .../infrastructure/databases/vector/config.py | 2 + .../databases/vector/create_vector_engine.py | 18 +- .../vector/pgvector/PGVectorAdapter.py | 170 ++++++++++++++++++ .../databases/vector/pgvector/__init__.py | 0 5 files changed, 195 insertions(+), 5 deletions(-) create mode 100644 cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py create mode 100644 cognee/infrastructure/databases/vector/pgvector/__init__.py diff --git a/.env.template b/.env.template index ada92ad39..c6194a64e 100644 --- a/.env.template +++ b/.env.template @@ -14,17 +14,19 @@ GRAPH_DATABASE_USERNAME= GRAPH_DATABASE_PASSWORD= VECTOR_ENGINE_PROVIDER="qdrant" # or "qdrant", "pgvector", "weaviate" or "lancedb" -# Not needed if using "lancedb" +# Not needed if using "lancedb" or "pgvector" VECTOR_DB_URL= VECTOR_DB_KEY= +# Needed if using "pgvector" +VECTOR_DB_NAME= -# Database provider +# Relational Database provider DB_PROVIDER="sqlite" # "sqlite" or "postgres" -# Database name +# Relational Database name DB_NAME=cognee_db -# Postgres specific parameters (Only if Postgres is run) +# Postgres specific parameters (Only if Postgres or PGVector is run) DB_HOST=127.0.0.1 DB_PORT=5432 DB_USERNAME=cognee diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index 8137a067c..5e403c92e 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -10,6 +10,7 @@ class VectorConfig(BaseSettings): ) vector_db_key: str = "" vector_engine_provider: str = "lancedb" + vector_db_name: str = "cognee_vector_db" model_config = SettingsConfigDict(env_file = ".env", extra = "allow") @@ -18,6 +19,7 @@ class VectorConfig(BaseSettings): "vector_db_url": self.vector_db_url, "vector_db_key": self.vector_db_key, "vector_db_provider": self.vector_engine_provider, + "vector_db_name": self.vector_db_name, } @lru_cache diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index e3571152f..578a601fc 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,9 +1,12 @@ from typing import Dict +from .config import get_relational_config + class VectorConfig(Dict): vector_db_url: str vector_db_key: str vector_db_provider: str + vector_db_name: str def create_vector_engine(config: VectorConfig, embedding_engine): if config["vector_db_provider"] == "weaviate": @@ -27,7 +30,20 @@ def create_vector_engine(config: VectorConfig, embedding_engine): embedding_engine = embedding_engine ) elif config["vector_db_provider"] == "pgvector": - pass + from .pgvector import PGVectorAdapter + + # Get configuration for postgres database + relational_config = get_relational_config() + db_username = relational_config.db_username + db_password = relational_config.db_password + db_host = relational_config.db_host + db_port = relational_config.db_port + + # Get name of vector database + db_name = config["vector_db_name"] + + connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" + return PGVectorAdapter(connection_string) else: from .lancedb.LanceDBAdapter import LanceDBAdapter diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py new file mode 100644 index 000000000..db558721d --- /dev/null +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -0,0 +1,170 @@ +from typing import List, Optional, get_type_hints, Generic, TypeVar +import asyncio +from ..models.ScoredResult import ScoredResult + +from ..vector_db_interface import VectorDBInterface, DataPoint +from ..embeddings.EmbeddingEngine import EmbeddingEngine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker + + +# Define the models +class Base(DeclarativeBase): + pass + +class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): + def __init__(self, connection_string: str, + api_key: Optional[str], + embedding_engine: EmbeddingEngine + ): + self.api_key = api_key + self.embedding_engine = embedding_engine + self.db_uri: str = connection_string + + self.engine = create_async_engine(connection_string) + self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) + + # Create pgvector extension in postgres + with engine.begin() as connection: + connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) + + + async def embed_data(self, data: list[str]) -> list[list[float]]: + return await self.embedding_engine.embed_text(data) + + async def has_collection(self, collection_name: str) -> bool: + connection = await self.get_async_session() + collection_names = await connection.table_names() + return collection_name in collection_names + + async def create_collection(self, collection_name: str, payload_schema = None): + data_point_types = get_type_hints(DataPoint) + vector_size = self.embedding_engine.get_vector_size() + + class PGVectorDataPoint(Base): + id: Mapped[int] = mapped_column(data_point_types["id"], primary_key=True) + vector = mapped_column(Vector(vector_size)) + payload: mapped_column(payload_schema) + + if not await self.has_collection(collection_name): + connection = await self.get_async_session() + return await connection.create_table( + name = collection_name, + schema = PGVectorDataPoint, + exist_ok = True, + ) + + async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): + connection = await self.get_async_session() + + if not await self.has_collection(collection_name): + await self.create_collection( + collection_name, + payload_schema = type(data_points[0].payload), + ) + + collection = await connection.open_table(collection_name) + + data_vectors = await self.embed_data( + [data_point.get_embeddable_data() for data_point in data_points] + ) + + IdType = TypeVar("IdType") + PayloadSchema = TypeVar("PayloadSchema") + vector_size = self.embedding_engine.get_vector_size() + + class PGVectorDataPoint(Base, Generic[IdType, PayloadSchema]): + id: Mapped[int] = mapped_column(IdType, primary_key=True) + vector = mapped_column(Vector(vector_size)) + payload: mapped_column(PayloadSchema) + + pgvector_data_points = [ + PGVectorDataPoint[type(data_point.id), type(data_point.payload)]( + id = data_point.id, + vector = data_vectors[data_index], + payload = data_point.payload, + ) for (data_index, data_point) in enumerate(data_points) + ] + + await collection.add(pgvector_data_points) + + async def retrieve(self, collection_name: str, data_point_ids: list[str]): + connection = await self.get_async_session() + collection = await connection.open_table(collection_name) + + if len(data_point_ids) == 1: + results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas() + else: + results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas() + + return [ScoredResult( + id = result["id"], + payload = result["payload"], + score = 0, + ) for result in results.to_dict("index").values()] + + async def search( + self, + collection_name: str, + query_text: str = None, + query_vector: List[float] = None, + limit: int = 5, + with_vector: bool = False, + ): + if query_text is None and query_vector is None: + raise ValueError("One of query_text or query_vector must be provided!") + + if query_text and not query_vector: + query_vector = (await self.embedding_engine.embed_text([query_text]))[0] + + connection = await self.get_async_session() + collection = await connection.open_table(collection_name) + + results = await collection.vector_search(query_vector).limit(limit).to_pandas() + + result_values = list(results.to_dict("index").values()) + + min_value = 100 + max_value = 0 + + for result in result_values: + value = float(result["_distance"]) + if value > max_value: + max_value = value + if value < min_value: + min_value = value + + normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values] + + return [ScoredResult( + id = str(result["id"]), + payload = result["payload"], + score = normalized_values[value_index], + ) for value_index, result in enumerate(result_values)] + + async def batch_search( + self, + collection_name: str, + query_texts: List[str], + limit: int = None, + with_vectors: bool = False, + ): + query_vectors = await self.embedding_engine.embed_text(query_texts) + + return asyncio.gather( + *[self.search( + collection_name = collection_name, + query_vector = query_vector, + limit = limit, + with_vector = with_vectors, + ) for query_vector in query_vectors] + ) + + async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + connection = await self.get_async_session() + collection = await connection.open_table(collection_name) + results = await collection.delete(f"id IN {tuple(data_point_ids)}") + return results + + async def prune(self): + # Clean up the database if it was set up as temporary + self.delete_database() diff --git a/cognee/infrastructure/databases/vector/pgvector/__init__.py b/cognee/infrastructure/databases/vector/pgvector/__init__.py new file mode 100644 index 000000000..e69de29bb