Improve classifier, add turn output data to json
This commit is contained in:
parent
cccc87b05c
commit
1e40ad37b8
3 changed files with 43 additions and 10 deletions
|
|
@ -9,7 +9,7 @@ from database.database import Base
|
||||||
|
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy import Column, String, DateTime, ForeignKey
|
from sqlalchemy import Column, String, DateTime, ForeignKey, JSON
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -27,7 +27,7 @@ class TestOutput(Base):
|
||||||
user_id = Column(String, ForeignKey('users.id'), index=True) # Added user_id field
|
user_id = Column(String, ForeignKey('users.id'), index=True) # Added user_id field
|
||||||
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
|
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
|
||||||
operation_id = Column(String, ForeignKey('operations.id'), index=True)
|
operation_id = Column(String, ForeignKey('operations.id'), index=True)
|
||||||
test_results = Column(String, nullable=True)
|
test_results = Column(JSON, nullable=True)
|
||||||
created_at = Column(DateTime, default=datetime.utcnow)
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,7 @@ def data_format_route( data_string: str):
|
||||||
def data_location_route(data_string: str):
|
def data_location_route(data_string: str):
|
||||||
@ai_classifier
|
@ai_classifier
|
||||||
class LocationRoute(Enum):
|
class LocationRoute(Enum):
|
||||||
"""Represents classifier for the data location"""
|
"""Represents classifier for the data location, if it is device, or database connections string or URL """
|
||||||
|
|
||||||
DEVICE = "file_path_starting_with_.data_or_containing_it"
|
DEVICE = "file_path_starting_with_.data_or_containing_it"
|
||||||
# URL = "url starting with http or https"
|
# URL = "url starting with http or https"
|
||||||
|
|
@ -273,6 +273,9 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
||||||
|
|
||||||
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
|
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
|
||||||
test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id)
|
test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id)
|
||||||
|
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
|
||||||
|
await memory.add_memory_instance("ExampleMemory")
|
||||||
|
existing_user = await Memory.check_existing_user(user_id, session)
|
||||||
|
|
||||||
if job_id is None:
|
if job_id is None:
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
|
|
@ -303,9 +306,6 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
||||||
|
|
||||||
if test_id is None:
|
if test_id is None:
|
||||||
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
|
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
|
||||||
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
|
|
||||||
await memory.add_memory_instance("ExampleMemory")
|
|
||||||
existing_user = await Memory.check_existing_user(user_id, session)
|
|
||||||
await memory.manage_memory_attributes(existing_user)
|
await memory.manage_memory_attributes(existing_user)
|
||||||
test_class = test_id + "_class"
|
test_class = test_id + "_class"
|
||||||
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
||||||
|
|
@ -378,16 +378,21 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
||||||
await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
|
await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
|
||||||
namespace=test_id)
|
namespace=test_id)
|
||||||
|
|
||||||
return test_eval_pipeline
|
return test_id, test_eval_pipeline
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
if only_llm_context:
|
if only_llm_context:
|
||||||
result = await run_test(test= None, loader_settings=loader_settings, metadata=metadata, only_llm_context=only_llm_context)
|
test_id, result = await run_test(test=None, loader_settings=loader_settings, metadata=metadata,
|
||||||
|
only_llm_context=only_llm_context)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
for param in test_params:
|
for param in test_params:
|
||||||
result = await run_test(param, loader_settings, metadata,only_llm_context=only_llm_context)
|
test_id, result = await run_test(param, loader_settings, metadata, only_llm_context=only_llm_context)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
|
await add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(json.dumps(results))))
|
||||||
|
|
||||||
print(results)
|
print(results)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
@ -432,7 +437,7 @@ async def main():
|
||||||
]
|
]
|
||||||
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||||
#http://public-library.uk/ebooks/59/83.pdf
|
#http://public-library.uk/ebooks/59/83.pdf
|
||||||
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="676", params=None, metadata=metadata)
|
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata)
|
||||||
#
|
#
|
||||||
# parser = argparse.ArgumentParser(description="Run tests against a document.")
|
# parser = argparse.ArgumentParser(description="Run tests against a document.")
|
||||||
# parser.add_argument("--url", required=True, help="URL of the document to test.")
|
# parser.add_argument("--url", required=True, help="URL of the document to test.")
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,34 @@ class WeaviateVectorDB(VectorDB):
|
||||||
)
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
from marshmallow import Schema, fields
|
||||||
|
|
||||||
|
def create_document_structure(observation, params, metadata_schema_class=None):
|
||||||
|
"""
|
||||||
|
Create and validate a document structure with optional custom fields.
|
||||||
|
|
||||||
|
:param observation: Content of the document.
|
||||||
|
:param params: Metadata information.
|
||||||
|
:param metadata_schema_class: Custom metadata schema class (optional).
|
||||||
|
:return: A list containing the validated document data.
|
||||||
|
"""
|
||||||
|
document_data = {
|
||||||
|
"metadata": params,
|
||||||
|
"page_content": observation
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_document_schema():
|
||||||
|
class DynamicDocumentSchema(Schema):
|
||||||
|
metadata = fields.Nested(metadata_schema_class, required=True)
|
||||||
|
page_content = fields.Str(required=True)
|
||||||
|
|
||||||
|
return DynamicDocumentSchema
|
||||||
|
|
||||||
|
# Validate and deserialize, defaulting to "1.0" if not provided
|
||||||
|
CurrentDocumentSchema = get_document_schema()
|
||||||
|
loaded_document = CurrentDocumentSchema().load(document_data)
|
||||||
|
return [loaded_document]
|
||||||
|
|
||||||
def _stuct(self, observation, params, metadata_schema_class =None):
|
def _stuct(self, observation, params, metadata_schema_class =None):
|
||||||
"""Utility function to create the document structure with optional custom fields."""
|
"""Utility function to create the document structure with optional custom fields."""
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue