829 lines
30 KiB
Python
829 lines
30 KiB
Python
import argparse
|
|
import json
|
|
import random
|
|
import itertools
|
|
|
|
|
|
import logging
|
|
import string
|
|
from enum import Enum
|
|
|
|
|
|
import openai
|
|
from deepeval.metrics.overall_score import OverallScoreMetric
|
|
from deepeval.run_test import run_test
|
|
from deepeval.test_case import LLMTestCase
|
|
from marvin import ai_classifier
|
|
from sqlalchemy.future import select
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
import marvin
|
|
from dotenv import load_dotenv
|
|
from sqlalchemy.orm import sessionmaker
|
|
from database.database import (
|
|
engine,
|
|
) # Ensure you have database engine defined somewhere
|
|
from models.user import User
|
|
from models.memory import MemoryModel
|
|
from models.sessions import Session
|
|
from models.testset import TestSet
|
|
from models.testoutput import TestOutput
|
|
from models.metadatas import MetaDatas
|
|
from models.operation import Operation
|
|
from models.docs import DocsModel
|
|
|
|
import segment.analytics as analytics
|
|
|
|
load_dotenv()
|
|
import ast
|
|
import tracemalloc
|
|
from database.database_crud import session_scope, add_entity
|
|
|
|
tracemalloc.start()
|
|
|
|
import os
|
|
from dotenv import load_dotenv
|
|
import uuid
|
|
|
|
load_dotenv()
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
|
marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
|
|
from vectordb.basevectordb import BaseMemory
|
|
from vectorstore_manager import Memory
|
|
import asyncio
|
|
|
|
from database.database_crud import session_scope
|
|
from database.database import AsyncSessionLocal
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY", "")
|
|
analytics.write_key = os.getenv("SEGMENT_KEY", "")
|
|
|
|
|
|
def on_error(error, items):
|
|
print("An error occurred:", error)
|
|
|
|
|
|
analytics.debug = True
|
|
analytics.on_error = on_error
|
|
|
|
async def retrieve_latest_test_case(session, user_id, memory_id):
|
|
try:
|
|
# Use await with session.execute() and row.fetchone() or row.all() for async query execution
|
|
result = await session.execute(
|
|
session.query(TestSet.attributes_list)
|
|
.filter_by(user_id=user_id, memory_id=memory_id)
|
|
.order_by(TestSet.created_at)
|
|
.first()
|
|
)
|
|
return (
|
|
result.scalar_one_or_none()
|
|
) # scalar_one_or_none() is a non-blocking call
|
|
except Exception as e:
|
|
logging.error(
|
|
f"An error occurred while retrieving the latest test case: {str(e)}"
|
|
)
|
|
return None
|
|
def get_document_names(doc_input):
|
|
"""
|
|
Get a list of document names.
|
|
|
|
This function takes doc_input, which can be a folder path, a single document file path, or a document name as a string.
|
|
It returns a list of document names based on the doc_input.
|
|
|
|
Args:
|
|
doc_input (str): The doc_input can be a folder path, a single document file path, or a document name as a string.
|
|
|
|
Returns:
|
|
list: A list of document names.
|
|
|
|
Example usage:
|
|
- Folder path: get_document_names(".data")
|
|
- Single document file path: get_document_names(".data/example.pdf")
|
|
- Document name provided as a string: get_document_names("example.docx")
|
|
|
|
"""
|
|
if isinstance(doc_input, list):
|
|
return doc_input
|
|
if os.path.isdir(doc_input):
|
|
# doc_input is a folder
|
|
folder_path = doc_input
|
|
document_names = []
|
|
for filename in os.listdir(folder_path):
|
|
if os.path.isfile(os.path.join(folder_path, filename)):
|
|
document_names.append(filename)
|
|
return document_names
|
|
elif os.path.isfile(doc_input):
|
|
# doc_input is a single document file
|
|
return [os.path.basename(doc_input)]
|
|
elif isinstance(doc_input, str):
|
|
# doc_input is a document name provided as a string
|
|
return [doc_input]
|
|
else:
|
|
# doc_input is not valid
|
|
return []
|
|
|
|
async def add_entity(session, entity):
|
|
async with session_scope(session) as s: # Use your async session_scope
|
|
s.add(entity) # No need to commit; session_scope takes care of it
|
|
|
|
return "Successfully added entity"
|
|
async def update_entity(session, model, entity_id, new_value):
|
|
async with session_scope(session) as s:
|
|
# Retrieve the entity from the database
|
|
entity = await s.get(model, entity_id)
|
|
|
|
if entity:
|
|
# Update the relevant column and 'updated_at' will be automatically updated
|
|
entity.operation_status = new_value
|
|
return "Successfully updated entity"
|
|
else:
|
|
return "Entity not found"
|
|
|
|
async def retrieve_job_by_id(session, user_id, job_id):
|
|
try:
|
|
result = await session.execute(
|
|
session.query(Session.id)
|
|
.filter_by(user_id=user_id, id=job_id)
|
|
.order_by(Session.created_at)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logging.error(f"An error occurred while retrieving the job: {str(e)}")
|
|
return None
|
|
|
|
|
|
async def fetch_job_id(session, user_id=None, memory_id=None, job_id=None):
|
|
try:
|
|
result = await session.execute(
|
|
session.query(Session.id)
|
|
.filter_by(user_id=user_id, id=job_id)
|
|
.order_by(Session.created_at)
|
|
.first()
|
|
)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logging.error(f"An error occurred: {str(e)}")
|
|
return None
|
|
|
|
|
|
async def fetch_test_set_id(session, user_id, content):
|
|
try:
|
|
# Await the execution of the query and fetching of the result
|
|
result = await session.execute(select(TestSet.id)
|
|
.filter_by(user_id=user_id, content=content)
|
|
.order_by(TestSet.created_at)
|
|
|
|
)
|
|
return (
|
|
result.scalar_one_or_none()
|
|
) # scalar_one_or_none() is a non-blocking call
|
|
except Exception as e:
|
|
logging.error(f"An error occurred while retrieving the test set: {str(e)}")
|
|
return None
|
|
|
|
|
|
# Adding "embeddings" to the parameter variants function
|
|
|
|
|
|
def generate_param_variants(
|
|
base_params=None, increments=None, ranges=None, included_params=None
|
|
):
|
|
"""Generate parameter variants for testing.
|
|
|
|
Args:
|
|
base_params (dict): Base parameter values.
|
|
increments (dict): Increment values for each parameter variant.
|
|
ranges (dict): Range (number of variants) to generate for each parameter.
|
|
included_params (list, optional): Parameters to include in the combinations.
|
|
If None, all parameters are included.
|
|
|
|
Returns:
|
|
list: A list of dictionaries containing parameter variants.
|
|
"""
|
|
|
|
# Default values
|
|
defaults = {
|
|
"chunk_size": 750,
|
|
"chunk_overlap": 20,
|
|
"similarity_score": 0.5,
|
|
"metadata_variation": 0,
|
|
"search_type": "hybrid",
|
|
"embeddings": "openai", # Default value added for 'embeddings'
|
|
}
|
|
|
|
# Update defaults with provided base parameters
|
|
params = {**defaults, **(base_params or {})}
|
|
|
|
default_increments = {
|
|
"chunk_size": 250,
|
|
"chunk_overlap": 10,
|
|
"similarity_score": 0.1,
|
|
"metadata_variation": 1,
|
|
}
|
|
|
|
# Update default increments with provided increments
|
|
increments = {**default_increments, **(increments or {})}
|
|
|
|
# Default ranges
|
|
default_ranges = {
|
|
"chunk_size": 2,
|
|
"chunk_overlap": 2,
|
|
"similarity_score": 2,
|
|
"metadata_variation": 2,
|
|
}
|
|
|
|
# Update default ranges with provided ranges
|
|
ranges = {**default_ranges, **(ranges or {})}
|
|
|
|
# Generate parameter variant ranges
|
|
param_ranges = {
|
|
key: [
|
|
params[key] + i * increments.get(key, 1) for i in range(ranges.get(key, 1))
|
|
]
|
|
for key in [
|
|
"chunk_size",
|
|
"chunk_overlap",
|
|
"similarity_score",
|
|
"metadata_variation",
|
|
]
|
|
}
|
|
|
|
# Add search_type and embeddings with possible values
|
|
param_ranges["search_type"] = [
|
|
"text",
|
|
"hybrid",
|
|
"bm25",
|
|
# "generate",
|
|
# "generate_grouped",
|
|
]
|
|
param_ranges["embeddings"] = [
|
|
"openai",
|
|
"cohere",
|
|
"huggingface",
|
|
] # Added 'embeddings' values
|
|
|
|
# Filter param_ranges based on included_params
|
|
if included_params is not None:
|
|
param_ranges = {
|
|
key: val for key, val in param_ranges.items() if key in included_params
|
|
}
|
|
|
|
# Generate all combinations of parameter variants
|
|
keys = param_ranges.keys()
|
|
values = param_ranges.values()
|
|
param_variants = [
|
|
dict(zip(keys, combination)) for combination in itertools.product(*values)
|
|
]
|
|
|
|
logging.info("Param combinations for testing", str(param_variants))
|
|
|
|
return param_variants
|
|
|
|
|
|
|
|
async def generate_chatgpt_output(query: str, context: str = None, api_key=None, model_name="gpt-3.5-turbo"):
|
|
"""
|
|
Generate a response from the OpenAI ChatGPT model.
|
|
|
|
Args:
|
|
query (str): The user's query or message.
|
|
context (str, optional): Additional context for the conversation. Defaults to an empty string.
|
|
api_key (str, optional): Your OpenAI API key. If not provided, the globally configured API key will be used.
|
|
model_name (str, optional): The name of the ChatGPT model to use. Defaults to "gpt-3.5-turbo".
|
|
|
|
Returns:
|
|
str: The response generated by the ChatGPT model.
|
|
|
|
Raises:
|
|
Exception: If an error occurs during the API call, an error message is returned for the caller to handle.
|
|
"""
|
|
if not context:
|
|
context = ""
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "assistant", "content": context},
|
|
{"role": "user", "content": query},
|
|
]
|
|
|
|
try:
|
|
openai.api_key = api_key if api_key else openai.api_key # Use the provided API key or the one set globally
|
|
|
|
response = openai.ChatCompletion.create(
|
|
model=model_name,
|
|
messages=messages,
|
|
)
|
|
|
|
llm_output = response.choices[0].message.content
|
|
return llm_output
|
|
except Exception as e:
|
|
return f"An error occurred: {e}" # Return the error message for the caller to handle
|
|
|
|
|
|
|
|
async def eval_test(
|
|
query=None,
|
|
output=None,
|
|
expected_output=None,
|
|
context=None,
|
|
synthetic_test_set=False,
|
|
):
|
|
logging.info("Generating chatgpt output")
|
|
result_output = await generate_chatgpt_output(query, str(context))
|
|
|
|
logging.info("Moving on")
|
|
|
|
if synthetic_test_set:
|
|
test_case = synthetic_test_set
|
|
else:
|
|
test_case = LLMTestCase(
|
|
input=str(query),
|
|
actual_output=str(result_output),
|
|
expected_output=str(expected_output),
|
|
context=[str(context)],
|
|
)
|
|
metric = OverallScoreMetric()
|
|
|
|
# If you want to run the test
|
|
test_result = run_test(test_case, metrics=[metric], raise_error=False)
|
|
|
|
def test_result_to_dict(test_result):
|
|
return {
|
|
"success": test_result.success,
|
|
"score": test_result.score,
|
|
"metric_name": test_result.metric_name,
|
|
"query": test_result.query,
|
|
"output": test_result.output,
|
|
"expected_output": test_result.expected_output,
|
|
"metadata": test_result.metadata,
|
|
"context": test_result.context,
|
|
}
|
|
|
|
test_result_dict = []
|
|
for test in test_result:
|
|
test_result_it = test_result_to_dict(test)
|
|
test_result_dict.append(test_result_it)
|
|
return test_result_dict
|
|
# You can also inspect the test result class
|
|
# print(test_result)
|
|
|
|
|
|
def count_files_in_data_folder(data_folder_path=".data"):
|
|
try:
|
|
# Get the list of files in the specified folder
|
|
files = os.listdir(data_folder_path)
|
|
|
|
# Count the number of files
|
|
file_count = len(files)
|
|
|
|
return file_count
|
|
except FileNotFoundError:
|
|
return 0 # Return 0 if the folder does not exist
|
|
except Exception as e:
|
|
print(f"An error occurred: {str(e)}")
|
|
return -1 # Return -1 to indicate an error
|
|
|
|
|
|
|
|
def data_format_route(data_string: str):
|
|
class FormatRoute(Enum):
|
|
"""Represents classifier for the data format"""
|
|
|
|
PDF = "PDF"
|
|
UNSTRUCTURED_WEB = "UNSTRUCTURED_WEB"
|
|
GITHUB = "GITHUB"
|
|
TEXT = "TEXT"
|
|
CSV = "CSV"
|
|
WIKIPEDIA = "WIKIPEDIA"
|
|
|
|
# Convert the input string to lowercase for case-insensitive matching
|
|
data_string = data_string.lower()
|
|
|
|
# Mapping of keywords to categories
|
|
keyword_mapping = {
|
|
"pdf": FormatRoute.PDF,
|
|
"web": FormatRoute.UNSTRUCTURED_WEB,
|
|
"github": FormatRoute.GITHUB,
|
|
"text": FormatRoute.TEXT,
|
|
"csv": FormatRoute.CSV,
|
|
"wikipedia": FormatRoute.WIKIPEDIA
|
|
}
|
|
|
|
# Try to match keywords in the data string
|
|
for keyword, category in keyword_mapping.items():
|
|
if keyword in data_string:
|
|
return category.name
|
|
|
|
# Return a default category if no match is found
|
|
return FormatRoute.PDF.name
|
|
|
|
def data_location_route(data_string: str):
|
|
class LocationRoute(Enum):
|
|
"""Represents classifier for the data location, if it is device, or database connection string or URL"""
|
|
|
|
DEVICE = "DEVICE"
|
|
URL = "URL"
|
|
DATABASE = "DATABASE"
|
|
|
|
# Convert the input string to lowercase for case-insensitive matching
|
|
data_string = data_string.lower()
|
|
|
|
# Check for specific patterns in the data string
|
|
if data_string.startswith(".data") or "data" in data_string:
|
|
return LocationRoute.DEVICE.name
|
|
elif data_string.startswith("http://") or data_string.startswith("https://"):
|
|
return LocationRoute.URL.name
|
|
elif "postgres" in data_string or "mysql" in data_string:
|
|
return LocationRoute.DATABASE.name
|
|
|
|
# Return a default category if no match is found
|
|
return "Unknown"
|
|
|
|
def dynamic_test_manager(context=None):
|
|
from deepeval.dataset import create_evaluation_query_answer_pairs
|
|
|
|
# fetch random chunks from the document
|
|
# feed them to the evaluation pipeline
|
|
dataset = create_evaluation_query_answer_pairs(
|
|
openai_api_key=os.environ.get("OPENAI_API_KEY"), context=context, n=10
|
|
)
|
|
|
|
return dataset
|
|
|
|
|
|
def generate_letter_uuid(length=8):
|
|
"""Generate a random string of uppercase letters with the specified length."""
|
|
letters = string.ascii_uppercase # A-Z
|
|
return "".join(random.choice(letters) for _ in range(length))
|
|
|
|
|
|
async def start_test(
|
|
data,
|
|
test_set=None,
|
|
user_id=None,
|
|
params=None,
|
|
param_ranges=None,
|
|
param_increments=None,
|
|
metadata=None,
|
|
generate_test_set=False,
|
|
retriever_type: str = None,
|
|
):
|
|
"""retriever_type = "llm_context, single_document_context, multi_document_context, "cognitive_architecture""" ""
|
|
|
|
async with session_scope(session=AsyncSessionLocal()) as session:
|
|
job_id = ""
|
|
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
|
|
test_set_id = await fetch_test_set_id(session, user_id=user_id, content=str(test_set))
|
|
memory = await Memory.create_memory(
|
|
user_id, session, namespace="SEMANTICMEMORY"
|
|
)
|
|
await memory.add_memory_instance("ExampleMemory")
|
|
existing_user = await Memory.check_existing_user(user_id, session)
|
|
|
|
if test_set_id is None:
|
|
test_set_id = str(uuid.uuid4())
|
|
await add_entity(
|
|
session, TestSet(id=test_set_id, user_id=user_id, content=str(test_set))
|
|
)
|
|
analytics.track(user_id, 'TestSet', {
|
|
'id': test_set_id,
|
|
'content': str(test_set)
|
|
})
|
|
|
|
if params is None:
|
|
data_format = data_format_route(
|
|
data[0]
|
|
) # Assume data_format_route is predefined
|
|
logging.info("Data format is %s", data_format)
|
|
data_location = data_location_route(data[0])
|
|
logging.info(
|
|
"Data location is %s", data_location
|
|
) # Assume data_location_route is predefined
|
|
test_params = generate_param_variants(included_params=["chunk_size"])
|
|
if params:
|
|
data_format = data_format_route(
|
|
data[0]
|
|
) # Assume data_format_route is predefined
|
|
logging.info("Data format is %s", data_format)
|
|
data_location = data_location_route(data[0])
|
|
logging.info(
|
|
"Data location is %s", data_location
|
|
)
|
|
logging.info("Provided params are %s", str(params))
|
|
test_params = generate_param_variants(included_params=params, increments=param_increments, ranges=param_ranges)
|
|
|
|
logging.info("Here are the test params %s", str(test_params))
|
|
|
|
loader_settings = {
|
|
"format": f"{data_format}",
|
|
"source": f"{data_location}",
|
|
"path": data,
|
|
}
|
|
if job_id is None:
|
|
job_id = str(uuid.uuid4())
|
|
|
|
await add_entity(
|
|
session,
|
|
Operation(
|
|
id=job_id,
|
|
user_id=user_id,
|
|
operation_params=str(test_params),
|
|
number_of_files=count_files_in_data_folder(),
|
|
operation_status = "RUNNING",
|
|
operation_type=retriever_type,
|
|
test_set_id=test_set_id,
|
|
),
|
|
)
|
|
analytics.track(user_id, 'Operation', {
|
|
'id': job_id,
|
|
'operation_params': str(test_params),
|
|
'number_of_files': count_files_in_data_folder(),
|
|
'operation_status': "RUNNING",
|
|
'operation_type': retriever_type,
|
|
'test_set_id': test_set_id,
|
|
})
|
|
|
|
doc_names = get_document_names(data)
|
|
for doc in doc_names:
|
|
|
|
await add_entity(
|
|
session,
|
|
DocsModel(
|
|
id=str(uuid.uuid4()),
|
|
operation_id=job_id,
|
|
doc_name = doc
|
|
)
|
|
)
|
|
|
|
async def run_test(
|
|
test, loader_settings, metadata, test_id=None, retriever_type=False
|
|
):
|
|
if test_id is None:
|
|
test_id = str(generate_letter_uuid()) + "_" + "SEMANTICMEMORY"
|
|
await memory.manage_memory_attributes(existing_user)
|
|
test_class = test_id + "_class"
|
|
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
|
dynamic_memory_class = getattr(memory, test_class.lower(), None)
|
|
methods_to_add = ["add_memories", "fetch_memories", "delete_memories"]
|
|
|
|
if dynamic_memory_class is not None:
|
|
for method_name in methods_to_add:
|
|
await memory.add_method_to_class(dynamic_memory_class, method_name)
|
|
print(f"Memory method {method_name} has been added")
|
|
else:
|
|
print(f"No attribute named {test_class.lower()} in memory.")
|
|
|
|
print(f"Trying to access: {test_class.lower()}")
|
|
print("Available memory classes:", await memory.list_memory_classes())
|
|
if test:
|
|
loader_settings.update(test)
|
|
# Check if the search_type is 'none'
|
|
if loader_settings.get('search_type') == 'none':
|
|
# Change it to 'hybrid'
|
|
loader_settings['search_type'] = 'hybrid'
|
|
|
|
test_class = test_id + "_class"
|
|
dynamic_memory_class = getattr(memory, test_class.lower(), None)
|
|
|
|
async def run_load_test_element(
|
|
loader_settings=loader_settings,
|
|
metadata=metadata,
|
|
test_id=test_id,
|
|
test_set=test_set,
|
|
):
|
|
print(f"Trying to access: {test_class.lower()}")
|
|
await memory.dynamic_method_call(
|
|
dynamic_memory_class,
|
|
"add_memories",
|
|
observation="Observation loaded",
|
|
params=metadata,
|
|
loader_settings=loader_settings,
|
|
)
|
|
return "Loaded test element"
|
|
|
|
async def run_search_element(test_item, test_id, search_type="text"):
|
|
retrieve_action = await memory.dynamic_method_call(
|
|
dynamic_memory_class,
|
|
"fetch_memories",
|
|
observation=str(test_item["question"]), search_type=loader_settings.get('search_type'),
|
|
)
|
|
print(
|
|
"Here is the test result",
|
|
str(retrieve_action),
|
|
)
|
|
if loader_settings.get('search_type') == 'bm25':
|
|
return retrieve_action["data"]["Get"][test_id]
|
|
else:
|
|
try:
|
|
return retrieve_action["data"]["Get"][test_id][0]["text"]
|
|
except:
|
|
return retrieve_action["data"]["Get"][test_id]
|
|
|
|
async def run_eval(test_item, search_result):
|
|
logging.info("Initiated test set evaluation")
|
|
test_eval = await eval_test(
|
|
query=str(test_item["question"]),
|
|
expected_output=str(test_item["answer"]),
|
|
context=str(search_result),
|
|
)
|
|
logging.info("Successfully evaluated test set")
|
|
return test_eval
|
|
|
|
async def run_generate_test_set(test_id):
|
|
test_class = test_id + "_class"
|
|
# await memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
|
dynamic_memory_class = getattr(memory, test_class.lower(), None)
|
|
print(dynamic_memory_class)
|
|
retrieve_action = await memory.dynamic_method_call(
|
|
dynamic_memory_class,
|
|
"fetch_memories",
|
|
observation="Generate a short summary of this document",
|
|
search_type="generative",
|
|
)
|
|
return dynamic_test_manager(retrieve_action)
|
|
|
|
test_eval_pipeline = []
|
|
if retriever_type == "llm_context":
|
|
for test_qa in test_set:
|
|
context = ""
|
|
logging.info("Loading and evaluating test set for LLM context")
|
|
test_result = await run_eval(test_qa, context)
|
|
test_eval_pipeline.append(test_result)
|
|
elif retriever_type == "single_document_context":
|
|
if test_set:
|
|
logging.info(
|
|
"Loading and evaluating test set for a single document context"
|
|
)
|
|
await run_load_test_element(
|
|
loader_settings, metadata, test_id, test_set
|
|
)
|
|
for test_qa in test_set:
|
|
result = await run_search_element(test_qa, test_id)
|
|
test_result = await run_eval(test_qa, result)
|
|
test_result.append(test)
|
|
test_eval_pipeline.append(test_result)
|
|
await memory.dynamic_method_call(
|
|
dynamic_memory_class, "delete_memories", namespace=test_id
|
|
)
|
|
else:
|
|
pass
|
|
if generate_test_set is True:
|
|
synthetic_test_set = run_generate_test_set(test_id)
|
|
else:
|
|
pass
|
|
|
|
return test_id, test_eval_pipeline
|
|
|
|
results = []
|
|
|
|
logging.info("Validating the retriever type")
|
|
|
|
logging.info("Retriever type: %s", retriever_type)
|
|
|
|
if retriever_type == "llm_context":
|
|
logging.info("Retriever type: llm_context")
|
|
test_id, result = await run_test(
|
|
test=None,
|
|
loader_settings=loader_settings,
|
|
metadata=metadata,
|
|
retriever_type=retriever_type,
|
|
) # No params for this case
|
|
results.append(result)
|
|
|
|
elif retriever_type == "single_document_context":
|
|
logging.info("Retriever type: single document context")
|
|
for param in test_params:
|
|
logging.info("Running for chunk size %s", param["chunk_size"])
|
|
test_id, result = await run_test(
|
|
param, loader_settings, metadata, retriever_type=retriever_type
|
|
) # Add the params to the result
|
|
# result.append(param)
|
|
results.append(result)
|
|
for b in results:
|
|
logging.info("Loading %s", str(b))
|
|
if retriever_type == "single_document_context":
|
|
for result, chunk in b:
|
|
logging.info("Loading %s", str(result))
|
|
await add_entity(
|
|
session,
|
|
TestOutput(
|
|
id=test_id,
|
|
test_set_id=test_set_id,
|
|
operation_id=job_id,
|
|
set_id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
test_results=result["success"],
|
|
test_score=str(result["score"]),
|
|
test_metric_name=result["metric_name"],
|
|
test_query=result["query"],
|
|
test_output=result["output"],
|
|
test_expected_output=str(["expected_output"]),
|
|
test_context=result["context"][0],
|
|
test_params=str(chunk), # Add params to the database table
|
|
),
|
|
)
|
|
analytics.track(user_id, 'TestOutput', {
|
|
'test_set_id': test_set_id,
|
|
'operation_id': job_id,
|
|
'set_id' : str(uuid.uuid4()),
|
|
'test_results' : result["success"],
|
|
'test_score' : str(result["score"]),
|
|
'test_metric_name' : result["metric_name"],
|
|
'test_query' : result["query"],
|
|
'test_output' : result["output"],
|
|
'test_expected_output' : str(["expected_output"]),
|
|
'test_context' : result["context"][0],
|
|
'test_params' : str(chunk),
|
|
})
|
|
analytics.flush()
|
|
else:
|
|
chunk="None"
|
|
for result in b:
|
|
logging.info("Loading %s", str(result))
|
|
await add_entity(
|
|
session,
|
|
TestOutput(
|
|
id=test_id,
|
|
test_set_id=test_set_id,
|
|
operation_id=job_id,
|
|
set_id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
test_results=result[0]["success"],
|
|
test_score=str(result[0]["score"]),
|
|
test_metric_name=result[0]["metric_name"],
|
|
test_query=result[0]["query"],
|
|
test_output=result[0]["output"],
|
|
test_expected_output=str(["expected_output"]),
|
|
test_context=result[0]["context"][0],
|
|
test_params=str(chunk), # Add params to the database table
|
|
),
|
|
)
|
|
analytics.track(user_id, 'TestOutput', {
|
|
'test_set_id': test_set_id,
|
|
'operation_id': job_id,
|
|
'set_id' : str(uuid.uuid4()),
|
|
'test_results' : result[0]["success"],
|
|
'test_score' : str(result[0]["score"]),
|
|
'test_metric_name' : result[0]["metric_name"],
|
|
'test_query' : result[0]["query"],
|
|
'test_output' : result[0]["output"],
|
|
'test_expected_output' : str(["expected_output"]),
|
|
'test_context' : result[0]["context"][0],
|
|
'test_params' : str(chunk),
|
|
})
|
|
analytics.flush()
|
|
|
|
|
|
await update_entity(session, Operation, job_id, "COMPLETED")
|
|
|
|
return results
|
|
|
|
|
|
async def main():
|
|
|
|
parser = argparse.ArgumentParser(description="Run tests against a document.")
|
|
parser.add_argument("--file", nargs="+", required=True, help="List of file paths to test.")
|
|
parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.")
|
|
parser.add_argument("--user_id", required=True, help="User ID.")
|
|
parser.add_argument("--params", nargs="+", help="Additional parameters in JSON format.")
|
|
parser.add_argument("--param_ranges", required=False, help="Param ranges")
|
|
parser.add_argument("--param_increments", required=False, help="Increment values for for example chunks")
|
|
parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.")
|
|
# parser.add_argument("--generate_test_set", required=False, help="Make a test set.")
|
|
parser.add_argument("--retriever_type", required=False, help="Do a test only within the existing LLM context")
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
with open(args.test_set, "r") as file:
|
|
test_set = json.load(file)
|
|
if not isinstance(test_set, list): # Expecting a list
|
|
raise TypeError("Parsed test_set JSON is not a list.")
|
|
except Exception as e:
|
|
print(f"Error loading test_set: {str(e)}")
|
|
return
|
|
|
|
try:
|
|
with open(args.metadata, "r") as file:
|
|
metadata = json.load(file)
|
|
if not isinstance(metadata, dict):
|
|
raise TypeError("Parsed metadata JSON is not a dictionary.")
|
|
except Exception as e:
|
|
print(f"Error loading metadata: {str(e)}")
|
|
return
|
|
|
|
if args.params:
|
|
params = args.params
|
|
if not isinstance(params, list):
|
|
raise TypeError("Parsed params JSON is not a list.")
|
|
|
|
else:
|
|
params = None
|
|
logging.info("Args datatype is", type(args.file))
|
|
#clean up params here
|
|
await start_test(data=args.file, test_set=test_set, user_id= args.user_id, params= args.params, param_ranges=args.param_ranges, param_increments=args.param_increments, metadata =metadata, retriever_type=args.retriever_type)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|
|
|