Merge pull request #28 from topoteretes/add_async_elements

Improve classifier, add turn output data to json
This commit is contained in:
Vasilije 2023-10-24 12:02:51 +02:00 committed by GitHub
commit 4e5d7ae8e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 127 additions and 39 deletions

View file

@ -9,7 +9,7 @@ from database.database import Base
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, String, DateTime, ForeignKey from sqlalchemy import Column, String, DateTime, ForeignKey, JSON
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
import os import os
import sys import sys
@ -27,7 +27,7 @@ class TestOutput(Base):
user_id = Column(String, ForeignKey('users.id'), index=True) # Added user_id field user_id = Column(String, ForeignKey('users.id'), index=True) # Added user_id field
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True) test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
operation_id = Column(String, ForeignKey('operations.id'), index=True) operation_id = Column(String, ForeignKey('operations.id'), index=True)
test_results = Column(String, nullable=True) test_results = Column(JSON, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow) updated_at = Column(DateTime, onupdate=datetime.utcnow)

View file

@ -216,7 +216,24 @@ async def eval_test(query=None, output=None, expected_output=None, context=None,
# If you want to run the test # If you want to run the test
test_result = run_test(test_case, metrics=[metric], raise_error=False) test_result = run_test(test_case, metrics=[metric], raise_error=False)
return test_result
def test_result_to_dict(test_result):
return {
"success": test_result.success,
"score": test_result.score,
"metric_name": test_result.metric_name,
"query": test_result.query,
"output": test_result.output,
"expected_output": test_result.expected_output,
"metadata": test_result.metadata,
"context": test_result.context
}
test_result_dict =[]
for test in test_result:
test_result_it = test_result_to_dict(test)
test_result_dict.append(test_result_it)
return test_result_dict
# You can also inspect the test result class # You can also inspect the test result class
# print(test_result) # print(test_result)
@ -241,7 +258,7 @@ def data_format_route( data_string: str):
def data_location_route(data_string: str): def data_location_route(data_string: str):
@ai_classifier @ai_classifier
class LocationRoute(Enum): class LocationRoute(Enum):
"""Represents classifier for the data location""" """Represents classifier for the data location, if it is device, or database connections string or URL """
DEVICE = "file_path_starting_with_.data_or_containing_it" DEVICE = "file_path_starting_with_.data_or_containing_it"
# URL = "url starting with http or https" # URL = "url starting with http or https"
@ -273,6 +290,9 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id) job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id) test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id)
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
await memory.add_memory_instance("ExampleMemory")
existing_user = await Memory.check_existing_user(user_id, session)
if job_id is None: if job_id is None:
job_id = str(uuid.uuid4()) job_id = str(uuid.uuid4())
@ -303,9 +323,6 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
if test_id is None: if test_id is None:
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY" test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
await memory.add_memory_instance("ExampleMemory")
existing_user = await Memory.check_existing_user(user_id, session)
await memory.manage_memory_attributes(existing_user) await memory.manage_memory_attributes(existing_user)
test_class = test_id + "_class" test_class = test_id + "_class"
await memory.add_dynamic_memory_class(test_id.lower(), test_id) await memory.add_dynamic_memory_class(test_id.lower(), test_id)
@ -378,16 +395,21 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories', await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
namespace=test_id) namespace=test_id)
return test_eval_pipeline return test_id, test_eval_pipeline
results = [] results = []
if only_llm_context: if only_llm_context:
result = await run_test(test= None, loader_settings=loader_settings, metadata=metadata, only_llm_context=only_llm_context) test_id, result = await run_test(test=None, loader_settings=loader_settings, metadata=metadata,
only_llm_context=only_llm_context)
results.append(result) results.append(result)
for param in test_params: for param in test_params:
result = await run_test(param, loader_settings, metadata,only_llm_context=only_llm_context) test_id, result = await run_test(param, loader_settings, metadata, only_llm_context=only_llm_context)
results.append(result) results.append(result)
await add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(json.dumps(results))))
print(results) print(results)
return results return results
@ -432,7 +454,7 @@ async def main():
] ]
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" # "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
#http://public-library.uk/ebooks/59/83.pdf #http://public-library.uk/ebooks/59/83.pdf
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="676", params=None, metadata=metadata) result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata)
# #
# parser = argparse.ArgumentParser(description="Run tests against a document.") # parser = argparse.ArgumentParser(description="Run tests against a document.")
# parser.add_argument("--url", required=True, help="URL of the document to test.") # parser.add_argument("--url", required=True, help="URL of the document to test.")

View file

@ -40,34 +40,28 @@ import json
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY") marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
LTM_MEMORY_ID_DEFAULT = "00000"
ST_MEMORY_ID_DEFAULT = "0000"
BUFFER_ID_DEFAULT = "0000"
class VectorDBFactory: class VectorDBFactory:
def __init__(self):
self.db_map = {
"pinecone": PineconeVectorDB,
"weaviate": WeaviateVectorDB,
# Add more database types and their corresponding classes here
}
def create_vector_db( def create_vector_db(
self, self,
user_id: str, user_id: str,
index_name: str, index_name: str,
memory_id: str, memory_id: str,
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT, db_type: str = "weaviate",
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
buffer_id: str = BUFFER_ID_DEFAULT,
db_type: str = "pinecone",
namespace: str = None, namespace: str = None,
embeddings = None, embeddings=None,
): ):
db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB} if db_type in self.db_map:
return self.db_map[db_type](
if db_type in db_map:
return db_map[db_type](
user_id, user_id,
index_name, index_name,
memory_id, memory_id,
ltm_memory_id,
st_memory_id,
buffer_id,
namespace, namespace,
embeddings embeddings
) )
@ -101,8 +95,61 @@ class BaseMemory:
) )
def init_client(self, embeddings, namespace: str): def init_client(self, embeddings, namespace: str):
return self.vector_db.init_client(embeddings, namespace)
return self.vector_db.init_weaviate_client(embeddings, namespace)
# class VectorDBFactory:
# def create_vector_db(
# self,
# user_id: str,
# index_name: str,
# memory_id: str,
# db_type: str = "pinecone",
# namespace: str = None,
# embeddings = None,
# ):
# db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
#
# if db_type in db_map:
# return db_map[db_type](
# user_id,
# index_name,
# memory_id,
# namespace,
# embeddings
# )
#
# raise ValueError(f"Unsupported database type: {db_type}")
#
# class BaseMemory:
# def __init__(
# self,
# user_id: str,
# memory_id: Optional[str],
# index_name: Optional[str],
# db_type: str,
# namespace: str,
# embeddings: Optional[None],
# ):
# self.user_id = user_id
# self.memory_id = memory_id
# self.index_name = index_name
# self.namespace = namespace
# self.embeddings = embeddings
# self.db_type = db_type
# factory = VectorDBFactory()
# self.vector_db = factory.create_vector_db(
# self.user_id,
# self.index_name,
# self.memory_id,
# db_type=self.db_type,
# namespace=self.namespace,
# embeddings=self.embeddings
# )
#
# def init_client(self, embeddings, namespace: str):
#
# return self.vector_db.init_weaviate_client(embeddings, namespace)
def create_field(self, field_type, **kwargs): def create_field(self, field_type, **kwargs):
field_mapping = { field_mapping = {

View file

@ -32,9 +32,6 @@ class VectorDB:
user_id: str, user_id: str,
index_name: str, index_name: str,
memory_id: str, memory_id: str,
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
buffer_id: str = BUFFER_ID_DEFAULT,
namespace: str = None, namespace: str = None,
embeddings = None, embeddings = None,
): ):
@ -42,9 +39,6 @@ class VectorDB:
self.index_name = index_name self.index_name = index_name
self.namespace = namespace self.namespace = namespace
self.memory_id = memory_id self.memory_id = memory_id
self.ltm_memory_id = ltm_memory_id
self.st_memory_id = st_memory_id
self.buffer_id = buffer_id
self.embeddings = embeddings self.embeddings = embeddings
class PineconeVectorDB(VectorDB): class PineconeVectorDB(VectorDB):
@ -81,7 +75,7 @@ class WeaviateVectorDB(VectorDB):
embedding=embeddings, embedding=embeddings,
create_schema_if_missing=True, create_schema_if_missing=True,
) )
return retriever # If this is part of the initialization, call it here. return retriever
def init_weaviate_client(self, namespace: str): def init_weaviate_client(self, namespace: str):
# Weaviate client initialization logic # Weaviate client initialization logic
@ -95,6 +89,34 @@ class WeaviateVectorDB(VectorDB):
) )
return client return client
from marshmallow import Schema, fields
def create_document_structure(observation, params, metadata_schema_class=None):
"""
Create and validate a document structure with optional custom fields.
:param observation: Content of the document.
:param params: Metadata information.
:param metadata_schema_class: Custom metadata schema class (optional).
:return: A list containing the validated document data.
"""
document_data = {
"metadata": params,
"page_content": observation
}
def get_document_schema():
class DynamicDocumentSchema(Schema):
metadata = fields.Nested(metadata_schema_class, required=True)
page_content = fields.Str(required=True)
return DynamicDocumentSchema
# Validate and deserialize, defaulting to "1.0" if not provided
CurrentDocumentSchema = get_document_schema()
loaded_document = CurrentDocumentSchema().load(document_data)
return [loaded_document]
def _stuct(self, observation, params, metadata_schema_class =None): def _stuct(self, observation, params, metadata_schema_class =None):
"""Utility function to create the document structure with optional custom fields.""" """Utility function to create the document structure with optional custom fields."""
@ -267,9 +289,6 @@ class WeaviateVectorDB(VectorDB):
data_object={ data_object={
# "text": observation, # "text": observation,
"user_id": str(self.user_id), "user_id": str(self.user_id),
"memory_id": str(self.memory_id),
"ltm_memory_id": str(self.ltm_memory_id),
"st_memory_id": str(self.st_memory_id),
"buffer_id": str(self.buffer_id), "buffer_id": str(self.buffer_id),
"version": params.get("version", None) or "", "version": params.get("version", None) or "",
"agreement_id": params.get("agreement_id", None) or "", "agreement_id": params.get("agreement_id", None) or "",