From f4a79454734efff7d5f79c8de684ec7d6d0d1d9d Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Mon, 22 Sep 2025 11:20:48 +0200 Subject: [PATCH] refactor: Move creation of baml dynamic type to own file --- .../extraction/acreate_structured_output.py | 144 +++--------------- .../extraction/create_dynamic_baml_type.py | 122 +++++++++++++++ 2 files changed, 141 insertions(+), 125 deletions(-) create mode 100644 cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/create_dynamic_baml_type.py diff --git a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py index 37bb7da1e..8efcce23d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +++ b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py @@ -1,149 +1,43 @@ import asyncio from typing import Type from cognee.shared.logging_utils import get_logger -from baml_py.baml_py import ClassBuilder from cognee.infrastructure.llm.config import get_llm_config - -from typing import Union, Literal -from enum import Enum -from datetime import datetime - +from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction.create_dynamic_baml_type import ( + create_dynamic_baml_type, +) from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type_builder import ( TypeBuilder, ) from cognee.infrastructure.llm.structured_output_framework.baml.baml_client import b from pydantic import BaseModel -from typing import get_origin, get_args + logger = get_logger() -def create_dynamic_baml_type(tb, baml_model, pydantic_model): - if pydantic_model is str: - baml_model.add_property("text", tb.string()) - return tb - - def map_type(field_type, field_info): - """ - Convert a Python / Pydantic type -> BAML TypeBuilder representation. - """ - - origin = get_origin(field_type) # e.g. list[…] -> list - args = get_args(field_type) # e.g. list[int] -> (int,) - - # ------------------------------------------------------------------ - # 1. Optional / Union ------------------------------------------------ - # ------------------------------------------------------------------ - if origin is Union: - non_none_args = [t for t in args if t is not type(None)] - - # Optional[T] ⇢ exactly (T, NoneType) - if len(args) == 2 and len(non_none_args) == 1: - return map_type(non_none_args[0], field_info).optional() - - # Plain Union[A, B, …] - return tb.union(*(map_type(t, field_info) for t in args)) - - # ------------------------------------------------------------------ - # 2. List / Sequence ------------------------------------------------- - # ------------------------------------------------------------------ - if origin in (list,): - (inner_type,) = args # list has exactly one parameter - return map_type(inner_type, field_info).list() - - # ------------------------------------------------------------------ - # 3. Dict / Map ------------------------------------------------------- - # ------------------------------------------------------------------ - def _is_enum_subclass(key_type) -> bool: - """Guarded issubclass – returns False when tp is not a class.""" - return isinstance(key_type, type) and issubclass(key_type, Enum) - - if origin in (dict,): - key_type, value_type = args or (str, object) - - if key_type is not str and not _is_enum_subclass(key_type): - raise ValueError("BAML maps only allow 'str' or Enum subclasses as keys") - - return tb.map( - map_type(key_type, field_info), # mostly tb.string() or enum - map_type(value_type, field_info), - ) - - # ------------------------------------------------------------------ - # 4. Enum ------------------------------------------------------------- - # ------------------------------------------------------------------ - if _is_enum_subclass(field_type): - enum_builder = tb.add_enum(field_type.__name__) - for member in field_type: - enum_builder.add_value(member.name) - return enum_builder.type() - - # ------------------------------------------------------------------ - # 5. Nested Pydantic model ------------------------------------------- - # ------------------------------------------------------------------ - from pydantic import BaseModel # local import - - if isinstance(field_type, type) and issubclass(field_type, BaseModel): - try: - # Create nested class if it doesn't exist - nested_class = tb.add_class(field_type.__name__) - # Find dynamic types of nested class - create_dynamic_baml_type(tb, nested_class, field_type) - except ValueError: - # If nested class already exists get it - nested_class = tb._tb.class_(field_type.__name__) - - # Return nested class model - if isinstance(nested_class, ClassBuilder): - # Different nested_class objects have different syntax for type information - # If nested class already exists type information can be found using the field method - return nested_class.field() - else: - # If nested class was created type information can be found using type method - return nested_class.type() - - primitive_map = { - str: tb.string(), - int: tb.int(), - float: tb.float(), - bool: tb.bool(), - datetime: tb.string(), # BAML has no native datetime - } - if field_type in primitive_map: - return primitive_map[field_type] - - raise ValueError(f"Unsupported type for BAML mapping: {field_type}") - - fields = pydantic_model.model_fields - - # Add fields - for field_name, field_info in fields.items(): - field_type = field_info.annotation - baml_type = map_type(field_type, field_info) - - # Add property with type - prop = baml_model.add_property(field_name, baml_type) - - # Add description if available - if field_info.description: - prop.description(field_info.description) - - return tb - - async def acreate_structured_output( text_input: str, system_prompt: str, response_model: Type[BaseModel] ): """ - Extract summary using BAML framework. + Generate a response from a user query. - Args: - content: The content to summarize - response_model: The Pydantic model type for the response + This method asynchronously creates structured output by sending a request through BAML + using the provided parameters to generate a completion based on the user input and + system prompt. + + Parameters: + ----------- + + - text_input (str): The input text provided by the user for generating a response. + - system_prompt (str): The system's prompt to guide the model's response. + - response_model (Type[BaseModel]): The expected model type for the response. Returns: - BaseModel: The summarized content in the specified format + -------- + + - BaseModel: A structured output generated by the model, returned as an instance of + BaseModel. """ config = get_llm_config() diff --git a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/create_dynamic_baml_type.py b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/create_dynamic_baml_type.py new file mode 100644 index 000000000..4b77c15d5 --- /dev/null +++ b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/create_dynamic_baml_type.py @@ -0,0 +1,122 @@ +from typing import Union +from enum import Enum +from datetime import datetime +from typing import get_origin, get_args + +from baml_py.baml_py import ClassBuilder +from cognee.shared.logging_utils import get_logger + +logger = get_logger() + + +def create_dynamic_baml_type(tb, baml_model, pydantic_model): + if pydantic_model is str: + baml_model.add_property("text", tb.string()) + return tb + + def map_type(field_type, field_info): + """ + Convert a Python / Pydantic type -> BAML TypeBuilder representation. + """ + + origin = get_origin(field_type) # e.g. list[…] -> list + args = get_args(field_type) # e.g. list[int] -> (int,) + + # ------------------------------------------------------------------ + # 1. Optional / Union ------------------------------------------------ + # ------------------------------------------------------------------ + if origin is Union: + non_none_args = [t for t in args if t is not type(None)] + + # Optional[T] ⇢ exactly (T, NoneType) + if len(args) == 2 and len(non_none_args) == 1: + return map_type(non_none_args[0], field_info).optional() + + # Plain Union[A, B, …] + return tb.union(*(map_type(t, field_info) for t in args)) + + # ------------------------------------------------------------------ + # 2. List / Sequence ------------------------------------------------- + # ------------------------------------------------------------------ + if origin in (list,): + (inner_type,) = args # list has exactly one parameter + return map_type(inner_type, field_info).list() + + # ------------------------------------------------------------------ + # 3. Dict / Map ------------------------------------------------------- + # ------------------------------------------------------------------ + def _is_enum_subclass(key_type) -> bool: + """Guarded issubclass – returns False when tp is not a class.""" + return isinstance(key_type, type) and issubclass(key_type, Enum) + + if origin in (dict,): + key_type, value_type = args or (str, object) + + if key_type is not str and not _is_enum_subclass(key_type): + raise ValueError("BAML maps only allow 'str' or Enum subclasses as keys") + + return tb.map( + map_type(key_type, field_info), # mostly tb.string() or enum + map_type(value_type, field_info), + ) + + # ------------------------------------------------------------------ + # 4. Enum ------------------------------------------------------------- + # ------------------------------------------------------------------ + if _is_enum_subclass(field_type): + enum_builder = tb.add_enum(field_type.__name__) + for member in field_type: + enum_builder.add_value(member.name) + return enum_builder.type() + + # ------------------------------------------------------------------ + # 5. Nested Pydantic model ------------------------------------------- + # ------------------------------------------------------------------ + from pydantic import BaseModel # local import + + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + try: + # Create nested class if it doesn't exist + nested_class = tb.add_class(field_type.__name__) + # Find dynamic types of nested class + create_dynamic_baml_type(tb, nested_class, field_type) + except ValueError: + # If nested class already exists get it + nested_class = tb._tb.class_(field_type.__name__) + + # Return nested class model + if isinstance(nested_class, ClassBuilder): + # Different nested_class objects have different syntax for type information + # If nested class already exists type information can be found using the field method + return nested_class.field() + else: + # If nested class was created type information can be found using type method + return nested_class.type() + + primitive_map = { + str: tb.string(), + int: tb.int(), + float: tb.float(), + bool: tb.bool(), + datetime: tb.string(), # BAML has no native datetime + } + if field_type in primitive_map: + return primitive_map[field_type] + + raise ValueError(f"Unsupported type for BAML mapping: {field_type}") + + fields = pydantic_model.model_fields + + # Add fields + for field_name, field_info in fields.items(): + field_type = field_info.annotation + baml_type = map_type(field_type, field_info) + + # Add property with type + prop = baml_model.add_property(field_name, baml_type) + + # Add description if available + if field_info.description: + prop.description(field_info.description) + + return tb