feat: Enable dynamic types of response model
This commit is contained in:
parent
2f59e6ee08
commit
59cd31b916
1 changed files with 17 additions and 11 deletions
|
|
@ -1,12 +1,12 @@
|
|||
import asyncio
|
||||
from typing import Type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from baml_py.baml_py import ClassBuilder
|
||||
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
from typing import List, Dict, Union, Optional, Literal
|
||||
from typing import List, Union, Optional, Literal
|
||||
from enum import Enum
|
||||
from baml_py import Image, Audio, Video, Pdf
|
||||
from datetime import datetime
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type_builder import (
|
||||
|
|
@ -44,9 +44,7 @@ class SummarizedCode(BaseModel):
|
|||
workflow_description: Optional[str] = None
|
||||
|
||||
|
||||
def create_dynamic_baml_type(baml_model, pydantic_model):
|
||||
tb = TypeBuilder()
|
||||
|
||||
def create_dynamic_baml_type(tb, baml_model, pydantic_model):
|
||||
if pydantic_model is str:
|
||||
baml_model.add_property("text", tb.string())
|
||||
return tb
|
||||
|
|
@ -118,15 +116,23 @@ def create_dynamic_baml_type(baml_model, pydantic_model):
|
|||
from pydantic import BaseModel # local import
|
||||
|
||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
# Create nested class if it doesn't exist
|
||||
try:
|
||||
# Create nested class if it doesn't exist
|
||||
nested_class = tb.add_class(field_type.__name__)
|
||||
# Find dynamic types of nested class
|
||||
create_dynamic_baml_type(tb, nested_class, field_type)
|
||||
except ValueError:
|
||||
pass
|
||||
# Find dynamic types of nested class
|
||||
create_dynamic_baml_type(nested_class, field_type)
|
||||
# If nested class already exists get it
|
||||
nested_class = tb._tb.class_(field_type.__name__)
|
||||
|
||||
# Return nested class model
|
||||
return nested_class.type()
|
||||
if isinstance(nested_class, ClassBuilder):
|
||||
# Different nested_class objects have different syntax for type information
|
||||
# If nested class already exists type information can be found using the field method
|
||||
return nested_class.field()
|
||||
else:
|
||||
# If nested class was created type information can be found using type method
|
||||
return nested_class.type()
|
||||
|
||||
primitive_map = {
|
||||
str: tb.string(),
|
||||
|
|
@ -173,7 +179,7 @@ async def acreate_structured_output(
|
|||
config = get_llm_config()
|
||||
tb = TypeBuilder()
|
||||
|
||||
type_builder = create_dynamic_baml_type(tb.ResponseModel, SummarizedCode)
|
||||
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, SummarizedCode)
|
||||
|
||||
result = await b.AcreateStructuredOutput(
|
||||
text_input=text_input,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue