fix: unwrap connections in PGVectorAdapter
This commit is contained in:
parent
40bb4bc37f
commit
bc17759c04
3 changed files with 14 additions and 25 deletions
|
|
@ -79,15 +79,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
async def create_data_points(
|
async def create_data_points(
|
||||||
self, collection_name: str, data_points: List[DataPoint]
|
self, collection_name: str, data_points: List[DataPoint]
|
||||||
):
|
):
|
||||||
async with self.get_async_session() as session:
|
if not await self.has_collection(collection_name):
|
||||||
if not await self.has_collection(collection_name):
|
await self.create_collection(
|
||||||
await self.create_collection(
|
collection_name = collection_name,
|
||||||
collection_name=collection_name,
|
payload_schema = type(data_points[0]),
|
||||||
payload_schema=type(data_points[0]),
|
|
||||||
)
|
|
||||||
|
|
||||||
data_vectors = await self.embed_data(
|
|
||||||
[data_point.get_embeddable_data() for data_point in data_points]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data_vectors = await self.embed_data(
|
data_vectors = await self.embed_data(
|
||||||
|
|
@ -107,14 +102,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
payload = Column(JSON)
|
payload = Column(JSON)
|
||||||
vector = Column(Vector(vector_size))
|
vector = Column(Vector(vector_size))
|
||||||
|
|
||||||
pgvector_data_points = [
|
def __init__(self, id, payload, vector):
|
||||||
PGVectorDataPoint(
|
self.id = id
|
||||||
id=data_point.id,
|
self.payload = payload
|
||||||
vector=data_vectors[data_index],
|
self.vector = vector
|
||||||
payload=serialize_data(data_point.model_dump()),
|
|
||||||
)
|
|
||||||
for (data_index, data_point) in enumerate(data_points)
|
|
||||||
]
|
|
||||||
|
|
||||||
pgvector_data_points = [
|
pgvector_data_points = [
|
||||||
PGVectorDataPoint(
|
PGVectorDataPoint(
|
||||||
|
|
@ -136,7 +127,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||||
IndexSchema(
|
IndexSchema(
|
||||||
id = data_point.id,
|
id = data_point.id,
|
||||||
text = getattr(data_point, data_point._metadata["index_fields"][0]),
|
text = data_point.get_embeddable_data(),
|
||||||
) for data_point in data_points
|
) for data_point in data_points
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
@ -188,8 +179,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
# Get PGVectorDataPoint Table from database
|
# Get PGVectorDataPoint Table from database
|
||||||
PGVectorDataPoint = await self.get_table(collection_name)
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
|
|
||||||
closest_items = []
|
|
||||||
|
|
||||||
# Use async session to connect to the database
|
# Use async session to connect to the database
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
# Find closest vectors to query_vector
|
# Find closest vectors to query_vector
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ async def main():
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||||
random_node_name = random_node.payload["text"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
|
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
|
||||||
|
|
|
||||||
|
|
@ -67,10 +67,6 @@ anthropic = "^0.26.1"
|
||||||
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
|
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
|
||||||
fastapi-users = {version = "*", extras = ["sqlalchemy"]}
|
fastapi-users = {version = "*", extras = ["sqlalchemy"]}
|
||||||
alembic = "^1.13.3"
|
alembic = "^1.13.3"
|
||||||
asyncpg = "^0.29.0"
|
|
||||||
pgvector = "^0.3.5"
|
|
||||||
psycopg2 = {version = "^2.9.10", optional = true}
|
|
||||||
falkordb = "^1.0.9"
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
filesystem = ["s3fs", "botocore"]
|
filesystem = ["s3fs", "botocore"]
|
||||||
|
|
@ -81,6 +77,10 @@ neo4j = ["neo4j"]
|
||||||
postgres = ["psycopg2", "pgvector", "asyncpg"]
|
postgres = ["psycopg2", "pgvector", "asyncpg"]
|
||||||
notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
|
notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
|
||||||
|
|
||||||
|
[tool.poetry.group.postgres.dependencies]
|
||||||
|
asyncpg = "^0.29.0"
|
||||||
|
pgvector = "^0.3.5"
|
||||||
|
psycopg2 = "^2.9.10"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest = "^7.4.0"
|
pytest = "^7.4.0"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue