Merge pull request #28 from topoteretes/add_async_elements
Improve classifier, add turn output data to json
This commit is contained in:
commit
4e5d7ae8e4
4 changed files with 127 additions and 39 deletions
|
|
@ -9,7 +9,7 @@ from database.database import Base
|
|||
|
||||
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, DateTime, ForeignKey
|
||||
from sqlalchemy import Column, String, DateTime, ForeignKey, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -27,7 +27,7 @@ class TestOutput(Base):
|
|||
user_id = Column(String, ForeignKey('users.id'), index=True) # Added user_id field
|
||||
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
|
||||
operation_id = Column(String, ForeignKey('operations.id'), index=True)
|
||||
test_results = Column(String, nullable=True)
|
||||
test_results = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||
|
||||
|
|
|
|||
|
|
@ -216,7 +216,24 @@ async def eval_test(query=None, output=None, expected_output=None, context=None,
|
|||
|
||||
# If you want to run the test
|
||||
test_result = run_test(test_case, metrics=[metric], raise_error=False)
|
||||
return test_result
|
||||
|
||||
def test_result_to_dict(test_result):
|
||||
return {
|
||||
"success": test_result.success,
|
||||
"score": test_result.score,
|
||||
"metric_name": test_result.metric_name,
|
||||
"query": test_result.query,
|
||||
"output": test_result.output,
|
||||
"expected_output": test_result.expected_output,
|
||||
"metadata": test_result.metadata,
|
||||
"context": test_result.context
|
||||
}
|
||||
|
||||
test_result_dict =[]
|
||||
for test in test_result:
|
||||
test_result_it = test_result_to_dict(test)
|
||||
test_result_dict.append(test_result_it)
|
||||
return test_result_dict
|
||||
# You can also inspect the test result class
|
||||
# print(test_result)
|
||||
|
||||
|
|
@ -241,7 +258,7 @@ def data_format_route( data_string: str):
|
|||
def data_location_route(data_string: str):
|
||||
@ai_classifier
|
||||
class LocationRoute(Enum):
|
||||
"""Represents classifier for the data location"""
|
||||
"""Represents classifier for the data location, if it is device, or database connections string or URL """
|
||||
|
||||
DEVICE = "file_path_starting_with_.data_or_containing_it"
|
||||
# URL = "url starting with http or https"
|
||||
|
|
@ -273,6 +290,9 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
|||
|
||||
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
|
||||
test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id)
|
||||
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
|
||||
await memory.add_memory_instance("ExampleMemory")
|
||||
existing_user = await Memory.check_existing_user(user_id, session)
|
||||
|
||||
if job_id is None:
|
||||
job_id = str(uuid.uuid4())
|
||||
|
|
@ -303,9 +323,6 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
|||
|
||||
if test_id is None:
|
||||
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
|
||||
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
|
||||
await memory.add_memory_instance("ExampleMemory")
|
||||
existing_user = await Memory.check_existing_user(user_id, session)
|
||||
await memory.manage_memory_attributes(existing_user)
|
||||
test_class = test_id + "_class"
|
||||
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
||||
|
|
@ -378,16 +395,21 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
|||
await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
|
||||
namespace=test_id)
|
||||
|
||||
return test_eval_pipeline
|
||||
return test_id, test_eval_pipeline
|
||||
|
||||
results = []
|
||||
|
||||
if only_llm_context:
|
||||
result = await run_test(test= None, loader_settings=loader_settings, metadata=metadata, only_llm_context=only_llm_context)
|
||||
test_id, result = await run_test(test=None, loader_settings=loader_settings, metadata=metadata,
|
||||
only_llm_context=only_llm_context)
|
||||
results.append(result)
|
||||
|
||||
for param in test_params:
|
||||
result = await run_test(param, loader_settings, metadata,only_llm_context=only_llm_context)
|
||||
test_id, result = await run_test(param, loader_settings, metadata, only_llm_context=only_llm_context)
|
||||
results.append(result)
|
||||
|
||||
await add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(json.dumps(results))))
|
||||
|
||||
print(results)
|
||||
|
||||
return results
|
||||
|
|
@ -432,7 +454,7 @@ async def main():
|
|||
]
|
||||
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||
#http://public-library.uk/ebooks/59/83.pdf
|
||||
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="676", params=None, metadata=metadata)
|
||||
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata)
|
||||
#
|
||||
# parser = argparse.ArgumentParser(description="Run tests against a document.")
|
||||
# parser.add_argument("--url", required=True, help="URL of the document to test.")
|
||||
|
|
|
|||
|
|
@ -40,34 +40,28 @@ import json
|
|||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
LTM_MEMORY_ID_DEFAULT = "00000"
|
||||
ST_MEMORY_ID_DEFAULT = "0000"
|
||||
BUFFER_ID_DEFAULT = "0000"
|
||||
|
||||
|
||||
class VectorDBFactory:
|
||||
def __init__(self):
|
||||
self.db_map = {
|
||||
"pinecone": PineconeVectorDB,
|
||||
"weaviate": WeaviateVectorDB,
|
||||
# Add more database types and their corresponding classes here
|
||||
}
|
||||
|
||||
def create_vector_db(
|
||||
self,
|
||||
user_id: str,
|
||||
index_name: str,
|
||||
memory_id: str,
|
||||
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
|
||||
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
|
||||
buffer_id: str = BUFFER_ID_DEFAULT,
|
||||
db_type: str = "pinecone",
|
||||
db_type: str = "weaviate",
|
||||
namespace: str = None,
|
||||
embeddings = None,
|
||||
embeddings=None,
|
||||
):
|
||||
db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
|
||||
|
||||
if db_type in db_map:
|
||||
return db_map[db_type](
|
||||
if db_type in self.db_map:
|
||||
return self.db_map[db_type](
|
||||
user_id,
|
||||
index_name,
|
||||
memory_id,
|
||||
ltm_memory_id,
|
||||
st_memory_id,
|
||||
buffer_id,
|
||||
namespace,
|
||||
embeddings
|
||||
)
|
||||
|
|
@ -101,8 +95,61 @@ class BaseMemory:
|
|||
)
|
||||
|
||||
def init_client(self, embeddings, namespace: str):
|
||||
return self.vector_db.init_client(embeddings, namespace)
|
||||
|
||||
return self.vector_db.init_weaviate_client(embeddings, namespace)
|
||||
|
||||
# class VectorDBFactory:
|
||||
# def create_vector_db(
|
||||
# self,
|
||||
# user_id: str,
|
||||
# index_name: str,
|
||||
# memory_id: str,
|
||||
# db_type: str = "pinecone",
|
||||
# namespace: str = None,
|
||||
# embeddings = None,
|
||||
# ):
|
||||
# db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
|
||||
#
|
||||
# if db_type in db_map:
|
||||
# return db_map[db_type](
|
||||
# user_id,
|
||||
# index_name,
|
||||
# memory_id,
|
||||
# namespace,
|
||||
# embeddings
|
||||
# )
|
||||
#
|
||||
# raise ValueError(f"Unsupported database type: {db_type}")
|
||||
#
|
||||
# class BaseMemory:
|
||||
# def __init__(
|
||||
# self,
|
||||
# user_id: str,
|
||||
# memory_id: Optional[str],
|
||||
# index_name: Optional[str],
|
||||
# db_type: str,
|
||||
# namespace: str,
|
||||
# embeddings: Optional[None],
|
||||
# ):
|
||||
# self.user_id = user_id
|
||||
# self.memory_id = memory_id
|
||||
# self.index_name = index_name
|
||||
# self.namespace = namespace
|
||||
# self.embeddings = embeddings
|
||||
# self.db_type = db_type
|
||||
# factory = VectorDBFactory()
|
||||
# self.vector_db = factory.create_vector_db(
|
||||
# self.user_id,
|
||||
# self.index_name,
|
||||
# self.memory_id,
|
||||
# db_type=self.db_type,
|
||||
# namespace=self.namespace,
|
||||
# embeddings=self.embeddings
|
||||
# )
|
||||
#
|
||||
# def init_client(self, embeddings, namespace: str):
|
||||
#
|
||||
# return self.vector_db.init_weaviate_client(embeddings, namespace)
|
||||
|
||||
def create_field(self, field_type, **kwargs):
|
||||
field_mapping = {
|
||||
|
|
|
|||
|
|
@ -32,9 +32,6 @@ class VectorDB:
|
|||
user_id: str,
|
||||
index_name: str,
|
||||
memory_id: str,
|
||||
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
|
||||
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
|
||||
buffer_id: str = BUFFER_ID_DEFAULT,
|
||||
namespace: str = None,
|
||||
embeddings = None,
|
||||
):
|
||||
|
|
@ -42,9 +39,6 @@ class VectorDB:
|
|||
self.index_name = index_name
|
||||
self.namespace = namespace
|
||||
self.memory_id = memory_id
|
||||
self.ltm_memory_id = ltm_memory_id
|
||||
self.st_memory_id = st_memory_id
|
||||
self.buffer_id = buffer_id
|
||||
self.embeddings = embeddings
|
||||
|
||||
class PineconeVectorDB(VectorDB):
|
||||
|
|
@ -81,7 +75,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
embedding=embeddings,
|
||||
create_schema_if_missing=True,
|
||||
)
|
||||
return retriever # If this is part of the initialization, call it here.
|
||||
return retriever
|
||||
|
||||
def init_weaviate_client(self, namespace: str):
|
||||
# Weaviate client initialization logic
|
||||
|
|
@ -95,6 +89,34 @@ class WeaviateVectorDB(VectorDB):
|
|||
)
|
||||
return client
|
||||
|
||||
from marshmallow import Schema, fields
|
||||
|
||||
def create_document_structure(observation, params, metadata_schema_class=None):
|
||||
"""
|
||||
Create and validate a document structure with optional custom fields.
|
||||
|
||||
:param observation: Content of the document.
|
||||
:param params: Metadata information.
|
||||
:param metadata_schema_class: Custom metadata schema class (optional).
|
||||
:return: A list containing the validated document data.
|
||||
"""
|
||||
document_data = {
|
||||
"metadata": params,
|
||||
"page_content": observation
|
||||
}
|
||||
|
||||
def get_document_schema():
|
||||
class DynamicDocumentSchema(Schema):
|
||||
metadata = fields.Nested(metadata_schema_class, required=True)
|
||||
page_content = fields.Str(required=True)
|
||||
|
||||
return DynamicDocumentSchema
|
||||
|
||||
# Validate and deserialize, defaulting to "1.0" if not provided
|
||||
CurrentDocumentSchema = get_document_schema()
|
||||
loaded_document = CurrentDocumentSchema().load(document_data)
|
||||
return [loaded_document]
|
||||
|
||||
def _stuct(self, observation, params, metadata_schema_class =None):
|
||||
"""Utility function to create the document structure with optional custom fields."""
|
||||
|
||||
|
|
@ -267,9 +289,6 @@ class WeaviateVectorDB(VectorDB):
|
|||
data_object={
|
||||
# "text": observation,
|
||||
"user_id": str(self.user_id),
|
||||
"memory_id": str(self.memory_id),
|
||||
"ltm_memory_id": str(self.ltm_memory_id),
|
||||
"st_memory_id": str(self.st_memory_id),
|
||||
"buffer_id": str(self.buffer_id),
|
||||
"version": params.get("version", None) or "",
|
||||
"agreement_id": params.get("agreement_id", None) or "",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue