From 4f354ba534e0dbf78b45744a4beba53e84805031 Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Mon, 24 Feb 2025 20:35:40 +0100 Subject: [PATCH] fix: reuse PostgreSQL database connections (#574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Fix PostgreSQL database connection problems ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin ## Summary by CodeRabbit - **Refactor** - Improved the system’s database connection process to enhance compatibility across multiple relational databases. The application now dynamically selects the optimal connection method—reusing established connections when possible—to ensure improved stability and performance without affecting the public interface. - Streamlined the creation of the embedding engine by removing it as a parameter and generating it internally. - Removed dependency on the embedding engine in the vector engine retrieval process. --- .../databases/vector/create_vector_engine.py | 5 ++++- .../databases/vector/get_vector_engine.py | 3 +-- .../databases/vector/pgvector/PGVectorAdapter.py | 14 ++++++++++++-- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 34e08d156..f5d458061 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,14 +1,17 @@ +from .embeddings import get_embedding_engine + from functools import lru_cache @lru_cache def create_vector_engine( - embedding_engine, vector_db_url: str, vector_db_port: str, vector_db_key: str, vector_db_provider: str, ): + embedding_engine = get_embedding_engine() + if vector_db_provider == "weaviate": from .weaviate_db import WeaviateAdapter diff --git a/cognee/infrastructure/databases/vector/get_vector_engine.py b/cognee/infrastructure/databases/vector/get_vector_engine.py index 280a55eee..1523f334c 100644 --- a/cognee/infrastructure/databases/vector/get_vector_engine.py +++ b/cognee/infrastructure/databases/vector/get_vector_engine.py @@ -1,7 +1,6 @@ from .config import get_vectordb_config -from .embeddings import get_embedding_engine from .create_vector_engine import create_vector_engine def get_vector_engine(): - return create_vector_engine(get_embedding_engine(), **get_vectordb_config().to_dict()) + return create_vector_engine(**get_vectordb_config().to_dict()) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 2426788cc..789ed2e84 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -10,6 +10,7 @@ from cognee.exceptions import InvalidValueError from cognee.infrastructure.databases.exceptions import EntityNotFoundError from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine.utils import parse_id +from cognee.infrastructure.databases.relational import get_relational_engine from ...relational.ModelBase import Base from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter @@ -36,8 +37,17 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): self.api_key = api_key self.embedding_engine = embedding_engine self.db_uri: str = connection_string - self.engine = create_async_engine(self.db_uri) - self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) + + relational_db = get_relational_engine() + + # If postgreSQL is used we must use the same engine and sessionmaker + if relational_db.engine.dialect.name == "postgresql": + self.engine = relational_db.engine + self.sessionmaker = relational_db.sessionmaker + else: + # If not create new instances of engine and sessionmaker + self.engine = create_async_engine(self.db_uri) + self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) # Has to be imported at class level # Functions reading tables from database need to know what a Vector column type is