fix: Fix issue with BAML client registry
This commit is contained in:
parent
dc988b28bd
commit
6a5df922c5
5 changed files with 12 additions and 48 deletions
|
|
@ -36,6 +36,7 @@ BAML_LLM_PROVIDER=openai
|
|||
BAML_LLM_MODEL="gpt-4o-mini"
|
||||
BAML_LLM_ENDPOINT=""
|
||||
BAML_LLM_API_KEY="your_api_key"
|
||||
BAML_LLM_API_VERSION=""
|
||||
|
||||
################################################################################
|
||||
# 🗄️ Relational database settings
|
||||
|
|
|
|||
4
.github/workflows/basic_tests.yml
vendored
4
.github/workflows/basic_tests.yml
vendored
|
|
@ -151,15 +151,17 @@ jobs:
|
|||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
STRUCTURED_OUTPUT_FRAMEWORK: "BAML"
|
||||
BAML_LLM_PROVIDER: openai-generic
|
||||
BAML_LLM_PROVIDER: azure-openai
|
||||
BAML_LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
BAML_LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
BAML_LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
BAML_LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ class LLMConfig(BaseSettings):
|
|||
baml_llm_endpoint: str = ""
|
||||
baml_llm_api_key: Optional[str] = None
|
||||
baml_llm_temperature: float = 0.0
|
||||
baml_llm_api_version: str = ""
|
||||
|
||||
transcription_model: str = "whisper-1"
|
||||
graph_prompt_path: str = "generate_graph_prompt.txt"
|
||||
|
|
@ -75,11 +76,12 @@ class LLMConfig(BaseSettings):
|
|||
"model": self.baml_llm_model,
|
||||
"temperature": self.baml_llm_temperature,
|
||||
"api_key": self.baml_llm_api_key,
|
||||
"endpoint": self.baml_llm_endpoint,
|
||||
"base_url": self.baml_llm_endpoint,
|
||||
"api_version": self.baml_llm_api_version,
|
||||
},
|
||||
)
|
||||
# Sets the primary client
|
||||
self.baml_registry.set_primary(self.llm_provider)
|
||||
self.baml_registry.set_primary(self.baml_llm_provider)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_env_vars_for_ollama(self) -> "LLMConfig":
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@ from cognee.shared.data_models import SummarizedCode
|
|||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
config = get_llm_config()
|
||||
|
||||
|
||||
logger = get_logger("extract_summary_baml")
|
||||
|
||||
|
|
@ -39,22 +37,9 @@ async def extract_summary(content: str, response_model: Type[BaseModel]):
|
|||
"""
|
||||
config = get_llm_config()
|
||||
|
||||
baml_registry = ClientRegistry()
|
||||
|
||||
baml_registry.add_llm_client(
|
||||
name="def",
|
||||
provider="openai",
|
||||
options={
|
||||
"model": config.llm_model,
|
||||
"temperature": config.llm_temperature,
|
||||
"api_key": config.llm_api_key,
|
||||
},
|
||||
)
|
||||
baml_registry.set_primary("def")
|
||||
|
||||
# Use BAML's SummarizeContent function
|
||||
summary_result = await b.SummarizeContent(
|
||||
content, baml_options={"client_registry": baml_registry}
|
||||
content, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
|
||||
# Convert BAML result to the expected response model
|
||||
|
|
@ -92,19 +77,9 @@ async def extract_code_summary(content: str):
|
|||
try:
|
||||
config = get_llm_config()
|
||||
|
||||
baml_registry = ClientRegistry()
|
||||
|
||||
baml_registry.add_llm_client(
|
||||
name="def",
|
||||
provider="openai",
|
||||
options={
|
||||
"model": config.llm_model,
|
||||
"temperature": config.llm_temperature,
|
||||
"api_key": config.llm_api_key,
|
||||
},
|
||||
result = await b.SummarizeCode(
|
||||
content, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
baml_registry.set_primary("def")
|
||||
result = await b.SummarizeCode(content, baml_options={"client_registry": baml_registry})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
from baml_py import ClientRegistry
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
||||
|
||||
config = get_llm_config()
|
||||
|
||||
|
||||
async def extract_content_graph(
|
||||
content: str, response_model: Type[BaseModel], mode: str = "simple"
|
||||
|
|
@ -16,19 +13,6 @@ async def extract_content_graph(
|
|||
|
||||
get_logger(level="INFO")
|
||||
|
||||
baml_registry = ClientRegistry()
|
||||
|
||||
baml_registry.add_llm_client(
|
||||
name="extract_content_client",
|
||||
provider=config.llm_provider,
|
||||
options={
|
||||
"model": config.llm_model,
|
||||
"temperature": config.llm_temperature,
|
||||
"api_key": config.llm_api_key,
|
||||
},
|
||||
)
|
||||
baml_registry.set_primary("extract_content_client")
|
||||
|
||||
# if response_model:
|
||||
# # tb = TypeBuilder()
|
||||
# # country = tb.union \
|
||||
|
|
@ -43,7 +27,7 @@ async def extract_content_graph(
|
|||
|
||||
# else:
|
||||
graph = await b.ExtractContentGraphGeneric(
|
||||
content, mode=mode, baml_options={"client_registry": baml_registry}
|
||||
content, mode=mode, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
|
||||
return graph
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue