From 6a5df922c5a6558ab3c7d2ac72f01851642cfbc0 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 5 Aug 2025 22:04:36 +0200 Subject: [PATCH] fix: Fix issue with BAML client registry --- .env.template | 1 + .github/workflows/basic_tests.yml | 4 ++- cognee/infrastructure/llm/config.py | 6 ++-- .../baml_src/extraction/extract_summary.py | 31 ++----------------- .../knowledge_graph/extract_content_graph.py | 18 +---------- 5 files changed, 12 insertions(+), 48 deletions(-) diff --git a/.env.template b/.env.template index 0beba7d1c..f9e9bdf6e 100644 --- a/.env.template +++ b/.env.template @@ -36,6 +36,7 @@ BAML_LLM_PROVIDER=openai BAML_LLM_MODEL="gpt-4o-mini" BAML_LLM_ENDPOINT="" BAML_LLM_API_KEY="your_api_key" +BAML_LLM_API_VERSION="" ################################################################################ # 🗄️ Relational database settings diff --git a/.github/workflows/basic_tests.yml b/.github/workflows/basic_tests.yml index 099044d71..28c708ddb 100644 --- a/.github/workflows/basic_tests.yml +++ b/.github/workflows/basic_tests.yml @@ -151,15 +151,17 @@ jobs: runs-on: ubuntu-22.04 env: STRUCTURED_OUTPUT_FRAMEWORK: "BAML" - BAML_LLM_PROVIDER: openai-generic + BAML_LLM_PROVIDER: azure-openai BAML_LLM_MODEL: ${{ secrets.LLM_MODEL }} BAML_LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} BAML_LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + BAML_LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} EMBEDDING_PROVIDER: openai EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 54153ed38..f31aada33 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -48,6 +48,7 @@ class LLMConfig(BaseSettings): baml_llm_endpoint: str = "" baml_llm_api_key: Optional[str] = None baml_llm_temperature: float = 0.0 + baml_llm_api_version: str = "" transcription_model: str = "whisper-1" graph_prompt_path: str = "generate_graph_prompt.txt" @@ -75,11 +76,12 @@ class LLMConfig(BaseSettings): "model": self.baml_llm_model, "temperature": self.baml_llm_temperature, "api_key": self.baml_llm_api_key, - "endpoint": self.baml_llm_endpoint, + "base_url": self.baml_llm_endpoint, + "api_version": self.baml_llm_api_version, }, ) # Sets the primary client - self.baml_registry.set_primary(self.llm_provider) + self.baml_registry.set_primary(self.baml_llm_provider) @model_validator(mode="after") def ensure_env_vars_for_ollama(self) -> "LLMConfig": diff --git a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_summary.py b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_summary.py index efedb0f02..697a52a45 100644 --- a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_summary.py +++ b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_summary.py @@ -7,8 +7,6 @@ from cognee.shared.data_models import SummarizedCode from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b from cognee.infrastructure.llm.config import get_llm_config -config = get_llm_config() - logger = get_logger("extract_summary_baml") @@ -39,22 +37,9 @@ async def extract_summary(content: str, response_model: Type[BaseModel]): """ config = get_llm_config() - baml_registry = ClientRegistry() - - baml_registry.add_llm_client( - name="def", - provider="openai", - options={ - "model": config.llm_model, - "temperature": config.llm_temperature, - "api_key": config.llm_api_key, - }, - ) - baml_registry.set_primary("def") - # Use BAML's SummarizeContent function summary_result = await b.SummarizeContent( - content, baml_options={"client_registry": baml_registry} + content, baml_options={"client_registry": config.baml_registry} ) # Convert BAML result to the expected response model @@ -92,19 +77,9 @@ async def extract_code_summary(content: str): try: config = get_llm_config() - baml_registry = ClientRegistry() - - baml_registry.add_llm_client( - name="def", - provider="openai", - options={ - "model": config.llm_model, - "temperature": config.llm_temperature, - "api_key": config.llm_api_key, - }, + result = await b.SummarizeCode( + content, baml_options={"client_registry": config.baml_registry} ) - baml_registry.set_primary("def") - result = await b.SummarizeCode(content, baml_options={"client_registry": baml_registry}) except Exception as e: logger.error( "Failed to extract code summary with BAML, falling back to mock summary", exc_info=e diff --git a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py index 89f4ef8b7..d98112434 100644 --- a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +++ b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py @@ -1,12 +1,9 @@ -from baml_py import ClientRegistry from typing import Type from pydantic import BaseModel from cognee.infrastructure.llm.config import get_llm_config from cognee.shared.logging_utils import get_logger, setup_logging from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b -config = get_llm_config() - async def extract_content_graph( content: str, response_model: Type[BaseModel], mode: str = "simple" @@ -16,19 +13,6 @@ async def extract_content_graph( get_logger(level="INFO") - baml_registry = ClientRegistry() - - baml_registry.add_llm_client( - name="extract_content_client", - provider=config.llm_provider, - options={ - "model": config.llm_model, - "temperature": config.llm_temperature, - "api_key": config.llm_api_key, - }, - ) - baml_registry.set_primary("extract_content_client") - # if response_model: # # tb = TypeBuilder() # # country = tb.union \ @@ -43,7 +27,7 @@ async def extract_content_graph( # else: graph = await b.ExtractContentGraphGeneric( - content, mode=mode, baml_options={"client_registry": baml_registry} + content, mode=mode, baml_options={"client_registry": config.baml_registry} ) return graph