fix: unwrap connections in PGVectorAdapter

This commit is contained in:
Boris Arzentar 2024-11-11 17:29:13 +01:00 committed by Leon Luithlen
parent 40bb4bc37f
commit bc17759c04
3 changed files with 14 additions and 25 deletions

View file

@ -79,15 +79,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_data_points(
self, collection_name: str, data_points: List[DataPoint]
):
async with self.get_async_session() as session:
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name=collection_name,
payload_schema=type(data_points[0]),
)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name = collection_name,
payload_schema = type(data_points[0]),
)
data_vectors = await self.embed_data(
@ -107,14 +102,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
payload = Column(JSON)
vector = Column(Vector(vector_size))
pgvector_data_points = [
PGVectorDataPoint(
id=data_point.id,
vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()),
)
for (data_index, data_point) in enumerate(data_points)
]
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
pgvector_data_points = [
PGVectorDataPoint(
@ -136,7 +127,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = getattr(data_point, data_point._metadata["index_fields"][0]),
text = data_point.get_embeddable_data(),
) for data_point in data_points
])
@ -188,8 +179,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
# Find closest vectors to query_vector

View file

@ -65,7 +65,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)

View file

@ -67,10 +67,6 @@ anthropic = "^0.26.1"
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
fastapi-users = {version = "*", extras = ["sqlalchemy"]}
alembic = "^1.13.3"
asyncpg = "^0.29.0"
pgvector = "^0.3.5"
psycopg2 = {version = "^2.9.10", optional = true}
falkordb = "^1.0.9"
[tool.poetry.extras]
filesystem = ["s3fs", "botocore"]
@ -81,6 +77,10 @@ neo4j = ["neo4j"]
postgres = ["psycopg2", "pgvector", "asyncpg"]
notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
[tool.poetry.group.postgres.dependencies]
asyncpg = "^0.29.0"
pgvector = "^0.3.5"
psycopg2 = "^2.9.10"
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"