Fixed translate script part
This commit is contained in:
parent
697c49a670
commit
2f85f6baff
1 changed files with 12 additions and 5 deletions
|
|
@ -655,7 +655,7 @@ class EpisodicBuffer:
|
|||
|
||||
#we just filter the data here
|
||||
prompt_filter = ChatPromptTemplate.from_template(
|
||||
"Filter and remove uneccessary information that is not relevant in the user query {query}")
|
||||
"Filter and remove uneccessary information that is not relevant in the user query, keep it as original as possbile: {query}")
|
||||
chain_filter = prompt_filter | self.llm
|
||||
output = await chain_filter.ainvoke({"query": user_input})
|
||||
|
||||
|
|
@ -699,6 +699,7 @@ class EpisodicBuffer:
|
|||
task_order: str = Field(..., description="The order at which the task needs to be performed")
|
||||
task_name: str = Field(None, description="The task that needs to be performed")
|
||||
operation: str = Field(None, description="The operation to be performed")
|
||||
original_query: str = Field(None, description="Original user query provided")
|
||||
|
||||
class TaskList(BaseModel):
|
||||
"""Schema for the record containing a list of tasks."""
|
||||
|
|
@ -777,15 +778,15 @@ class EpisodicBuffer:
|
|||
description="observation we want to translate"
|
||||
)
|
||||
|
||||
@tool("translate_to_en", args_schema=TranslateText, return_direct=True)
|
||||
def translate_to_en(observation, args_schema=TranslateText):
|
||||
@tool("translate_to_de", args_schema=TranslateText, return_direct=True)
|
||||
def translate_to_de(observation, args_schema=TranslateText):
|
||||
"""Translate to English"""
|
||||
out = GoogleTranslator(source='auto', target='en').translate(text=observation)
|
||||
out = GoogleTranslator(source='auto', target='de').translate(text=observation)
|
||||
return out
|
||||
|
||||
agent = initialize_agent(
|
||||
llm=self.llm,
|
||||
tools=[translate_to_en, convert_to_structured],
|
||||
tools=[translate_to_de, convert_to_structured],
|
||||
agent=AgentType.OPENAI_FUNCTIONS,
|
||||
|
||||
verbose=True,
|
||||
|
|
@ -797,6 +798,10 @@ class EpisodicBuffer:
|
|||
result_tasks.append(output)
|
||||
|
||||
|
||||
|
||||
print("HERE IS THE RESULT TASKS", str(result_tasks))
|
||||
|
||||
|
||||
await self.encoding(str(result_tasks), self.namespace, params=params)
|
||||
|
||||
|
||||
|
|
@ -827,6 +832,8 @@ class EpisodicBuffer:
|
|||
|
||||
_input = prompt.format_prompt(query=user_input, steps=str(tasks_list), buffer=buffer_result)
|
||||
|
||||
print("a few things to do like load episodic memory in a structured format")
|
||||
|
||||
return "a few things to do like load episodic memory in a structured format"
|
||||
|
||||
# output = self.llm(_input.to_string())
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue