fix: Fix issue with BAML client registry

This commit is contained in:
Igor Ilic 2025-08-05 22:04:36 +02:00
parent dc988b28bd
commit 6a5df922c5
5 changed files with 12 additions and 48 deletions

View file

@ -36,6 +36,7 @@ BAML_LLM_PROVIDER=openai
BAML_LLM_MODEL="gpt-4o-mini"
BAML_LLM_ENDPOINT=""
BAML_LLM_API_KEY="your_api_key"
BAML_LLM_API_VERSION=""
################################################################################
# 🗄️ Relational database settings

View file

@ -151,15 +151,17 @@ jobs:
runs-on: ubuntu-22.04
env:
STRUCTURED_OUTPUT_FRAMEWORK: "BAML"
BAML_LLM_PROVIDER: openai-generic
BAML_LLM_PROVIDER: azure-openai
BAML_LLM_MODEL: ${{ secrets.LLM_MODEL }}
BAML_LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
BAML_LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
BAML_LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}

View file

@ -48,6 +48,7 @@ class LLMConfig(BaseSettings):
baml_llm_endpoint: str = ""
baml_llm_api_key: Optional[str] = None
baml_llm_temperature: float = 0.0
baml_llm_api_version: str = ""
transcription_model: str = "whisper-1"
graph_prompt_path: str = "generate_graph_prompt.txt"
@ -75,11 +76,12 @@ class LLMConfig(BaseSettings):
"model": self.baml_llm_model,
"temperature": self.baml_llm_temperature,
"api_key": self.baml_llm_api_key,
"endpoint": self.baml_llm_endpoint,
"base_url": self.baml_llm_endpoint,
"api_version": self.baml_llm_api_version,
},
)
# Sets the primary client
self.baml_registry.set_primary(self.llm_provider)
self.baml_registry.set_primary(self.baml_llm_provider)
@model_validator(mode="after")
def ensure_env_vars_for_ollama(self) -> "LLMConfig":

View file

@ -7,8 +7,6 @@ from cognee.shared.data_models import SummarizedCode
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
from cognee.infrastructure.llm.config import get_llm_config
config = get_llm_config()
logger = get_logger("extract_summary_baml")
@ -39,22 +37,9 @@ async def extract_summary(content: str, response_model: Type[BaseModel]):
"""
config = get_llm_config()
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="def",
provider="openai",
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
)
baml_registry.set_primary("def")
# Use BAML's SummarizeContent function
summary_result = await b.SummarizeContent(
content, baml_options={"client_registry": baml_registry}
content, baml_options={"client_registry": config.baml_registry}
)
# Convert BAML result to the expected response model
@ -92,19 +77,9 @@ async def extract_code_summary(content: str):
try:
config = get_llm_config()
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="def",
provider="openai",
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
result = await b.SummarizeCode(
content, baml_options={"client_registry": config.baml_registry}
)
baml_registry.set_primary("def")
result = await b.SummarizeCode(content, baml_options={"client_registry": baml_registry})
except Exception as e:
logger.error(
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e

View file

@ -1,12 +1,9 @@
from baml_py import ClientRegistry
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.config import get_llm_config
from cognee.shared.logging_utils import get_logger, setup_logging
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
config = get_llm_config()
async def extract_content_graph(
content: str, response_model: Type[BaseModel], mode: str = "simple"
@ -16,19 +13,6 @@ async def extract_content_graph(
get_logger(level="INFO")
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="extract_content_client",
provider=config.llm_provider,
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
)
baml_registry.set_primary("extract_content_client")
# if response_model:
# # tb = TypeBuilder()
# # country = tb.union \
@ -43,7 +27,7 @@ async def extract_content_graph(
# else:
graph = await b.ExtractContentGraphGeneric(
content, mode=mode, baml_options={"client_registry": baml_registry}
content, mode=mode, baml_options={"client_registry": config.baml_registry}
)
return graph