feat: Enable dynamic types of response model
This commit is contained in:
parent
2f59e6ee08
commit
59cd31b916
1 changed files with 17 additions and 11 deletions
|
|
@ -1,12 +1,12 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from baml_py.baml_py import ClassBuilder
|
||||||
|
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
from typing import List, Dict, Union, Optional, Literal
|
from typing import List, Union, Optional, Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from baml_py import Image, Audio, Video, Pdf
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type_builder import (
|
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type_builder import (
|
||||||
|
|
@ -44,9 +44,7 @@ class SummarizedCode(BaseModel):
|
||||||
workflow_description: Optional[str] = None
|
workflow_description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_baml_type(baml_model, pydantic_model):
|
def create_dynamic_baml_type(tb, baml_model, pydantic_model):
|
||||||
tb = TypeBuilder()
|
|
||||||
|
|
||||||
if pydantic_model is str:
|
if pydantic_model is str:
|
||||||
baml_model.add_property("text", tb.string())
|
baml_model.add_property("text", tb.string())
|
||||||
return tb
|
return tb
|
||||||
|
|
@ -118,15 +116,23 @@ def create_dynamic_baml_type(baml_model, pydantic_model):
|
||||||
from pydantic import BaseModel # local import
|
from pydantic import BaseModel # local import
|
||||||
|
|
||||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||||
# Create nested class if it doesn't exist
|
|
||||||
try:
|
try:
|
||||||
|
# Create nested class if it doesn't exist
|
||||||
nested_class = tb.add_class(field_type.__name__)
|
nested_class = tb.add_class(field_type.__name__)
|
||||||
|
# Find dynamic types of nested class
|
||||||
|
create_dynamic_baml_type(tb, nested_class, field_type)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
# If nested class already exists get it
|
||||||
# Find dynamic types of nested class
|
nested_class = tb._tb.class_(field_type.__name__)
|
||||||
create_dynamic_baml_type(nested_class, field_type)
|
|
||||||
# Return nested class model
|
# Return nested class model
|
||||||
return nested_class.type()
|
if isinstance(nested_class, ClassBuilder):
|
||||||
|
# Different nested_class objects have different syntax for type information
|
||||||
|
# If nested class already exists type information can be found using the field method
|
||||||
|
return nested_class.field()
|
||||||
|
else:
|
||||||
|
# If nested class was created type information can be found using type method
|
||||||
|
return nested_class.type()
|
||||||
|
|
||||||
primitive_map = {
|
primitive_map = {
|
||||||
str: tb.string(),
|
str: tb.string(),
|
||||||
|
|
@ -173,7 +179,7 @@ async def acreate_structured_output(
|
||||||
config = get_llm_config()
|
config = get_llm_config()
|
||||||
tb = TypeBuilder()
|
tb = TypeBuilder()
|
||||||
|
|
||||||
type_builder = create_dynamic_baml_type(tb.ResponseModel, SummarizedCode)
|
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, SummarizedCode)
|
||||||
|
|
||||||
result = await b.AcreateStructuredOutput(
|
result = await b.AcreateStructuredOutput(
|
||||||
text_input=text_input,
|
text_input=text_input,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue