fix: Fix issue with BAML client registry
This commit is contained in:
parent
dc988b28bd
commit
6a5df922c5
5 changed files with 12 additions and 48 deletions
|
|
@ -36,6 +36,7 @@ BAML_LLM_PROVIDER=openai
|
||||||
BAML_LLM_MODEL="gpt-4o-mini"
|
BAML_LLM_MODEL="gpt-4o-mini"
|
||||||
BAML_LLM_ENDPOINT=""
|
BAML_LLM_ENDPOINT=""
|
||||||
BAML_LLM_API_KEY="your_api_key"
|
BAML_LLM_API_KEY="your_api_key"
|
||||||
|
BAML_LLM_API_VERSION=""
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# 🗄️ Relational database settings
|
# 🗄️ Relational database settings
|
||||||
|
|
|
||||||
4
.github/workflows/basic_tests.yml
vendored
4
.github/workflows/basic_tests.yml
vendored
|
|
@ -151,15 +151,17 @@ jobs:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
env:
|
env:
|
||||||
STRUCTURED_OUTPUT_FRAMEWORK: "BAML"
|
STRUCTURED_OUTPUT_FRAMEWORK: "BAML"
|
||||||
BAML_LLM_PROVIDER: openai-generic
|
BAML_LLM_PROVIDER: azure-openai
|
||||||
BAML_LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
BAML_LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
BAML_LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
BAML_LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
BAML_LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
BAML_LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
BAML_LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
|
||||||
LLM_PROVIDER: openai
|
LLM_PROVIDER: openai
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
|
||||||
EMBEDDING_PROVIDER: openai
|
EMBEDDING_PROVIDER: openai
|
||||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,7 @@ class LLMConfig(BaseSettings):
|
||||||
baml_llm_endpoint: str = ""
|
baml_llm_endpoint: str = ""
|
||||||
baml_llm_api_key: Optional[str] = None
|
baml_llm_api_key: Optional[str] = None
|
||||||
baml_llm_temperature: float = 0.0
|
baml_llm_temperature: float = 0.0
|
||||||
|
baml_llm_api_version: str = ""
|
||||||
|
|
||||||
transcription_model: str = "whisper-1"
|
transcription_model: str = "whisper-1"
|
||||||
graph_prompt_path: str = "generate_graph_prompt.txt"
|
graph_prompt_path: str = "generate_graph_prompt.txt"
|
||||||
|
|
@ -75,11 +76,12 @@ class LLMConfig(BaseSettings):
|
||||||
"model": self.baml_llm_model,
|
"model": self.baml_llm_model,
|
||||||
"temperature": self.baml_llm_temperature,
|
"temperature": self.baml_llm_temperature,
|
||||||
"api_key": self.baml_llm_api_key,
|
"api_key": self.baml_llm_api_key,
|
||||||
"endpoint": self.baml_llm_endpoint,
|
"base_url": self.baml_llm_endpoint,
|
||||||
|
"api_version": self.baml_llm_api_version,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Sets the primary client
|
# Sets the primary client
|
||||||
self.baml_registry.set_primary(self.llm_provider)
|
self.baml_registry.set_primary(self.baml_llm_provider)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def ensure_env_vars_for_ollama(self) -> "LLMConfig":
|
def ensure_env_vars_for_ollama(self) -> "LLMConfig":
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,6 @@ from cognee.shared.data_models import SummarizedCode
|
||||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
config = get_llm_config()
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("extract_summary_baml")
|
logger = get_logger("extract_summary_baml")
|
||||||
|
|
||||||
|
|
@ -39,22 +37,9 @@ async def extract_summary(content: str, response_model: Type[BaseModel]):
|
||||||
"""
|
"""
|
||||||
config = get_llm_config()
|
config = get_llm_config()
|
||||||
|
|
||||||
baml_registry = ClientRegistry()
|
|
||||||
|
|
||||||
baml_registry.add_llm_client(
|
|
||||||
name="def",
|
|
||||||
provider="openai",
|
|
||||||
options={
|
|
||||||
"model": config.llm_model,
|
|
||||||
"temperature": config.llm_temperature,
|
|
||||||
"api_key": config.llm_api_key,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
baml_registry.set_primary("def")
|
|
||||||
|
|
||||||
# Use BAML's SummarizeContent function
|
# Use BAML's SummarizeContent function
|
||||||
summary_result = await b.SummarizeContent(
|
summary_result = await b.SummarizeContent(
|
||||||
content, baml_options={"client_registry": baml_registry}
|
content, baml_options={"client_registry": config.baml_registry}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert BAML result to the expected response model
|
# Convert BAML result to the expected response model
|
||||||
|
|
@ -92,19 +77,9 @@ async def extract_code_summary(content: str):
|
||||||
try:
|
try:
|
||||||
config = get_llm_config()
|
config = get_llm_config()
|
||||||
|
|
||||||
baml_registry = ClientRegistry()
|
result = await b.SummarizeCode(
|
||||||
|
content, baml_options={"client_registry": config.baml_registry}
|
||||||
baml_registry.add_llm_client(
|
|
||||||
name="def",
|
|
||||||
provider="openai",
|
|
||||||
options={
|
|
||||||
"model": config.llm_model,
|
|
||||||
"temperature": config.llm_temperature,
|
|
||||||
"api_key": config.llm_api_key,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
baml_registry.set_primary("def")
|
|
||||||
result = await b.SummarizeCode(content, baml_options={"client_registry": baml_registry})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
|
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,9 @@
|
||||||
from baml_py import ClientRegistry
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
||||||
|
|
||||||
config = get_llm_config()
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_content_graph(
|
async def extract_content_graph(
|
||||||
content: str, response_model: Type[BaseModel], mode: str = "simple"
|
content: str, response_model: Type[BaseModel], mode: str = "simple"
|
||||||
|
|
@ -16,19 +13,6 @@ async def extract_content_graph(
|
||||||
|
|
||||||
get_logger(level="INFO")
|
get_logger(level="INFO")
|
||||||
|
|
||||||
baml_registry = ClientRegistry()
|
|
||||||
|
|
||||||
baml_registry.add_llm_client(
|
|
||||||
name="extract_content_client",
|
|
||||||
provider=config.llm_provider,
|
|
||||||
options={
|
|
||||||
"model": config.llm_model,
|
|
||||||
"temperature": config.llm_temperature,
|
|
||||||
"api_key": config.llm_api_key,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
baml_registry.set_primary("extract_content_client")
|
|
||||||
|
|
||||||
# if response_model:
|
# if response_model:
|
||||||
# # tb = TypeBuilder()
|
# # tb = TypeBuilder()
|
||||||
# # country = tb.union \
|
# # country = tb.union \
|
||||||
|
|
@ -43,7 +27,7 @@ async def extract_content_graph(
|
||||||
|
|
||||||
# else:
|
# else:
|
||||||
graph = await b.ExtractContentGraphGeneric(
|
graph = await b.ExtractContentGraphGeneric(
|
||||||
content, mode=mode, baml_options={"client_registry": baml_registry}
|
content, mode=mode, baml_options={"client_registry": config.baml_registry}
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue