refactor: Initial baml refactor commit
This commit is contained in:
parent
ce130f358a
commit
7c27546951
14 changed files with 335 additions and 10 deletions
|
|
@ -19,6 +19,17 @@ class LLMGateway:
|
||||||
def acreate_structured_output(
|
def acreate_structured_output(
|
||||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
) -> Coroutine:
|
) -> Coroutine:
|
||||||
|
llm_config = get_llm_config()
|
||||||
|
if llm_config.structured_output_framework.upper() == "BAML":
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
||||||
|
acreate_structured_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
return acreate_structured_output(
|
||||||
|
content=text_input,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
||||||
get_llm_client,
|
get_llm_client,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ with EnsureBamlPyImport(__version__) as e:
|
||||||
from . import config
|
from . import config
|
||||||
from .config import reset_baml_env_vars
|
from .config import reset_baml_env_vars
|
||||||
|
|
||||||
from .sync_client import b
|
from .async_client import b
|
||||||
|
|
||||||
|
|
||||||
# FOR LEGACY COMPATIBILITY, expose "partial_types" as an alias for "stream_types"
|
# FOR LEGACY COMPATIBILITY, expose "partial_types" as an alias for "stream_types"
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,25 @@ class BamlAsyncClient:
|
||||||
def parse_stream(self):
|
def parse_stream(self):
|
||||||
return self.__llm_stream_parser
|
return self.__llm_stream_parser
|
||||||
|
|
||||||
|
async def AcreateStructuredOutput(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
baml_options: BamlCallOptions = {},
|
||||||
|
) -> types.DynamicModel:
|
||||||
|
result = await self.__options.merge_options(baml_options).call_function_async(
|
||||||
|
function_name="AcreateStructuredOutput",
|
||||||
|
args={
|
||||||
|
"content": content,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return typing.cast(
|
||||||
|
types.DynamicModel, result.cast_to(types, types, stream_types, False, __runtime__)
|
||||||
|
)
|
||||||
|
|
||||||
async def ExtractCategories(
|
async def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|
@ -184,6 +203,32 @@ class BamlStreamClient:
|
||||||
def __init__(self, options: DoNotUseDirectlyCallManager):
|
def __init__(self, options: DoNotUseDirectlyCallManager):
|
||||||
self.__options = options
|
self.__options = options
|
||||||
|
|
||||||
|
def AcreateStructuredOutput(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
baml_options: BamlCallOptions = {},
|
||||||
|
) -> baml_py.BamlStream[stream_types.DynamicModel, types.DynamicModel]:
|
||||||
|
ctx, result = self.__options.merge_options(baml_options).create_async_stream(
|
||||||
|
function_name="AcreateStructuredOutput",
|
||||||
|
args={
|
||||||
|
"content": content,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return baml_py.BamlStream[stream_types.DynamicModel, types.DynamicModel](
|
||||||
|
result,
|
||||||
|
lambda x: typing.cast(
|
||||||
|
stream_types.DynamicModel, x.cast_to(types, types, stream_types, True, __runtime__)
|
||||||
|
),
|
||||||
|
lambda x: typing.cast(
|
||||||
|
types.DynamicModel, x.cast_to(types, types, stream_types, False, __runtime__)
|
||||||
|
),
|
||||||
|
ctx,
|
||||||
|
)
|
||||||
|
|
||||||
def ExtractCategories(
|
def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|
@ -334,6 +379,24 @@ class BamlHttpRequestClient:
|
||||||
def __init__(self, options: DoNotUseDirectlyCallManager):
|
def __init__(self, options: DoNotUseDirectlyCallManager):
|
||||||
self.__options = options
|
self.__options = options
|
||||||
|
|
||||||
|
async def AcreateStructuredOutput(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
baml_options: BamlCallOptions = {},
|
||||||
|
) -> baml_py.baml_py.HTTPRequest:
|
||||||
|
result = await self.__options.merge_options(baml_options).create_http_request_async(
|
||||||
|
function_name="AcreateStructuredOutput",
|
||||||
|
args={
|
||||||
|
"content": content,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
},
|
||||||
|
mode="request",
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
async def ExtractCategories(
|
async def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|
@ -435,6 +498,24 @@ class BamlHttpStreamRequestClient:
|
||||||
def __init__(self, options: DoNotUseDirectlyCallManager):
|
def __init__(self, options: DoNotUseDirectlyCallManager):
|
||||||
self.__options = options
|
self.__options = options
|
||||||
|
|
||||||
|
async def AcreateStructuredOutput(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
baml_options: BamlCallOptions = {},
|
||||||
|
) -> baml_py.baml_py.HTTPRequest:
|
||||||
|
result = await self.__options.merge_options(baml_options).create_http_request_async(
|
||||||
|
function_name="AcreateStructuredOutput",
|
||||||
|
args={
|
||||||
|
"content": content,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
},
|
||||||
|
mode="stream",
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
async def ExtractCategories(
|
async def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -23,6 +23,16 @@ class LlmResponseParser:
|
||||||
def __init__(self, options: DoNotUseDirectlyCallManager):
|
def __init__(self, options: DoNotUseDirectlyCallManager):
|
||||||
self.__options = options
|
self.__options = options
|
||||||
|
|
||||||
|
def AcreateStructuredOutput(
|
||||||
|
self,
|
||||||
|
llm_response: str,
|
||||||
|
baml_options: BamlCallOptions = {},
|
||||||
|
) -> types.DynamicModel:
|
||||||
|
result = self.__options.merge_options(baml_options).parse_response(
|
||||||
|
function_name="AcreateStructuredOutput", llm_response=llm_response, mode="request"
|
||||||
|
)
|
||||||
|
return typing.cast(types.DynamicModel, result)
|
||||||
|
|
||||||
def ExtractCategories(
|
def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
llm_response: str,
|
llm_response: str,
|
||||||
|
|
@ -80,6 +90,16 @@ class LlmStreamParser:
|
||||||
def __init__(self, options: DoNotUseDirectlyCallManager):
|
def __init__(self, options: DoNotUseDirectlyCallManager):
|
||||||
self.__options = options
|
self.__options = options
|
||||||
|
|
||||||
|
def AcreateStructuredOutput(
|
||||||
|
self,
|
||||||
|
llm_response: str,
|
||||||
|
baml_options: BamlCallOptions = {},
|
||||||
|
) -> stream_types.DynamicModel:
|
||||||
|
result = self.__options.merge_options(baml_options).parse_response(
|
||||||
|
function_name="AcreateStructuredOutput", llm_response=llm_response, mode="stream"
|
||||||
|
)
|
||||||
|
return typing.cast(stream_types.DynamicModel, result)
|
||||||
|
|
||||||
def ExtractCategories(
|
def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
llm_response: str,
|
llm_response: str,
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ class StreamState(BaseModel, typing.Generic[StreamStateValueT]):
|
||||||
|
|
||||||
|
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
# Generated classes (17)
|
# Generated classes (18)
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -50,6 +50,11 @@ class DynamicKnowledgeGraph(BaseModel):
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicModel(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
test: typing.Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel):
|
class Edge(BaseModel):
|
||||||
# doc string for edge
|
# doc string for edge
|
||||||
# doc string for source_node_id
|
# doc string for source_node_id
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,25 @@ class BamlSyncClient:
|
||||||
def parse_stream(self):
|
def parse_stream(self):
|
||||||
return self.__llm_stream_parser
|
return self.__llm_stream_parser
|
||||||
|
|
||||||
|
def AcreateStructuredOutput(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
baml_options: BamlCallOptions = {},
|
||||||
|
) -> types.DynamicModel:
|
||||||
|
result = self.__options.merge_options(baml_options).call_function_sync(
|
||||||
|
function_name="AcreateStructuredOutput",
|
||||||
|
args={
|
||||||
|
"content": content,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return typing.cast(
|
||||||
|
types.DynamicModel, result.cast_to(types, types, stream_types, False, __runtime__)
|
||||||
|
)
|
||||||
|
|
||||||
def ExtractCategories(
|
def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|
@ -197,6 +216,32 @@ class BamlStreamClient:
|
||||||
def __init__(self, options: DoNotUseDirectlyCallManager):
|
def __init__(self, options: DoNotUseDirectlyCallManager):
|
||||||
self.__options = options
|
self.__options = options
|
||||||
|
|
||||||
|
def AcreateStructuredOutput(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
baml_options: BamlCallOptions = {},
|
||||||
|
) -> baml_py.BamlSyncStream[stream_types.DynamicModel, types.DynamicModel]:
|
||||||
|
ctx, result = self.__options.merge_options(baml_options).create_sync_stream(
|
||||||
|
function_name="AcreateStructuredOutput",
|
||||||
|
args={
|
||||||
|
"content": content,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return baml_py.BamlSyncStream[stream_types.DynamicModel, types.DynamicModel](
|
||||||
|
result,
|
||||||
|
lambda x: typing.cast(
|
||||||
|
stream_types.DynamicModel, x.cast_to(types, types, stream_types, True, __runtime__)
|
||||||
|
),
|
||||||
|
lambda x: typing.cast(
|
||||||
|
types.DynamicModel, x.cast_to(types, types, stream_types, False, __runtime__)
|
||||||
|
),
|
||||||
|
ctx,
|
||||||
|
)
|
||||||
|
|
||||||
def ExtractCategories(
|
def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|
@ -351,6 +396,24 @@ class BamlHttpRequestClient:
|
||||||
def __init__(self, options: DoNotUseDirectlyCallManager):
|
def __init__(self, options: DoNotUseDirectlyCallManager):
|
||||||
self.__options = options
|
self.__options = options
|
||||||
|
|
||||||
|
def AcreateStructuredOutput(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
baml_options: BamlCallOptions = {},
|
||||||
|
) -> baml_py.baml_py.HTTPRequest:
|
||||||
|
result = self.__options.merge_options(baml_options).create_http_request_sync(
|
||||||
|
function_name="AcreateStructuredOutput",
|
||||||
|
args={
|
||||||
|
"content": content,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
},
|
||||||
|
mode="request",
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def ExtractCategories(
|
def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|
@ -452,6 +515,24 @@ class BamlHttpStreamRequestClient:
|
||||||
def __init__(self, options: DoNotUseDirectlyCallManager):
|
def __init__(self, options: DoNotUseDirectlyCallManager):
|
||||||
self.__options = options
|
self.__options = options
|
||||||
|
|
||||||
|
def AcreateStructuredOutput(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
baml_options: BamlCallOptions = {},
|
||||||
|
) -> baml_py.baml_py.HTTPRequest:
|
||||||
|
result = self.__options.merge_options(baml_options).create_http_request_sync(
|
||||||
|
function_name="AcreateStructuredOutput",
|
||||||
|
args={
|
||||||
|
"content": content,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
},
|
||||||
|
mode="stream",
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def ExtractCategories(
|
def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ class TypeBuilder(type_builder.TypeBuilder):
|
||||||
"ContentLabel",
|
"ContentLabel",
|
||||||
"DefaultContentPrediction",
|
"DefaultContentPrediction",
|
||||||
"DynamicKnowledgeGraph",
|
"DynamicKnowledgeGraph",
|
||||||
|
"DynamicModel",
|
||||||
"Edge",
|
"Edge",
|
||||||
"ImageContent",
|
"ImageContent",
|
||||||
"KnowledgeGraph",
|
"KnowledgeGraph",
|
||||||
|
|
@ -49,7 +50,7 @@ class TypeBuilder(type_builder.TypeBuilder):
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
# Generated classes 17
|
# Generated classes 18
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -68,6 +69,10 @@ class TypeBuilder(type_builder.TypeBuilder):
|
||||||
def DynamicKnowledgeGraph(self) -> "DynamicKnowledgeGraphBuilder":
|
def DynamicKnowledgeGraph(self) -> "DynamicKnowledgeGraphBuilder":
|
||||||
return DynamicKnowledgeGraphBuilder(self)
|
return DynamicKnowledgeGraphBuilder(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def DynamicModel(self) -> "DynamicModelBuilder":
|
||||||
|
return DynamicModelBuilder(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def Edge(self) -> "EdgeViewer":
|
def Edge(self) -> "EdgeViewer":
|
||||||
return EdgeViewer(self)
|
return EdgeViewer(self)
|
||||||
|
|
@ -127,7 +132,7 @@ class TypeBuilder(type_builder.TypeBuilder):
|
||||||
|
|
||||||
|
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
# Generated classes 17
|
# Generated classes 18
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -305,6 +310,53 @@ class DynamicKnowledgeGraphProperties:
|
||||||
return self.__bldr.property(name)
|
return self.__bldr.property(name)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicModelAst:
|
||||||
|
def __init__(self, tb: type_builder.TypeBuilder):
|
||||||
|
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
||||||
|
self._bldr = _tb.class_("DynamicModel")
|
||||||
|
self._properties: typing.Set[str] = set(
|
||||||
|
[
|
||||||
|
"test",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._props = DynamicModelProperties(self._bldr, self._properties)
|
||||||
|
|
||||||
|
def type(self) -> baml_py.FieldType:
|
||||||
|
return self._bldr.field()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def props(self) -> "DynamicModelProperties":
|
||||||
|
return self._props
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicModelBuilder(DynamicModelAst):
|
||||||
|
def __init__(self, tb: type_builder.TypeBuilder):
|
||||||
|
super().__init__(tb)
|
||||||
|
|
||||||
|
def add_property(self, name: str, type: baml_py.FieldType) -> baml_py.ClassPropertyBuilder:
|
||||||
|
if name in self._properties:
|
||||||
|
raise ValueError(f"Property {name} already exists.")
|
||||||
|
return self._bldr.property(name).type(type)
|
||||||
|
|
||||||
|
def list_properties(self) -> typing.List[typing.Tuple[str, baml_py.ClassPropertyBuilder]]:
|
||||||
|
return [(name, self._bldr.property(name)) for name in self._properties]
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicModelProperties:
|
||||||
|
def __init__(self, bldr: baml_py.ClassBuilder, properties: typing.Set[str]):
|
||||||
|
self.__bldr = bldr
|
||||||
|
self.__properties = properties # type: ignore (we know how to use this private attribute) # noqa: F821
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> baml_py.ClassPropertyBuilder:
|
||||||
|
if name not in self.__properties:
|
||||||
|
raise AttributeError(f"Property {name} not found.")
|
||||||
|
return self.__bldr.property(name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test(self) -> baml_py.ClassPropertyBuilder:
|
||||||
|
return self.__bldr.property("test")
|
||||||
|
|
||||||
|
|
||||||
class EdgeAst:
|
class EdgeAst:
|
||||||
def __init__(self, tb: type_builder.TypeBuilder):
|
def __init__(self, tb: type_builder.TypeBuilder):
|
||||||
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,8 @@ type_map = {
|
||||||
"stream_types.DefaultContentPrediction": stream_types.DefaultContentPrediction,
|
"stream_types.DefaultContentPrediction": stream_types.DefaultContentPrediction,
|
||||||
"types.DynamicKnowledgeGraph": types.DynamicKnowledgeGraph,
|
"types.DynamicKnowledgeGraph": types.DynamicKnowledgeGraph,
|
||||||
"stream_types.DynamicKnowledgeGraph": stream_types.DynamicKnowledgeGraph,
|
"stream_types.DynamicKnowledgeGraph": stream_types.DynamicKnowledgeGraph,
|
||||||
|
"types.DynamicModel": types.DynamicModel,
|
||||||
|
"stream_types.DynamicModel": stream_types.DynamicModel,
|
||||||
"types.Edge": types.Edge,
|
"types.Edge": types.Edge,
|
||||||
"stream_types.Edge": stream_types.Edge,
|
"stream_types.Edge": stream_types.Edge,
|
||||||
"types.ImageContent": types.ImageContent,
|
"types.ImageContent": types.ImageContent,
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ def all_succeeded(checks: typing.Dict[CheckName, Check]) -> bool:
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
# Generated classes (17)
|
# Generated classes (18)
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -79,6 +79,11 @@ class DynamicKnowledgeGraph(BaseModel):
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicModel(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
test: str
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel):
|
class Edge(BaseModel):
|
||||||
# doc string for edge
|
# doc string for edge
|
||||||
# doc string for source_node_id
|
# doc string for source_node_id
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
class DynamicModel {
|
||||||
|
test string
|
||||||
|
@@dynamic
|
||||||
|
}
|
||||||
|
|
||||||
|
function AcreateStructuredOutput(
|
||||||
|
content: string,
|
||||||
|
system_prompt: string,
|
||||||
|
user_prompt: string,
|
||||||
|
) -> DynamicModel {
|
||||||
|
client OpenAI
|
||||||
|
|
||||||
|
prompt #"
|
||||||
|
{{ system_prompt }}
|
||||||
|
{{ ctx.output_format }}
|
||||||
|
{{ _.role('user') }}
|
||||||
|
{{ user_prompt }}
|
||||||
|
{{ content }}
|
||||||
|
"#
|
||||||
|
}
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
from .knowledge_graph.extract_content_graph import extract_content_graph
|
from .knowledge_graph.extract_content_graph import extract_content_graph
|
||||||
from .extract_summary import extract_summary, extract_code_summary
|
from .extract_summary import extract_summary, extract_code_summary
|
||||||
|
from .acreate_structured_output import acreate_structured_output
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from typing import Type
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.shared.data_models import SummarizedCode
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
||||||
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("extract_summary_baml")
|
||||||
|
|
||||||
|
|
||||||
|
async def acreate_structured_output(
|
||||||
|
content: str, system_prompt: str, user_prompt: str, response_model: Type[BaseModel]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Extract summary using BAML framework.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The content to summarize
|
||||||
|
response_model: The Pydantic model type for the response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseModel: The summarized content in the specified format
|
||||||
|
"""
|
||||||
|
config = get_llm_config()
|
||||||
|
|
||||||
|
# Use BAML's SummarizeContent function
|
||||||
|
result = await b.AcreateStructuredOutput(
|
||||||
|
content=content,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
baml_options={"client_registry": config.baml_registry},
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(acreate_structured_output("TEST", SummarizedCode))
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
|
@ -6,13 +6,13 @@ generator target {
|
||||||
output_type "python/pydantic"
|
output_type "python/pydantic"
|
||||||
|
|
||||||
// Where the generated code will be saved (relative to baml_src/)
|
// Where the generated code will be saved (relative to baml_src/)
|
||||||
output_dir "../baml/"
|
output_dir "../"
|
||||||
|
|
||||||
// The version of the BAML package you have installed (e.g. same version as your baml-py or @boundaryml/baml).
|
// The version of the BAML package you have installed (e.g. same version as your baml-py or @boundaryml/baml).
|
||||||
// The BAML VSCode extension version should also match this version.
|
// The BAML VSCode extension version should also match this version.
|
||||||
version "0.201.0"
|
version "0.206.0"
|
||||||
|
|
||||||
// Valid values: "sync", "async"
|
// Valid values: "sync", "async"
|
||||||
// This controls what `b.FunctionName()` will be (sync or async).
|
// This controls what `b.FunctionName()` will be (sync or async).
|
||||||
default_client_mode sync
|
default_client_mode async
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue