diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 34e08d156..f5d458061 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,14 +1,17 @@ +from .embeddings import get_embedding_engine + from functools import lru_cache @lru_cache def create_vector_engine( - embedding_engine, vector_db_url: str, vector_db_port: str, vector_db_key: str, vector_db_provider: str, ): + embedding_engine = get_embedding_engine() + if vector_db_provider == "weaviate": from .weaviate_db import WeaviateAdapter diff --git a/cognee/infrastructure/databases/vector/get_vector_engine.py b/cognee/infrastructure/databases/vector/get_vector_engine.py index 280a55eee..1523f334c 100644 --- a/cognee/infrastructure/databases/vector/get_vector_engine.py +++ b/cognee/infrastructure/databases/vector/get_vector_engine.py @@ -1,7 +1,6 @@ from .config import get_vectordb_config -from .embeddings import get_embedding_engine from .create_vector_engine import create_vector_engine def get_vector_engine(): - return create_vector_engine(get_embedding_engine(), **get_vectordb_config().to_dict()) + return create_vector_engine(**get_vectordb_config().to_dict()) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 2426788cc..789ed2e84 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -10,6 +10,7 @@ from cognee.exceptions import InvalidValueError from cognee.infrastructure.databases.exceptions import EntityNotFoundError from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine.utils import parse_id +from cognee.infrastructure.databases.relational import get_relational_engine from ...relational.ModelBase import Base from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter @@ -36,8 +37,17 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): self.api_key = api_key self.embedding_engine = embedding_engine self.db_uri: str = connection_string - self.engine = create_async_engine(self.db_uri) - self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) + + relational_db = get_relational_engine() + + # If postgreSQL is used we must use the same engine and sessionmaker + if relational_db.engine.dialect.name == "postgresql": + self.engine = relational_db.engine + self.sessionmaker = relational_db.sessionmaker + else: + # If not create new instances of engine and sessionmaker + self.engine = create_async_engine(self.db_uri) + self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) # Has to be imported at class level # Functions reading tables from database need to know what a Vector column type is