feat: Enable dynamic types of response model

This commit is contained in:
Igor Ilic 2025-09-09 12:45:36 +02:00
parent 2f59e6ee08
commit 59cd31b916

View file

@ -1,12 +1,12 @@
import asyncio
from typing import Type
from cognee.shared.logging_utils import get_logger
from baml_py.baml_py import ClassBuilder
from cognee.infrastructure.llm.config import get_llm_config
from typing import List, Dict, Union, Optional, Literal
from typing import List, Union, Optional, Literal
from enum import Enum
from baml_py import Image, Audio, Video, Pdf
from datetime import datetime
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type_builder import (
@ -44,9 +44,7 @@ class SummarizedCode(BaseModel):
workflow_description: Optional[str] = None
def create_dynamic_baml_type(baml_model, pydantic_model):
tb = TypeBuilder()
def create_dynamic_baml_type(tb, baml_model, pydantic_model):
if pydantic_model is str:
baml_model.add_property("text", tb.string())
return tb
@ -118,15 +116,23 @@ def create_dynamic_baml_type(baml_model, pydantic_model):
from pydantic import BaseModel # local import
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
# Create nested class if it doesn't exist
try:
# Create nested class if it doesn't exist
nested_class = tb.add_class(field_type.__name__)
# Find dynamic types of nested class
create_dynamic_baml_type(tb, nested_class, field_type)
except ValueError:
pass
# Find dynamic types of nested class
create_dynamic_baml_type(nested_class, field_type)
# If nested class already exists get it
nested_class = tb._tb.class_(field_type.__name__)
# Return nested class model
return nested_class.type()
if isinstance(nested_class, ClassBuilder):
# Different nested_class objects have different syntax for type information
# If nested class already exists type information can be found using the field method
return nested_class.field()
else:
# If nested class was created type information can be found using type method
return nested_class.type()
primitive_map = {
str: tb.string(),
@ -173,7 +179,7 @@ async def acreate_structured_output(
config = get_llm_config()
tb = TypeBuilder()
type_builder = create_dynamic_baml_type(tb.ResponseModel, SummarizedCode)
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, SummarizedCode)
result = await b.AcreateStructuredOutput(
text_input=text_input,