refactor: Add baml client changes

This commit is contained in:
Igor Ilic 2025-09-08 12:55:33 +02:00
parent 7c27546951
commit 6231882c24
6 changed files with 344 additions and 129 deletions

View file

@ -10,7 +10,7 @@
# BAML files and re-generate this code using: baml-cli generate
# baml-cli is available with the baml package.
__version__ = "0.201.0"
__version__ = "0.206.1"
try:
from baml_py.safe_import import EnsureBamlPyImport

View file

@ -44,6 +44,7 @@ class BamlAsyncClient:
typing.Union[baml_py.baml_py.Collector, typing.List[baml_py.baml_py.Collector]]
] = None,
env: typing.Optional[typing.Dict[str, typing.Optional[str]]] = None,
on_tick: typing.Optional[typing.Callable[[str, baml_py.baml_py.FunctionLog], None]] = None,
) -> "BamlAsyncClient":
options: BamlCallOptions = {}
if tb is not None:
@ -54,6 +55,8 @@ class BamlAsyncClient:
options["collector"] = collector
if env is not None:
options["env"] = env
if on_tick is not None:
options["on_tick"] = on_tick
return BamlAsyncClient(self.__options.merge_options(options))
@property
@ -83,33 +86,52 @@ class BamlAsyncClient:
user_prompt: str,
baml_options: BamlCallOptions = {},
) -> types.DynamicModel:
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="AcreateStructuredOutput",
args={
"content": content,
"system_prompt": system_prompt,
"user_prompt": user_prompt,
},
)
return typing.cast(
types.DynamicModel, result.cast_to(types, types, stream_types, False, __runtime__)
)
# Check if on_tick is provided
if "on_tick" in baml_options:
# Use streaming internally when on_tick is provided
stream = self.stream.AcreateStructuredOutput(
content=content,
system_prompt=system_prompt,
user_prompt=user_prompt,
baml_options=baml_options,
)
return await stream.get_final_response()
else:
# Original non-streaming code
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="AcreateStructuredOutput",
args={
"content": content,
"system_prompt": system_prompt,
"user_prompt": user_prompt,
},
)
return typing.cast(
types.DynamicModel, result.cast_to(types, types, stream_types, False, __runtime__)
)
async def ExtractCategories(
self,
content: str,
baml_options: BamlCallOptions = {},
) -> types.DefaultContentPrediction:
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="ExtractCategories",
args={
"content": content,
},
)
return typing.cast(
types.DefaultContentPrediction,
result.cast_to(types, types, stream_types, False, __runtime__),
)
# Check if on_tick is provided
if "on_tick" in baml_options:
# Use streaming internally when on_tick is provided
stream = self.stream.ExtractCategories(content=content, baml_options=baml_options)
return await stream.get_final_response()
else:
# Original non-streaming code
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="ExtractCategories",
args={
"content": content,
},
)
return typing.cast(
types.DefaultContentPrediction,
result.cast_to(types, types, stream_types, False, __runtime__),
)
async def ExtractContentGraphGeneric(
self,
@ -126,17 +148,29 @@ class BamlAsyncClient:
custom_prompt_content: typing.Optional[str] = None,
baml_options: BamlCallOptions = {},
) -> types.KnowledgeGraph:
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="ExtractContentGraphGeneric",
args={
"content": content,
"mode": mode,
"custom_prompt_content": custom_prompt_content,
},
)
return typing.cast(
types.KnowledgeGraph, result.cast_to(types, types, stream_types, False, __runtime__)
)
# Check if on_tick is provided
if "on_tick" in baml_options:
# Use streaming internally when on_tick is provided
stream = self.stream.ExtractContentGraphGeneric(
content=content,
mode=mode,
custom_prompt_content=custom_prompt_content,
baml_options=baml_options,
)
return await stream.get_final_response()
else:
# Original non-streaming code
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="ExtractContentGraphGeneric",
args={
"content": content,
"mode": mode,
"custom_prompt_content": custom_prompt_content,
},
)
return typing.cast(
types.KnowledgeGraph, result.cast_to(types, types, stream_types, False, __runtime__)
)
async def ExtractDynamicContentGraph(
self,
@ -153,48 +187,75 @@ class BamlAsyncClient:
custom_prompt_content: typing.Optional[str] = None,
baml_options: BamlCallOptions = {},
) -> types.DynamicKnowledgeGraph:
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="ExtractDynamicContentGraph",
args={
"content": content,
"mode": mode,
"custom_prompt_content": custom_prompt_content,
},
)
return typing.cast(
types.DynamicKnowledgeGraph,
result.cast_to(types, types, stream_types, False, __runtime__),
)
# Check if on_tick is provided
if "on_tick" in baml_options:
# Use streaming internally when on_tick is provided
stream = self.stream.ExtractDynamicContentGraph(
content=content,
mode=mode,
custom_prompt_content=custom_prompt_content,
baml_options=baml_options,
)
return await stream.get_final_response()
else:
# Original non-streaming code
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="ExtractDynamicContentGraph",
args={
"content": content,
"mode": mode,
"custom_prompt_content": custom_prompt_content,
},
)
return typing.cast(
types.DynamicKnowledgeGraph,
result.cast_to(types, types, stream_types, False, __runtime__),
)
async def SummarizeCode(
self,
content: str,
baml_options: BamlCallOptions = {},
) -> types.SummarizedCode:
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="SummarizeCode",
args={
"content": content,
},
)
return typing.cast(
types.SummarizedCode, result.cast_to(types, types, stream_types, False, __runtime__)
)
# Check if on_tick is provided
if "on_tick" in baml_options:
# Use streaming internally when on_tick is provided
stream = self.stream.SummarizeCode(content=content, baml_options=baml_options)
return await stream.get_final_response()
else:
# Original non-streaming code
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="SummarizeCode",
args={
"content": content,
},
)
return typing.cast(
types.SummarizedCode, result.cast_to(types, types, stream_types, False, __runtime__)
)
async def SummarizeContent(
self,
content: str,
baml_options: BamlCallOptions = {},
) -> types.SummarizedContent:
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="SummarizeContent",
args={
"content": content,
},
)
return typing.cast(
types.SummarizedContent, result.cast_to(types, types, stream_types, False, __runtime__)
)
# Check if on_tick is provided
if "on_tick" in baml_options:
# Use streaming internally when on_tick is provided
stream = self.stream.SummarizeContent(content=content, baml_options=baml_options)
return await stream.get_final_response()
else:
# Original non-streaming code
result = await self.__options.merge_options(baml_options).call_function_async(
function_name="SummarizeContent",
args={
"content": content,
},
)
return typing.cast(
types.SummarizedContent,
result.cast_to(types, types, stream_types, False, __runtime__),
)
class BamlStreamClient:

File diff suppressed because one or more lines are too long

View file

@ -30,6 +30,10 @@ class BamlCallOptions(typing.TypedDict, total=False):
collector: typing_extensions.NotRequired[
typing.Union[baml_py.baml_py.Collector, typing.List[baml_py.baml_py.Collector]]
]
abort_controller: typing_extensions.NotRequired[baml_py.baml_py.AbortController]
on_tick: typing_extensions.NotRequired[
typing.Callable[[str, baml_py.baml_py.FunctionLog], None]
]
class _ResolvedBamlOptions:
@ -37,6 +41,8 @@ class _ResolvedBamlOptions:
client_registry: typing.Optional[baml_py.baml_py.ClientRegistry]
collectors: typing.List[baml_py.baml_py.Collector]
env_vars: typing.Dict[str, str]
abort_controller: typing.Optional[baml_py.baml_py.AbortController]
on_tick: typing.Optional[typing.Callable[[], None]]
def __init__(
self,
@ -44,11 +50,15 @@ class _ResolvedBamlOptions:
client_registry: typing.Optional[baml_py.baml_py.ClientRegistry],
collectors: typing.List[baml_py.baml_py.Collector],
env_vars: typing.Dict[str, str],
abort_controller: typing.Optional[baml_py.baml_py.AbortController],
on_tick: typing.Optional[typing.Callable[[], None]],
):
self.tb = tb
self.client_registry = client_registry
self.collectors = collectors
self.env_vars = env_vars
self.abort_controller = abort_controller
self.on_tick = on_tick
class DoNotUseDirectlyCallManager:
@ -85,11 +95,27 @@ class DoNotUseDirectlyCallManager:
else:
env_vars.pop(k, None)
abort_controller = self.__baml_options.get("abort_controller")
on_tick = self.__baml_options.get("on_tick")
if on_tick is not None:
collector = baml_py.baml_py.Collector("on-tick-collector")
collectors_as_list.append(collector)
def on_tick_wrapper():
log = collector.last
if log is not None:
on_tick("Unknown", log)
else:
on_tick_wrapper = None
return _ResolvedBamlOptions(
baml_tb,
client_registry,
collectors_as_list,
env_vars,
abort_controller,
on_tick_wrapper,
)
def merge_options(self, options: BamlCallOptions) -> "DoNotUseDirectlyCallManager":
@ -99,6 +125,14 @@ class DoNotUseDirectlyCallManager:
self, *, function_name: str, args: typing.Dict[str, typing.Any]
) -> baml_py.baml_py.FunctionResult:
resolved_options = self.__resolve()
# Check if already aborted
if (
resolved_options.abort_controller is not None
and resolved_options.abort_controller.aborted
):
raise Exception("BamlAbortError: Operation was aborted")
return await __runtime__.call_function(
function_name,
args,
@ -112,12 +146,22 @@ class DoNotUseDirectlyCallManager:
resolved_options.collectors,
# env_vars
resolved_options.env_vars,
# abort_controller
resolved_options.abort_controller,
)
def call_function_sync(
self, *, function_name: str, args: typing.Dict[str, typing.Any]
) -> baml_py.baml_py.FunctionResult:
resolved_options = self.__resolve()
# Check if already aborted
if (
resolved_options.abort_controller is not None
and resolved_options.abort_controller.aborted
):
raise Exception("BamlAbortError: Operation was aborted")
ctx = __ctx__manager__.get()
return __runtime__.call_function_sync(
function_name,
@ -132,6 +176,8 @@ class DoNotUseDirectlyCallManager:
resolved_options.collectors,
# env_vars
resolved_options.env_vars,
# abort_controller
resolved_options.abort_controller,
)
def create_async_stream(
@ -158,6 +204,8 @@ class DoNotUseDirectlyCallManager:
resolved_options.collectors,
# env_vars
resolved_options.env_vars,
# on_tick
resolved_options.on_tick,
)
return ctx, result
@ -170,6 +218,10 @@ class DoNotUseDirectlyCallManager:
baml_py.baml_py.RuntimeContextManager, baml_py.baml_py.SyncFunctionResultStream
]:
resolved_options = self.__resolve()
if resolved_options.on_tick is not None:
raise ValueError(
"on_tick is not supported for sync streams. Please use async streams instead."
)
ctx = __ctx__manager__.get()
result = __runtime__.stream_function_sync(
function_name,
@ -187,6 +239,9 @@ class DoNotUseDirectlyCallManager:
resolved_options.collectors,
# env_vars
resolved_options.env_vars,
# on_tick
# always None! sync streams don't support on_tick
None,
)
return ctx, result
@ -264,3 +319,26 @@ class DoNotUseDirectlyCallManager:
# env_vars
resolved_options.env_vars,
)
def disassemble(function: typing.Callable) -> None:
import inspect
from . import b
if not callable(function):
print(f"disassemble: object {function} is not a Baml function")
return
is_client_method = False
for method_name, _ in inspect.getmembers(b, predicate=inspect.ismethod):
if method_name == function.__name__:
is_client_method = True
break
if not is_client_method:
print(f"disassemble: function {function.__name__} is not a Baml function")
return
print(f"----- function {function.__name__} -----")
__runtime__.disassemble(function.__name__)

View file

@ -57,6 +57,7 @@ class BamlSyncClient:
typing.Union[baml_py.baml_py.Collector, typing.List[baml_py.baml_py.Collector]]
] = None,
env: typing.Optional[typing.Dict[str, typing.Optional[str]]] = None,
on_tick: typing.Optional[typing.Callable[[str, baml_py.baml_py.FunctionLog], None]] = None,
) -> "BamlSyncClient":
options: BamlCallOptions = {}
if tb is not None:
@ -67,6 +68,8 @@ class BamlSyncClient:
options["collector"] = collector
if env is not None:
options["env"] = env
if on_tick is not None:
options["on_tick"] = on_tick
return BamlSyncClient(self.__options.merge_options(options))
@property
@ -96,33 +99,50 @@ class BamlSyncClient:
user_prompt: str,
baml_options: BamlCallOptions = {},
) -> types.DynamicModel:
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="AcreateStructuredOutput",
args={
"content": content,
"system_prompt": system_prompt,
"user_prompt": user_prompt,
},
)
return typing.cast(
types.DynamicModel, result.cast_to(types, types, stream_types, False, __runtime__)
)
# Check if on_tick is provided
if "on_tick" in baml_options:
stream = self.stream.AcreateStructuredOutput(
content=content,
system_prompt=system_prompt,
user_prompt=user_prompt,
baml_options=baml_options,
)
return stream.get_final_response()
else:
# Original non-streaming code
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="AcreateStructuredOutput",
args={
"content": content,
"system_prompt": system_prompt,
"user_prompt": user_prompt,
},
)
return typing.cast(
types.DynamicModel, result.cast_to(types, types, stream_types, False, __runtime__)
)
def ExtractCategories(
self,
content: str,
baml_options: BamlCallOptions = {},
) -> types.DefaultContentPrediction:
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="ExtractCategories",
args={
"content": content,
},
)
return typing.cast(
types.DefaultContentPrediction,
result.cast_to(types, types, stream_types, False, __runtime__),
)
# Check if on_tick is provided
if "on_tick" in baml_options:
stream = self.stream.ExtractCategories(content=content, baml_options=baml_options)
return stream.get_final_response()
else:
# Original non-streaming code
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="ExtractCategories",
args={
"content": content,
},
)
return typing.cast(
types.DefaultContentPrediction,
result.cast_to(types, types, stream_types, False, __runtime__),
)
def ExtractContentGraphGeneric(
self,
@ -139,17 +159,28 @@ class BamlSyncClient:
custom_prompt_content: typing.Optional[str] = None,
baml_options: BamlCallOptions = {},
) -> types.KnowledgeGraph:
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="ExtractContentGraphGeneric",
args={
"content": content,
"mode": mode,
"custom_prompt_content": custom_prompt_content,
},
)
return typing.cast(
types.KnowledgeGraph, result.cast_to(types, types, stream_types, False, __runtime__)
)
# Check if on_tick is provided
if "on_tick" in baml_options:
stream = self.stream.ExtractContentGraphGeneric(
content=content,
mode=mode,
custom_prompt_content=custom_prompt_content,
baml_options=baml_options,
)
return stream.get_final_response()
else:
# Original non-streaming code
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="ExtractContentGraphGeneric",
args={
"content": content,
"mode": mode,
"custom_prompt_content": custom_prompt_content,
},
)
return typing.cast(
types.KnowledgeGraph, result.cast_to(types, types, stream_types, False, __runtime__)
)
def ExtractDynamicContentGraph(
self,
@ -166,48 +197,72 @@ class BamlSyncClient:
custom_prompt_content: typing.Optional[str] = None,
baml_options: BamlCallOptions = {},
) -> types.DynamicKnowledgeGraph:
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="ExtractDynamicContentGraph",
args={
"content": content,
"mode": mode,
"custom_prompt_content": custom_prompt_content,
},
)
return typing.cast(
types.DynamicKnowledgeGraph,
result.cast_to(types, types, stream_types, False, __runtime__),
)
# Check if on_tick is provided
if "on_tick" in baml_options:
stream = self.stream.ExtractDynamicContentGraph(
content=content,
mode=mode,
custom_prompt_content=custom_prompt_content,
baml_options=baml_options,
)
return stream.get_final_response()
else:
# Original non-streaming code
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="ExtractDynamicContentGraph",
args={
"content": content,
"mode": mode,
"custom_prompt_content": custom_prompt_content,
},
)
return typing.cast(
types.DynamicKnowledgeGraph,
result.cast_to(types, types, stream_types, False, __runtime__),
)
def SummarizeCode(
self,
content: str,
baml_options: BamlCallOptions = {},
) -> types.SummarizedCode:
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="SummarizeCode",
args={
"content": content,
},
)
return typing.cast(
types.SummarizedCode, result.cast_to(types, types, stream_types, False, __runtime__)
)
# Check if on_tick is provided
if "on_tick" in baml_options:
stream = self.stream.SummarizeCode(content=content, baml_options=baml_options)
return stream.get_final_response()
else:
# Original non-streaming code
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="SummarizeCode",
args={
"content": content,
},
)
return typing.cast(
types.SummarizedCode, result.cast_to(types, types, stream_types, False, __runtime__)
)
def SummarizeContent(
self,
content: str,
baml_options: BamlCallOptions = {},
) -> types.SummarizedContent:
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="SummarizeContent",
args={
"content": content,
},
)
return typing.cast(
types.SummarizedContent, result.cast_to(types, types, stream_types, False, __runtime__)
)
# Check if on_tick is provided
if "on_tick" in baml_options:
stream = self.stream.SummarizeContent(content=content, baml_options=baml_options)
return stream.get_final_response()
else:
# Original non-streaming code
result = self.__options.merge_options(baml_options).call_function_sync(
function_name="SummarizeContent",
args={
"content": content,
},
)
return typing.cast(
types.SummarizedContent,
result.cast_to(types, types, stream_types, False, __runtime__),
)
class BamlStreamClient:

View file

@ -13,6 +13,9 @@
import typing
from baml_py import type_builder
from baml_py import baml_py
# These are exports, not used here, hence the linter is disabled
from baml_py.baml_py import FieldType, EnumValueBuilder, EnumBuilder, ClassBuilder # noqa: F401 # pylint: disable=unused-import
from .globals import DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME
@ -296,7 +299,13 @@ class DynamicKnowledgeGraphBuilder(DynamicKnowledgeGraphAst):
return self._bldr.property(name).type(type)
def list_properties(self) -> typing.List[typing.Tuple[str, baml_py.ClassPropertyBuilder]]:
return [(name, self._bldr.property(name)) for name in self._properties]
return self._bldr.list_properties()
def remove_property(self, name: str) -> None:
self._bldr.remove_property(name)
def reset(self) -> None:
self._bldr.reset()
class DynamicKnowledgeGraphProperties:
@ -339,7 +348,13 @@ class DynamicModelBuilder(DynamicModelAst):
return self._bldr.property(name).type(type)
def list_properties(self) -> typing.List[typing.Tuple[str, baml_py.ClassPropertyBuilder]]:
return [(name, self._bldr.property(name)) for name in self._properties]
return self._bldr.list_properties()
def remove_property(self, name: str) -> None:
self._bldr.remove_property(name)
def reset(self) -> None:
self._bldr.reset()
class DynamicModelProperties:
@ -619,7 +634,13 @@ class NodeBuilder(NodeAst):
return self._bldr.property(name).type(type)
def list_properties(self) -> typing.List[typing.Tuple[str, baml_py.ClassPropertyBuilder]]:
return [(name, self._bldr.property(name)) for name in self._properties]
return self._bldr.list_properties()
def remove_property(self, name: str) -> None:
self._bldr.remove_property(name)
def reset(self) -> None:
self._bldr.reset()
class NodeProperties: