fix the demo logic

This commit is contained in:
Vasilije 2023-08-25 11:09:13 +02:00
parent 59b1c54cb8
commit 32d3bd026a

View file

@ -59,6 +59,8 @@ from langchain.vectorstores import Weaviate
import weaviate
import uuid
import humanize
import pinecone
import weaviate
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
@ -88,8 +90,7 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
print("Hi! I just woke up. Your llm is ending")
import pinecone
import weaviate
# Assuming OpenAIEmbeddings and other necessary imports are available
@ -306,7 +307,7 @@ class WeaviateVectorDB(VectorDB):
)
else:
# Delete all objects
print("HERE IS THE USER ID", self.user_id)
return client.batch.delete_objects(
class_name=self.namespace,
where={
@ -386,9 +387,9 @@ class EpisodicMemory(BaseMemory):
super().__init__(user_id, memory_id, index_name, db_type, namespace="EPISODICMEMORY")
class EpisodicBuffer(EpisodicMemory):
class EpisodicBuffer(BaseMemory):
def __init__(self, user_id: str, memory_id: Optional[str], index_name: Optional[str], db_type: str = "weaviate"):
super().__init__(user_id, memory_id, index_name, db_type)
super().__init__(user_id, memory_id, index_name, db_type, namespace="BUFFERMEMORY")
self.st_memory_id = "blah"
self.llm = ChatOpenAI(
@ -480,10 +481,7 @@ class EpisodicBuffer(EpisodicMemory):
async def encoding(self, document: str, namespace: str = "EPISODICBUFFER", params: dict = None) -> list[str]:
"""Encoding for the buffer, stores raw data in the buffer
Note, this is not comp-sci encoding, but rather encoding in the sense of storing the content in the buffer"""
vector_db = VectorDB(user_id=self.user_id, memory_id=self.memory_id, st_memory_id=self.st_memory_id,
index_name=self.index_name, db_type=self.db_type, namespace=namespace)
query = await vector_db.add_memories(document, params=params)
query = await self.add_memories(document, params=params)
return query
async def available_operations(self) -> list[str]:
@ -502,7 +500,13 @@ class EpisodicBuffer(EpisodicMemory):
memory = Memory(user_id=self.user_id)
await memory.async_init()
await memory._delete_buffer_memory()
try:
# we delete all memories in the episodic buffer, so we can start fresh
await self.delete_memories()
except:
# in case there are no memories, we pass
pass
# we just filter the data here to make sure input is clean
prompt_filter = ChatPromptTemplate.from_template(
@ -535,7 +539,7 @@ class EpisodicBuffer(EpisodicMemory):
lookup_value_episodic = await memory._fetch_episodic_memory(observation=str(output))
lookup_value_semantic = await memory._fetch_semantic_memory(observation=str(output))
lookup_value_buffer = await self._fetch_memories(observation=str(output), namespace=self.namespace)
lookup_value_buffer = await self.fetch_memories(observation=str(output))
context.append(lookup_value_buffer)
context.append(lookup_value_semantic)
@ -685,7 +689,7 @@ class EpisodicBuffer(EpisodicMemory):
await self.encoding(str(result_tasks), self.namespace, params=params)
buffer_result = await self._fetch_memories(observation=str(output), namespace=self.namespace)
buffer_result = await self.fetch_memories(observation=str(output))
print("HERE IS THE RESULT TASKS", str(buffer_result))
@ -722,13 +726,11 @@ class EpisodicBuffer(EpisodicMemory):
result_parsing = parser.parse(output)
print("here is the parsing result", result_parsing)
memory = Memory(user_id=self.user_id)
await memory.async_init()
#
lookup_value = await memory._add_episodic_memory(observation=str(output), params=params)
# now we clean up buffer memory
await memory._delete_buffer_memory()
await self.delete_memories()
return lookup_value
@ -782,23 +784,10 @@ class Memory:
async def async_create_long_term_memory(self, user_id, memory_id, index_name, db_type):
# Perform asynchronous initialization steps if needed
return LongTermMemory(
user_id=user_id, memory_id=memory_id, index_name=index_name,
db_type=db_type
)
async def async_init(self):
# Asynchronous initialization of LongTermMemory and ShortTermMemory
self.long_term_memory = await self.async_create_long_term_memory(
user_id=self.user_id, memory_id=self.memory_id, index_name=self.index_name,
db_type=self.db_type
)
async def async_create_short_term_memory(self, user_id, memory_id, index_name, db_type):
# Perform asynchronous initialization steps if needed
return ShortTermMemory(
user_id=user_id, memory_id=memory_id, index_name=index_name, db_type=db_type
)
async def async_init(self):
# Asynchronous initialization of LongTermMemory and ShortTermMemory
self.long_term_memory = await self.async_create_long_term_memory(
@ -809,40 +798,48 @@ class Memory:
user_id=self.user_id, memory_id=self.memory_id, index_name=self.index_name,
db_type=self.db_type
)
async def async_create_short_term_memory(self, user_id, memory_id, index_name, db_type):
# Perform asynchronous initialization steps if needed
return ShortTermMemory(
user_id=self.user_id, memory_id=self.memory_id, index_name=self.index_name, db_type=self.db_type
)
# self.short_term_memory = await ShortTermMemory.async_init(
# user_id=self.user_id, memory_id=self.memory_id, index_name=self.index_name, db_type=self.db_type
# )
async def _add_semantic_memory(self, semantic_memory: str, params: dict = None):
return await self.long_term_memory.semantic_memory._add_memories(
return await self.long_term_memory.semantic_memory.add_memories(
semantic_memory=semantic_memory, params=params
)
async def _fetch_semantic_memory(self, observation, params):
return await self.long_term_memory.semantic_memory._fetch_memories(
return await self.long_term_memory.semantic_memory.fetch_memories(
observation=observation, params=params
)
async def _delete_semantic_memory(self, params: str = None):
return await self.long_term_memory.semantic_memory._delete_memories(
return await self.long_term_memory.semantic_memory.delete_memories(
params=params
)
async def _add_episodic_memory(self, observation: str, params: dict = None):
return await self.long_term_memory.episodic_memory._add_memories(
return await self.long_term_memory.episodic_memory.add_memories(
observation=observation, params=params
)
async def _fetch_episodic_memory(self, observation, params: str = None):
return await self.long_term_memory.episodic_memory._fetch_memories(
return await self.long_term_memory.episodic_memory.fetch_memories(
observation=observation, params=params
)
async def _delete_episodic_memory(self, params: str = None):
return await self.long_term_memory.episodic_memory._delete_memories(
return await self.long_term_memory.episodic_memory.delete_memories(
params=params
)
@ -850,18 +847,18 @@ class Memory:
return await self.short_term_memory.episodic_buffer.main_buffer(user_input=user_input, content=content, params=params)
async def _add_buffer_memory(self, user_input: str, namespace: str = None, params: dict = None):
return await self.short_term_memory.episodic_buffer._add_memories(observation=user_input, namespace=namespace,
return await self.short_term_memory.episodic_buffer.add_memories(observation=user_input,
params=params)
async def _fetch_buffer_memory(self, user_input: str, namespace: str = None):
return await self.short_term_memory.episodic_buffer._fetch_memories(observation=user_input, namespace=namespace)
return await self.short_term_memory.episodic_buffer.fetch_memories(observation=user_input)
async def _delete_buffer_memory(self, params: str = None):
return await self.short_term_memory.episodic_buffer._delete_memories(
return await self.short_term_memory.episodic_buffer.delete_memories(
params=params
)
async def _available_operations(self):
return await self.long_term_memory.episodic_buffer._available_operations()
return await self.long_term_memory.episodic_buffer.available_operations()
@ -886,7 +883,7 @@ async def main():
gg = await memory._run_buffer(user_input="i NEED TRANSLATION TO GERMAN ", content="i NEED TRANSLATION TO GERMAN ", params=params)
print(gg)
# gg = await memory._delete_buffer_memory()
# gg = await memory._fetch_buffer_memory(user_input="i TO GERMAN ")
# print(gg)
episodic = """{
@ -909,7 +906,9 @@ async def main():
]
}"""
#
# ggur = await memory._add_episodic_memory(observation = episodic, params=params)
# ggur = await memory._delete_buffer_memory()
# print(ggur)
# ggur = await memory._add_buffer_memory(user_input = episodic, params=params)
# print(ggur)
# fff = await memory._fetch_episodic_memory(observation = "healthy diet")