feat: Add PGVector support
Added first working iteration of PGVector for cognee, some important funcionality is still missing, but the core is there. Also some refactoring will be necessary. Feature: #COG-170
This commit is contained in:
parent
268396abdc
commit
9fbf2d857f
6 changed files with 145 additions and 92 deletions
|
|
@ -14,9 +14,12 @@ from cognee.tasks.ingestion import get_dlt_destination
|
||||||
from cognee.modules.users.permissions.methods import give_permission_on_document
|
from cognee.modules.users.permissions.methods import give_permission_on_document
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.data.methods import create_dataset
|
from cognee.modules.data.methods import create_dataset
|
||||||
|
from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables
|
||||||
|
from cognee.infrastructure.databases.vector import create_db_and_tables as create_vector_db_and_tables
|
||||||
|
|
||||||
async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None):
|
async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None):
|
||||||
await create_db_and_tables()
|
await create_relational_db_and_tables()
|
||||||
|
await create_vector_db_and_tables()
|
||||||
|
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
if "data://" in data:
|
if "data://" in data:
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,12 @@ from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.pipelines import run_tasks, Task
|
from cognee.modules.pipelines import run_tasks, Task
|
||||||
from cognee.tasks.ingestion import save_data_to_storage, ingest_data
|
from cognee.tasks.ingestion import save_data_to_storage, ingest_data
|
||||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables
|
||||||
|
from cognee.infrastructure.databases.vector import create_db_and_tables as create_vector_db_and_tables
|
||||||
|
|
||||||
async def add(data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_name: str = "main_dataset", user: User = None):
|
async def add(data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_name: str = "main_dataset", user: User = None):
|
||||||
await create_db_and_tables()
|
await create_relational_db_and_tables()
|
||||||
|
await create_vector_db_and_tables()
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
|
||||||
|
|
@ -4,3 +4,4 @@ from .models.CollectionConfig import CollectionConfig
|
||||||
from .vector_db_interface import VectorDBInterface
|
from .vector_db_interface import VectorDBInterface
|
||||||
from .config import get_vectordb_config
|
from .config import get_vectordb_config
|
||||||
from .get_vector_engine import get_vector_engine
|
from .get_vector_engine import get_vector_engine
|
||||||
|
from .create_db_and_tables import create_db_and_tables
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
from ..relational.ModelBase import Base
|
||||||
|
from .get_vector_engine import get_vector_engine, get_vectordb_config
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
async def create_db_and_tables():
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
if vector_config.vector_engine_provider == "pgvector":
|
||||||
|
async with vector_engine.engine.begin() as connection:
|
||||||
|
if len(Base.metadata.tables.keys()) > 0:
|
||||||
|
await connection.run_sync(Base.metadata.create_all)
|
||||||
|
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -42,7 +42,8 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
# Get name of vector database
|
# Get name of vector database
|
||||||
db_name = config["vector_db_name"]
|
db_name = config["vector_db_name"]
|
||||||
|
|
||||||
connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
connection_string: str = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||||
|
|
||||||
return PGVectorAdapter(connection_string,
|
return PGVectorAdapter(connection_string,
|
||||||
config["vector_db_key"],
|
config["vector_db_key"],
|
||||||
embedding_engine
|
embedding_engine
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,35 @@
|
||||||
from typing import List, Optional, get_type_hints, Generic, TypeVar
|
from typing import List, Optional, get_type_hints, Any, Dict
|
||||||
import asyncio
|
from sqlalchemy import text, select
|
||||||
|
from sqlalchemy import JSON, Column, Table
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
|
|
||||||
from ..vector_db_interface import VectorDBInterface, DataPoint
|
from ..vector_db_interface import VectorDBInterface, DataPoint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeBase, mapped_column
|
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector
|
||||||
|
|
||||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||||
|
from ...relational.ModelBase import Base
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
# Define the models
|
# TODO: Find better location for function
|
||||||
class Base(DeclarativeBase):
|
def serialize_datetime(data):
|
||||||
pass
|
"""Recursively convert datetime objects in dictionaries/lists to ISO format."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return {key: serialize_datetime(value) for key, value in data.items()}
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return [serialize_datetime(item) for item in data]
|
||||||
|
elif isinstance(data, datetime):
|
||||||
|
return data.isoformat() # Convert datetime to ISO 8601 string
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
async def create_vector_extension(self):
|
|
||||||
async with self.get_async_session() as session:
|
|
||||||
await session.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
|
||||||
|
|
||||||
def __init__(self, connection_string: str,
|
def __init__(self, connection_string: str,
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
embedding_engine: EmbeddingEngine
|
embedding_engine: EmbeddingEngine
|
||||||
):
|
):
|
||||||
|
|
@ -29,121 +37,156 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
self.db_uri: str = connection_string
|
self.db_uri: str = connection_string
|
||||||
|
|
||||||
self.engine = create_async_engine(connection_string)
|
self.engine = create_async_engine(self.db_uri, echo=True)
|
||||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||||
self.create_vector_extension()
|
|
||||||
|
|
||||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||||
return await self.embedding_engine.embed_text(data)
|
return await self.embedding_engine.embed_text(data)
|
||||||
|
|
||||||
async def has_collection(self, collection_name: str) -> bool:
|
async def has_collection(self, collection_name: str) -> bool:
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
collection_names = await connection.table_names()
|
#TODO: Switch to using ORM instead of raw query
|
||||||
return collection_name in collection_names
|
result = await connection.execute(
|
||||||
|
text("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';")
|
||||||
|
)
|
||||||
|
tables = result.fetchall()
|
||||||
|
for table in tables:
|
||||||
|
if collection_name == table[0]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema = None):
|
async def create_collection(self, collection_name: str, payload_schema = None):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
data_point_types = get_type_hints(DataPoint)
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
class PGVectorDataPoint(Base):
|
|
||||||
id: Mapped[int] = mapped_column(data_point_types["id"], primary_key=True)
|
|
||||||
vector = mapped_column(Vector(vector_size))
|
|
||||||
payload: mapped_column(payload_schema)
|
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
|
|
||||||
|
class PGVectorDataPoint(Base):
|
||||||
|
__tablename__ = collection_name
|
||||||
|
__table_args__ = {'extend_existing': True}
|
||||||
|
# PGVector requires one column to be the primary key
|
||||||
|
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||||
|
id: Mapped[data_point_types["id"]]
|
||||||
|
payload = Column(JSON)
|
||||||
|
vector = Column(Vector(vector_size))
|
||||||
|
|
||||||
|
def __init__(self, id, payload, vector):
|
||||||
|
self.id = id
|
||||||
|
self.payload = payload
|
||||||
|
self.vector = vector
|
||||||
|
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
return await connection.create_table(
|
if len(Base.metadata.tables.keys()) > 0:
|
||||||
name = collection_name,
|
await connection.run_sync(Base.metadata.create_all, tables=[PGVectorDataPoint.__table__])
|
||||||
schema = PGVectorDataPoint,
|
|
||||||
exist_ok = True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||||
async with self.engine.begin() as connection:
|
async with self.get_async_session() as session:
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
await self.create_collection(
|
await self.create_collection(
|
||||||
collection_name,
|
collection_name = collection_name,
|
||||||
payload_schema = type(data_points[0].payload),
|
payload_schema = type(data_points[0].payload),
|
||||||
)
|
)
|
||||||
|
|
||||||
collection = await connection.open_table(collection_name)
|
|
||||||
|
|
||||||
data_vectors = await self.embed_data(
|
data_vectors = await self.embed_data(
|
||||||
[data_point.get_embeddable_data() for data_point in data_points]
|
[data_point.get_embeddable_data() for data_point in data_points]
|
||||||
)
|
)
|
||||||
|
|
||||||
IdType = TypeVar("IdType")
|
|
||||||
PayloadSchema = TypeVar("PayloadSchema")
|
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
class PGVectorDataPoint(Base, Generic[IdType, PayloadSchema]):
|
class PGVectorDataPoint(Base):
|
||||||
id: Mapped[int] = mapped_column(IdType, primary_key=True)
|
__tablename__ = collection_name
|
||||||
vector = mapped_column(Vector(vector_size))
|
__table_args__ = {'extend_existing': True}
|
||||||
payload: mapped_column(PayloadSchema)
|
# PGVector requires one column to be the primary key
|
||||||
|
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||||
|
id: Mapped[type(data_points[0].id)]
|
||||||
|
payload = Column(JSON)
|
||||||
|
vector = Column(Vector(vector_size))
|
||||||
|
|
||||||
|
def __init__(self, id, payload, vector):
|
||||||
|
self.id = id
|
||||||
|
self.payload = payload
|
||||||
|
self.vector = vector
|
||||||
|
|
||||||
pgvector_data_points = [
|
pgvector_data_points = [
|
||||||
PGVectorDataPoint[type(data_point.id), type(data_point.payload)](
|
PGVectorDataPoint(
|
||||||
id = data_point.id,
|
id = data_point.id,
|
||||||
vector = data_vectors[data_index],
|
vector = data_vectors[data_index],
|
||||||
payload = data_point.payload,
|
payload = serialize_datetime(data_point.payload.dict())
|
||||||
) for (data_index, data_point) in enumerate(data_points)
|
) for (data_index, data_point) in enumerate(data_points)
|
||||||
]
|
]
|
||||||
|
|
||||||
await collection.add(pgvector_data_points)
|
session.add_all(pgvector_data_points)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||||
async with self.engine.begin() as connection:
|
async with AsyncSession(self.engine) as session:
|
||||||
collection = await connection.open_table(collection_name)
|
try:
|
||||||
|
# Construct the SQL query
|
||||||
|
# TODO: Switch to using ORM instead of raw query
|
||||||
|
if len(data_point_ids) == 1:
|
||||||
|
query = text(f"SELECT * FROM {collection_name} WHERE id = :id")
|
||||||
|
result = await session.execute(query, {"id": data_point_ids[0]})
|
||||||
|
else:
|
||||||
|
query = text(f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)")
|
||||||
|
result = await session.execute(query, {"ids": data_point_ids})
|
||||||
|
|
||||||
if len(data_point_ids) == 1:
|
# Fetch all rows
|
||||||
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
|
rows = result.fetchall()
|
||||||
else:
|
|
||||||
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
|
|
||||||
|
|
||||||
return [ScoredResult(
|
return [
|
||||||
id = result["id"],
|
ScoredResult(
|
||||||
payload = result["payload"],
|
id=row["id"],
|
||||||
score = 0,
|
payload=row["payload"],
|
||||||
) for result in results.to_dict("index").values()]
|
score=0
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error retrieving data: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: str = None,
|
query_text: Optional[str] = None,
|
||||||
query_vector: List[float] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
):
|
) -> List[ScoredResult]:
|
||||||
|
# Validate inputs
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
|
# Get the vector for query_text if provided
|
||||||
if query_text and not query_vector:
|
if query_text and not query_vector:
|
||||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
||||||
async with self.engine.begin() as connection:
|
# Use async session to connect to the database
|
||||||
collection = await connection.open_table(collection_name)
|
async with self.get_async_session() as session:
|
||||||
|
try:
|
||||||
|
PGVectorDataPoint = Table(collection_name, Base.metadata, autoload_with=self.engine)
|
||||||
|
|
||||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
closest_items = await session.execute(select(PGVectorDataPoint, PGVectorDataPoint.c.vector.cosine_distance(query_vector).label('similarity')).order_by(PGVectorDataPoint.c.vector.cosine_distance(query_vector)).limit(limit))
|
||||||
|
|
||||||
result_values = list(results.to_dict("index").values())
|
vector_list = []
|
||||||
|
# Extract distances and find min/max for normalization
|
||||||
|
for vector in closest_items:
|
||||||
|
#TODO: Add normalization of similarity score
|
||||||
|
vector_list.append(vector)
|
||||||
|
|
||||||
min_value = 100
|
# Create and return ScoredResult objects
|
||||||
max_value = 0
|
return [
|
||||||
|
ScoredResult(
|
||||||
|
id=str(row.id),
|
||||||
|
payload=row.payload,
|
||||||
|
score=row.similarity
|
||||||
|
)
|
||||||
|
for row in vector_list
|
||||||
|
]
|
||||||
|
|
||||||
for result in result_values:
|
except Exception as e:
|
||||||
value = float(result["_distance"])
|
print(f"Error during search: {e}")
|
||||||
if value > max_value:
|
return []
|
||||||
max_value = value
|
|
||||||
if value < min_value:
|
|
||||||
min_value = value
|
|
||||||
|
|
||||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
|
|
||||||
|
|
||||||
return [ScoredResult(
|
|
||||||
id = str(result["id"]),
|
|
||||||
payload = result["payload"],
|
|
||||||
score = normalized_values[value_index],
|
|
||||||
) for value_index, result in enumerate(result_values)]
|
|
||||||
|
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
self,
|
self,
|
||||||
|
|
@ -152,23 +195,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
limit: int = None,
|
limit: int = None,
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
):
|
):
|
||||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
pass
|
||||||
|
|
||||||
return asyncio.gather(
|
|
||||||
*[self.search(
|
|
||||||
collection_name = collection_name,
|
|
||||||
query_vector = query_vector,
|
|
||||||
limit = limit,
|
|
||||||
with_vector = with_vectors,
|
|
||||||
) for query_vector in query_vectors]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||||
async with self.engine.begin() as connection:
|
pass
|
||||||
collection = await connection.open_table(collection_name)
|
|
||||||
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self):
|
||||||
# Clean up the database if it was set up as temporary
|
# Clean up the database if it was set up as temporary
|
||||||
self.delete_database()
|
await self.delete_database()
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue