Merge pull request #15 from topoteretes/added_backend_pg_and_dynamic_memory_classes

Updated and tested retry logic, still more to be done
This commit is contained in:
Vasilije 2023-10-08 22:17:43 +02:00 committed by GitHub
commit a1b322e7bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 11 deletions

View file

@ -18,7 +18,6 @@ from typing import Optional, Dict, List, Union
import tracemalloc
tracemalloc.start()
import os
from datetime import datetime

View file

@ -217,15 +217,15 @@ async def eval_test(query=None, output=None, expected_output=None, context=None)
output=result_output,
expected_output=expected_output,
context=context,
)
metric = OverallScoreMetric()
# if you want to make sure that the test returns an error
assert_test(test_case, metrics=[metric])
# If you want to run the test
test_result = run_test(test_case, metrics=[metric])
test_result = run_test(test_case, metrics=[metric], raise_error=False)
return test_result
# You can also inspect the test result class
print(test_result)
# print(test_result)
@ -365,18 +365,28 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
test_result_colletion =[]
test_result_collection =[]
for test in test_set:
retrieve_action = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories',
observation=test["question"])
test_results = await eval_test( query=test["question"], expected_output=test["answer"], context= str(retrieve_action))
test_result_colletion.append(test_results)
test_result_collection.append(test_results)
print(test_results)
if dynamic_memory_class is not None:
memory.add_method_to_class(dynamic_memory_class, 'delete_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
load_action = await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
namespace ='some_observation', params=metadata,
loader_settings=loader_settings)
memory.delete_memories(namespace=test_id)
add_entity(session, TestOutput(id=test_id, user_id=user_id, content=str(test_result_colletion)))
print(test_result_collection)
add_entity(session, TestOutput(id=test_id, user_id=user_id, content=str(test_result_collection)))
async def main():

View file

@ -196,7 +196,9 @@ class WeaviateVectorDB(VectorDB):
return query_output
async def delete_memories(self, params: dict = None):
async def delete_memories(self, namespace:str, params: dict = None):
if namespace is None:
namespace = self.namespace
client = self.init_weaviate_client(self.namespace)
if params:
where_filter = {
@ -213,7 +215,7 @@ class WeaviateVectorDB(VectorDB):
# Delete all objects
print("HERE IS THE USER ID", self.user_id)
return client.batch.delete_objects(
class_name=self.namespace,
class_name=namespace,
where={
"path": ["user_id"],
"operator": "Equal",

View file

@ -39,15 +39,51 @@ class DynamicBaseMemory(BaseMemory):
self.associations = []
def add_method(self, method_name):
"""
Add a method to the memory class.
Args:
- method_name (str): The name of the method to be added.
Returns:
None
"""
self.methods.add(method_name)
def add_attribute(self, attribute_name):
"""
Add an attribute to the memory class.
Args:
- attribute_name (str): The name of the attribute to be added.
Returns:
None
"""
self.attributes.add(attribute_name)
def get_attribute(self, attribute_name):
"""
Check if the attribute is in the memory class.
Args:
- attribute_name (str): The name of the attribute to be checked.
Returns:
bool: True if attribute exists, False otherwise.
"""
return attribute_name in self.attributes
def add_association(self, associated_memory):
"""
Add an association to another memory class.
Args:
- associated_memory (MemoryClass): The memory class to be associated with.
Returns:
None
"""
if associated_memory not in self.associations:
self.associations.append(associated_memory)
# Optionally, establish a bidirectional association
@ -55,10 +91,28 @@ class DynamicBaseMemory(BaseMemory):
class Attribute:
def __init__(self, name):
"""
Initialize the Attribute class.
Args:
- name (str): The name of the attribute.
Attributes:
- name (str): Stores the name of the attribute.
"""
self.name = name
class Method:
def __init__(self, name):
"""
Initialize the Method class.
Args:
- name (str): The name of the method.
Attributes:
- name (str): Stores the name of the method.
"""
self.name = name
@ -176,7 +230,7 @@ class Memory:
else:
# Define default methods for a new user
methods_list = [
'async_create_long_term_memory', 'async_init', 'add_memories', "fetch_memories",
'async_create_long_term_memory', 'async_init', 'add_memories', "fetch_memories", "delete_memories",
'async_create_short_term_memory', '_create_buffer_context', '_get_task_list',
'_run_main_buffer', '_available_operations', '_provide_feedback'
]