diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index 85f0688ac..b0b2fc8c6 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -14,9 +14,12 @@ from cognee.tasks.ingestion import get_dlt_destination from cognee.modules.users.permissions.methods import give_permission_on_document from cognee.modules.users.models import User from cognee.modules.data.methods import create_dataset +from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables +from cognee.infrastructure.databases.vector import create_db_and_tables as create_vector_db_and_tables async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None): - await create_db_and_tables() + await create_relational_db_and_tables() + await create_vector_db_and_tables() if isinstance(data, str): if "data://" in data: diff --git a/cognee/api/v1/add/add_v2.py b/cognee/api/v1/add/add_v2.py index 291ec5f4c..f32f470a9 100644 --- a/cognee/api/v1/add/add_v2.py +++ b/cognee/api/v1/add/add_v2.py @@ -3,10 +3,12 @@ from cognee.modules.users.models import User from cognee.modules.users.methods import get_default_user from cognee.modules.pipelines import run_tasks, Task from cognee.tasks.ingestion import save_data_to_storage, ingest_data -from cognee.infrastructure.databases.relational import create_db_and_tables +from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables +from cognee.infrastructure.databases.vector import create_db_and_tables as create_vector_db_and_tables async def add(data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_name: str = "main_dataset", user: User = None): - await create_db_and_tables() + await create_relational_db_and_tables() + await create_vector_db_and_tables() if user is None: user = await get_default_user() diff --git a/cognee/infrastructure/databases/vector/__init__.py b/cognee/infrastructure/databases/vector/__init__.py index 604170f1d..02d13bb9c 100644 --- a/cognee/infrastructure/databases/vector/__init__.py +++ b/cognee/infrastructure/databases/vector/__init__.py @@ -4,3 +4,4 @@ from .models.CollectionConfig import CollectionConfig from .vector_db_interface import VectorDBInterface from .config import get_vectordb_config from .get_vector_engine import get_vector_engine +from .create_db_and_tables import create_db_and_tables diff --git a/cognee/infrastructure/databases/vector/create_db_and_tables.py b/cognee/infrastructure/databases/vector/create_db_and_tables.py new file mode 100644 index 000000000..21522db6b --- /dev/null +++ b/cognee/infrastructure/databases/vector/create_db_and_tables.py @@ -0,0 +1,15 @@ +from ..relational.ModelBase import Base +from .get_vector_engine import get_vector_engine, get_vectordb_config +from sqlalchemy import text + +async def create_db_and_tables(): + vector_config = get_vectordb_config() + vector_engine = get_vector_engine() + + if vector_config.vector_engine_provider == "pgvector": + async with vector_engine.engine.begin() as connection: + if len(Base.metadata.tables.keys()) > 0: + await connection.run_sync(Base.metadata.create_all) + await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + + diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 2399eac09..58e6dbf03 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -42,7 +42,8 @@ def create_vector_engine(config: VectorConfig, embedding_engine): # Get name of vector database db_name = config["vector_db_name"] - connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" + connection_string: str = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" + return PGVectorAdapter(connection_string, config["vector_db_key"], embedding_engine diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 8b79fb9d3..b0a44bb8b 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -1,27 +1,35 @@ -from typing import List, Optional, get_type_hints, Generic, TypeVar -import asyncio +from typing import List, Optional, get_type_hints, Any, Dict +from sqlalchemy import text, select +from sqlalchemy import JSON, Column, Table +from sqlalchemy.dialects.postgresql import ARRAY from ..models.ScoredResult import ScoredResult from ..vector_db_interface import VectorDBInterface, DataPoint +from sqlalchemy.orm import Mapped, mapped_column from ..embeddings.EmbeddingEngine import EmbeddingEngine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker - -from sqlalchemy.orm import DeclarativeBase, mapped_column from pgvector.sqlalchemy import Vector from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter +from ...relational.ModelBase import Base +from datetime import datetime -# Define the models -class Base(DeclarativeBase): - pass +# TODO: Find better location for function +def serialize_datetime(data): + """Recursively convert datetime objects in dictionaries/lists to ISO format.""" + if isinstance(data, dict): + return {key: serialize_datetime(value) for key, value in data.items()} + elif isinstance(data, list): + return [serialize_datetime(item) for item in data] + elif isinstance(data, datetime): + return data.isoformat() # Convert datetime to ISO 8601 string + else: + return data class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): - async def create_vector_extension(self): - async with self.get_async_session() as session: - await session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) - def __init__(self, connection_string: str, + def __init__(self, connection_string: str, api_key: Optional[str], embedding_engine: EmbeddingEngine ): @@ -29,121 +37,156 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): self.embedding_engine = embedding_engine self.db_uri: str = connection_string - self.engine = create_async_engine(connection_string) + self.engine = create_async_engine(self.db_uri, echo=True) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) - self.create_vector_extension() async def embed_data(self, data: list[str]) -> list[list[float]]: return await self.embedding_engine.embed_text(data) async def has_collection(self, collection_name: str) -> bool: async with self.engine.begin() as connection: - collection_names = await connection.table_names() - return collection_name in collection_names + #TODO: Switch to using ORM instead of raw query + result = await connection.execute( + text("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';") + ) + tables = result.fetchall() + for table in tables: + if collection_name == table[0]: + return True + return False async def create_collection(self, collection_name: str, payload_schema = None): data_point_types = get_type_hints(DataPoint) vector_size = self.embedding_engine.get_vector_size() - class PGVectorDataPoint(Base): - id: Mapped[int] = mapped_column(data_point_types["id"], primary_key=True) - vector = mapped_column(Vector(vector_size)) - payload: mapped_column(payload_schema) - if not await self.has_collection(collection_name): + + class PGVectorDataPoint(Base): + __tablename__ = collection_name + __table_args__ = {'extend_existing': True} + # PGVector requires one column to be the primary key + primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + id: Mapped[data_point_types["id"]] + payload = Column(JSON) + vector = Column(Vector(vector_size)) + + def __init__(self, id, payload, vector): + self.id = id + self.payload = payload + self.vector = vector + async with self.engine.begin() as connection: - return await connection.create_table( - name = collection_name, - schema = PGVectorDataPoint, - exist_ok = True, - ) + if len(Base.metadata.tables.keys()) > 0: + await connection.run_sync(Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]) async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): - async with self.engine.begin() as connection: + async with self.get_async_session() as session: if not await self.has_collection(collection_name): await self.create_collection( - collection_name, + collection_name = collection_name, payload_schema = type(data_points[0].payload), ) - collection = await connection.open_table(collection_name) - data_vectors = await self.embed_data( [data_point.get_embeddable_data() for data_point in data_points] ) - IdType = TypeVar("IdType") - PayloadSchema = TypeVar("PayloadSchema") vector_size = self.embedding_engine.get_vector_size() - class PGVectorDataPoint(Base, Generic[IdType, PayloadSchema]): - id: Mapped[int] = mapped_column(IdType, primary_key=True) - vector = mapped_column(Vector(vector_size)) - payload: mapped_column(PayloadSchema) + class PGVectorDataPoint(Base): + __tablename__ = collection_name + __table_args__ = {'extend_existing': True} + # PGVector requires one column to be the primary key + primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + id: Mapped[type(data_points[0].id)] + payload = Column(JSON) + vector = Column(Vector(vector_size)) + + def __init__(self, id, payload, vector): + self.id = id + self.payload = payload + self.vector = vector pgvector_data_points = [ - PGVectorDataPoint[type(data_point.id), type(data_point.payload)]( + PGVectorDataPoint( id = data_point.id, vector = data_vectors[data_index], - payload = data_point.payload, + payload = serialize_datetime(data_point.payload.dict()) ) for (data_index, data_point) in enumerate(data_points) ] - await collection.add(pgvector_data_points) + session.add_all(pgvector_data_points) + await session.commit() - async def retrieve(self, collection_name: str, data_point_ids: list[str]): - async with self.engine.begin() as connection: - collection = await connection.open_table(collection_name) + async def retrieve(self, collection_name: str, data_point_ids: List[str]): + async with AsyncSession(self.engine) as session: + try: + # Construct the SQL query + # TODO: Switch to using ORM instead of raw query + if len(data_point_ids) == 1: + query = text(f"SELECT * FROM {collection_name} WHERE id = :id") + result = await session.execute(query, {"id": data_point_ids[0]}) + else: + query = text(f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)") + result = await session.execute(query, {"ids": data_point_ids}) - if len(data_point_ids) == 1: - results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas() - else: - results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas() + # Fetch all rows + rows = result.fetchall() - return [ScoredResult( - id = result["id"], - payload = result["payload"], - score = 0, - ) for result in results.to_dict("index").values()] + return [ + ScoredResult( + id=row["id"], + payload=row["payload"], + score=0 + ) + for row in rows + ] + except Exception as e: + print(f"Error retrieving data: {e}") + return [] async def search( self, collection_name: str, - query_text: str = None, - query_vector: List[float] = None, + query_text: Optional[str] = None, + query_vector: Optional[List[float]] = None, limit: int = 5, with_vector: bool = False, - ): + ) -> List[ScoredResult]: + # Validate inputs if query_text is None and query_vector is None: raise ValueError("One of query_text or query_vector must be provided!") + # Get the vector for query_text if provided if query_text and not query_vector: query_vector = (await self.embedding_engine.embed_text([query_text]))[0] - async with self.engine.begin() as connection: - collection = await connection.open_table(collection_name) + # Use async session to connect to the database + async with self.get_async_session() as session: + try: + PGVectorDataPoint = Table(collection_name, Base.metadata, autoload_with=self.engine) - results = await collection.vector_search(query_vector).limit(limit).to_pandas() + closest_items = await session.execute(select(PGVectorDataPoint, PGVectorDataPoint.c.vector.cosine_distance(query_vector).label('similarity')).order_by(PGVectorDataPoint.c.vector.cosine_distance(query_vector)).limit(limit)) - result_values = list(results.to_dict("index").values()) + vector_list = [] + # Extract distances and find min/max for normalization + for vector in closest_items: + #TODO: Add normalization of similarity score + vector_list.append(vector) - min_value = 100 - max_value = 0 + # Create and return ScoredResult objects + return [ + ScoredResult( + id=str(row.id), + payload=row.payload, + score=row.similarity + ) + for row in vector_list + ] - for result in result_values: - value = float(result["_distance"]) - if value > max_value: - max_value = value - if value < min_value: - min_value = value - - normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values] - - return [ScoredResult( - id = str(result["id"]), - payload = result["payload"], - score = normalized_values[value_index], - ) for value_index, result in enumerate(result_values)] + except Exception as e: + print(f"Error during search: {e}") + return [] async def batch_search( self, @@ -152,23 +195,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): limit: int = None, with_vectors: bool = False, ): - query_vectors = await self.embedding_engine.embed_text(query_texts) - - return asyncio.gather( - *[self.search( - collection_name = collection_name, - query_vector = query_vector, - limit = limit, - with_vector = with_vectors, - ) for query_vector in query_vectors] - ) + pass async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): - async with self.engine.begin() as connection: - collection = await connection.open_table(collection_name) - results = await collection.delete(f"id IN {tuple(data_point_ids)}") - return results + pass async def prune(self): # Clean up the database if it was set up as temporary - self.delete_database() + await self.delete_database()