Fixed context, added episodic event analysis
This commit is contained in:
parent
6e01e9af79
commit
f89b82af04
2 changed files with 68 additions and 13 deletions
|
|
@ -531,14 +531,69 @@ class EpisodicBuffer(BaseMemory):
|
||||||
for memory in lookup_value_semantic["data"]["Get"]["SEMANTICMEMORY"]
|
for memory in lookup_value_semantic["data"]["Get"]["SEMANTICMEMORY"]
|
||||||
]
|
]
|
||||||
|
|
||||||
print("HERE IS THE LENGTH OF THE TASKS", str(tasks))
|
|
||||||
memory_scores = await asyncio.gather(*tasks)
|
memory_scores = await asyncio.gather(*tasks)
|
||||||
# Sort the memories based on their average scores
|
# Sort the memories based on their average scores
|
||||||
sorted_memories = sorted(memory_scores, key=lambda x: x["average_score"], reverse=True)[:5]
|
sorted_memories = sorted(memory_scores, key=lambda x: x["average_score"], reverse=True)[:5]
|
||||||
# Store the sorted memories in the context
|
# Store the sorted memories in the context
|
||||||
context.extend([item for item in sorted_memories])
|
context.extend([item for item in sorted_memories])
|
||||||
|
|
||||||
|
for item in context:
|
||||||
|
memory = item.get('memory', {})
|
||||||
|
text = memory.get('text', '')
|
||||||
|
|
||||||
|
prompt_sum= ChatPromptTemplate.from_template("""Based on this query: {query} Summarize the following text so it can be best used as a context summary for the user when running query: {text}"""
|
||||||
|
)
|
||||||
|
chain_sum = prompt_sum | self.llm
|
||||||
|
summary_context = await chain_sum.ainvoke({"query": output, "text": text})
|
||||||
|
item['memory']['text'] = summary_context
|
||||||
|
|
||||||
|
|
||||||
print("HERE IS THE CONTEXT", context)
|
print("HERE IS THE CONTEXT", context)
|
||||||
|
|
||||||
|
lookup_value_episodic = await self.fetch_memories(
|
||||||
|
observation=str(output), namespace="EPISODICMEMORY"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Event(BaseModel):
|
||||||
|
"""Schema for an individual event."""
|
||||||
|
|
||||||
|
event_order: str = Field(
|
||||||
|
..., description="The order at which the task needs to be performed"
|
||||||
|
)
|
||||||
|
event_name: str = Field(
|
||||||
|
None, description="The task that needs to be performed"
|
||||||
|
)
|
||||||
|
operation: str = Field(None, description="The operation that was performed")
|
||||||
|
original_query: str = Field(
|
||||||
|
None, description="Original user query provided"
|
||||||
|
)
|
||||||
|
class EventList(BaseModel):
|
||||||
|
"""Schema for the record containing a list of events of the user chronologically."""
|
||||||
|
|
||||||
|
tasks: List[Event] = Field(..., description="List of tasks")
|
||||||
|
|
||||||
|
prompt_filter_chunk = f" Based on available memories {lookup_value_episodic} determine only the relevant list of steps and operations sequentially "
|
||||||
|
prompt_msgs = [
|
||||||
|
SystemMessage(
|
||||||
|
content="You are a world class algorithm for determining what happened in the past and ordering events chronologically."
|
||||||
|
),
|
||||||
|
HumanMessage(content="Analyze the following memories and provide the relevant response:"),
|
||||||
|
HumanMessagePromptTemplate.from_template("{input}"),
|
||||||
|
HumanMessage(content="Tips: Make sure to answer in the correct format"),
|
||||||
|
HumanMessage(
|
||||||
|
content="Tips: Only choose actions that are relevant to the user query and ignore others"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
prompt_ = ChatPromptTemplate(messages=prompt_msgs)
|
||||||
|
chain = create_structured_output_chain(
|
||||||
|
EventList, self.llm, prompt_, verbose=True
|
||||||
|
)
|
||||||
|
from langchain.callbacks import get_openai_callback
|
||||||
|
|
||||||
|
with get_openai_callback() as cb:
|
||||||
|
episodic_context = await chain.arun(input=prompt_filter_chunk, verbose=True)
|
||||||
|
print(cb)
|
||||||
|
|
||||||
class BufferModulators(BaseModel):
|
class BufferModulators(BaseModel):
|
||||||
attention_modulators: Dict[str, float] = Field(... , description="Attention modulators")
|
attention_modulators: Dict[str, float] = Field(... , description="Attention modulators")
|
||||||
|
|
||||||
|
|
@ -579,17 +634,16 @@ class EpisodicBuffer(BaseMemory):
|
||||||
|
|
||||||
# we structure the data here to make it easier to work with
|
# we structure the data here to make it easier to work with
|
||||||
parser = PydanticOutputParser(pydantic_object=BufferRawContextList)
|
parser = PydanticOutputParser(pydantic_object=BufferRawContextList)
|
||||||
|
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template="""Summarize and create semantic search queries and relevant
|
template="""Summarize and create semantic search queries and relevant
|
||||||
document summaries for the user query.\n
|
document summaries for the user query.\n
|
||||||
{format_instructions}\nOriginal query is:
|
{format_instructions}\nOriginal query is:
|
||||||
{query}\n Retrieved context is: {context}""",
|
{query}\n Retrieved document context is: {context}. Retrieved memory context is {memory_context}""",
|
||||||
input_variables=["query", "context"],
|
input_variables=["query", "context", "memory_context"],
|
||||||
partial_variables={"format_instructions": parser.get_format_instructions()},
|
partial_variables={"format_instructions": parser.get_format_instructions()},
|
||||||
)
|
)
|
||||||
|
|
||||||
_input = prompt.format_prompt(query=user_input, context=context)
|
_input = prompt.format_prompt(query=user_input, context=str(context), memory_context=str(episodic_context))
|
||||||
document_context_result = self.llm_base(_input.to_string())
|
document_context_result = self.llm_base(_input.to_string())
|
||||||
document_context_result_parsed = parser.parse(document_context_result)
|
document_context_result_parsed = parser.parse(document_context_result)
|
||||||
# print(document_context_result_parsed)
|
# print(document_context_result_parsed)
|
||||||
|
|
|
||||||
|
|
@ -104,10 +104,10 @@ class WeaviateVectorDB(VectorDB):
|
||||||
return client
|
return client
|
||||||
|
|
||||||
def _document_loader(self, observation: str, loader_settings: dict):
|
def _document_loader(self, observation: str, loader_settings: dict):
|
||||||
# Create an in-memory file-like object for the PDF content
|
# Check the format of the document
|
||||||
|
document_format = loader_settings.get("format", "text")
|
||||||
if loader_settings.get("format") == "PDF":
|
|
||||||
|
|
||||||
|
if document_format == "PDF":
|
||||||
if loader_settings.get("source") == "url":
|
if loader_settings.get("source") == "url":
|
||||||
pdf_response = requests.get(loader_settings["path"])
|
pdf_response = requests.get(loader_settings["path"])
|
||||||
pdf_stream = BytesIO(pdf_response.content)
|
pdf_stream = BytesIO(pdf_response.content)
|
||||||
|
|
@ -121,19 +121,20 @@ class WeaviateVectorDB(VectorDB):
|
||||||
# adapt this for different chunking strategies
|
# adapt this for different chunking strategies
|
||||||
pages = loader.load_and_split()
|
pages = loader.load_and_split()
|
||||||
return pages
|
return pages
|
||||||
|
elif loader_settings.get("source") == "file":
|
||||||
if loader_settings.get("source") == "file":
|
|
||||||
# Process the PDF using PyPDFLoader
|
# Process the PDF using PyPDFLoader
|
||||||
# might need adapting for different loaders + OCR
|
# might need adapting for different loaders + OCR
|
||||||
# need to test the path
|
# need to test the path
|
||||||
loader = PyPDFLoader(loader_settings["path"])
|
loader = PyPDFLoader(loader_settings["path"])
|
||||||
pages = loader.load_and_split()
|
pages = loader.load_and_split()
|
||||||
|
|
||||||
return pages
|
return pages
|
||||||
else:
|
|
||||||
# Process the text by just loading the base text
|
elif document_format == "text":
|
||||||
|
# Process the text directly
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported document format: {document_format}")
|
||||||
|
|
||||||
async def add_memories(
|
async def add_memories(
|
||||||
self, observation: str, loader_settings: dict = None, params: dict = None ,namespace:str=None
|
self, observation: str, loader_settings: dict = None, params: dict = None ,namespace:str=None
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue