refactor: Move creation of baml dynamic type to own file
This commit is contained in:
parent
4bae611721
commit
f4a7945473
2 changed files with 141 additions and 125 deletions
|
|
@ -1,149 +1,43 @@
|
|||
import asyncio
|
||||
from typing import Type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from baml_py.baml_py import ClassBuilder
|
||||
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
from typing import Union, Literal
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction.create_dynamic_baml_type import (
|
||||
create_dynamic_baml_type,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type_builder import (
|
||||
TypeBuilder,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client import b
|
||||
from pydantic import BaseModel
|
||||
from typing import get_origin, get_args
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def create_dynamic_baml_type(tb, baml_model, pydantic_model):
|
||||
if pydantic_model is str:
|
||||
baml_model.add_property("text", tb.string())
|
||||
return tb
|
||||
|
||||
def map_type(field_type, field_info):
|
||||
"""
|
||||
Convert a Python / Pydantic type -> BAML TypeBuilder representation.
|
||||
"""
|
||||
|
||||
origin = get_origin(field_type) # e.g. list[…] -> list
|
||||
args = get_args(field_type) # e.g. list[int] -> (int,)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Optional / Union ------------------------------------------------
|
||||
# ------------------------------------------------------------------
|
||||
if origin is Union:
|
||||
non_none_args = [t for t in args if t is not type(None)]
|
||||
|
||||
# Optional[T] ⇢ exactly (T, NoneType)
|
||||
if len(args) == 2 and len(non_none_args) == 1:
|
||||
return map_type(non_none_args[0], field_info).optional()
|
||||
|
||||
# Plain Union[A, B, …]
|
||||
return tb.union(*(map_type(t, field_info) for t in args))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. List / Sequence -------------------------------------------------
|
||||
# ------------------------------------------------------------------
|
||||
if origin in (list,):
|
||||
(inner_type,) = args # list has exactly one parameter
|
||||
return map_type(inner_type, field_info).list()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Dict / Map -------------------------------------------------------
|
||||
# ------------------------------------------------------------------
|
||||
def _is_enum_subclass(key_type) -> bool:
|
||||
"""Guarded issubclass – returns False when tp is not a class."""
|
||||
return isinstance(key_type, type) and issubclass(key_type, Enum)
|
||||
|
||||
if origin in (dict,):
|
||||
key_type, value_type = args or (str, object)
|
||||
|
||||
if key_type is not str and not _is_enum_subclass(key_type):
|
||||
raise ValueError("BAML maps only allow 'str' or Enum subclasses as keys")
|
||||
|
||||
return tb.map(
|
||||
map_type(key_type, field_info), # mostly tb.string() or enum
|
||||
map_type(value_type, field_info),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Enum -------------------------------------------------------------
|
||||
# ------------------------------------------------------------------
|
||||
if _is_enum_subclass(field_type):
|
||||
enum_builder = tb.add_enum(field_type.__name__)
|
||||
for member in field_type:
|
||||
enum_builder.add_value(member.name)
|
||||
return enum_builder.type()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Nested Pydantic model -------------------------------------------
|
||||
# ------------------------------------------------------------------
|
||||
from pydantic import BaseModel # local import
|
||||
|
||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
try:
|
||||
# Create nested class if it doesn't exist
|
||||
nested_class = tb.add_class(field_type.__name__)
|
||||
# Find dynamic types of nested class
|
||||
create_dynamic_baml_type(tb, nested_class, field_type)
|
||||
except ValueError:
|
||||
# If nested class already exists get it
|
||||
nested_class = tb._tb.class_(field_type.__name__)
|
||||
|
||||
# Return nested class model
|
||||
if isinstance(nested_class, ClassBuilder):
|
||||
# Different nested_class objects have different syntax for type information
|
||||
# If nested class already exists type information can be found using the field method
|
||||
return nested_class.field()
|
||||
else:
|
||||
# If nested class was created type information can be found using type method
|
||||
return nested_class.type()
|
||||
|
||||
primitive_map = {
|
||||
str: tb.string(),
|
||||
int: tb.int(),
|
||||
float: tb.float(),
|
||||
bool: tb.bool(),
|
||||
datetime: tb.string(), # BAML has no native datetime
|
||||
}
|
||||
if field_type in primitive_map:
|
||||
return primitive_map[field_type]
|
||||
|
||||
raise ValueError(f"Unsupported type for BAML mapping: {field_type}")
|
||||
|
||||
fields = pydantic_model.model_fields
|
||||
|
||||
# Add fields
|
||||
for field_name, field_info in fields.items():
|
||||
field_type = field_info.annotation
|
||||
baml_type = map_type(field_type, field_info)
|
||||
|
||||
# Add property with type
|
||||
prop = baml_model.add_property(field_name, baml_type)
|
||||
|
||||
# Add description if available
|
||||
if field_info.description:
|
||||
prop.description(field_info.description)
|
||||
|
||||
return tb
|
||||
|
||||
|
||||
async def acreate_structured_output(
|
||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
):
|
||||
"""
|
||||
Extract summary using BAML framework.
|
||||
Generate a response from a user query.
|
||||
|
||||
Args:
|
||||
content: The content to summarize
|
||||
response_model: The Pydantic model type for the response
|
||||
This method asynchronously creates structured output by sending a request through BAML
|
||||
using the provided parameters to generate a completion based on the user input and
|
||||
system prompt.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- text_input (str): The input text provided by the user for generating a response.
|
||||
- system_prompt (str): The system's prompt to guide the model's response.
|
||||
- response_model (Type[BaseModel]): The expected model type for the response.
|
||||
|
||||
Returns:
|
||||
BaseModel: The summarized content in the specified format
|
||||
--------
|
||||
|
||||
- BaseModel: A structured output generated by the model, returned as an instance of
|
||||
BaseModel.
|
||||
"""
|
||||
config = get_llm_config()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
from typing import Union
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from typing import get_origin, get_args
|
||||
|
||||
from baml_py.baml_py import ClassBuilder
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def create_dynamic_baml_type(tb, baml_model, pydantic_model):
|
||||
if pydantic_model is str:
|
||||
baml_model.add_property("text", tb.string())
|
||||
return tb
|
||||
|
||||
def map_type(field_type, field_info):
|
||||
"""
|
||||
Convert a Python / Pydantic type -> BAML TypeBuilder representation.
|
||||
"""
|
||||
|
||||
origin = get_origin(field_type) # e.g. list[…] -> list
|
||||
args = get_args(field_type) # e.g. list[int] -> (int,)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Optional / Union ------------------------------------------------
|
||||
# ------------------------------------------------------------------
|
||||
if origin is Union:
|
||||
non_none_args = [t for t in args if t is not type(None)]
|
||||
|
||||
# Optional[T] ⇢ exactly (T, NoneType)
|
||||
if len(args) == 2 and len(non_none_args) == 1:
|
||||
return map_type(non_none_args[0], field_info).optional()
|
||||
|
||||
# Plain Union[A, B, …]
|
||||
return tb.union(*(map_type(t, field_info) for t in args))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. List / Sequence -------------------------------------------------
|
||||
# ------------------------------------------------------------------
|
||||
if origin in (list,):
|
||||
(inner_type,) = args # list has exactly one parameter
|
||||
return map_type(inner_type, field_info).list()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Dict / Map -------------------------------------------------------
|
||||
# ------------------------------------------------------------------
|
||||
def _is_enum_subclass(key_type) -> bool:
|
||||
"""Guarded issubclass – returns False when tp is not a class."""
|
||||
return isinstance(key_type, type) and issubclass(key_type, Enum)
|
||||
|
||||
if origin in (dict,):
|
||||
key_type, value_type = args or (str, object)
|
||||
|
||||
if key_type is not str and not _is_enum_subclass(key_type):
|
||||
raise ValueError("BAML maps only allow 'str' or Enum subclasses as keys")
|
||||
|
||||
return tb.map(
|
||||
map_type(key_type, field_info), # mostly tb.string() or enum
|
||||
map_type(value_type, field_info),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Enum -------------------------------------------------------------
|
||||
# ------------------------------------------------------------------
|
||||
if _is_enum_subclass(field_type):
|
||||
enum_builder = tb.add_enum(field_type.__name__)
|
||||
for member in field_type:
|
||||
enum_builder.add_value(member.name)
|
||||
return enum_builder.type()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Nested Pydantic model -------------------------------------------
|
||||
# ------------------------------------------------------------------
|
||||
from pydantic import BaseModel # local import
|
||||
|
||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
try:
|
||||
# Create nested class if it doesn't exist
|
||||
nested_class = tb.add_class(field_type.__name__)
|
||||
# Find dynamic types of nested class
|
||||
create_dynamic_baml_type(tb, nested_class, field_type)
|
||||
except ValueError:
|
||||
# If nested class already exists get it
|
||||
nested_class = tb._tb.class_(field_type.__name__)
|
||||
|
||||
# Return nested class model
|
||||
if isinstance(nested_class, ClassBuilder):
|
||||
# Different nested_class objects have different syntax for type information
|
||||
# If nested class already exists type information can be found using the field method
|
||||
return nested_class.field()
|
||||
else:
|
||||
# If nested class was created type information can be found using type method
|
||||
return nested_class.type()
|
||||
|
||||
primitive_map = {
|
||||
str: tb.string(),
|
||||
int: tb.int(),
|
||||
float: tb.float(),
|
||||
bool: tb.bool(),
|
||||
datetime: tb.string(), # BAML has no native datetime
|
||||
}
|
||||
if field_type in primitive_map:
|
||||
return primitive_map[field_type]
|
||||
|
||||
raise ValueError(f"Unsupported type for BAML mapping: {field_type}")
|
||||
|
||||
fields = pydantic_model.model_fields
|
||||
|
||||
# Add fields
|
||||
for field_name, field_info in fields.items():
|
||||
field_type = field_info.annotation
|
||||
baml_type = map_type(field_type, field_info)
|
||||
|
||||
# Add property with type
|
||||
prop = baml_model.add_property(field_name, baml_type)
|
||||
|
||||
# Add description if available
|
||||
if field_info.description:
|
||||
prop.description(field_info.description)
|
||||
|
||||
return tb
|
||||
Loading…
Add table
Reference in a new issue