Merge pull request #28 from topoteretes/add_async_elements

Improve classifier, add turn output data to json
This commit is contained in:
Vasilije 2023-10-24 12:02:51 +02:00 committed by GitHub
commit 4e5d7ae8e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 127 additions and 39 deletions

View file

@ -9,7 +9,7 @@ from database.database import Base
from datetime import datetime
from sqlalchemy import Column, String, DateTime, ForeignKey
from sqlalchemy import Column, String, DateTime, ForeignKey, JSON
from sqlalchemy.orm import relationship
import os
import sys
@ -27,7 +27,7 @@ class TestOutput(Base):
user_id = Column(String, ForeignKey('users.id'), index=True) # Added user_id field
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
operation_id = Column(String, ForeignKey('operations.id'), index=True)
test_results = Column(String, nullable=True)
test_results = Column(JSON, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)

View file

@ -216,7 +216,24 @@ async def eval_test(query=None, output=None, expected_output=None, context=None,
# If you want to run the test
test_result = run_test(test_case, metrics=[metric], raise_error=False)
return test_result
def test_result_to_dict(test_result):
return {
"success": test_result.success,
"score": test_result.score,
"metric_name": test_result.metric_name,
"query": test_result.query,
"output": test_result.output,
"expected_output": test_result.expected_output,
"metadata": test_result.metadata,
"context": test_result.context
}
test_result_dict =[]
for test in test_result:
test_result_it = test_result_to_dict(test)
test_result_dict.append(test_result_it)
return test_result_dict
# You can also inspect the test result class
# print(test_result)
@ -241,7 +258,7 @@ def data_format_route( data_string: str):
def data_location_route(data_string: str):
@ai_classifier
class LocationRoute(Enum):
"""Represents classifier for the data location"""
"""Represents classifier for the data location, if it is device, or database connections string or URL """
DEVICE = "file_path_starting_with_.data_or_containing_it"
# URL = "url starting with http or https"
@ -273,6 +290,9 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id)
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
await memory.add_memory_instance("ExampleMemory")
existing_user = await Memory.check_existing_user(user_id, session)
if job_id is None:
job_id = str(uuid.uuid4())
@ -303,9 +323,6 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
if test_id is None:
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
await memory.add_memory_instance("ExampleMemory")
existing_user = await Memory.check_existing_user(user_id, session)
await memory.manage_memory_attributes(existing_user)
test_class = test_id + "_class"
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
@ -378,16 +395,21 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
namespace=test_id)
return test_eval_pipeline
return test_id, test_eval_pipeline
results = []
if only_llm_context:
result = await run_test(test= None, loader_settings=loader_settings, metadata=metadata, only_llm_context=only_llm_context)
test_id, result = await run_test(test=None, loader_settings=loader_settings, metadata=metadata,
only_llm_context=only_llm_context)
results.append(result)
for param in test_params:
result = await run_test(param, loader_settings, metadata,only_llm_context=only_llm_context)
test_id, result = await run_test(param, loader_settings, metadata, only_llm_context=only_llm_context)
results.append(result)
await add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(json.dumps(results))))
print(results)
return results
@ -432,7 +454,7 @@ async def main():
]
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
#http://public-library.uk/ebooks/59/83.pdf
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="676", params=None, metadata=metadata)
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata)
#
# parser = argparse.ArgumentParser(description="Run tests against a document.")
# parser.add_argument("--url", required=True, help="URL of the document to test.")

View file

@ -40,34 +40,28 @@ import json
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
LTM_MEMORY_ID_DEFAULT = "00000"
ST_MEMORY_ID_DEFAULT = "0000"
BUFFER_ID_DEFAULT = "0000"
class VectorDBFactory:
def __init__(self):
self.db_map = {
"pinecone": PineconeVectorDB,
"weaviate": WeaviateVectorDB,
# Add more database types and their corresponding classes here
}
def create_vector_db(
self,
user_id: str,
index_name: str,
memory_id: str,
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
buffer_id: str = BUFFER_ID_DEFAULT,
db_type: str = "pinecone",
db_type: str = "weaviate",
namespace: str = None,
embeddings = None,
embeddings=None,
):
db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
if db_type in db_map:
return db_map[db_type](
if db_type in self.db_map:
return self.db_map[db_type](
user_id,
index_name,
memory_id,
ltm_memory_id,
st_memory_id,
buffer_id,
namespace,
embeddings
)
@ -101,8 +95,61 @@ class BaseMemory:
)
def init_client(self, embeddings, namespace: str):
return self.vector_db.init_client(embeddings, namespace)
return self.vector_db.init_weaviate_client(embeddings, namespace)
# class VectorDBFactory:
# def create_vector_db(
# self,
# user_id: str,
# index_name: str,
# memory_id: str,
# db_type: str = "pinecone",
# namespace: str = None,
# embeddings = None,
# ):
# db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
#
# if db_type in db_map:
# return db_map[db_type](
# user_id,
# index_name,
# memory_id,
# namespace,
# embeddings
# )
#
# raise ValueError(f"Unsupported database type: {db_type}")
#
# class BaseMemory:
# def __init__(
# self,
# user_id: str,
# memory_id: Optional[str],
# index_name: Optional[str],
# db_type: str,
# namespace: str,
# embeddings: Optional[None],
# ):
# self.user_id = user_id
# self.memory_id = memory_id
# self.index_name = index_name
# self.namespace = namespace
# self.embeddings = embeddings
# self.db_type = db_type
# factory = VectorDBFactory()
# self.vector_db = factory.create_vector_db(
# self.user_id,
# self.index_name,
# self.memory_id,
# db_type=self.db_type,
# namespace=self.namespace,
# embeddings=self.embeddings
# )
#
# def init_client(self, embeddings, namespace: str):
#
# return self.vector_db.init_weaviate_client(embeddings, namespace)
def create_field(self, field_type, **kwargs):
field_mapping = {

View file

@ -32,9 +32,6 @@ class VectorDB:
user_id: str,
index_name: str,
memory_id: str,
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
buffer_id: str = BUFFER_ID_DEFAULT,
namespace: str = None,
embeddings = None,
):
@ -42,9 +39,6 @@ class VectorDB:
self.index_name = index_name
self.namespace = namespace
self.memory_id = memory_id
self.ltm_memory_id = ltm_memory_id
self.st_memory_id = st_memory_id
self.buffer_id = buffer_id
self.embeddings = embeddings
class PineconeVectorDB(VectorDB):
@ -81,7 +75,7 @@ class WeaviateVectorDB(VectorDB):
embedding=embeddings,
create_schema_if_missing=True,
)
return retriever # If this is part of the initialization, call it here.
return retriever
def init_weaviate_client(self, namespace: str):
# Weaviate client initialization logic
@ -95,6 +89,34 @@ class WeaviateVectorDB(VectorDB):
)
return client
from marshmallow import Schema, fields
def create_document_structure(observation, params, metadata_schema_class=None):
"""
Create and validate a document structure with optional custom fields.
:param observation: Content of the document.
:param params: Metadata information.
:param metadata_schema_class: Custom metadata schema class (optional).
:return: A list containing the validated document data.
"""
document_data = {
"metadata": params,
"page_content": observation
}
def get_document_schema():
class DynamicDocumentSchema(Schema):
metadata = fields.Nested(metadata_schema_class, required=True)
page_content = fields.Str(required=True)
return DynamicDocumentSchema
# Validate and deserialize, defaulting to "1.0" if not provided
CurrentDocumentSchema = get_document_schema()
loaded_document = CurrentDocumentSchema().load(document_data)
return [loaded_document]
def _stuct(self, observation, params, metadata_schema_class =None):
"""Utility function to create the document structure with optional custom fields."""
@ -267,9 +289,6 @@ class WeaviateVectorDB(VectorDB):
data_object={
# "text": observation,
"user_id": str(self.user_id),
"memory_id": str(self.memory_id),
"ltm_memory_id": str(self.ltm_memory_id),
"st_memory_id": str(self.st_memory_id),
"buffer_id": str(self.buffer_id),
"version": params.get("version", None) or "",
"agreement_id": params.get("agreement_id", None) or "",