cognee/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py
Boris 64b8aac86f
feat: code graph swe integration
Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Co-authored-by: hande-k <handekafkas7@gmail.com>
Co-authored-by: Igor Ilic <igorilic03@gmail.com>
Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
2024-11-27 09:32:29 +01:00

308 lines
11 KiB
Python

import asyncio
from uuid import UUID
from pgvector.sqlalchemy import Vector
from typing import List, Optional, get_type_hints
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from cognee.infrastructure.engine import DataPoint
from .serialize_data import serialize_data
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ...relational.ModelBase import Base
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
def singleton(class_):
# Note: Using this singleton as a decorator to a class removes
# the option to use class methods for that class
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
@singleton
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
def __init__(
self,
connection_string: str,
api_key: Optional[str],
embedding_engine: EmbeddingEngine,
):
self.api_key = api_key
self.embedding_engine = embedding_engine
self.db_uri: str = connection_string
self.engine = create_async_engine(self.db_uri)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
async with self.engine.begin() as connection:
# Load the schema information into the MetaData object
await connection.run_sync(Base.metadata.reflect)
if collection_name in Base.metadata.tables:
return True
else:
return False
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
async def create_data_points(
self, collection_name: str, data_points: List[DataPoint]
):
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name = collection_name,
payload_schema = type(data_points[0]),
)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
)
vector_size = self.embedding_engine.get_vector_size()
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
pgvector_data_points = [
PGVectorDataPoint(
id = data_point.id,
vector = data_vectors[data_index],
payload = serialize_data(data_point.model_dump()),
)
for (data_index, data_point) in enumerate(data_points)
]
async with self.get_async_session() as session:
session.add_all(pgvector_data_points)
await session.commit()
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = data_point.get_embeddable_data(),
) for data_point in data_points
])
async def get_table(self, collection_name: str) -> Table:
"""
Dynamically loads a table using the given collection name
with an async engine.
"""
async with self.engine.begin() as connection:
# Load the schema information into the MetaData object
await connection.run_sync(Base.metadata.reflect)
if collection_name in Base.metadata.tables:
return Base.metadata.tables[collection_name]
else:
raise ValueError(f"Table '{collection_name}' not found.")
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
return [
ScoredResult(
id = UUID(result.id),
payload = result.payload,
score = 0
) for result in results
]
async def get_distances_of_collection(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False
)-> List[ScoredResult]:
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
"similarity"
),
)
.order_by("similarity")
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects
return [
ScoredResult(
id = UUID(str(row.id)),
payload = row.payload,
score = row.similarity
) for row in vector_list
]
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 5,
with_vector: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
"similarity"
),
)
.order_by("similarity")
.limit(limit)
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects
return [
ScoredResult(
id = UUID(str(row.id)),
payload = row.payload,
score = row.similarity
) for row in vector_list
]
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
results = await session.execute(
delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
await session.commit()
return results
async def prune(self):
# Clean up the database if it was set up as temporary
await self.delete_database()