diff --git a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py index 16bf7c3d7..37bb7da1e 100644 --- a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +++ b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py @@ -55,9 +55,9 @@ def create_dynamic_baml_type(tb, baml_model, pydantic_model): # ------------------------------------------------------------------ # 3. Dict / Map ------------------------------------------------------- # ------------------------------------------------------------------ - def _is_enum_subclass(tp) -> bool: + def _is_enum_subclass(key_type) -> bool: """Guarded issubclass – returns False when tp is not a class.""" - return isinstance(tp, type) and issubclass(tp, Enum) + return isinstance(key_type, type) and issubclass(key_type, Enum) if origin in (dict,): key_type, value_type = args or (str, object) @@ -71,13 +71,7 @@ def create_dynamic_baml_type(tb, baml_model, pydantic_model): ) # ------------------------------------------------------------------ - # 4. Literal ---------------------------------------------------------- - # ------------------------------------------------------------------ - if origin is Literal: - return tb.union(*(tb.literal(v) for v in args)) - - # ------------------------------------------------------------------ - # 5. Enum ------------------------------------------------------------- + # 4. Enum ------------------------------------------------------------- # ------------------------------------------------------------------ if _is_enum_subclass(field_type): enum_builder = tb.add_enum(field_type.__name__) @@ -86,7 +80,7 @@ def create_dynamic_baml_type(tb, baml_model, pydantic_model): return enum_builder.type() # ------------------------------------------------------------------ - # 6. Nested Pydantic model ------------------------------------------- + # 5. Nested Pydantic model ------------------------------------------- # ------------------------------------------------------------------ from pydantic import BaseModel # local import @@ -173,6 +167,15 @@ if __name__ == "__main__": loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - loop.run_until_complete(acreate_structured_output("TEST", "THIS IS A TEST", str)) + from typing import Optional, Dict, Any, List, Literal + + # Models for representing different entities + class TestModel(BaseModel): + type: str + source: Optional[str] = None + target: Optional[str] = None + properties: Optional[Dict[str, List[str]]] = None + + loop.run_until_complete(acreate_structured_output("TEST", "THIS IS A TEST", TestModel)) finally: loop.run_until_complete(loop.shutdown_asyncgens())