refactor: Move postgres handling to database creation time
This commit is contained in:
parent
d2d0d0de4e
commit
ca2e63bd84
4 changed files with 14 additions and 26 deletions
|
|
@ -66,7 +66,12 @@ def create_vector_engine(
|
|||
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
)
|
||||
|
||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||
try:
|
||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee[postgres]' or 'pip install cognee[postgres-binary]' to use PGVector functionality."
|
||||
)
|
||||
|
||||
return PGVectorAdapter(
|
||||
connection_string,
|
||||
|
|
|
|||
|
|
@ -7,20 +7,7 @@ from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
|||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
try:
|
||||
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
|
||||
except ImportError:
|
||||
# PostgreSQL dependencies not installed, define dummy exceptions
|
||||
class DeadlockDetectedError(Exception):
|
||||
pass
|
||||
|
||||
class DuplicateTableError(Exception):
|
||||
pass
|
||||
|
||||
class UniqueViolationError(Exception):
|
||||
pass
|
||||
|
||||
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -82,14 +69,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
# Has to be imported at class level
|
||||
# Functions reading tables from database need to know what a Vector column type is
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
self.Vector = Vector
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee[postgres]' or 'pip install cognee[postgres-binary]' to use PGVector functionality."
|
||||
)
|
||||
self.Vector = Vector
|
||||
|
||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -76,14 +76,14 @@ class LLMConfig(BaseSettings):
|
|||
def model_post_init(self, __context) -> None:
|
||||
"""Initialize the BAML registry after the model is created."""
|
||||
# Check if BAML is selected as structured output framework but not available
|
||||
if self.structured_output_framework == "baml" and ClientRegistry is None:
|
||||
if self.structured_output_framework.lower() == "baml" and ClientRegistry is None:
|
||||
raise ImportError(
|
||||
"BAML is selected as structured output framework but not available. "
|
||||
"Please install with 'pip install cognee[baml]' to use BAML extraction features."
|
||||
)
|
||||
elif self.structured_output_framework.lower() == "baml" and ClientRegistry is not None:
|
||||
self.baml_registry = ClientRegistry()
|
||||
|
||||
if ClientRegistry is not None:
|
||||
LLMConfig.baml_registry = ClientRegistry()
|
||||
raw_options = {
|
||||
"model": self.baml_llm_model,
|
||||
"temperature": self.baml_llm_temperature,
|
||||
|
|
|
|||
|
|
@ -53,7 +53,8 @@ async def acreate_structured_output(
|
|||
|
||||
# Transform BAML response to proper pydantic reponse model
|
||||
if response_model is str:
|
||||
return str(result)
|
||||
# Note: when a response model is set to string in python result is stored in text property in the BAML response model
|
||||
return str(result.text)
|
||||
return response_model.model_validate(result.dict())
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue