feat: Save baml progress
This commit is contained in:
parent
f7eb482ba5
commit
befb8ac237
12 changed files with 221 additions and 240 deletions
|
|
@ -26,7 +26,7 @@ class LLMGateway:
|
||||||
)
|
)
|
||||||
|
|
||||||
return acreate_structured_output(
|
return acreate_structured_output(
|
||||||
content=text_input,
|
text_input=text_input,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
@ -142,19 +142,19 @@ class LLMGateway:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_summary(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
def extract_summary(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
||||||
llm_config = get_llm_config()
|
# llm_config = get_llm_config()
|
||||||
if llm_config.structured_output_framework.upper() == "BAML":
|
# if llm_config.structured_output_framework.upper() == "BAML":
|
||||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
# from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
||||||
extract_summary,
|
# extract_summary,
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
|
# return extract_summary(content=content, response_model=response_model)
|
||||||
|
# else:
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
||||||
|
extract_summary,
|
||||||
|
)
|
||||||
|
|
||||||
return extract_summary(content=content, response_model=response_model)
|
return extract_summary(content=content, response_model=response_model)
|
||||||
else:
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
|
||||||
extract_summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
return extract_summary(content=content, response_model=response_model)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_event_graph(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
def extract_event_graph(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
||||||
|
|
|
||||||
|
|
@ -81,19 +81,15 @@ class BamlAsyncClient:
|
||||||
|
|
||||||
async def AcreateStructuredOutput(
|
async def AcreateStructuredOutput(
|
||||||
self,
|
self,
|
||||||
content: str,
|
text_input: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
|
||||||
baml_options: BamlCallOptions = {},
|
baml_options: BamlCallOptions = {},
|
||||||
) -> types.DynamicOutputModel:
|
) -> types.ResponseModel:
|
||||||
# Check if on_tick is provided
|
# Check if on_tick is provided
|
||||||
if "on_tick" in baml_options:
|
if "on_tick" in baml_options:
|
||||||
# Use streaming internally when on_tick is provided
|
# Use streaming internally when on_tick is provided
|
||||||
stream = self.stream.AcreateStructuredOutput(
|
stream = self.stream.AcreateStructuredOutput(
|
||||||
content=content,
|
text_input=text_input, system_prompt=system_prompt, baml_options=baml_options
|
||||||
system_prompt=system_prompt,
|
|
||||||
user_prompt=user_prompt,
|
|
||||||
baml_options=baml_options,
|
|
||||||
)
|
)
|
||||||
return await stream.get_final_response()
|
return await stream.get_final_response()
|
||||||
else:
|
else:
|
||||||
|
|
@ -101,14 +97,12 @@ class BamlAsyncClient:
|
||||||
result = await self.__options.merge_options(baml_options).call_function_async(
|
result = await self.__options.merge_options(baml_options).call_function_async(
|
||||||
function_name="AcreateStructuredOutput",
|
function_name="AcreateStructuredOutput",
|
||||||
args={
|
args={
|
||||||
"content": content,
|
"text_input": text_input,
|
||||||
"system_prompt": system_prompt,
|
"system_prompt": system_prompt,
|
||||||
"user_prompt": user_prompt,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return typing.cast(
|
return typing.cast(
|
||||||
types.DynamicOutputModel,
|
types.ResponseModel, result.cast_to(types, types, stream_types, False, __runtime__)
|
||||||
result.cast_to(types, types, stream_types, False, __runtime__),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def ExtractCategories(
|
async def ExtractCategories(
|
||||||
|
|
@ -267,27 +261,24 @@ class BamlStreamClient:
|
||||||
|
|
||||||
def AcreateStructuredOutput(
|
def AcreateStructuredOutput(
|
||||||
self,
|
self,
|
||||||
content: str,
|
text_input: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
|
||||||
baml_options: BamlCallOptions = {},
|
baml_options: BamlCallOptions = {},
|
||||||
) -> baml_py.BamlStream[stream_types.DynamicOutputModel, types.DynamicOutputModel]:
|
) -> baml_py.BamlStream[stream_types.ResponseModel, types.ResponseModel]:
|
||||||
ctx, result = self.__options.merge_options(baml_options).create_async_stream(
|
ctx, result = self.__options.merge_options(baml_options).create_async_stream(
|
||||||
function_name="AcreateStructuredOutput",
|
function_name="AcreateStructuredOutput",
|
||||||
args={
|
args={
|
||||||
"content": content,
|
"text_input": text_input,
|
||||||
"system_prompt": system_prompt,
|
"system_prompt": system_prompt,
|
||||||
"user_prompt": user_prompt,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return baml_py.BamlStream[stream_types.DynamicOutputModel, types.DynamicOutputModel](
|
return baml_py.BamlStream[stream_types.ResponseModel, types.ResponseModel](
|
||||||
result,
|
result,
|
||||||
lambda x: typing.cast(
|
lambda x: typing.cast(
|
||||||
stream_types.DynamicOutputModel,
|
stream_types.ResponseModel, x.cast_to(types, types, stream_types, True, __runtime__)
|
||||||
x.cast_to(types, types, stream_types, True, __runtime__),
|
|
||||||
),
|
),
|
||||||
lambda x: typing.cast(
|
lambda x: typing.cast(
|
||||||
types.DynamicOutputModel, x.cast_to(types, types, stream_types, False, __runtime__)
|
types.ResponseModel, x.cast_to(types, types, stream_types, False, __runtime__)
|
||||||
),
|
),
|
||||||
ctx,
|
ctx,
|
||||||
)
|
)
|
||||||
|
|
@ -444,17 +435,15 @@ class BamlHttpRequestClient:
|
||||||
|
|
||||||
async def AcreateStructuredOutput(
|
async def AcreateStructuredOutput(
|
||||||
self,
|
self,
|
||||||
content: str,
|
text_input: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
|
||||||
baml_options: BamlCallOptions = {},
|
baml_options: BamlCallOptions = {},
|
||||||
) -> baml_py.baml_py.HTTPRequest:
|
) -> baml_py.baml_py.HTTPRequest:
|
||||||
result = await self.__options.merge_options(baml_options).create_http_request_async(
|
result = await self.__options.merge_options(baml_options).create_http_request_async(
|
||||||
function_name="AcreateStructuredOutput",
|
function_name="AcreateStructuredOutput",
|
||||||
args={
|
args={
|
||||||
"content": content,
|
"text_input": text_input,
|
||||||
"system_prompt": system_prompt,
|
"system_prompt": system_prompt,
|
||||||
"user_prompt": user_prompt,
|
|
||||||
},
|
},
|
||||||
mode="request",
|
mode="request",
|
||||||
)
|
)
|
||||||
|
|
@ -563,17 +552,15 @@ class BamlHttpStreamRequestClient:
|
||||||
|
|
||||||
async def AcreateStructuredOutput(
|
async def AcreateStructuredOutput(
|
||||||
self,
|
self,
|
||||||
content: str,
|
text_input: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
|
||||||
baml_options: BamlCallOptions = {},
|
baml_options: BamlCallOptions = {},
|
||||||
) -> baml_py.baml_py.HTTPRequest:
|
) -> baml_py.baml_py.HTTPRequest:
|
||||||
result = await self.__options.merge_options(baml_options).create_http_request_async(
|
result = await self.__options.merge_options(baml_options).create_http_request_async(
|
||||||
function_name="AcreateStructuredOutput",
|
function_name="AcreateStructuredOutput",
|
||||||
args={
|
args={
|
||||||
"content": content,
|
"text_input": text_input,
|
||||||
"system_prompt": system_prompt,
|
"system_prompt": system_prompt,
|
||||||
"user_prompt": user_prompt,
|
|
||||||
},
|
},
|
||||||
mode="stream",
|
mode="stream",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -27,11 +27,11 @@ class LlmResponseParser:
|
||||||
self,
|
self,
|
||||||
llm_response: str,
|
llm_response: str,
|
||||||
baml_options: BamlCallOptions = {},
|
baml_options: BamlCallOptions = {},
|
||||||
) -> types.DynamicOutputModel:
|
) -> types.ResponseModel:
|
||||||
result = self.__options.merge_options(baml_options).parse_response(
|
result = self.__options.merge_options(baml_options).parse_response(
|
||||||
function_name="AcreateStructuredOutput", llm_response=llm_response, mode="request"
|
function_name="AcreateStructuredOutput", llm_response=llm_response, mode="request"
|
||||||
)
|
)
|
||||||
return typing.cast(types.DynamicOutputModel, result)
|
return typing.cast(types.ResponseModel, result)
|
||||||
|
|
||||||
def ExtractCategories(
|
def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
|
|
@ -94,11 +94,11 @@ class LlmStreamParser:
|
||||||
self,
|
self,
|
||||||
llm_response: str,
|
llm_response: str,
|
||||||
baml_options: BamlCallOptions = {},
|
baml_options: BamlCallOptions = {},
|
||||||
) -> stream_types.DynamicOutputModel:
|
) -> stream_types.ResponseModel:
|
||||||
result = self.__options.merge_options(baml_options).parse_response(
|
result = self.__options.merge_options(baml_options).parse_response(
|
||||||
function_name="AcreateStructuredOutput", llm_response=llm_response, mode="stream"
|
function_name="AcreateStructuredOutput", llm_response=llm_response, mode="stream"
|
||||||
)
|
)
|
||||||
return typing.cast(stream_types.DynamicOutputModel, result)
|
return typing.cast(stream_types.ResponseModel, result)
|
||||||
|
|
||||||
def ExtractCategories(
|
def ExtractCategories(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ class StreamState(BaseModel, typing.Generic[StreamStateValueT]):
|
||||||
|
|
||||||
|
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
# Generated classes (19)
|
# Generated classes (18)
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -46,20 +46,10 @@ class DefaultContentPrediction(BaseModel):
|
||||||
label: typing.Optional["ContentLabel"] = None
|
label: typing.Optional["ContentLabel"] = None
|
||||||
|
|
||||||
|
|
||||||
class DynamicInputModel(BaseModel):
|
|
||||||
model_config = ConfigDict(extra="allow")
|
|
||||||
test: typing.Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicKnowledgeGraph(BaseModel):
|
class DynamicKnowledgeGraph(BaseModel):
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutputModel(BaseModel):
|
|
||||||
model_config = ConfigDict(extra="allow")
|
|
||||||
test: typing.Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel):
|
class Edge(BaseModel):
|
||||||
# doc string for edge
|
# doc string for edge
|
||||||
# doc string for source_node_id
|
# doc string for source_node_id
|
||||||
|
|
@ -102,6 +92,10 @@ class ProceduralContent(BaseModel):
|
||||||
subclass: typing.List[str]
|
subclass: typing.List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseModel(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class SummarizedClass(BaseModel):
|
class SummarizedClass(BaseModel):
|
||||||
name: typing.Optional[str] = None
|
name: typing.Optional[str] = None
|
||||||
description: typing.Optional[str] = None
|
description: typing.Optional[str] = None
|
||||||
|
|
|
||||||
|
|
@ -94,18 +94,14 @@ class BamlSyncClient:
|
||||||
|
|
||||||
def AcreateStructuredOutput(
|
def AcreateStructuredOutput(
|
||||||
self,
|
self,
|
||||||
content: str,
|
text_input: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
|
||||||
baml_options: BamlCallOptions = {},
|
baml_options: BamlCallOptions = {},
|
||||||
) -> types.DynamicOutputModel:
|
) -> types.ResponseModel:
|
||||||
# Check if on_tick is provided
|
# Check if on_tick is provided
|
||||||
if "on_tick" in baml_options:
|
if "on_tick" in baml_options:
|
||||||
stream = self.stream.AcreateStructuredOutput(
|
stream = self.stream.AcreateStructuredOutput(
|
||||||
content=content,
|
text_input=text_input, system_prompt=system_prompt, baml_options=baml_options
|
||||||
system_prompt=system_prompt,
|
|
||||||
user_prompt=user_prompt,
|
|
||||||
baml_options=baml_options,
|
|
||||||
)
|
)
|
||||||
return stream.get_final_response()
|
return stream.get_final_response()
|
||||||
else:
|
else:
|
||||||
|
|
@ -113,14 +109,12 @@ class BamlSyncClient:
|
||||||
result = self.__options.merge_options(baml_options).call_function_sync(
|
result = self.__options.merge_options(baml_options).call_function_sync(
|
||||||
function_name="AcreateStructuredOutput",
|
function_name="AcreateStructuredOutput",
|
||||||
args={
|
args={
|
||||||
"content": content,
|
"text_input": text_input,
|
||||||
"system_prompt": system_prompt,
|
"system_prompt": system_prompt,
|
||||||
"user_prompt": user_prompt,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return typing.cast(
|
return typing.cast(
|
||||||
types.DynamicOutputModel,
|
types.ResponseModel, result.cast_to(types, types, stream_types, False, __runtime__)
|
||||||
result.cast_to(types, types, stream_types, False, __runtime__),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def ExtractCategories(
|
def ExtractCategories(
|
||||||
|
|
@ -274,27 +268,24 @@ class BamlStreamClient:
|
||||||
|
|
||||||
def AcreateStructuredOutput(
|
def AcreateStructuredOutput(
|
||||||
self,
|
self,
|
||||||
content: str,
|
text_input: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
|
||||||
baml_options: BamlCallOptions = {},
|
baml_options: BamlCallOptions = {},
|
||||||
) -> baml_py.BamlSyncStream[stream_types.DynamicOutputModel, types.DynamicOutputModel]:
|
) -> baml_py.BamlSyncStream[stream_types.ResponseModel, types.ResponseModel]:
|
||||||
ctx, result = self.__options.merge_options(baml_options).create_sync_stream(
|
ctx, result = self.__options.merge_options(baml_options).create_sync_stream(
|
||||||
function_name="AcreateStructuredOutput",
|
function_name="AcreateStructuredOutput",
|
||||||
args={
|
args={
|
||||||
"content": content,
|
"text_input": text_input,
|
||||||
"system_prompt": system_prompt,
|
"system_prompt": system_prompt,
|
||||||
"user_prompt": user_prompt,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return baml_py.BamlSyncStream[stream_types.DynamicOutputModel, types.DynamicOutputModel](
|
return baml_py.BamlSyncStream[stream_types.ResponseModel, types.ResponseModel](
|
||||||
result,
|
result,
|
||||||
lambda x: typing.cast(
|
lambda x: typing.cast(
|
||||||
stream_types.DynamicOutputModel,
|
stream_types.ResponseModel, x.cast_to(types, types, stream_types, True, __runtime__)
|
||||||
x.cast_to(types, types, stream_types, True, __runtime__),
|
|
||||||
),
|
),
|
||||||
lambda x: typing.cast(
|
lambda x: typing.cast(
|
||||||
types.DynamicOutputModel, x.cast_to(types, types, stream_types, False, __runtime__)
|
types.ResponseModel, x.cast_to(types, types, stream_types, False, __runtime__)
|
||||||
),
|
),
|
||||||
ctx,
|
ctx,
|
||||||
)
|
)
|
||||||
|
|
@ -455,17 +446,15 @@ class BamlHttpRequestClient:
|
||||||
|
|
||||||
def AcreateStructuredOutput(
|
def AcreateStructuredOutput(
|
||||||
self,
|
self,
|
||||||
content: str,
|
text_input: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
|
||||||
baml_options: BamlCallOptions = {},
|
baml_options: BamlCallOptions = {},
|
||||||
) -> baml_py.baml_py.HTTPRequest:
|
) -> baml_py.baml_py.HTTPRequest:
|
||||||
result = self.__options.merge_options(baml_options).create_http_request_sync(
|
result = self.__options.merge_options(baml_options).create_http_request_sync(
|
||||||
function_name="AcreateStructuredOutput",
|
function_name="AcreateStructuredOutput",
|
||||||
args={
|
args={
|
||||||
"content": content,
|
"text_input": text_input,
|
||||||
"system_prompt": system_prompt,
|
"system_prompt": system_prompt,
|
||||||
"user_prompt": user_prompt,
|
|
||||||
},
|
},
|
||||||
mode="request",
|
mode="request",
|
||||||
)
|
)
|
||||||
|
|
@ -574,17 +563,15 @@ class BamlHttpStreamRequestClient:
|
||||||
|
|
||||||
def AcreateStructuredOutput(
|
def AcreateStructuredOutput(
|
||||||
self,
|
self,
|
||||||
content: str,
|
text_input: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
|
||||||
baml_options: BamlCallOptions = {},
|
baml_options: BamlCallOptions = {},
|
||||||
) -> baml_py.baml_py.HTTPRequest:
|
) -> baml_py.baml_py.HTTPRequest:
|
||||||
result = self.__options.merge_options(baml_options).create_http_request_sync(
|
result = self.__options.merge_options(baml_options).create_http_request_sync(
|
||||||
function_name="AcreateStructuredOutput",
|
function_name="AcreateStructuredOutput",
|
||||||
args={
|
args={
|
||||||
"content": content,
|
"text_input": text_input,
|
||||||
"system_prompt": system_prompt,
|
"system_prompt": system_prompt,
|
||||||
"user_prompt": user_prompt,
|
|
||||||
},
|
},
|
||||||
mode="stream",
|
mode="stream",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,7 @@ class TypeBuilder(type_builder.TypeBuilder):
|
||||||
"AudioContent",
|
"AudioContent",
|
||||||
"ContentLabel",
|
"ContentLabel",
|
||||||
"DefaultContentPrediction",
|
"DefaultContentPrediction",
|
||||||
"DynamicInputModel",
|
|
||||||
"DynamicKnowledgeGraph",
|
"DynamicKnowledgeGraph",
|
||||||
"DynamicOutputModel",
|
|
||||||
"Edge",
|
"Edge",
|
||||||
"ImageContent",
|
"ImageContent",
|
||||||
"KnowledgeGraph",
|
"KnowledgeGraph",
|
||||||
|
|
@ -37,6 +35,7 @@ class TypeBuilder(type_builder.TypeBuilder):
|
||||||
"MultimediaContent",
|
"MultimediaContent",
|
||||||
"Node",
|
"Node",
|
||||||
"ProceduralContent",
|
"ProceduralContent",
|
||||||
|
"ResponseModel",
|
||||||
"SummarizedClass",
|
"SummarizedClass",
|
||||||
"SummarizedCode",
|
"SummarizedCode",
|
||||||
"SummarizedContent",
|
"SummarizedContent",
|
||||||
|
|
@ -54,7 +53,7 @@ class TypeBuilder(type_builder.TypeBuilder):
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
# Generated classes 19
|
# Generated classes 18
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -69,18 +68,10 @@ class TypeBuilder(type_builder.TypeBuilder):
|
||||||
def DefaultContentPrediction(self) -> "DefaultContentPredictionViewer":
|
def DefaultContentPrediction(self) -> "DefaultContentPredictionViewer":
|
||||||
return DefaultContentPredictionViewer(self)
|
return DefaultContentPredictionViewer(self)
|
||||||
|
|
||||||
@property
|
|
||||||
def DynamicInputModel(self) -> "DynamicInputModelBuilder":
|
|
||||||
return DynamicInputModelBuilder(self)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def DynamicKnowledgeGraph(self) -> "DynamicKnowledgeGraphBuilder":
|
def DynamicKnowledgeGraph(self) -> "DynamicKnowledgeGraphBuilder":
|
||||||
return DynamicKnowledgeGraphBuilder(self)
|
return DynamicKnowledgeGraphBuilder(self)
|
||||||
|
|
||||||
@property
|
|
||||||
def DynamicOutputModel(self) -> "DynamicOutputModelBuilder":
|
|
||||||
return DynamicOutputModelBuilder(self)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def Edge(self) -> "EdgeViewer":
|
def Edge(self) -> "EdgeViewer":
|
||||||
return EdgeViewer(self)
|
return EdgeViewer(self)
|
||||||
|
|
@ -109,6 +100,10 @@ class TypeBuilder(type_builder.TypeBuilder):
|
||||||
def ProceduralContent(self) -> "ProceduralContentViewer":
|
def ProceduralContent(self) -> "ProceduralContentViewer":
|
||||||
return ProceduralContentViewer(self)
|
return ProceduralContentViewer(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ResponseModel(self) -> "ResponseModelBuilder":
|
||||||
|
return ResponseModelBuilder(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def SummarizedClass(self) -> "SummarizedClassViewer":
|
def SummarizedClass(self) -> "SummarizedClassViewer":
|
||||||
return SummarizedClassViewer(self)
|
return SummarizedClassViewer(self)
|
||||||
|
|
@ -140,7 +135,7 @@ class TypeBuilder(type_builder.TypeBuilder):
|
||||||
|
|
||||||
|
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
# Generated classes 19
|
# Generated classes 18
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -279,59 +274,6 @@ class DefaultContentPredictionProperties:
|
||||||
return type_builder.ClassPropertyViewer(self.__bldr.property("label"))
|
return type_builder.ClassPropertyViewer(self.__bldr.property("label"))
|
||||||
|
|
||||||
|
|
||||||
class DynamicInputModelAst:
|
|
||||||
def __init__(self, tb: type_builder.TypeBuilder):
|
|
||||||
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
|
||||||
self._bldr = _tb.class_("DynamicInputModel")
|
|
||||||
self._properties: typing.Set[str] = set(
|
|
||||||
[
|
|
||||||
"test",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self._props = DynamicInputModelProperties(self._bldr, self._properties)
|
|
||||||
|
|
||||||
def type(self) -> baml_py.FieldType:
|
|
||||||
return self._bldr.field()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def props(self) -> "DynamicInputModelProperties":
|
|
||||||
return self._props
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicInputModelBuilder(DynamicInputModelAst):
|
|
||||||
def __init__(self, tb: type_builder.TypeBuilder):
|
|
||||||
super().__init__(tb)
|
|
||||||
|
|
||||||
def add_property(self, name: str, type: baml_py.FieldType) -> baml_py.ClassPropertyBuilder:
|
|
||||||
if name in self._properties:
|
|
||||||
raise ValueError(f"Property {name} already exists.")
|
|
||||||
return self._bldr.property(name).type(type)
|
|
||||||
|
|
||||||
def list_properties(self) -> typing.List[typing.Tuple[str, baml_py.ClassPropertyBuilder]]:
|
|
||||||
return self._bldr.list_properties()
|
|
||||||
|
|
||||||
def remove_property(self, name: str) -> None:
|
|
||||||
self._bldr.remove_property(name)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
self._bldr.reset()
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicInputModelProperties:
|
|
||||||
def __init__(self, bldr: baml_py.ClassBuilder, properties: typing.Set[str]):
|
|
||||||
self.__bldr = bldr
|
|
||||||
self.__properties = properties # type: ignore (we know how to use this private attribute) # noqa: F821
|
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> baml_py.ClassPropertyBuilder:
|
|
||||||
if name not in self.__properties:
|
|
||||||
raise AttributeError(f"Property {name} not found.")
|
|
||||||
return self.__bldr.property(name)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def test(self) -> baml_py.ClassPropertyBuilder:
|
|
||||||
return self.__bldr.property("test")
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicKnowledgeGraphAst:
|
class DynamicKnowledgeGraphAst:
|
||||||
def __init__(self, tb: type_builder.TypeBuilder):
|
def __init__(self, tb: type_builder.TypeBuilder):
|
||||||
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
||||||
|
|
@ -377,59 +319,6 @@ class DynamicKnowledgeGraphProperties:
|
||||||
return self.__bldr.property(name)
|
return self.__bldr.property(name)
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutputModelAst:
|
|
||||||
def __init__(self, tb: type_builder.TypeBuilder):
|
|
||||||
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
|
||||||
self._bldr = _tb.class_("DynamicOutputModel")
|
|
||||||
self._properties: typing.Set[str] = set(
|
|
||||||
[
|
|
||||||
"test",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self._props = DynamicOutputModelProperties(self._bldr, self._properties)
|
|
||||||
|
|
||||||
def type(self) -> baml_py.FieldType:
|
|
||||||
return self._bldr.field()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def props(self) -> "DynamicOutputModelProperties":
|
|
||||||
return self._props
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutputModelBuilder(DynamicOutputModelAst):
|
|
||||||
def __init__(self, tb: type_builder.TypeBuilder):
|
|
||||||
super().__init__(tb)
|
|
||||||
|
|
||||||
def add_property(self, name: str, type: baml_py.FieldType) -> baml_py.ClassPropertyBuilder:
|
|
||||||
if name in self._properties:
|
|
||||||
raise ValueError(f"Property {name} already exists.")
|
|
||||||
return self._bldr.property(name).type(type)
|
|
||||||
|
|
||||||
def list_properties(self) -> typing.List[typing.Tuple[str, baml_py.ClassPropertyBuilder]]:
|
|
||||||
return self._bldr.list_properties()
|
|
||||||
|
|
||||||
def remove_property(self, name: str) -> None:
|
|
||||||
self._bldr.remove_property(name)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
self._bldr.reset()
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutputModelProperties:
|
|
||||||
def __init__(self, bldr: baml_py.ClassBuilder, properties: typing.Set[str]):
|
|
||||||
self.__bldr = bldr
|
|
||||||
self.__properties = properties # type: ignore (we know how to use this private attribute) # noqa: F821
|
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> baml_py.ClassPropertyBuilder:
|
|
||||||
if name not in self.__properties:
|
|
||||||
raise AttributeError(f"Property {name} not found.")
|
|
||||||
return self.__bldr.property(name)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def test(self) -> baml_py.ClassPropertyBuilder:
|
|
||||||
return self.__bldr.property("test")
|
|
||||||
|
|
||||||
|
|
||||||
class EdgeAst:
|
class EdgeAst:
|
||||||
def __init__(self, tb: type_builder.TypeBuilder):
|
def __init__(self, tb: type_builder.TypeBuilder):
|
||||||
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
||||||
|
|
@ -773,6 +662,51 @@ class ProceduralContentProperties:
|
||||||
return type_builder.ClassPropertyViewer(self.__bldr.property("subclass"))
|
return type_builder.ClassPropertyViewer(self.__bldr.property("subclass"))
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseModelAst:
|
||||||
|
def __init__(self, tb: type_builder.TypeBuilder):
|
||||||
|
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
||||||
|
self._bldr = _tb.class_("ResponseModel")
|
||||||
|
self._properties: typing.Set[str] = set([])
|
||||||
|
self._props = ResponseModelProperties(self._bldr, self._properties)
|
||||||
|
|
||||||
|
def type(self) -> baml_py.FieldType:
|
||||||
|
return self._bldr.field()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def props(self) -> "ResponseModelProperties":
|
||||||
|
return self._props
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseModelBuilder(ResponseModelAst):
|
||||||
|
def __init__(self, tb: type_builder.TypeBuilder):
|
||||||
|
super().__init__(tb)
|
||||||
|
|
||||||
|
def add_property(self, name: str, type: baml_py.FieldType) -> baml_py.ClassPropertyBuilder:
|
||||||
|
if name in self._properties:
|
||||||
|
raise ValueError(f"Property {name} already exists.")
|
||||||
|
return self._bldr.property(name).type(type)
|
||||||
|
|
||||||
|
def list_properties(self) -> typing.List[typing.Tuple[str, baml_py.ClassPropertyBuilder]]:
|
||||||
|
return self._bldr.list_properties()
|
||||||
|
|
||||||
|
def remove_property(self, name: str) -> None:
|
||||||
|
self._bldr.remove_property(name)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._bldr.reset()
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseModelProperties:
|
||||||
|
def __init__(self, bldr: baml_py.ClassBuilder, properties: typing.Set[str]):
|
||||||
|
self.__bldr = bldr
|
||||||
|
self.__properties = properties # type: ignore (we know how to use this private attribute) # noqa: F821
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> baml_py.ClassPropertyBuilder:
|
||||||
|
if name not in self.__properties:
|
||||||
|
raise AttributeError(f"Property {name} not found.")
|
||||||
|
return self.__bldr.property(name)
|
||||||
|
|
||||||
|
|
||||||
class SummarizedClassAst:
|
class SummarizedClassAst:
|
||||||
def __init__(self, tb: type_builder.TypeBuilder):
|
def __init__(self, tb: type_builder.TypeBuilder):
|
||||||
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
_tb = tb._tb # type: ignore (we know how to use this private attribute)
|
||||||
|
|
|
||||||
|
|
@ -21,12 +21,8 @@ type_map = {
|
||||||
"stream_types.ContentLabel": stream_types.ContentLabel,
|
"stream_types.ContentLabel": stream_types.ContentLabel,
|
||||||
"types.DefaultContentPrediction": types.DefaultContentPrediction,
|
"types.DefaultContentPrediction": types.DefaultContentPrediction,
|
||||||
"stream_types.DefaultContentPrediction": stream_types.DefaultContentPrediction,
|
"stream_types.DefaultContentPrediction": stream_types.DefaultContentPrediction,
|
||||||
"types.DynamicInputModel": types.DynamicInputModel,
|
|
||||||
"stream_types.DynamicInputModel": stream_types.DynamicInputModel,
|
|
||||||
"types.DynamicKnowledgeGraph": types.DynamicKnowledgeGraph,
|
"types.DynamicKnowledgeGraph": types.DynamicKnowledgeGraph,
|
||||||
"stream_types.DynamicKnowledgeGraph": stream_types.DynamicKnowledgeGraph,
|
"stream_types.DynamicKnowledgeGraph": stream_types.DynamicKnowledgeGraph,
|
||||||
"types.DynamicOutputModel": types.DynamicOutputModel,
|
|
||||||
"stream_types.DynamicOutputModel": stream_types.DynamicOutputModel,
|
|
||||||
"types.Edge": types.Edge,
|
"types.Edge": types.Edge,
|
||||||
"stream_types.Edge": stream_types.Edge,
|
"stream_types.Edge": stream_types.Edge,
|
||||||
"types.ImageContent": types.ImageContent,
|
"types.ImageContent": types.ImageContent,
|
||||||
|
|
@ -41,6 +37,8 @@ type_map = {
|
||||||
"stream_types.Node": stream_types.Node,
|
"stream_types.Node": stream_types.Node,
|
||||||
"types.ProceduralContent": types.ProceduralContent,
|
"types.ProceduralContent": types.ProceduralContent,
|
||||||
"stream_types.ProceduralContent": stream_types.ProceduralContent,
|
"stream_types.ProceduralContent": stream_types.ProceduralContent,
|
||||||
|
"types.ResponseModel": types.ResponseModel,
|
||||||
|
"stream_types.ResponseModel": stream_types.ResponseModel,
|
||||||
"types.SummarizedClass": types.SummarizedClass,
|
"types.SummarizedClass": types.SummarizedClass,
|
||||||
"stream_types.SummarizedClass": stream_types.SummarizedClass,
|
"stream_types.SummarizedClass": stream_types.SummarizedClass,
|
||||||
"types.SummarizedCode": types.SummarizedCode,
|
"types.SummarizedCode": types.SummarizedCode,
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ def all_succeeded(checks: typing.Dict[CheckName, Check]) -> bool:
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
# Generated classes (19)
|
# Generated classes (18)
|
||||||
# #########################################################################
|
# #########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -75,20 +75,10 @@ class DefaultContentPrediction(BaseModel):
|
||||||
label: "ContentLabel"
|
label: "ContentLabel"
|
||||||
|
|
||||||
|
|
||||||
class DynamicInputModel(BaseModel):
|
|
||||||
model_config = ConfigDict(extra="allow")
|
|
||||||
test: str
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicKnowledgeGraph(BaseModel):
|
class DynamicKnowledgeGraph(BaseModel):
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutputModel(BaseModel):
|
|
||||||
model_config = ConfigDict(extra="allow")
|
|
||||||
test: str
|
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel):
|
class Edge(BaseModel):
|
||||||
# doc string for edge
|
# doc string for edge
|
||||||
# doc string for source_node_id
|
# doc string for source_node_id
|
||||||
|
|
@ -131,6 +121,10 @@ class ProceduralContent(BaseModel):
|
||||||
subclass: typing.List[str]
|
subclass: typing.List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseModel(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class SummarizedClass(BaseModel):
|
class SummarizedClass(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,17 @@
|
||||||
class DynamicInputModel {
|
class ResponseModel {
|
||||||
test string
|
|
||||||
@@dynamic
|
|
||||||
}
|
|
||||||
|
|
||||||
class DynamicOutputModel {
|
|
||||||
test string
|
|
||||||
@@dynamic
|
@@dynamic
|
||||||
}
|
}
|
||||||
|
|
||||||
function AcreateStructuredOutput(
|
function AcreateStructuredOutput(
|
||||||
content: string,
|
text_input: string,
|
||||||
system_prompt: string,
|
system_prompt: string,
|
||||||
user_prompt: string,
|
) -> ResponseModel {
|
||||||
) -> DynamicOutputModel {
|
|
||||||
client OpenAI
|
client OpenAI
|
||||||
|
|
||||||
prompt #"
|
prompt #"
|
||||||
{{ system_prompt }}
|
{{ system_prompt }}
|
||||||
{{ ctx.output_format }}
|
{{ ctx.output_format }}
|
||||||
{{ _.role('user') }}
|
{{ _.role('user') }}
|
||||||
{{ user_prompt }}
|
{{ text_input }}
|
||||||
{{ content }}
|
|
||||||
"#
|
"#
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,113 @@
|
||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from pydantic import BaseModel
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.shared.data_models import SummarizedCode
|
from cognee.shared.data_models import SummarizedCode
|
||||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
|
from typing import List, Dict, Union, Optional, Literal
|
||||||
|
from enum import Enum
|
||||||
|
from baml_py import Image, Audio, Video, Pdf
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type_builder import (
|
||||||
|
TypeBuilder,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client import b
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("extract_summary_baml")
|
logger = get_logger("extract_summary_baml")
|
||||||
|
|
||||||
|
|
||||||
|
def create_dynamic_baml_type(pydantic_model):
|
||||||
|
tb = TypeBuilder()
|
||||||
|
|
||||||
|
# if pydantic_model == str:
|
||||||
|
# b.ResponseModel.add_property("text", tb.string())
|
||||||
|
# return tb
|
||||||
|
#
|
||||||
|
# def map_type(field_type, field_info):
|
||||||
|
# # Handle Optional/Union types
|
||||||
|
# if getattr(field_type, "__origin__", None) == Union:
|
||||||
|
# # Extract types from Union
|
||||||
|
# types = field_type.__args__
|
||||||
|
# # Handle Optional (Union with NoneType)
|
||||||
|
# if type(None) in types:
|
||||||
|
# inner_type = next(t for t in types if t != type(None))
|
||||||
|
# return map_type(inner_type, field_info).optional()
|
||||||
|
# # Handle regular Union
|
||||||
|
# mapped_types = [map_type(t, field_info) for t in types]
|
||||||
|
# return tb.union(*mapped_types)
|
||||||
|
#
|
||||||
|
# # Handle Lists
|
||||||
|
# if getattr(field_type, "__origin__", None) == list:
|
||||||
|
# inner_type = field_type.__args__[0]
|
||||||
|
# return map_type(inner_type, field_info).list()
|
||||||
|
#
|
||||||
|
# # Handle Maps/Dictionaries
|
||||||
|
# if getattr(field_type, "__origin__", None) == dict:
|
||||||
|
# key_type, value_type = field_type.__args__
|
||||||
|
# # BAML only supports string or enum keys in maps
|
||||||
|
# if key_type not in [str, Enum]:
|
||||||
|
# raise ValueError("Map keys must be strings or enums in BAML")
|
||||||
|
# return tb.map(map_type(key_type, field_info), map_type(value_type, field_info))
|
||||||
|
#
|
||||||
|
# # Handle Literal types
|
||||||
|
# if getattr(field_type, "__origin__", None) == Literal:
|
||||||
|
# literal_values = field_type.__args__
|
||||||
|
# return tb.union(*[tb.literal(val) for val in literal_values])
|
||||||
|
#
|
||||||
|
# # Handle Enums
|
||||||
|
# if isinstance(field_type, type) and issubclass(field_type, Enum):
|
||||||
|
# enum_type = tb.add_enum(field_type.__name__)
|
||||||
|
# for member in field_type:
|
||||||
|
# enum_type.add_value(member.name)
|
||||||
|
# return enum_type.type()
|
||||||
|
#
|
||||||
|
# # Handle primitive and special types
|
||||||
|
# type_mapping = {
|
||||||
|
# str: tb.string(),
|
||||||
|
# int: tb.int(),
|
||||||
|
# float: tb.float(),
|
||||||
|
# bool: tb.bool(),
|
||||||
|
# Image: tb.image(),
|
||||||
|
# Audio: tb.audio(),
|
||||||
|
# Video: tb.video(),
|
||||||
|
# Pdf: tb.pdf(),
|
||||||
|
# # datetime is not natively supported in BAML, map to string
|
||||||
|
# datetime: tb.string(),
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# # Handle nested BaseModel classes
|
||||||
|
# if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||||
|
# nested_tb = create_dynamic_baml_type(field_type)
|
||||||
|
# # Get the last created class from the nested TypeBuilder
|
||||||
|
# return nested_tb.get_last_class().type()
|
||||||
|
#
|
||||||
|
# if field_type in type_mapping:
|
||||||
|
# return type_mapping[field_type]
|
||||||
|
#
|
||||||
|
# raise ValueError(f"Unsupported type: {field_type}")
|
||||||
|
#
|
||||||
|
# fields = pydantic_model.model_fields
|
||||||
|
#
|
||||||
|
# # Add fields
|
||||||
|
# for field_name, field_info in fields.items():
|
||||||
|
# field_type = field_info.annotation
|
||||||
|
# baml_type = map_type(field_type, field_info)
|
||||||
|
#
|
||||||
|
# # Add property with type
|
||||||
|
# prop = b.ResponseModel.add_property(field_name, baml_type)
|
||||||
|
#
|
||||||
|
# # Add description if available
|
||||||
|
# if field_info.description:
|
||||||
|
# prop.description(field_info.description)
|
||||||
|
|
||||||
|
return tb
|
||||||
|
|
||||||
|
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
content: str, system_prompt: str, user_prompt: str, response_model: Type[BaseModel]
|
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Extract summary using BAML framework.
|
Extract summary using BAML framework.
|
||||||
|
|
@ -26,12 +121,12 @@ async def acreate_structured_output(
|
||||||
"""
|
"""
|
||||||
config = get_llm_config()
|
config = get_llm_config()
|
||||||
|
|
||||||
# Use BAML's SummarizeContent function
|
type_builder = create_dynamic_baml_type(response_model)
|
||||||
|
|
||||||
result = await b.AcreateStructuredOutput(
|
result = await b.AcreateStructuredOutput(
|
||||||
content=content,
|
text_input=text_input,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
user_prompt=user_prompt,
|
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
||||||
baml_options={"client_registry": config.baml_registry},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import litellm
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
||||||
get_llm_client,
|
get_llm_client,
|
||||||
)
|
)
|
||||||
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
@ -76,8 +77,7 @@ async def test_llm_connection():
|
||||||
the connection attempt and re-raise the exception for further handling.
|
the connection attempt and re-raise the exception for further handling.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
llm_adapter = get_llm_client()
|
await LLMGateway.acreate_structured_output(
|
||||||
await llm_adapter.acreate_structured_output(
|
|
||||||
text_input="test",
|
text_input="test",
|
||||||
system_prompt='Respond to me with the following string: "test"',
|
system_prompt='Respond to me with the following string: "test"',
|
||||||
response_model=str,
|
response_model=str,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue