feat: Add baml dynamic typing
This commit is contained in:
parent
59cd31b916
commit
89b51a244d
19 changed files with 67 additions and 326 deletions
|
|
@ -9,12 +9,6 @@ class LLMGateway:
|
|||
Class used as a namespace for LLM related functions, should not be instantiated, all methods are static.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def render_prompt(filename: str, context: dict, base_directory: str = None):
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
|
||||
return render_prompt(filename=filename, context=context, base_directory=base_directory)
|
||||
|
||||
@staticmethod
|
||||
def acreate_structured_output(
|
||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
|
|
@ -30,14 +24,15 @@ class LLMGateway:
|
|||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
else:
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.acreate_structured_output(
|
||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
||||
)
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.acreate_structured_output(
|
||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_structured_output(
|
||||
|
|
@ -69,107 +64,3 @@ class LLMGateway:
|
|||
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.transcribe_image(input=input)
|
||||
|
||||
@staticmethod
|
||||
def show_prompt(text_input: str, system_prompt: str) -> str:
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.show_prompt(text_input=text_input, system_prompt=system_prompt)
|
||||
|
||||
@staticmethod
|
||||
def read_query_prompt(prompt_file_name: str, base_directory: str = None):
|
||||
from cognee.infrastructure.llm.prompts import (
|
||||
read_query_prompt,
|
||||
)
|
||||
|
||||
return read_query_prompt(prompt_file_name=prompt_file_name, base_directory=base_directory)
|
||||
|
||||
@staticmethod
|
||||
def extract_content_graph(
|
||||
content: str,
|
||||
response_model: Type[BaseModel],
|
||||
mode: str = "simple",
|
||||
custom_prompt: Optional[str] = None,
|
||||
) -> Coroutine:
|
||||
llm_config = get_llm_config()
|
||||
if llm_config.structured_output_framework.upper() == "BAML":
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
||||
extract_content_graph,
|
||||
)
|
||||
|
||||
return extract_content_graph(
|
||||
content=content,
|
||||
response_model=response_model,
|
||||
mode=mode,
|
||||
custom_prompt=custom_prompt,
|
||||
)
|
||||
else:
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
||||
extract_content_graph,
|
||||
)
|
||||
|
||||
return extract_content_graph(
|
||||
content=content, response_model=response_model, custom_prompt=custom_prompt
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_categories(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
||||
# TODO: Add BAML version of category and extraction and update function
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
||||
extract_categories,
|
||||
)
|
||||
|
||||
return extract_categories(content=content, response_model=response_model)
|
||||
|
||||
@staticmethod
|
||||
def extract_code_summary(content: str) -> Coroutine:
|
||||
llm_config = get_llm_config()
|
||||
if llm_config.structured_output_framework.upper() == "BAML":
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
||||
extract_code_summary,
|
||||
)
|
||||
|
||||
return extract_code_summary(content=content)
|
||||
else:
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
||||
extract_code_summary,
|
||||
)
|
||||
|
||||
return extract_code_summary(content=content)
|
||||
|
||||
@staticmethod
|
||||
def extract_summary(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
||||
# llm_config = get_llm_config()
|
||||
# if llm_config.structured_output_framework.upper() == "BAML":
|
||||
# from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
||||
# extract_summary,
|
||||
# )
|
||||
#
|
||||
# return extract_summary(content=content, response_model=response_model)
|
||||
# else:
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
||||
extract_summary,
|
||||
)
|
||||
|
||||
return extract_summary(content=content, response_model=response_model)
|
||||
|
||||
@staticmethod
|
||||
def extract_event_graph(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
||||
# TODO: Add BAML version of category and extraction and update function (consulted with Igor)
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
||||
extract_event_graph,
|
||||
)
|
||||
|
||||
return extract_event_graph(content=content, response_model=response_model)
|
||||
|
||||
@staticmethod
|
||||
def extract_event_entities(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
||||
# TODO: Add BAML version of category and extraction and update function (consulted with Igor)
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
||||
extract_event_entities,
|
||||
)
|
||||
|
||||
return extract_event_entities(content=content, response_model=response_model)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
||||
|
||||
async def extract_categories(content: str, response_model: Type[BaseModel]):
|
||||
system_prompt = LLMGateway.read_query_prompt("classify_content.txt")
|
||||
system_prompt = read_query_prompt("classify_content.txt")
|
||||
|
||||
llm_output = await LLMGateway.acreate_structured_output(content, system_prompt, response_model)
|
||||
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
from typing import List, Type
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.prompts.render_prompt import render_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.config import (
|
||||
get_llm_config,
|
||||
|
|
@ -35,7 +36,7 @@ async def extract_event_entities(content: str, response_model: Type[BaseModel]):
|
|||
else:
|
||||
base_directory = None
|
||||
|
||||
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
|
||||
content_graph = await LLMGateway.acreate_structured_output(
|
||||
content, system_prompt, response_model
|
||||
|
|
@ -2,7 +2,8 @@ from cognee.shared.logging_utils import get_logger
|
|||
import os
|
||||
from typing import Type
|
||||
|
||||
from instructor.exceptions import InstructorRetryException
|
||||
from instructor.core import InstructorRetryException
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
|
@ -25,7 +26,7 @@ def get_mock_summarized_code():
|
|||
|
||||
|
||||
async def extract_summary(content: str, response_model: Type[BaseModel]):
|
||||
system_prompt = LLMGateway.read_query_prompt("summarize_content.txt")
|
||||
system_prompt = read_query_prompt("summarize_content.txt")
|
||||
|
||||
llm_output = await LLMGateway.acreate_structured_output(content, system_prompt, response_model)
|
||||
|
||||
|
|
@ -2,6 +2,7 @@ import os
|
|||
from typing import Type, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.config import (
|
||||
get_llm_config,
|
||||
|
|
@ -26,7 +27,7 @@ async def extract_content_graph(
|
|||
else:
|
||||
base_directory = None
|
||||
|
||||
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
|
||||
content_graph = await LLMGateway.acreate_structured_output(
|
||||
content, system_prompt, response_model
|
||||
|
|
@ -1,8 +1,9 @@
|
|||
import os
|
||||
from pydantic import BaseModel
|
||||
from typing import Type
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.config import (
|
||||
get_llm_config,
|
||||
)
|
||||
|
|
@ -37,7 +38,7 @@ async def extract_event_graph(content: str, response_model: Type[BaseModel]):
|
|||
else:
|
||||
base_directory = None
|
||||
|
||||
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
|
||||
content_graph = await LLMGateway.acreate_structured_output(
|
||||
content, system_prompt, response_model
|
||||
35
cognee/infrastructure/llm/prompts/show_prompt.py
Normal file
35
cognee/infrastructure/llm/prompts/show_prompt.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
||||
|
||||
def show_prompt(text_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
Format and display the prompt for a user query.
|
||||
|
||||
This method formats the prompt using the provided user input and system prompt,
|
||||
returning a string representation. Raises MissingSystemPromptPathError if the system prompt is not
|
||||
provided.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- text_input (str): The input text provided by the user.
|
||||
- system_prompt (str): The system's prompt to guide the model's response.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A formatted string representing the user input and system prompt.
|
||||
"""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise MissingSystemPromptPathError()
|
||||
system_prompt = read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
if system_prompt
|
||||
else None
|
||||
)
|
||||
return formatted_prompt
|
||||
|
|
@ -1,3 +1 @@
|
|||
from .knowledge_graph.extract_content_graph import extract_content_graph
|
||||
from .extract_summary import extract_summary, extract_code_summary
|
||||
from .acreate_structured_output import acreate_structured_output
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from baml_py.baml_py import ClassBuilder
|
|||
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
from typing import List, Union, Optional, Literal
|
||||
from typing import Union, Literal
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -16,32 +16,7 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client impo
|
|||
from pydantic import BaseModel
|
||||
from typing import get_origin, get_args
|
||||
|
||||
logger = get_logger("extract_summary_baml")
|
||||
|
||||
|
||||
class SummarizedFunction(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
inputs: Optional[List[str]] = None
|
||||
outputs: Optional[List[str]] = None
|
||||
decorators: Optional[List[str]] = None
|
||||
|
||||
|
||||
class SummarizedClass(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
methods: Optional[List[SummarizedFunction]] = None
|
||||
decorators: Optional[List[str]] = None
|
||||
|
||||
|
||||
class SummarizedCode(BaseModel):
|
||||
high_level_summary: str
|
||||
key_features: List[str]
|
||||
imports: List[str] = []
|
||||
constants: List[str] = []
|
||||
classes: List[SummarizedClass] = []
|
||||
functions: List[SummarizedFunction] = []
|
||||
workflow_description: Optional[str] = None
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def create_dynamic_baml_type(tb, baml_model, pydantic_model):
|
||||
|
|
@ -179,7 +154,7 @@ async def acreate_structured_output(
|
|||
config = get_llm_config()
|
||||
tb = TypeBuilder()
|
||||
|
||||
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, SummarizedCode)
|
||||
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, response_model)
|
||||
|
||||
result = await b.AcreateStructuredOutput(
|
||||
text_input=text_input,
|
||||
|
|
@ -187,13 +162,15 @@ async def acreate_structured_output(
|
|||
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
||||
)
|
||||
|
||||
return result
|
||||
if response_model is str:
|
||||
return result
|
||||
return response_model.model_validate(result.dict())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(acreate_structured_output("TEST", "THIS IS A TEST", SummarizedCode))
|
||||
loop.run_until_complete(acreate_structured_output("TEST", "THIS IS A TEST", str))
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
|
|
|||
|
|
@ -1,89 +0,0 @@
|
|||
import os
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from baml_py import ClientRegistry
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.data_models import SummarizedCode
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
|
||||
logger = get_logger("extract_summary_baml")
|
||||
|
||||
|
||||
def get_mock_summarized_code():
|
||||
"""Local mock function to avoid circular imports."""
|
||||
return SummarizedCode(
|
||||
high_level_summary="Mock code summary",
|
||||
key_features=["Mock feature 1", "Mock feature 2"],
|
||||
imports=["mock_import"],
|
||||
constants=["MOCK_CONSTANT"],
|
||||
classes=[],
|
||||
functions=[],
|
||||
workflow_description="Mock workflow description",
|
||||
)
|
||||
|
||||
|
||||
async def extract_summary(content: str, response_model: Type[BaseModel]):
|
||||
"""
|
||||
Extract summary using BAML framework.
|
||||
|
||||
Args:
|
||||
content: The content to summarize
|
||||
response_model: The Pydantic model type for the response
|
||||
|
||||
Returns:
|
||||
BaseModel: The summarized content in the specified format
|
||||
"""
|
||||
config = get_llm_config()
|
||||
|
||||
# Use BAML's SummarizeContent function
|
||||
summary_result = await b.SummarizeContent(
|
||||
content, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
|
||||
# Convert BAML result to the expected response model
|
||||
if response_model is SummarizedCode:
|
||||
# If it's asking for SummarizedCode but we got SummarizedContent,
|
||||
# we need to use SummarizeCode instead
|
||||
code_result = await b.SummarizeCode(
|
||||
content, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
return code_result
|
||||
else:
|
||||
# For other models, return the summary result
|
||||
return summary_result
|
||||
|
||||
|
||||
async def extract_code_summary(content: str):
|
||||
"""
|
||||
Extract code summary using BAML framework with mocking support.
|
||||
|
||||
Args:
|
||||
content: The code content to summarize
|
||||
|
||||
Returns:
|
||||
SummarizedCode: The summarized code information
|
||||
"""
|
||||
enable_mocking = os.getenv("MOCK_CODE_SUMMARY", "false")
|
||||
if isinstance(enable_mocking, bool):
|
||||
enable_mocking = str(enable_mocking).lower()
|
||||
enable_mocking = enable_mocking in ("true", "1", "yes")
|
||||
|
||||
if enable_mocking:
|
||||
result = get_mock_summarized_code()
|
||||
return result
|
||||
else:
|
||||
try:
|
||||
config = get_llm_config()
|
||||
|
||||
result = await b.SummarizeCode(
|
||||
content, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
|
||||
)
|
||||
result = get_mock_summarized_code()
|
||||
|
||||
return result
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
from typing import Type, Optional
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
||||
|
||||
|
||||
async def extract_content_graph(
|
||||
content: str,
|
||||
response_model: Type[BaseModel],
|
||||
mode: str = "simple",
|
||||
custom_prompt: Optional[str] = None,
|
||||
):
|
||||
config = get_llm_config()
|
||||
setup_logging()
|
||||
|
||||
get_logger(level="INFO")
|
||||
|
||||
# if response_model:
|
||||
# # tb = TypeBuilder()
|
||||
# # country = tb.union \
|
||||
# # ([tb.literal_string("USA"), tb.literal_string("UK"), tb.literal_string("Germany"), tb.literal_string("other")])
|
||||
# # tb.Node.add_property("country", country)
|
||||
#
|
||||
# graph = await b.ExtractDynamicContentGraph(
|
||||
# content, mode=mode, baml_options={"client_registry": baml_registry}
|
||||
# )
|
||||
#
|
||||
# return graph
|
||||
|
||||
# else:
|
||||
if custom_prompt:
|
||||
graph = await b.ExtractContentGraphGeneric(
|
||||
content,
|
||||
mode="custom",
|
||||
custom_prompt_content=custom_prompt,
|
||||
baml_options={"client_registry": config.baml_registry},
|
||||
)
|
||||
else:
|
||||
graph = await b.ExtractContentGraphGeneric(
|
||||
content, mode=mode, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
|
||||
return graph
|
||||
|
|
@ -328,35 +328,3 @@ class OpenAIAdapter(LLMInterface):
|
|||
max_completion_tokens=300,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
Format and display the prompt for a user query.
|
||||
|
||||
This method formats the prompt using the provided user input and system prompt,
|
||||
returning a string representation. Raises MissingSystemPromptPathError if the system prompt is not
|
||||
provided.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- text_input (str): The input text provided by the user.
|
||||
- system_prompt (str): The system's prompt to guide the model's response.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A formatted string representing the user input and system prompt.
|
||||
"""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise MissingSystemPromptPathError()
|
||||
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
if system_prompt
|
||||
else None
|
||||
)
|
||||
return formatted_prompt
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from cognee.modules.graph.utils import (
|
|||
retrieve_existing_edges,
|
||||
)
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.extraction import extract_content_graph
|
||||
from cognee.tasks.graph.exceptions import (
|
||||
InvalidGraphModelError,
|
||||
InvalidDataChunksError,
|
||||
|
|
@ -86,7 +86,7 @@ async def extract_graph_from_data(
|
|||
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[
|
||||
LLMGateway.extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
|
||||
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
|
||||
for chunk in data_chunks
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from pydantic import BaseModel
|
|||
|
||||
from cognee.tasks.summarization.exceptions import InvalidSummaryInputsError
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.extraction import extract_summary
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.tasks.summarization.models import TextSummary
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ async def summarize_text(
|
|||
summarization_model = cognee_config.summarization_model
|
||||
|
||||
chunk_summaries = await asyncio.gather(
|
||||
*[LLMGateway.extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
|
||||
*[extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
|
||||
)
|
||||
|
||||
summaries = [
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue