From bc17759c04b66ab8fded9f6c75881b5976a49690 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 17:29:13 +0100 Subject: [PATCH] fix: unwrap connections in PGVectorAdapter --- .../vector/pgvector/PGVectorAdapter.py | 29 ++++++------------- cognee/tests/test_pgvector.py | 2 +- pyproject.toml | 8 ++--- 3 files changed, 14 insertions(+), 25 deletions(-) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 025d361bd..84a32e3e2 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -79,15 +79,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def create_data_points( self, collection_name: str, data_points: List[DataPoint] ): - async with self.get_async_session() as session: - if not await self.has_collection(collection_name): - await self.create_collection( - collection_name=collection_name, - payload_schema=type(data_points[0]), - ) - - data_vectors = await self.embed_data( - [data_point.get_embeddable_data() for data_point in data_points] + if not await self.has_collection(collection_name): + await self.create_collection( + collection_name = collection_name, + payload_schema = type(data_points[0]), ) data_vectors = await self.embed_data( @@ -107,14 +102,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): payload = Column(JSON) vector = Column(Vector(vector_size)) - pgvector_data_points = [ - PGVectorDataPoint( - id=data_point.id, - vector=data_vectors[data_index], - payload=serialize_data(data_point.model_dump()), - ) - for (data_index, data_point) in enumerate(data_points) - ] + def __init__(self, id, payload, vector): + self.id = id + self.payload = payload + self.vector = vector pgvector_data_points = [ PGVectorDataPoint( @@ -136,7 +127,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): await self.create_data_points(f"{index_name}_{index_property_name}", [ IndexSchema( id = data_point.id, - text = getattr(data_point, data_point._metadata["index_fields"][0]), + text = data_point.get_embeddable_data(), ) for data_point in data_points ]) @@ -188,8 +179,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # Get PGVectorDataPoint Table from database PGVectorDataPoint = await self.get_table(collection_name) - closest_items = [] - # Use async session to connect to the database async with self.get_async_session() as session: # Find closest vectors to query_vector diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index cea7c8f72..ac4d08fbb 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -65,7 +65,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) diff --git a/pyproject.toml b/pyproject.toml index 28529b446..c7363d4a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,10 +67,6 @@ anthropic = "^0.26.1" sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"} fastapi-users = {version = "*", extras = ["sqlalchemy"]} alembic = "^1.13.3" -asyncpg = "^0.29.0" -pgvector = "^0.3.5" -psycopg2 = {version = "^2.9.10", optional = true} -falkordb = "^1.0.9" [tool.poetry.extras] filesystem = ["s3fs", "botocore"] @@ -81,6 +77,10 @@ neo4j = ["neo4j"] postgres = ["psycopg2", "pgvector", "asyncpg"] notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] +[tool.poetry.group.postgres.dependencies] +asyncpg = "^0.29.0" +pgvector = "^0.3.5" +psycopg2 = "^2.9.10" [tool.poetry.group.dev.dependencies] pytest = "^7.4.0"