feat: Add baml dynamic typing

This commit is contained in:
Igor Ilic 2025-09-09 13:12:59 +02:00
parent 59cd31b916
commit 89b51a244d
19 changed files with 67 additions and 326 deletions

View file

@ -9,12 +9,6 @@ class LLMGateway:
Class used as a namespace for LLM related functions, should not be instantiated, all methods are static.
"""
@staticmethod
def render_prompt(filename: str, context: dict, base_directory: str = None):
from cognee.infrastructure.llm.prompts import render_prompt
return render_prompt(filename=filename, context=context, base_directory=base_directory)
@staticmethod
def acreate_structured_output(
text_input: str, system_prompt: str, response_model: Type[BaseModel]
@ -30,14 +24,15 @@ class LLMGateway:
system_prompt=system_prompt,
response_model=response_model,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
get_llm_client,
)
else:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
get_llm_client,
)
llm_client = get_llm_client()
return llm_client.acreate_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model
)
llm_client = get_llm_client()
return llm_client.acreate_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model
)
@staticmethod
def create_structured_output(
@ -69,107 +64,3 @@ class LLMGateway:
llm_client = get_llm_client()
return llm_client.transcribe_image(input=input)
@staticmethod
def show_prompt(text_input: str, system_prompt: str) -> str:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
get_llm_client,
)
llm_client = get_llm_client()
return llm_client.show_prompt(text_input=text_input, system_prompt=system_prompt)
@staticmethod
def read_query_prompt(prompt_file_name: str, base_directory: str = None):
from cognee.infrastructure.llm.prompts import (
read_query_prompt,
)
return read_query_prompt(prompt_file_name=prompt_file_name, base_directory=base_directory)
@staticmethod
def extract_content_graph(
content: str,
response_model: Type[BaseModel],
mode: str = "simple",
custom_prompt: Optional[str] = None,
) -> Coroutine:
llm_config = get_llm_config()
if llm_config.structured_output_framework.upper() == "BAML":
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
extract_content_graph,
)
return extract_content_graph(
content=content,
response_model=response_model,
mode=mode,
custom_prompt=custom_prompt,
)
else:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
extract_content_graph,
)
return extract_content_graph(
content=content, response_model=response_model, custom_prompt=custom_prompt
)
@staticmethod
def extract_categories(content: str, response_model: Type[BaseModel]) -> Coroutine:
# TODO: Add BAML version of category and extraction and update function
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
extract_categories,
)
return extract_categories(content=content, response_model=response_model)
@staticmethod
def extract_code_summary(content: str) -> Coroutine:
llm_config = get_llm_config()
if llm_config.structured_output_framework.upper() == "BAML":
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
extract_code_summary,
)
return extract_code_summary(content=content)
else:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
extract_code_summary,
)
return extract_code_summary(content=content)
@staticmethod
def extract_summary(content: str, response_model: Type[BaseModel]) -> Coroutine:
# llm_config = get_llm_config()
# if llm_config.structured_output_framework.upper() == "BAML":
# from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
# extract_summary,
# )
#
# return extract_summary(content=content, response_model=response_model)
# else:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
extract_summary,
)
return extract_summary(content=content, response_model=response_model)
@staticmethod
def extract_event_graph(content: str, response_model: Type[BaseModel]) -> Coroutine:
# TODO: Add BAML version of category and extraction and update function (consulted with Igor)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
extract_event_graph,
)
return extract_event_graph(content=content, response_model=response_model)
@staticmethod
def extract_event_entities(content: str, response_model: Type[BaseModel]) -> Coroutine:
# TODO: Add BAML version of category and extraction and update function (consulted with Igor)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
extract_event_entities,
)
return extract_event_entities(content=content, response_model=response_model)

View file

@ -1,11 +1,12 @@
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway
async def extract_categories(content: str, response_model: Type[BaseModel]):
system_prompt = LLMGateway.read_query_prompt("classify_content.txt")
system_prompt = read_query_prompt("classify_content.txt")
llm_output = await LLMGateway.acreate_structured_output(content, system_prompt, response_model)

View file

@ -1,6 +1,7 @@
import os
from typing import List, Type
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts.render_prompt import render_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import (
get_llm_config,
@ -35,7 +36,7 @@ async def extract_event_entities(content: str, response_model: Type[BaseModel]):
else:
base_directory = None
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
content_graph = await LLMGateway.acreate_structured_output(
content, system_prompt, response_model

View file

@ -2,7 +2,8 @@ from cognee.shared.logging_utils import get_logger
import os
from typing import Type
from instructor.exceptions import InstructorRetryException
from instructor.core import InstructorRetryException
from cognee.infrastructure.llm.prompts import read_query_prompt
from pydantic import BaseModel
from cognee.infrastructure.llm.LLMGateway import LLMGateway
@ -25,7 +26,7 @@ def get_mock_summarized_code():
async def extract_summary(content: str, response_model: Type[BaseModel]):
system_prompt = LLMGateway.read_query_prompt("summarize_content.txt")
system_prompt = read_query_prompt("summarize_content.txt")
llm_output = await LLMGateway.acreate_structured_output(content, system_prompt, response_model)

View file

@ -2,6 +2,7 @@ import os
from typing import Type, Optional
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import (
get_llm_config,
@ -26,7 +27,7 @@ async def extract_content_graph(
else:
base_directory = None
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
content_graph = await LLMGateway.acreate_structured_output(
content, system_prompt, response_model

View file

@ -1,8 +1,9 @@
import os
from pydantic import BaseModel
from typing import Type
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import (
get_llm_config,
)
@ -37,7 +38,7 @@ async def extract_event_graph(content: str, response_model: Type[BaseModel]):
else:
base_directory = None
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
content_graph = await LLMGateway.acreate_structured_output(
content, system_prompt, response_model

View file

@ -0,0 +1,35 @@
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
from cognee.infrastructure.llm.prompts import read_query_prompt
def show_prompt(text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
This method formats the prompt using the provided user input and system prompt,
returning a string representation. Raises MissingSystemPromptPathError if the system prompt is not
provided.
Parameters:
-----------
- text_input (str): The input text provided by the user.
- system_prompt (str): The system's prompt to guide the model's response.
Returns:
--------
- str: A formatted string representing the user input and system prompt.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -1,3 +1 @@
from .knowledge_graph.extract_content_graph import extract_content_graph
from .extract_summary import extract_summary, extract_code_summary
from .acreate_structured_output import acreate_structured_output

View file

@ -5,7 +5,7 @@ from baml_py.baml_py import ClassBuilder
from cognee.infrastructure.llm.config import get_llm_config
from typing import List, Union, Optional, Literal
from typing import Union, Literal
from enum import Enum
from datetime import datetime
@ -16,32 +16,7 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client impo
from pydantic import BaseModel
from typing import get_origin, get_args
logger = get_logger("extract_summary_baml")
class SummarizedFunction(BaseModel):
name: str
description: str
inputs: Optional[List[str]] = None
outputs: Optional[List[str]] = None
decorators: Optional[List[str]] = None
class SummarizedClass(BaseModel):
name: str
description: str
methods: Optional[List[SummarizedFunction]] = None
decorators: Optional[List[str]] = None
class SummarizedCode(BaseModel):
high_level_summary: str
key_features: List[str]
imports: List[str] = []
constants: List[str] = []
classes: List[SummarizedClass] = []
functions: List[SummarizedFunction] = []
workflow_description: Optional[str] = None
logger = get_logger()
def create_dynamic_baml_type(tb, baml_model, pydantic_model):
@ -179,7 +154,7 @@ async def acreate_structured_output(
config = get_llm_config()
tb = TypeBuilder()
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, SummarizedCode)
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, response_model)
result = await b.AcreateStructuredOutput(
text_input=text_input,
@ -187,13 +162,15 @@ async def acreate_structured_output(
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
)
return result
if response_model is str:
return result
return response_model.model_validate(result.dict())
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(acreate_structured_output("TEST", "THIS IS A TEST", SummarizedCode))
loop.run_until_complete(acreate_structured_output("TEST", "THIS IS A TEST", str))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -1,89 +0,0 @@
import os
from typing import Type
from pydantic import BaseModel
from baml_py import ClientRegistry
from cognee.shared.logging_utils import get_logger
from cognee.shared.data_models import SummarizedCode
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger("extract_summary_baml")
def get_mock_summarized_code():
"""Local mock function to avoid circular imports."""
return SummarizedCode(
high_level_summary="Mock code summary",
key_features=["Mock feature 1", "Mock feature 2"],
imports=["mock_import"],
constants=["MOCK_CONSTANT"],
classes=[],
functions=[],
workflow_description="Mock workflow description",
)
async def extract_summary(content: str, response_model: Type[BaseModel]):
"""
Extract summary using BAML framework.
Args:
content: The content to summarize
response_model: The Pydantic model type for the response
Returns:
BaseModel: The summarized content in the specified format
"""
config = get_llm_config()
# Use BAML's SummarizeContent function
summary_result = await b.SummarizeContent(
content, baml_options={"client_registry": config.baml_registry}
)
# Convert BAML result to the expected response model
if response_model is SummarizedCode:
# If it's asking for SummarizedCode but we got SummarizedContent,
# we need to use SummarizeCode instead
code_result = await b.SummarizeCode(
content, baml_options={"client_registry": config.baml_registry}
)
return code_result
else:
# For other models, return the summary result
return summary_result
async def extract_code_summary(content: str):
"""
Extract code summary using BAML framework with mocking support.
Args:
content: The code content to summarize
Returns:
SummarizedCode: The summarized code information
"""
enable_mocking = os.getenv("MOCK_CODE_SUMMARY", "false")
if isinstance(enable_mocking, bool):
enable_mocking = str(enable_mocking).lower()
enable_mocking = enable_mocking in ("true", "1", "yes")
if enable_mocking:
result = get_mock_summarized_code()
return result
else:
try:
config = get_llm_config()
result = await b.SummarizeCode(
content, baml_options={"client_registry": config.baml_registry}
)
except Exception as e:
logger.error(
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
)
result = get_mock_summarized_code()
return result

View file

@ -1,44 +0,0 @@
from typing import Type, Optional
from pydantic import BaseModel
from cognee.infrastructure.llm.config import get_llm_config
from cognee.shared.logging_utils import get_logger, setup_logging
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
async def extract_content_graph(
content: str,
response_model: Type[BaseModel],
mode: str = "simple",
custom_prompt: Optional[str] = None,
):
config = get_llm_config()
setup_logging()
get_logger(level="INFO")
# if response_model:
# # tb = TypeBuilder()
# # country = tb.union \
# # ([tb.literal_string("USA"), tb.literal_string("UK"), tb.literal_string("Germany"), tb.literal_string("other")])
# # tb.Node.add_property("country", country)
#
# graph = await b.ExtractDynamicContentGraph(
# content, mode=mode, baml_options={"client_registry": baml_registry}
# )
#
# return graph
# else:
if custom_prompt:
graph = await b.ExtractContentGraphGeneric(
content,
mode="custom",
custom_prompt_content=custom_prompt,
baml_options={"client_registry": config.baml_registry},
)
else:
graph = await b.ExtractContentGraphGeneric(
content, mode=mode, baml_options={"client_registry": config.baml_registry}
)
return graph

View file

@ -328,35 +328,3 @@ class OpenAIAdapter(LLMInterface):
max_completion_tokens=300,
max_retries=self.MAX_RETRIES,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
This method formats the prompt using the provided user input and system prompt,
returning a string representation. Raises MissingSystemPromptPathError if the system prompt is not
provided.
Parameters:
-----------
- text_input (str): The input text provided by the user.
- system_prompt (str): The system's prompt to guide the model's response.
Returns:
--------
- str: A formatted string representing the user input and system prompt.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -11,7 +11,7 @@ from cognee.modules.graph.utils import (
retrieve_existing_edges,
)
from cognee.shared.data_models import KnowledgeGraph
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.extraction import extract_content_graph
from cognee.tasks.graph.exceptions import (
InvalidGraphModelError,
InvalidDataChunksError,
@ -86,7 +86,7 @@ async def extract_graph_from_data(
chunk_graphs = await asyncio.gather(
*[
LLMGateway.extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
for chunk in data_chunks
]
)

View file

@ -5,7 +5,7 @@ from pydantic import BaseModel
from cognee.tasks.summarization.exceptions import InvalidSummaryInputsError
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.extraction import extract_summary
from cognee.modules.cognify.config import get_cognify_config
from cognee.tasks.summarization.models import TextSummary
@ -50,7 +50,7 @@ async def summarize_text(
summarization_model = cognee_config.summarization_model
chunk_summaries = await asyncio.gather(
*[LLMGateway.extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
*[extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
)
summaries = [