baml fixes
This commit is contained in:
parent
0c46f7dc26
commit
f5ca55b248
2 changed files with 4 additions and 4 deletions
|
|
@ -36,13 +36,13 @@ async def extract_summary(content: str, response_model: Type[BaseModel]):
|
|||
config = get_llm_config()
|
||||
|
||||
# Use BAML's SummarizeContent function
|
||||
summary_result = await b.SummarizeContent(content, baml_options={"tb": config.baml_registry})
|
||||
summary_result = await b.SummarizeContent(content, baml_options={"client_registry": config.baml_registry})
|
||||
|
||||
# Convert BAML result to the expected response model
|
||||
if response_model is SummarizedCode:
|
||||
# If it's asking for SummarizedCode but we got SummarizedContent,
|
||||
# we need to use SummarizeCode instead
|
||||
code_result = await b.SummarizeCode(content, baml_options={"tb": config.baml_registry})
|
||||
code_result = await b.SummarizeCode(content, baml_options={"client_registry": config.baml_registry})
|
||||
return code_result
|
||||
else:
|
||||
# For other models, return the summary result
|
||||
|
|
@ -70,7 +70,7 @@ async def extract_code_summary(content: str):
|
|||
else:
|
||||
try:
|
||||
config = get_llm_config()
|
||||
result = await b.SummarizeCode(content, baml_options={"tb": config.baml_registry})
|
||||
result = await b.SummarizeCode(content, baml_options={"client_registry": config.baml_registry})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ async def extract_content_graph(content: str, response_model: Type[BaseModel]):
|
|||
# tb.Node.add_property("country", country)
|
||||
|
||||
graph = await b.ExtractContentGraph(
|
||||
content, mode="simple", baml_options={"tb": config.baml_registry}
|
||||
content, mode="simple", baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
|
||||
return graph
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue