fix: unwrap connections in PGVectorAdapter
This commit is contained in:
parent
40bb4bc37f
commit
bc17759c04
3 changed files with 14 additions and 25 deletions
|
|
@ -79,15 +79,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
async def create_data_points(
|
||||
self, collection_name: str, data_points: List[DataPoint]
|
||||
):
|
||||
async with self.get_async_session() as session:
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name=collection_name,
|
||||
payload_schema=type(data_points[0]),
|
||||
)
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
[data_point.get_embeddable_data() for data_point in data_points]
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name = collection_name,
|
||||
payload_schema = type(data_points[0]),
|
||||
)
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
|
|
@ -107,14 +102,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
payload = Column(JSON)
|
||||
vector = Column(Vector(vector_size))
|
||||
|
||||
pgvector_data_points = [
|
||||
PGVectorDataPoint(
|
||||
id=data_point.id,
|
||||
vector=data_vectors[data_index],
|
||||
payload=serialize_data(data_point.model_dump()),
|
||||
)
|
||||
for (data_index, data_point) in enumerate(data_points)
|
||||
]
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
|
||||
pgvector_data_points = [
|
||||
PGVectorDataPoint(
|
||||
|
|
@ -136,7 +127,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||
IndexSchema(
|
||||
id = data_point.id,
|
||||
text = getattr(data_point, data_point._metadata["index_fields"][0]),
|
||||
text = data_point.get_embeddable_data(),
|
||||
) for data_point in data_points
|
||||
])
|
||||
|
||||
|
|
@ -188,8 +179,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
closest_items = []
|
||||
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
# Find closest vectors to query_vector
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
|
||||
|
|
|
|||
|
|
@ -67,10 +67,6 @@ anthropic = "^0.26.1"
|
|||
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
|
||||
fastapi-users = {version = "*", extras = ["sqlalchemy"]}
|
||||
alembic = "^1.13.3"
|
||||
asyncpg = "^0.29.0"
|
||||
pgvector = "^0.3.5"
|
||||
psycopg2 = {version = "^2.9.10", optional = true}
|
||||
falkordb = "^1.0.9"
|
||||
|
||||
[tool.poetry.extras]
|
||||
filesystem = ["s3fs", "botocore"]
|
||||
|
|
@ -81,6 +77,10 @@ neo4j = ["neo4j"]
|
|||
postgres = ["psycopg2", "pgvector", "asyncpg"]
|
||||
notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
|
||||
|
||||
[tool.poetry.group.postgres.dependencies]
|
||||
asyncpg = "^0.29.0"
|
||||
pgvector = "^0.3.5"
|
||||
psycopg2 = "^2.9.10"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^7.4.0"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue