Merge pull request #28 from topoteretes/add_async_elements
Improve classifier, add turn output data to json
This commit is contained in:
commit
4e5d7ae8e4
4 changed files with 127 additions and 39 deletions
|
|
@ -9,7 +9,7 @@ from database.database import Base
|
||||||
|
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy import Column, String, DateTime, ForeignKey
|
from sqlalchemy import Column, String, DateTime, ForeignKey, JSON
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -27,7 +27,7 @@ class TestOutput(Base):
|
||||||
user_id = Column(String, ForeignKey('users.id'), index=True) # Added user_id field
|
user_id = Column(String, ForeignKey('users.id'), index=True) # Added user_id field
|
||||||
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
|
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
|
||||||
operation_id = Column(String, ForeignKey('operations.id'), index=True)
|
operation_id = Column(String, ForeignKey('operations.id'), index=True)
|
||||||
test_results = Column(String, nullable=True)
|
test_results = Column(JSON, nullable=True)
|
||||||
created_at = Column(DateTime, default=datetime.utcnow)
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -216,7 +216,24 @@ async def eval_test(query=None, output=None, expected_output=None, context=None,
|
||||||
|
|
||||||
# If you want to run the test
|
# If you want to run the test
|
||||||
test_result = run_test(test_case, metrics=[metric], raise_error=False)
|
test_result = run_test(test_case, metrics=[metric], raise_error=False)
|
||||||
return test_result
|
|
||||||
|
def test_result_to_dict(test_result):
|
||||||
|
return {
|
||||||
|
"success": test_result.success,
|
||||||
|
"score": test_result.score,
|
||||||
|
"metric_name": test_result.metric_name,
|
||||||
|
"query": test_result.query,
|
||||||
|
"output": test_result.output,
|
||||||
|
"expected_output": test_result.expected_output,
|
||||||
|
"metadata": test_result.metadata,
|
||||||
|
"context": test_result.context
|
||||||
|
}
|
||||||
|
|
||||||
|
test_result_dict =[]
|
||||||
|
for test in test_result:
|
||||||
|
test_result_it = test_result_to_dict(test)
|
||||||
|
test_result_dict.append(test_result_it)
|
||||||
|
return test_result_dict
|
||||||
# You can also inspect the test result class
|
# You can also inspect the test result class
|
||||||
# print(test_result)
|
# print(test_result)
|
||||||
|
|
||||||
|
|
@ -241,7 +258,7 @@ def data_format_route( data_string: str):
|
||||||
def data_location_route(data_string: str):
|
def data_location_route(data_string: str):
|
||||||
@ai_classifier
|
@ai_classifier
|
||||||
class LocationRoute(Enum):
|
class LocationRoute(Enum):
|
||||||
"""Represents classifier for the data location"""
|
"""Represents classifier for the data location, if it is device, or database connections string or URL """
|
||||||
|
|
||||||
DEVICE = "file_path_starting_with_.data_or_containing_it"
|
DEVICE = "file_path_starting_with_.data_or_containing_it"
|
||||||
# URL = "url starting with http or https"
|
# URL = "url starting with http or https"
|
||||||
|
|
@ -273,6 +290,9 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
||||||
|
|
||||||
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
|
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
|
||||||
test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id)
|
test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id)
|
||||||
|
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
|
||||||
|
await memory.add_memory_instance("ExampleMemory")
|
||||||
|
existing_user = await Memory.check_existing_user(user_id, session)
|
||||||
|
|
||||||
if job_id is None:
|
if job_id is None:
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
|
|
@ -303,9 +323,6 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
||||||
|
|
||||||
if test_id is None:
|
if test_id is None:
|
||||||
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
|
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
|
||||||
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
|
|
||||||
await memory.add_memory_instance("ExampleMemory")
|
|
||||||
existing_user = await Memory.check_existing_user(user_id, session)
|
|
||||||
await memory.manage_memory_attributes(existing_user)
|
await memory.manage_memory_attributes(existing_user)
|
||||||
test_class = test_id + "_class"
|
test_class = test_id + "_class"
|
||||||
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
||||||
|
|
@ -378,16 +395,21 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
||||||
await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
|
await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
|
||||||
namespace=test_id)
|
namespace=test_id)
|
||||||
|
|
||||||
return test_eval_pipeline
|
return test_id, test_eval_pipeline
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
if only_llm_context:
|
if only_llm_context:
|
||||||
result = await run_test(test= None, loader_settings=loader_settings, metadata=metadata, only_llm_context=only_llm_context)
|
test_id, result = await run_test(test=None, loader_settings=loader_settings, metadata=metadata,
|
||||||
|
only_llm_context=only_llm_context)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
for param in test_params:
|
for param in test_params:
|
||||||
result = await run_test(param, loader_settings, metadata,only_llm_context=only_llm_context)
|
test_id, result = await run_test(param, loader_settings, metadata, only_llm_context=only_llm_context)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
|
await add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(json.dumps(results))))
|
||||||
|
|
||||||
print(results)
|
print(results)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
@ -432,7 +454,7 @@ async def main():
|
||||||
]
|
]
|
||||||
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||||
#http://public-library.uk/ebooks/59/83.pdf
|
#http://public-library.uk/ebooks/59/83.pdf
|
||||||
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="676", params=None, metadata=metadata)
|
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata)
|
||||||
#
|
#
|
||||||
# parser = argparse.ArgumentParser(description="Run tests against a document.")
|
# parser = argparse.ArgumentParser(description="Run tests against a document.")
|
||||||
# parser.add_argument("--url", required=True, help="URL of the document to test.")
|
# parser.add_argument("--url", required=True, help="URL of the document to test.")
|
||||||
|
|
|
||||||
|
|
@ -40,34 +40,28 @@ import json
|
||||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||||
marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
|
marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
LTM_MEMORY_ID_DEFAULT = "00000"
|
|
||||||
ST_MEMORY_ID_DEFAULT = "0000"
|
|
||||||
BUFFER_ID_DEFAULT = "0000"
|
|
||||||
|
|
||||||
|
|
||||||
class VectorDBFactory:
|
class VectorDBFactory:
|
||||||
|
def __init__(self):
|
||||||
|
self.db_map = {
|
||||||
|
"pinecone": PineconeVectorDB,
|
||||||
|
"weaviate": WeaviateVectorDB,
|
||||||
|
# Add more database types and their corresponding classes here
|
||||||
|
}
|
||||||
|
|
||||||
def create_vector_db(
|
def create_vector_db(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
index_name: str,
|
index_name: str,
|
||||||
memory_id: str,
|
memory_id: str,
|
||||||
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
|
db_type: str = "weaviate",
|
||||||
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
|
|
||||||
buffer_id: str = BUFFER_ID_DEFAULT,
|
|
||||||
db_type: str = "pinecone",
|
|
||||||
namespace: str = None,
|
namespace: str = None,
|
||||||
embeddings = None,
|
embeddings=None,
|
||||||
):
|
):
|
||||||
db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
|
if db_type in self.db_map:
|
||||||
|
return self.db_map[db_type](
|
||||||
if db_type in db_map:
|
|
||||||
return db_map[db_type](
|
|
||||||
user_id,
|
user_id,
|
||||||
index_name,
|
index_name,
|
||||||
memory_id,
|
memory_id,
|
||||||
ltm_memory_id,
|
|
||||||
st_memory_id,
|
|
||||||
buffer_id,
|
|
||||||
namespace,
|
namespace,
|
||||||
embeddings
|
embeddings
|
||||||
)
|
)
|
||||||
|
|
@ -101,8 +95,61 @@ class BaseMemory:
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_client(self, embeddings, namespace: str):
|
def init_client(self, embeddings, namespace: str):
|
||||||
|
return self.vector_db.init_client(embeddings, namespace)
|
||||||
|
|
||||||
return self.vector_db.init_weaviate_client(embeddings, namespace)
|
|
||||||
|
# class VectorDBFactory:
|
||||||
|
# def create_vector_db(
|
||||||
|
# self,
|
||||||
|
# user_id: str,
|
||||||
|
# index_name: str,
|
||||||
|
# memory_id: str,
|
||||||
|
# db_type: str = "pinecone",
|
||||||
|
# namespace: str = None,
|
||||||
|
# embeddings = None,
|
||||||
|
# ):
|
||||||
|
# db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
|
||||||
|
#
|
||||||
|
# if db_type in db_map:
|
||||||
|
# return db_map[db_type](
|
||||||
|
# user_id,
|
||||||
|
# index_name,
|
||||||
|
# memory_id,
|
||||||
|
# namespace,
|
||||||
|
# embeddings
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# raise ValueError(f"Unsupported database type: {db_type}")
|
||||||
|
#
|
||||||
|
# class BaseMemory:
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# user_id: str,
|
||||||
|
# memory_id: Optional[str],
|
||||||
|
# index_name: Optional[str],
|
||||||
|
# db_type: str,
|
||||||
|
# namespace: str,
|
||||||
|
# embeddings: Optional[None],
|
||||||
|
# ):
|
||||||
|
# self.user_id = user_id
|
||||||
|
# self.memory_id = memory_id
|
||||||
|
# self.index_name = index_name
|
||||||
|
# self.namespace = namespace
|
||||||
|
# self.embeddings = embeddings
|
||||||
|
# self.db_type = db_type
|
||||||
|
# factory = VectorDBFactory()
|
||||||
|
# self.vector_db = factory.create_vector_db(
|
||||||
|
# self.user_id,
|
||||||
|
# self.index_name,
|
||||||
|
# self.memory_id,
|
||||||
|
# db_type=self.db_type,
|
||||||
|
# namespace=self.namespace,
|
||||||
|
# embeddings=self.embeddings
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# def init_client(self, embeddings, namespace: str):
|
||||||
|
#
|
||||||
|
# return self.vector_db.init_weaviate_client(embeddings, namespace)
|
||||||
|
|
||||||
def create_field(self, field_type, **kwargs):
|
def create_field(self, field_type, **kwargs):
|
||||||
field_mapping = {
|
field_mapping = {
|
||||||
|
|
|
||||||
|
|
@ -32,9 +32,6 @@ class VectorDB:
|
||||||
user_id: str,
|
user_id: str,
|
||||||
index_name: str,
|
index_name: str,
|
||||||
memory_id: str,
|
memory_id: str,
|
||||||
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
|
|
||||||
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
|
|
||||||
buffer_id: str = BUFFER_ID_DEFAULT,
|
|
||||||
namespace: str = None,
|
namespace: str = None,
|
||||||
embeddings = None,
|
embeddings = None,
|
||||||
):
|
):
|
||||||
|
|
@ -42,9 +39,6 @@ class VectorDB:
|
||||||
self.index_name = index_name
|
self.index_name = index_name
|
||||||
self.namespace = namespace
|
self.namespace = namespace
|
||||||
self.memory_id = memory_id
|
self.memory_id = memory_id
|
||||||
self.ltm_memory_id = ltm_memory_id
|
|
||||||
self.st_memory_id = st_memory_id
|
|
||||||
self.buffer_id = buffer_id
|
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
|
|
||||||
class PineconeVectorDB(VectorDB):
|
class PineconeVectorDB(VectorDB):
|
||||||
|
|
@ -81,7 +75,7 @@ class WeaviateVectorDB(VectorDB):
|
||||||
embedding=embeddings,
|
embedding=embeddings,
|
||||||
create_schema_if_missing=True,
|
create_schema_if_missing=True,
|
||||||
)
|
)
|
||||||
return retriever # If this is part of the initialization, call it here.
|
return retriever
|
||||||
|
|
||||||
def init_weaviate_client(self, namespace: str):
|
def init_weaviate_client(self, namespace: str):
|
||||||
# Weaviate client initialization logic
|
# Weaviate client initialization logic
|
||||||
|
|
@ -95,6 +89,34 @@ class WeaviateVectorDB(VectorDB):
|
||||||
)
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
from marshmallow import Schema, fields
|
||||||
|
|
||||||
|
def create_document_structure(observation, params, metadata_schema_class=None):
|
||||||
|
"""
|
||||||
|
Create and validate a document structure with optional custom fields.
|
||||||
|
|
||||||
|
:param observation: Content of the document.
|
||||||
|
:param params: Metadata information.
|
||||||
|
:param metadata_schema_class: Custom metadata schema class (optional).
|
||||||
|
:return: A list containing the validated document data.
|
||||||
|
"""
|
||||||
|
document_data = {
|
||||||
|
"metadata": params,
|
||||||
|
"page_content": observation
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_document_schema():
|
||||||
|
class DynamicDocumentSchema(Schema):
|
||||||
|
metadata = fields.Nested(metadata_schema_class, required=True)
|
||||||
|
page_content = fields.Str(required=True)
|
||||||
|
|
||||||
|
return DynamicDocumentSchema
|
||||||
|
|
||||||
|
# Validate and deserialize, defaulting to "1.0" if not provided
|
||||||
|
CurrentDocumentSchema = get_document_schema()
|
||||||
|
loaded_document = CurrentDocumentSchema().load(document_data)
|
||||||
|
return [loaded_document]
|
||||||
|
|
||||||
def _stuct(self, observation, params, metadata_schema_class =None):
|
def _stuct(self, observation, params, metadata_schema_class =None):
|
||||||
"""Utility function to create the document structure with optional custom fields."""
|
"""Utility function to create the document structure with optional custom fields."""
|
||||||
|
|
||||||
|
|
@ -267,9 +289,6 @@ class WeaviateVectorDB(VectorDB):
|
||||||
data_object={
|
data_object={
|
||||||
# "text": observation,
|
# "text": observation,
|
||||||
"user_id": str(self.user_id),
|
"user_id": str(self.user_id),
|
||||||
"memory_id": str(self.memory_id),
|
|
||||||
"ltm_memory_id": str(self.ltm_memory_id),
|
|
||||||
"st_memory_id": str(self.st_memory_id),
|
|
||||||
"buffer_id": str(self.buffer_id),
|
"buffer_id": str(self.buffer_id),
|
||||||
"version": params.get("version", None) or "",
|
"version": params.get("version", None) or "",
|
||||||
"agreement_id": params.get("agreement_id", None) or "",
|
"agreement_id": params.get("agreement_id", None) or "",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue