diff --git a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py index 68cb5e6de..a39196ff6 100644 --- a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +++ b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py @@ -1,12 +1,12 @@ import asyncio from typing import Type from cognee.shared.logging_utils import get_logger +from baml_py.baml_py import ClassBuilder from cognee.infrastructure.llm.config import get_llm_config -from typing import List, Dict, Union, Optional, Literal +from typing import List, Union, Optional, Literal from enum import Enum -from baml_py import Image, Audio, Video, Pdf from datetime import datetime from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type_builder import ( @@ -44,9 +44,7 @@ class SummarizedCode(BaseModel): workflow_description: Optional[str] = None -def create_dynamic_baml_type(baml_model, pydantic_model): - tb = TypeBuilder() - +def create_dynamic_baml_type(tb, baml_model, pydantic_model): if pydantic_model is str: baml_model.add_property("text", tb.string()) return tb @@ -118,15 +116,23 @@ def create_dynamic_baml_type(baml_model, pydantic_model): from pydantic import BaseModel # local import if isinstance(field_type, type) and issubclass(field_type, BaseModel): - # Create nested class if it doesn't exist try: + # Create nested class if it doesn't exist nested_class = tb.add_class(field_type.__name__) + # Find dynamic types of nested class + create_dynamic_baml_type(tb, nested_class, field_type) except ValueError: - pass - # Find dynamic types of nested class - create_dynamic_baml_type(nested_class, field_type) + # If nested class already exists get it + nested_class = tb._tb.class_(field_type.__name__) + # Return nested class model - return nested_class.type() + if isinstance(nested_class, ClassBuilder): + # Different nested_class objects have different syntax for type information + # If nested class already exists type information can be found using the field method + return nested_class.field() + else: + # If nested class was created type information can be found using type method + return nested_class.type() primitive_map = { str: tb.string(), @@ -173,7 +179,7 @@ async def acreate_structured_output( config = get_llm_config() tb = TypeBuilder() - type_builder = create_dynamic_baml_type(tb.ResponseModel, SummarizedCode) + type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, SummarizedCode) result = await b.AcreateStructuredOutput( text_input=text_input,