Updated the code considerably to fix issues with context overloads
This commit is contained in:
parent
d9a2ee6646
commit
ca5e090526
3 changed files with 222 additions and 87 deletions
|
|
@ -2,8 +2,10 @@
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
import marvin
|
import marvin
|
||||||
import requests
|
import requests
|
||||||
from deep_translator import GoogleTranslator
|
from deep_translator import GoogleTranslator
|
||||||
|
|
@ -290,12 +292,20 @@ class WeaviateVectorDB(VectorDB):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def fetch_memories(
|
async def fetch_memories(
|
||||||
self, observation: str, namespace: str, params: dict = None
|
self, observation: str, namespace: str, params: dict = None, n_of_observations =int(2)
|
||||||
):
|
):
|
||||||
# Fetch Weaviate memories here
|
|
||||||
"""
|
"""
|
||||||
Get documents from weaviate.
|
Get documents from weaviate.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- observation (str): User query.
|
||||||
|
- namespace (str): Type of memory we access.
|
||||||
|
- params (dict, optional):
|
||||||
|
- n_of_observations (int, optional): For weaviate, equals to autocut, defaults to 1. Ranges from 1 to 3. Check weaviate docs for more info.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Describe the return type and what the function returns.
|
||||||
|
|
||||||
Args a json containing:
|
Args a json containing:
|
||||||
query (str): The query string.
|
query (str): The query string.
|
||||||
path (list): The path for filtering, e.g., ['year'].
|
path (list): The path for filtering, e.g., ['year'].
|
||||||
|
|
@ -304,6 +314,7 @@ class WeaviateVectorDB(VectorDB):
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
get_from_weaviate(query="some query", path=['year'], operator='Equal', valueText='2017*')
|
get_from_weaviate(query="some query", path=['year'], operator='Equal', valueText='2017*')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
client = self.init_weaviate_client(self.namespace)
|
client = self.init_weaviate_client(self.namespace)
|
||||||
|
|
||||||
|
|
@ -349,6 +360,7 @@ class WeaviateVectorDB(VectorDB):
|
||||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score",'distance']
|
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score",'distance']
|
||||||
)
|
)
|
||||||
.with_where(params_user_id)
|
.with_where(params_user_id)
|
||||||
|
.with_limit(10)
|
||||||
.do()
|
.do()
|
||||||
)
|
)
|
||||||
return query_output
|
return query_output
|
||||||
|
|
@ -384,8 +396,9 @@ class WeaviateVectorDB(VectorDB):
|
||||||
query=observation,
|
query=observation,
|
||||||
fusion_type=HybridFusion.RELATIVE_SCORE
|
fusion_type=HybridFusion.RELATIVE_SCORE
|
||||||
)
|
)
|
||||||
.with_autocut(1)
|
.with_autocut(n_of_observations)
|
||||||
.with_where(params_user_id)
|
.with_where(params_user_id)
|
||||||
|
.with_limit(10)
|
||||||
.do()
|
.do()
|
||||||
)
|
)
|
||||||
return query_output
|
return query_output
|
||||||
|
|
@ -493,11 +506,13 @@ class BaseMemory:
|
||||||
observation: str,
|
observation: str,
|
||||||
params: Optional[str] = None,
|
params: Optional[str] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
|
n_of_observations: Optional[int] = 2,
|
||||||
):
|
):
|
||||||
if self.db_type == "weaviate":
|
if self.db_type == "weaviate":
|
||||||
return await self.vector_db.fetch_memories(
|
return await self.vector_db.fetch_memories(
|
||||||
observation=observation, params=params,
|
observation=observation, params=params,
|
||||||
namespace=namespace
|
namespace=namespace,
|
||||||
|
n_of_observations=n_of_observations
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_memories(self, params: Optional[str] = None):
|
async def delete_memories(self, params: Optional[str] = None):
|
||||||
|
|
@ -559,6 +574,34 @@ class EpisodicBuffer(BaseMemory):
|
||||||
model_name="gpt-4-0613",
|
model_name="gpt-4-0613",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _summarizer(self, text: str, document:str, max_tokens: int = 1200):
|
||||||
|
"""Summarize text using OpenAI API, to reduce amount of code for modulators contributing to context"""
|
||||||
|
class Summaries(BaseModel):
|
||||||
|
"""Schema for documentGroups"""
|
||||||
|
summary: str = Field(
|
||||||
|
...,
|
||||||
|
description="Summarized document")
|
||||||
|
class SummaryContextList(BaseModel):
|
||||||
|
"""Buffer raw context processed by the buffer"""
|
||||||
|
|
||||||
|
summaries: List[Summaries] = Field(..., description="List of summaries")
|
||||||
|
observation: str = Field(..., description="The original user query")
|
||||||
|
|
||||||
|
parser = PydanticOutputParser(pydantic_object=SummaryContextList)
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template=" \n{format_instructions}\nSummarize the observation briefly based on the user query, observation is: {query}\n. The document is: {document}",
|
||||||
|
input_variables=["query", "document"],
|
||||||
|
partial_variables={"format_instructions": parser.get_format_instructions()},
|
||||||
|
)
|
||||||
|
|
||||||
|
_input = prompt.format_prompt(query=text, document=document)
|
||||||
|
document_context_result = self.llm_base(_input.to_string())
|
||||||
|
document_context_result_parsed = parser.parse(document_context_result)
|
||||||
|
document_context_result_parsed = json.loads(document_context_result_parsed.json())
|
||||||
|
document_summary = document_context_result_parsed["summaries"][0]["summary"]
|
||||||
|
|
||||||
|
return document_summary
|
||||||
|
|
||||||
async def memory_route(self, text_time_diff: str):
|
async def memory_route(self, text_time_diff: str):
|
||||||
@ai_classifier
|
@ai_classifier
|
||||||
class MemoryRoute(Enum):
|
class MemoryRoute(Enum):
|
||||||
|
|
@ -575,8 +618,9 @@ class EpisodicBuffer(BaseMemory):
|
||||||
|
|
||||||
return namespace
|
return namespace
|
||||||
|
|
||||||
async def freshness(self, observation: str, namespace: str = None) -> list[str]:
|
async def freshness(self, observation: str, namespace: str = None, memory=None) -> list[str]:
|
||||||
"""Freshness - Score between 0 and 1 on how often was the information updated in episodic or semantic memory in the past"""
|
"""Freshness - Score between 0 and 1 on how often was the information updated in episodic or semantic memory in the past"""
|
||||||
|
logging.info("Starting with Freshness")
|
||||||
|
|
||||||
lookup_value = await self.fetch_memories(
|
lookup_value = await self.fetch_memories(
|
||||||
observation=observation, namespace=namespace
|
observation=observation, namespace=namespace
|
||||||
|
|
@ -589,13 +633,14 @@ class EpisodicBuffer(BaseMemory):
|
||||||
last_update_datetime = datetime.fromtimestamp(int(unix_t) / 1000)
|
last_update_datetime = datetime.fromtimestamp(int(unix_t) / 1000)
|
||||||
time_difference = datetime.now() - last_update_datetime
|
time_difference = datetime.now() - last_update_datetime
|
||||||
time_difference_text = humanize.naturaltime(time_difference)
|
time_difference_text = humanize.naturaltime(time_difference)
|
||||||
namespace = await self.memory_route(str(time_difference_text))
|
namespace_ = await self.memory_route(str(time_difference_text))
|
||||||
return [namespace.value, lookup_value]
|
return [namespace_.value, lookup_value]
|
||||||
|
|
||||||
async def frequency(self, observation: str, namespace: str) -> list[str]:
|
async def frequency(self, observation: str, namespace: str, memory) -> list[str]:
|
||||||
"""Frequency - Score between 0 and 1 on how often was the information processed in episodic memory in the past
|
"""Frequency - Score between 0 and 1 on how often was the information processed in episodic memory in the past
|
||||||
Counts the number of times a memory was accessed in the past and divides it by the total number of memories in the episodic memory
|
Counts the number of times a memory was accessed in the past and divides it by the total number of memories in the episodic memory
|
||||||
"""
|
"""
|
||||||
|
logging.info("Starting with Frequency")
|
||||||
weaviate_client = self.init_client(namespace=namespace)
|
weaviate_client = self.init_client(namespace=namespace)
|
||||||
|
|
||||||
result_output = await self.fetch_memories(
|
result_output = await self.fetch_memories(
|
||||||
|
|
@ -610,19 +655,22 @@ class EpisodicBuffer(BaseMemory):
|
||||||
"count"
|
"count"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
return [str(frequency), result_output["data"]["Get"]["EPISODICMEMORY"][0]]
|
summary = await self._summarizer(text=observation, document=result_output["data"]["Get"]["EPISODICMEMORY"][0])
|
||||||
|
logging.info("Frequency summary is %s", str(summary))
|
||||||
|
return [str(frequency), summary]
|
||||||
|
|
||||||
async def repetition(self, observation: str, namespace: str) -> list[str]:
|
async def repetition(self, observation: str, namespace: str, memory) -> list[str]:
|
||||||
"""Repetition - Score between 0 and 1 based on how often and at what intervals a memory has been revisited.
|
"""Repetition - Score between 0 and 1 based on how often and at what intervals a memory has been revisited.
|
||||||
Accounts for the spacing effect, where memories accessed at increasing intervals are given higher scores.
|
Accounts for the spacing effect, where memories accessed at increasing intervals are given higher scores.
|
||||||
|
# TO DO -> add metadata column to make sure that the access is not equal to update, and run update vector function each time a memory is accessed
|
||||||
"""
|
"""
|
||||||
weaviate_client = self.init_client(namespace=namespace)
|
logging.info("Starting with Repetition")
|
||||||
|
|
||||||
result_output = await self.fetch_memories(
|
result_output = await self.fetch_memories(
|
||||||
observation=observation, params=None, namespace=namespace
|
observation=observation, params=None, namespace=namespace
|
||||||
)
|
)
|
||||||
|
|
||||||
access_times = result_output["data"]["Get"]["EPISODICMEMORY"][0]["_additional"]["accessTimes"]
|
access_times = result_output["data"]["Get"]["EPISODICMEMORY"][0]["_additional"]["lastUpdateTimeUnix"]
|
||||||
# Calculate repetition score based on access times
|
# Calculate repetition score based on access times
|
||||||
if not access_times or len(access_times) == 1:
|
if not access_times or len(access_times) == 1:
|
||||||
return ["0", result_output["data"]["Get"]["EPISODICMEMORY"][0]]
|
return ["0", result_output["data"]["Get"]["EPISODICMEMORY"][0]]
|
||||||
|
|
@ -633,13 +681,15 @@ class EpisodicBuffer(BaseMemory):
|
||||||
intervals = [access_times[i + 1] - access_times[i] for i in range(len(access_times) - 1)]
|
intervals = [access_times[i + 1] - access_times[i] for i in range(len(access_times) - 1)]
|
||||||
# A simple scoring mechanism: Longer intervals get higher scores, as they indicate spaced repetition
|
# A simple scoring mechanism: Longer intervals get higher scores, as they indicate spaced repetition
|
||||||
repetition_score = sum([1.0 / (interval + 1) for interval in intervals]) / len(intervals)
|
repetition_score = sum([1.0 / (interval + 1) for interval in intervals]) / len(intervals)
|
||||||
|
summary = await self._summarizer(text = observation, document=result_output["data"]["Get"]["EPISODICMEMORY"][0])
|
||||||
|
logging.info("Repetition is %s", str(repetition_score))
|
||||||
|
logging.info("Repetition summary is %s", str(summary))
|
||||||
|
return [str(repetition_score), summary]
|
||||||
|
|
||||||
return [str(repetition_score), result_output["data"]["Get"]["EPISODICMEMORY"][0]]
|
async def relevance(self, observation: str, namespace: str, memory) -> list[str]:
|
||||||
|
|
||||||
async def relevance(self, observation: str, namespace: str) -> list[str]:
|
|
||||||
"""
|
"""
|
||||||
Fetches the relevance score for a given observation from the episodic memory.
|
Fetches the fusion relevance score for a given observation from the episodic memory.
|
||||||
|
Learn more about fusion scores here on Weaviate docs: https://weaviate.io/blog/hybrid-search-fusion-algorithms
|
||||||
Parameters:
|
Parameters:
|
||||||
- observation: The user's query or observation.
|
- observation: The user's query or observation.
|
||||||
- namespace: The namespace for the data.
|
- namespace: The namespace for the data.
|
||||||
|
|
@ -647,40 +697,20 @@ class EpisodicBuffer(BaseMemory):
|
||||||
Returns:
|
Returns:
|
||||||
- The relevance score between 0 and 1.
|
- The relevance score between 0 and 1.
|
||||||
"""
|
"""
|
||||||
|
logging.info("Starting with Relevance")
|
||||||
|
score = memory["_additional"]["score"]
|
||||||
|
logging.info("Relevance is %s", str(score))
|
||||||
|
return [score, "fusion score"]
|
||||||
|
|
||||||
# Fetch the memory content based on the observation
|
async def saliency(self, observation: str, namespace=None, memory=None) -> list[str]:
|
||||||
result_output = await self.fetch_memories(
|
|
||||||
observation=observation, params=None, namespace=namespace
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the relevance score from the memory content
|
|
||||||
score = result_output["data"]["Get"]["EPISODICMEMORY"][0]["_additional"]["score"]
|
|
||||||
|
|
||||||
return score
|
|
||||||
|
|
||||||
|
|
||||||
#each of the requests is numbered, and then the previous requests are retrieved . The request is classified based on past and current content as :
|
|
||||||
# 1. Very positive request
|
|
||||||
# 2. Positive request
|
|
||||||
# 3. Neutral request
|
|
||||||
# 4. Negative request
|
|
||||||
# 5. Very negative request
|
|
||||||
|
|
||||||
|
|
||||||
# After this, we update the weights of the request based on the classification of the request.
|
|
||||||
# After updating the weights, we update the buffer with the new weights. When new weights are calculated, we start from the updated values
|
|
||||||
# Which chunking strategy works best?
|
|
||||||
|
|
||||||
# Adding to the buffer - process the weights, and then use them as filters
|
|
||||||
|
|
||||||
async def saliency(self, observation: str, namespace=None) -> list[str]:
|
|
||||||
"""Determines saliency by scoring the set of retrieved documents against each other and trying to determine saliency
|
"""Determines saliency by scoring the set of retrieved documents against each other and trying to determine saliency
|
||||||
"""
|
"""
|
||||||
|
logging.info("Starting with Saliency")
|
||||||
class SaliencyRawList(BaseModel):
|
class SaliencyRawList(BaseModel):
|
||||||
"""Schema for documentGroups"""
|
"""Schema for documentGroups"""
|
||||||
original_document: str = Field(
|
summary: str = Field(
|
||||||
...,
|
...,
|
||||||
description="The original document retrieved from the database")
|
description="Summarized document")
|
||||||
saliency_score: str = Field(
|
saliency_score: str = Field(
|
||||||
None, description="The score between 0 and 1")
|
None, description="The score between 0 and 1")
|
||||||
class SailencyContextList(BaseModel):
|
class SailencyContextList(BaseModel):
|
||||||
|
|
@ -691,7 +721,7 @@ class EpisodicBuffer(BaseMemory):
|
||||||
|
|
||||||
parser = PydanticOutputParser(pydantic_object=SailencyContextList)
|
parser = PydanticOutputParser(pydantic_object=SailencyContextList)
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template="Determine saliency of documents compared to the other documents retrieved \n{format_instructions}\nOriginal observation is: {query}\n",
|
template="Determine saliency of documents compared to the other documents retrieved \n{format_instructions}\nSummarize the observation briefly based on the user query, observation is: {query}\n",
|
||||||
input_variables=["query"],
|
input_variables=["query"],
|
||||||
partial_variables={"format_instructions": parser.get_format_instructions()},
|
partial_variables={"format_instructions": parser.get_format_instructions()},
|
||||||
)
|
)
|
||||||
|
|
@ -699,7 +729,14 @@ class EpisodicBuffer(BaseMemory):
|
||||||
_input = prompt.format_prompt(query=observation)
|
_input = prompt.format_prompt(query=observation)
|
||||||
document_context_result = self.llm_base(_input.to_string())
|
document_context_result = self.llm_base(_input.to_string())
|
||||||
document_context_result_parsed = parser.parse(document_context_result)
|
document_context_result_parsed = parser.parse(document_context_result)
|
||||||
return document_context_result_parsed.json()
|
document_context_result_parsed = json.loads(document_context_result_parsed.json())
|
||||||
|
saliency_score = document_context_result_parsed["docs"][0]["saliency_score"]
|
||||||
|
saliency_values = document_context_result_parsed["docs"][0]["summary"]
|
||||||
|
|
||||||
|
logging.info("Saliency is %s", str(saliency_score))
|
||||||
|
logging.info("Saliency summary is %s", str(saliency_values))
|
||||||
|
|
||||||
|
return [saliency_score, saliency_values]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -722,6 +759,7 @@ class EpisodicBuffer(BaseMemory):
|
||||||
attention_modulators: Dict[str, float],
|
attention_modulators: Dict[str, float],
|
||||||
observation: str,
|
observation: str,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
|
memory: Optional[Dict[str, Any]] = None,
|
||||||
) -> Optional[List[Union[str, float]]]:
|
) -> Optional[List[Union[str, float]]]:
|
||||||
"""
|
"""
|
||||||
Handle the given modulator based on the observation and namespace.
|
Handle the given modulator based on the observation and namespace.
|
||||||
|
|
@ -737,22 +775,25 @@ class EpisodicBuffer(BaseMemory):
|
||||||
"""
|
"""
|
||||||
modulator_value = attention_modulators.get(modulator_name, 0.0)
|
modulator_value = attention_modulators.get(modulator_name, 0.0)
|
||||||
modulator_functions = {
|
modulator_functions = {
|
||||||
"freshness": lambda obs, ns: self.freshness(observation=obs, namespace=ns),
|
"freshness": lambda obs, ns, mem: self.freshness(observation=obs, namespace=ns, memory=mem),
|
||||||
"frequency": lambda obs, ns: self.frequency(observation=obs, namespace=ns),
|
"frequency": lambda obs, ns, mem: self.frequency(observation=obs, namespace=ns, memory=mem),
|
||||||
"relevance": lambda obs, ns: self.relevance(observation=obs, namespace=ns),
|
"relevance": lambda obs, ns, mem: self.relevance(observation=obs, namespace=ns, memory=mem),
|
||||||
"saliency": lambda obs, ns: self.saliency(observation=obs, namespace=ns),
|
"saliency": lambda obs, ns, mem: self.saliency(observation=obs, namespace=ns, memory=mem),
|
||||||
}
|
}
|
||||||
|
|
||||||
result_func = modulator_functions.get(modulator_name)
|
result_func = modulator_functions.get(modulator_name)
|
||||||
if not result_func:
|
if not result_func:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
result = await result_func(observation, namespace)
|
result = await result_func(observation, namespace, memory)
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if float(modulator_value) >= float(result[0]):
|
logging.info("Modulator %s", modulator_name)
|
||||||
|
logging.info("Modulator value %s", modulator_value)
|
||||||
|
logging.info("Result %s", result[0])
|
||||||
|
if float(result[0]) >= float(modulator_value):
|
||||||
return result
|
return result
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
@ -809,11 +850,11 @@ class EpisodicBuffer(BaseMemory):
|
||||||
# check if modulators exist, initialize the modulators if needed
|
# check if modulators exist, initialize the modulators if needed
|
||||||
if attention_modulators is None:
|
if attention_modulators is None:
|
||||||
# try:
|
# try:
|
||||||
print("Starting with attention mods")
|
logging.info("Starting with attention mods")
|
||||||
attention_modulators = await self.fetch_memories(observation="Attention modulators",
|
attention_modulators = await self.fetch_memories(observation="Attention modulators",
|
||||||
namespace="BUFFERMEMORY")
|
namespace="BUFFERMEMORY")
|
||||||
|
|
||||||
print("Attention modulators exist", str(attention_modulators))
|
logging.info("Attention modulators exist %s", str(attention_modulators))
|
||||||
lookup_value_episodic = await self.fetch_memories(
|
lookup_value_episodic = await self.fetch_memories(
|
||||||
observation=str(output), namespace="EPISODICMEMORY"
|
observation=str(output), namespace="EPISODICMEMORY"
|
||||||
)
|
)
|
||||||
|
|
@ -896,26 +937,52 @@ class EpisodicBuffer(BaseMemory):
|
||||||
lookup_value_semantic = await self.fetch_memories(
|
lookup_value_semantic = await self.fetch_memories(
|
||||||
observation=str(output), namespace="SEMANTICMEMORY"
|
observation=str(output), namespace="SEMANTICMEMORY"
|
||||||
)
|
)
|
||||||
|
print("This is the lookup value semantic", len(lookup_value_semantic))
|
||||||
context = []
|
context = []
|
||||||
for memory in lookup_value_semantic["data"]["Get"]["SEMANTICMEMORY"]:
|
memory_scores = []
|
||||||
# extract memory id, and pass it to fetch function as a parameter
|
|
||||||
|
async def compute_score_for_memory(memory, output, attention_modulators):
|
||||||
modulators = list(attention_modulators.keys())
|
modulators = list(attention_modulators.keys())
|
||||||
|
total_score = 0
|
||||||
|
num_scores = 0
|
||||||
|
individual_scores = {} # Store individual scores with their modulator names
|
||||||
|
|
||||||
for modulator in modulators:
|
for modulator in modulators:
|
||||||
result = await self.handle_modulator(
|
result = await self.handle_modulator(
|
||||||
modulator,
|
modulator_name=modulator,
|
||||||
attention_modulators,
|
attention_modulators=attention_modulators,
|
||||||
str(output),
|
observation=str(output),
|
||||||
namespace="EPISODICMEMORY",
|
namespace="EPISODICMEMORY",
|
||||||
|
memory=memory,
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
context.append(result)
|
score = float(result[0]) # Assuming the first value in result is the score
|
||||||
context.append(memory)
|
individual_scores[modulator] = score # Store the score with its modulator name
|
||||||
|
total_score += score
|
||||||
|
num_scores += 1
|
||||||
|
|
||||||
|
average_score = total_score / num_scores if num_scores else 0
|
||||||
|
return {
|
||||||
|
"memory": memory,
|
||||||
|
"average_score": average_score,
|
||||||
|
"individual_scores": individual_scores
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
compute_score_for_memory(memory=memory, output=output, attention_modulators=attention_modulators)
|
||||||
|
for memory in lookup_value_semantic["data"]["Get"]["SEMANTICMEMORY"]
|
||||||
|
]
|
||||||
|
|
||||||
|
print("HERE IS THE LENGTH OF THE TASKS", str(tasks))
|
||||||
|
memory_scores = await asyncio.gather(*tasks)
|
||||||
|
# Sort the memories based on their average scores
|
||||||
|
sorted_memories = sorted(memory_scores, key=lambda x: x["average_score"], reverse=True)[:5]
|
||||||
|
# Store the sorted memories in the context
|
||||||
|
context.extend([item for item in sorted_memories])
|
||||||
|
print("HERE IS THE CONTEXT", context)
|
||||||
|
|
||||||
class BufferModulators(BaseModel):
|
class BufferModulators(BaseModel):
|
||||||
frequency: str = Field(..., description="Frequency score of the document")
|
attention_modulators: Dict[str, float] = Field(... , description="Attention modulators")
|
||||||
saliency: str = Field(..., description="Saliency score of the document")
|
|
||||||
relevance: str = Field(..., description="Relevance score of the document")
|
|
||||||
|
|
||||||
class BufferRawContextTerms(BaseModel):
|
class BufferRawContextTerms(BaseModel):
|
||||||
"""Schema for documentGroups"""
|
"""Schema for documentGroups"""
|
||||||
|
|
@ -927,18 +994,29 @@ class EpisodicBuffer(BaseMemory):
|
||||||
document_content: str = Field(
|
document_content: str = Field(
|
||||||
None, description="Shortened original content of the document"
|
None, description="Shortened original content of the document"
|
||||||
)
|
)
|
||||||
document_relevance: str = Field(
|
|
||||||
None,
|
|
||||||
description="The relevance of the document for the task on the scale from 0 to 1",
|
|
||||||
)
|
|
||||||
attention_modulators_list: List[BufferModulators] = Field(
|
attention_modulators_list: List[BufferModulators] = Field(
|
||||||
..., description="List of modulators"
|
..., description="List of modulators"
|
||||||
)
|
)
|
||||||
|
average_modulator_score: str = Field(None, description="Average modulator score")
|
||||||
|
class StructuredEpisodicEvents(BaseModel):
|
||||||
|
"""Schema for documentGroups"""
|
||||||
|
|
||||||
|
event_order: str = Field(
|
||||||
|
...,
|
||||||
|
description="Order when event occured",
|
||||||
|
)
|
||||||
|
event_type: str = Field(
|
||||||
|
None, description="Type of the event"
|
||||||
|
)
|
||||||
|
event_context: List[BufferModulators] = Field(
|
||||||
|
..., description="Context of the event"
|
||||||
|
)
|
||||||
|
|
||||||
class BufferRawContextList(BaseModel):
|
class BufferRawContextList(BaseModel):
|
||||||
"""Buffer raw context processed by the buffer"""
|
"""Buffer raw context processed by the buffer"""
|
||||||
|
|
||||||
docs: List[BufferRawContextTerms] = Field(..., description="List of docs")
|
docs: List[BufferRawContextTerms] = Field(..., description="List of docs")
|
||||||
|
events: List[StructuredEpisodicEvents] = Field(..., description="List of events")
|
||||||
user_query: str = Field(..., description="The original user query")
|
user_query: str = Field(..., description="The original user query")
|
||||||
|
|
||||||
# we structure the data here to make it easier to work with
|
# we structure the data here to make it easier to work with
|
||||||
|
|
@ -956,6 +1034,7 @@ class EpisodicBuffer(BaseMemory):
|
||||||
_input = prompt.format_prompt(query=user_input, context=context)
|
_input = prompt.format_prompt(query=user_input, context=context)
|
||||||
document_context_result = self.llm_base(_input.to_string())
|
document_context_result = self.llm_base(_input.to_string())
|
||||||
document_context_result_parsed = parser.parse(document_context_result)
|
document_context_result_parsed = parser.parse(document_context_result)
|
||||||
|
# print(document_context_result_parsed)
|
||||||
return document_context_result_parsed
|
return document_context_result_parsed
|
||||||
|
|
||||||
async def get_task_list(
|
async def get_task_list(
|
||||||
|
|
@ -1373,7 +1452,7 @@ class Memory:
|
||||||
async def main():
|
async def main():
|
||||||
|
|
||||||
# if you want to run the script as a standalone script, do so with the examples below
|
# if you want to run the script as a standalone script, do so with the examples below
|
||||||
memory = Memory(user_id="123")
|
memory = Memory(user_id="TestUser")
|
||||||
await memory.async_init()
|
await memory.async_init()
|
||||||
params = {
|
params = {
|
||||||
"version": "1.0",
|
"version": "1.0",
|
||||||
|
|
@ -1396,9 +1475,10 @@ async def main():
|
||||||
# load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
|
# load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
|
||||||
# print(load_jack_london)
|
# print(load_jack_london)
|
||||||
|
|
||||||
modulator = {"relevance": 0.0, "saliency": 0.0, "frequency": 0.0}
|
modulator = {"relevance": 0.1, "frequency": 0.1}
|
||||||
|
# await memory._delete_episodic_memory()
|
||||||
run_main_buffer = await memory._run_main_buffer(
|
#
|
||||||
|
run_main_buffer = await memory._create_buffer_context(
|
||||||
user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
||||||
params=params,
|
params=params,
|
||||||
attention_modulators=modulator,
|
attention_modulators=modulator,
|
||||||
|
|
|
||||||
0
level_2/modulators/modulators.py
Normal file
0
level_2/modulators/modulators.py
Normal file
File diff suppressed because one or more lines are too long
Loading…
Add table
Reference in a new issue