feat: Add baml dynamic typing
This commit is contained in:
parent
59cd31b916
commit
89b51a244d
19 changed files with 67 additions and 326 deletions
|
|
@ -9,12 +9,6 @@ class LLMGateway:
|
||||||
Class used as a namespace for LLM related functions, should not be instantiated, all methods are static.
|
Class used as a namespace for LLM related functions, should not be instantiated, all methods are static.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def render_prompt(filename: str, context: dict, base_directory: str = None):
|
|
||||||
from cognee.infrastructure.llm.prompts import render_prompt
|
|
||||||
|
|
||||||
return render_prompt(filename=filename, context=context, base_directory=base_directory)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def acreate_structured_output(
|
def acreate_structured_output(
|
||||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
|
|
@ -30,14 +24,15 @@ class LLMGateway:
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
else:
|
||||||
get_llm_client,
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
||||||
)
|
get_llm_client,
|
||||||
|
)
|
||||||
|
|
||||||
llm_client = get_llm_client()
|
llm_client = get_llm_client()
|
||||||
return llm_client.acreate_structured_output(
|
return llm_client.acreate_structured_output(
|
||||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_structured_output(
|
def create_structured_output(
|
||||||
|
|
@ -69,107 +64,3 @@ class LLMGateway:
|
||||||
|
|
||||||
llm_client = get_llm_client()
|
llm_client = get_llm_client()
|
||||||
return llm_client.transcribe_image(input=input)
|
return llm_client.transcribe_image(input=input)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def show_prompt(text_input: str, system_prompt: str) -> str:
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
|
||||||
get_llm_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_client = get_llm_client()
|
|
||||||
return llm_client.show_prompt(text_input=text_input, system_prompt=system_prompt)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def read_query_prompt(prompt_file_name: str, base_directory: str = None):
|
|
||||||
from cognee.infrastructure.llm.prompts import (
|
|
||||||
read_query_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
return read_query_prompt(prompt_file_name=prompt_file_name, base_directory=base_directory)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_content_graph(
|
|
||||||
content: str,
|
|
||||||
response_model: Type[BaseModel],
|
|
||||||
mode: str = "simple",
|
|
||||||
custom_prompt: Optional[str] = None,
|
|
||||||
) -> Coroutine:
|
|
||||||
llm_config = get_llm_config()
|
|
||||||
if llm_config.structured_output_framework.upper() == "BAML":
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
|
||||||
extract_content_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
return extract_content_graph(
|
|
||||||
content=content,
|
|
||||||
response_model=response_model,
|
|
||||||
mode=mode,
|
|
||||||
custom_prompt=custom_prompt,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
|
||||||
extract_content_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
return extract_content_graph(
|
|
||||||
content=content, response_model=response_model, custom_prompt=custom_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_categories(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
|
||||||
# TODO: Add BAML version of category and extraction and update function
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
|
||||||
extract_categories,
|
|
||||||
)
|
|
||||||
|
|
||||||
return extract_categories(content=content, response_model=response_model)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_code_summary(content: str) -> Coroutine:
|
|
||||||
llm_config = get_llm_config()
|
|
||||||
if llm_config.structured_output_framework.upper() == "BAML":
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
|
||||||
extract_code_summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
return extract_code_summary(content=content)
|
|
||||||
else:
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
|
||||||
extract_code_summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
return extract_code_summary(content=content)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_summary(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
|
||||||
# llm_config = get_llm_config()
|
|
||||||
# if llm_config.structured_output_framework.upper() == "BAML":
|
|
||||||
# from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
|
||||||
# extract_summary,
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# return extract_summary(content=content, response_model=response_model)
|
|
||||||
# else:
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
|
||||||
extract_summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
return extract_summary(content=content, response_model=response_model)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_event_graph(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
|
||||||
# TODO: Add BAML version of category and extraction and update function (consulted with Igor)
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
|
||||||
extract_event_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
return extract_event_graph(content=content, response_model=response_model)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_event_entities(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
|
||||||
# TODO: Add BAML version of category and extraction and update function (consulted with Igor)
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
|
||||||
extract_event_entities,
|
|
||||||
)
|
|
||||||
|
|
||||||
return extract_event_entities(content=content, response_model=response_model)
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
|
|
||||||
|
|
||||||
async def extract_categories(content: str, response_model: Type[BaseModel]):
|
async def extract_categories(content: str, response_model: Type[BaseModel]):
|
||||||
system_prompt = LLMGateway.read_query_prompt("classify_content.txt")
|
system_prompt = read_query_prompt("classify_content.txt")
|
||||||
|
|
||||||
llm_output = await LLMGateway.acreate_structured_output(content, system_prompt, response_model)
|
llm_output = await LLMGateway.acreate_structured_output(content, system_prompt, response_model)
|
||||||
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import List, Type
|
from typing import Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from cognee.infrastructure.llm.prompts.render_prompt import render_prompt
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
from cognee.infrastructure.llm.config import (
|
from cognee.infrastructure.llm.config import (
|
||||||
get_llm_config,
|
get_llm_config,
|
||||||
|
|
@ -35,7 +36,7 @@ async def extract_event_entities(content: str, response_model: Type[BaseModel]):
|
||||||
else:
|
else:
|
||||||
base_directory = None
|
base_directory = None
|
||||||
|
|
||||||
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||||
|
|
||||||
content_graph = await LLMGateway.acreate_structured_output(
|
content_graph = await LLMGateway.acreate_structured_output(
|
||||||
content, system_prompt, response_model
|
content, system_prompt, response_model
|
||||||
|
|
@ -2,7 +2,8 @@ from cognee.shared.logging_utils import get_logger
|
||||||
import os
|
import os
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from instructor.exceptions import InstructorRetryException
|
from instructor.core import InstructorRetryException
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
|
|
@ -25,7 +26,7 @@ def get_mock_summarized_code():
|
||||||
|
|
||||||
|
|
||||||
async def extract_summary(content: str, response_model: Type[BaseModel]):
|
async def extract_summary(content: str, response_model: Type[BaseModel]):
|
||||||
system_prompt = LLMGateway.read_query_prompt("summarize_content.txt")
|
system_prompt = read_query_prompt("summarize_content.txt")
|
||||||
|
|
||||||
llm_output = await LLMGateway.acreate_structured_output(content, system_prompt, response_model)
|
llm_output = await LLMGateway.acreate_structured_output(content, system_prompt, response_model)
|
||||||
|
|
||||||
|
|
@ -2,6 +2,7 @@ import os
|
||||||
from typing import Type, Optional
|
from typing import Type, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.prompts import render_prompt
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
from cognee.infrastructure.llm.config import (
|
from cognee.infrastructure.llm.config import (
|
||||||
get_llm_config,
|
get_llm_config,
|
||||||
|
|
@ -26,7 +27,7 @@ async def extract_content_graph(
|
||||||
else:
|
else:
|
||||||
base_directory = None
|
base_directory = None
|
||||||
|
|
||||||
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||||
|
|
||||||
content_graph = await LLMGateway.acreate_structured_output(
|
content_graph = await LLMGateway.acreate_structured_output(
|
||||||
content, system_prompt, response_model
|
content, system_prompt, response_model
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import os
|
import os
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.prompts import render_prompt
|
||||||
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
from cognee.infrastructure.llm.config import (
|
from cognee.infrastructure.llm.config import (
|
||||||
get_llm_config,
|
get_llm_config,
|
||||||
)
|
)
|
||||||
|
|
@ -37,7 +38,7 @@ async def extract_event_graph(content: str, response_model: Type[BaseModel]):
|
||||||
else:
|
else:
|
||||||
base_directory = None
|
base_directory = None
|
||||||
|
|
||||||
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||||
|
|
||||||
content_graph = await LLMGateway.acreate_structured_output(
|
content_graph = await LLMGateway.acreate_structured_output(
|
||||||
content, system_prompt, response_model
|
content, system_prompt, response_model
|
||||||
35
cognee/infrastructure/llm/prompts/show_prompt.py
Normal file
35
cognee/infrastructure/llm/prompts/show_prompt.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def show_prompt(text_input: str, system_prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Format and display the prompt for a user query.
|
||||||
|
|
||||||
|
This method formats the prompt using the provided user input and system prompt,
|
||||||
|
returning a string representation. Raises MissingSystemPromptPathError if the system prompt is not
|
||||||
|
provided.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- text_input (str): The input text provided by the user.
|
||||||
|
- system_prompt (str): The system's prompt to guide the model's response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- str: A formatted string representing the user input and system prompt.
|
||||||
|
"""
|
||||||
|
if not text_input:
|
||||||
|
text_input = "No user input provided."
|
||||||
|
if not system_prompt:
|
||||||
|
raise MissingSystemPromptPathError()
|
||||||
|
system_prompt = read_query_prompt(system_prompt)
|
||||||
|
|
||||||
|
formatted_prompt = (
|
||||||
|
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||||
|
if system_prompt
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return formatted_prompt
|
||||||
|
|
@ -1,3 +1 @@
|
||||||
from .knowledge_graph.extract_content_graph import extract_content_graph
|
|
||||||
from .extract_summary import extract_summary, extract_code_summary
|
|
||||||
from .acreate_structured_output import acreate_structured_output
|
from .acreate_structured_output import acreate_structured_output
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from baml_py.baml_py import ClassBuilder
|
||||||
|
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
from typing import List, Union, Optional, Literal
|
from typing import Union, Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
@ -16,32 +16,7 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client impo
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import get_origin, get_args
|
from typing import get_origin, get_args
|
||||||
|
|
||||||
logger = get_logger("extract_summary_baml")
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class SummarizedFunction(BaseModel):
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
inputs: Optional[List[str]] = None
|
|
||||||
outputs: Optional[List[str]] = None
|
|
||||||
decorators: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizedClass(BaseModel):
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
methods: Optional[List[SummarizedFunction]] = None
|
|
||||||
decorators: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizedCode(BaseModel):
|
|
||||||
high_level_summary: str
|
|
||||||
key_features: List[str]
|
|
||||||
imports: List[str] = []
|
|
||||||
constants: List[str] = []
|
|
||||||
classes: List[SummarizedClass] = []
|
|
||||||
functions: List[SummarizedFunction] = []
|
|
||||||
workflow_description: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_baml_type(tb, baml_model, pydantic_model):
|
def create_dynamic_baml_type(tb, baml_model, pydantic_model):
|
||||||
|
|
@ -179,7 +154,7 @@ async def acreate_structured_output(
|
||||||
config = get_llm_config()
|
config = get_llm_config()
|
||||||
tb = TypeBuilder()
|
tb = TypeBuilder()
|
||||||
|
|
||||||
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, SummarizedCode)
|
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, response_model)
|
||||||
|
|
||||||
result = await b.AcreateStructuredOutput(
|
result = await b.AcreateStructuredOutput(
|
||||||
text_input=text_input,
|
text_input=text_input,
|
||||||
|
|
@ -187,13 +162,15 @@ async def acreate_structured_output(
|
||||||
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
if response_model is str:
|
||||||
|
return result
|
||||||
|
return response_model.model_validate(result.dict())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(acreate_structured_output("TEST", "THIS IS A TEST", SummarizedCode))
|
loop.run_until_complete(acreate_structured_output("TEST", "THIS IS A TEST", str))
|
||||||
finally:
|
finally:
|
||||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
|
|
||||||
|
|
@ -1,89 +0,0 @@
|
||||||
import os
|
|
||||||
from typing import Type
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from baml_py import ClientRegistry
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from cognee.shared.data_models import SummarizedCode
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("extract_summary_baml")
|
|
||||||
|
|
||||||
|
|
||||||
def get_mock_summarized_code():
|
|
||||||
"""Local mock function to avoid circular imports."""
|
|
||||||
return SummarizedCode(
|
|
||||||
high_level_summary="Mock code summary",
|
|
||||||
key_features=["Mock feature 1", "Mock feature 2"],
|
|
||||||
imports=["mock_import"],
|
|
||||||
constants=["MOCK_CONSTANT"],
|
|
||||||
classes=[],
|
|
||||||
functions=[],
|
|
||||||
workflow_description="Mock workflow description",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_summary(content: str, response_model: Type[BaseModel]):
|
|
||||||
"""
|
|
||||||
Extract summary using BAML framework.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: The content to summarize
|
|
||||||
response_model: The Pydantic model type for the response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseModel: The summarized content in the specified format
|
|
||||||
"""
|
|
||||||
config = get_llm_config()
|
|
||||||
|
|
||||||
# Use BAML's SummarizeContent function
|
|
||||||
summary_result = await b.SummarizeContent(
|
|
||||||
content, baml_options={"client_registry": config.baml_registry}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert BAML result to the expected response model
|
|
||||||
if response_model is SummarizedCode:
|
|
||||||
# If it's asking for SummarizedCode but we got SummarizedContent,
|
|
||||||
# we need to use SummarizeCode instead
|
|
||||||
code_result = await b.SummarizeCode(
|
|
||||||
content, baml_options={"client_registry": config.baml_registry}
|
|
||||||
)
|
|
||||||
return code_result
|
|
||||||
else:
|
|
||||||
# For other models, return the summary result
|
|
||||||
return summary_result
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_code_summary(content: str):
|
|
||||||
"""
|
|
||||||
Extract code summary using BAML framework with mocking support.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: The code content to summarize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SummarizedCode: The summarized code information
|
|
||||||
"""
|
|
||||||
enable_mocking = os.getenv("MOCK_CODE_SUMMARY", "false")
|
|
||||||
if isinstance(enable_mocking, bool):
|
|
||||||
enable_mocking = str(enable_mocking).lower()
|
|
||||||
enable_mocking = enable_mocking in ("true", "1", "yes")
|
|
||||||
|
|
||||||
if enable_mocking:
|
|
||||||
result = get_mock_summarized_code()
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
config = get_llm_config()
|
|
||||||
|
|
||||||
result = await b.SummarizeCode(
|
|
||||||
content, baml_options={"client_registry": config.baml_registry}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
|
|
||||||
)
|
|
||||||
result = get_mock_summarized_code()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
from typing import Type, Optional
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
|
||||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_content_graph(
|
|
||||||
content: str,
|
|
||||||
response_model: Type[BaseModel],
|
|
||||||
mode: str = "simple",
|
|
||||||
custom_prompt: Optional[str] = None,
|
|
||||||
):
|
|
||||||
config = get_llm_config()
|
|
||||||
setup_logging()
|
|
||||||
|
|
||||||
get_logger(level="INFO")
|
|
||||||
|
|
||||||
# if response_model:
|
|
||||||
# # tb = TypeBuilder()
|
|
||||||
# # country = tb.union \
|
|
||||||
# # ([tb.literal_string("USA"), tb.literal_string("UK"), tb.literal_string("Germany"), tb.literal_string("other")])
|
|
||||||
# # tb.Node.add_property("country", country)
|
|
||||||
#
|
|
||||||
# graph = await b.ExtractDynamicContentGraph(
|
|
||||||
# content, mode=mode, baml_options={"client_registry": baml_registry}
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# return graph
|
|
||||||
|
|
||||||
# else:
|
|
||||||
if custom_prompt:
|
|
||||||
graph = await b.ExtractContentGraphGeneric(
|
|
||||||
content,
|
|
||||||
mode="custom",
|
|
||||||
custom_prompt_content=custom_prompt,
|
|
||||||
baml_options={"client_registry": config.baml_registry},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
graph = await b.ExtractContentGraphGeneric(
|
|
||||||
content, mode=mode, baml_options={"client_registry": config.baml_registry}
|
|
||||||
)
|
|
||||||
|
|
||||||
return graph
|
|
||||||
|
|
@ -328,35 +328,3 @@ class OpenAIAdapter(LLMInterface):
|
||||||
max_completion_tokens=300,
|
max_completion_tokens=300,
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
)
|
)
|
||||||
|
|
||||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
|
||||||
"""
|
|
||||||
Format and display the prompt for a user query.
|
|
||||||
|
|
||||||
This method formats the prompt using the provided user input and system prompt,
|
|
||||||
returning a string representation. Raises MissingSystemPromptPathError if the system prompt is not
|
|
||||||
provided.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- text_input (str): The input text provided by the user.
|
|
||||||
- system_prompt (str): The system's prompt to guide the model's response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- str: A formatted string representing the user input and system prompt.
|
|
||||||
"""
|
|
||||||
if not text_input:
|
|
||||||
text_input = "No user input provided."
|
|
||||||
if not system_prompt:
|
|
||||||
raise MissingSystemPromptPathError()
|
|
||||||
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
|
||||||
|
|
||||||
formatted_prompt = (
|
|
||||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
|
||||||
if system_prompt
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return formatted_prompt
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from cognee.modules.graph.utils import (
|
||||||
retrieve_existing_edges,
|
retrieve_existing_edges,
|
||||||
)
|
)
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.extraction import extract_content_graph
|
||||||
from cognee.tasks.graph.exceptions import (
|
from cognee.tasks.graph.exceptions import (
|
||||||
InvalidGraphModelError,
|
InvalidGraphModelError,
|
||||||
InvalidDataChunksError,
|
InvalidDataChunksError,
|
||||||
|
|
@ -86,7 +86,7 @@ async def extract_graph_from_data(
|
||||||
|
|
||||||
chunk_graphs = await asyncio.gather(
|
chunk_graphs = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
LLMGateway.extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
|
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
|
||||||
for chunk in data_chunks
|
for chunk in data_chunks
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.tasks.summarization.exceptions import InvalidSummaryInputsError
|
from cognee.tasks.summarization.exceptions import InvalidSummaryInputsError
|
||||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.extraction import extract_summary
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.tasks.summarization.models import TextSummary
|
from cognee.tasks.summarization.models import TextSummary
|
||||||
|
|
||||||
|
|
@ -50,7 +50,7 @@ async def summarize_text(
|
||||||
summarization_model = cognee_config.summarization_model
|
summarization_model = cognee_config.summarization_model
|
||||||
|
|
||||||
chunk_summaries = await asyncio.gather(
|
chunk_summaries = await asyncio.gather(
|
||||||
*[LLMGateway.extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
|
*[extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
|
||||||
)
|
)
|
||||||
|
|
||||||
summaries = [
|
summaries = [
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue