Merge pull request #30 from topoteretes/enable_cmd_runner

Enable cmd runner
This commit is contained in:
Vasilije 2023-10-29 00:19:13 +02:00 committed by GitHub
commit 6a0e5674e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 772 additions and 663 deletions

View file

@ -140,10 +140,11 @@ After that, you can run the RAG test manager.
``` ```
python rag_test_manager.py \ python rag_test_manager.py \
--url "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" \ --file ".data" \
--test_set "example_data/test_set.json" \ --test_set "example_data/test_set.json" \
--user_id "666" \ --user_id "666" \
--metadata "example_data/metadata.json" --metadata "example_data/metadata.json" \
--retriever_type "single_document_context"
``` ```

View file

@ -7,8 +7,8 @@ from fastapi import FastAPI
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel from pydantic import BaseModel
from level_3.database.database import AsyncSessionLocal from database.database import AsyncSessionLocal
from level_3.database.database_crud import session_scope from database.database_crud import session_scope
from vectorstore_manager import Memory from vectorstore_manager import Memory
from dotenv import load_dotenv from dotenv import load_dotenv
@ -202,6 +202,24 @@ for memory_type in memory_list:
memory_factory(memory_type) memory_factory(memory_type)
@app.post("/rag-test/rag_test_run", response_model=dict)
async def rag_test_run(
payload: Payload,
# files: List[UploadFile] = File(...),
):
try:
from rag_test_manager import start_test
logging.info(" Running RAG Test ")
decoded_payload = payload.payload
output = await start_test(data=decoded_payload['data'], test_set=decoded_payload['test_set'], user_id=decoded_payload['user_id'], params=decoded_payload['params'], metadata=decoded_payload['metadata'],
retriever_type=decoded_payload['retriever_type'])
return JSONResponse(content={"response": output}, status_code=200)
except Exception as e:
return JSONResponse(
content={"response": {"error": str(e)}}, status_code=503
)
# @app.get("/available-buffer-actions", response_model=dict) # @app.get("/available-buffer-actions", response_model=dict)
# async def available_buffer_actions( # async def available_buffer_actions(
# payload: Payload, # payload: Payload,

View file

@ -13,19 +13,19 @@ services:
# networks: # networks:
# - promethai_mem_backend # - promethai_mem_backend
# promethai_mem: promethai_mem:
# networks: networks:
# - promethai_mem_backend - promethai_mem_backend
# build: build:
# context: ./ context: ./
# volumes: volumes:
# - "./:/app" - "./:/app"
# environment: environment:
# - HOST=0.0.0.0 - HOST=0.0.0.0
# profiles: ["exclude-from-up"] profiles: ["exclude-from-up"]
# ports: ports:
# - 8000:8000 - 8000:8000
# - 443:443 - 443:443
postgres: postgres:
image: postgres image: postgres
@ -40,23 +40,23 @@ services:
ports: ports:
- "5432:5432" - "5432:5432"
superset: # superset:
platform: linux/amd64 # platform: linux/amd64
build: # build:
context: ./superset # context: ./superset
dockerfile: Dockerfile # dockerfile: Dockerfile
container_name: superset # container_name: superset
environment: # environment:
- ADMIN_USERNAME=admin # - ADMIN_USERNAME=admin
- ADMIN_EMAIL=vasilije@topoteretes.com # - ADMIN_EMAIL=vasilije@topoteretes.com
- ADMIN_PASSWORD=admin # - ADMIN_PASSWORD=admin
- POSTGRES_USER=bla # - POSTGRES_USER=bla
- POSTGRES_PASSWORD=bla # - POSTGRES_PASSWORD=bla
- POSTGRES_DB=bubu # - POSTGRES_DB=bubu
networks: # networks:
- promethai_mem_backend # - promethai_mem_backend
ports: # ports:
- '8088:8088' # - '8088:8088'
networks: networks:
promethai_mem_backend: promethai_mem_backend:

18
level_3/models/docs.py Normal file
View file

@ -0,0 +1,18 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from database.database import Base
class DocsModel(Base):
__tablename__ = 'docs'
id = Column(String, primary_key=True)
operation_id = Column(String, ForeignKey('operations.id'), index=True)
doc_name = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
operation = relationship("Operation", back_populates="docs")

1109
level_3/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -73,12 +73,45 @@ async def retrieve_latest_test_case(session, user_id, memory_id):
f"An error occurred while retrieving the latest test case: {str(e)}" f"An error occurred while retrieving the latest test case: {str(e)}"
) )
return None return None
def get_document_names(doc_input):
"""
Get a list of document names.
This function takes doc_input, which can be a folder path, a single document file path, or a document name as a string.
It returns a list of document names based on the doc_input.
Args:
doc_input (str): The doc_input can be a folder path, a single document file path, or a document name as a string.
Returns:
list: A list of document names.
Example usage:
- Folder path: get_document_names(".data")
- Single document file path: get_document_names(".data/example.pdf")
- Document name provided as a string: get_document_names("example.docx")
"""
if os.path.isdir(doc_input):
# doc_input is a folder
folder_path = doc_input
document_names = []
for filename in os.listdir(folder_path):
if os.path.isfile(os.path.join(folder_path, filename)):
document_names.append(filename)
return document_names
elif os.path.isfile(doc_input):
# doc_input is a single document file
return [os.path.basename(doc_input)]
elif isinstance(doc_input, str):
# doc_input is a document name provided as a string
return [doc_input]
else:
# doc_input is not valid
return []
async def add_entity(session, entity): async def add_entity(session, entity):
async with session_scope(session) as s: # Use your async session_scope async with session_scope(session) as s: # Use your async session_scope
s.add(entity) # No need to commit; session_scope takes care of it s.add(entity) # No need to commit; session_scope takes care of it
s.commit()
return "Successfully added entity" return "Successfully added entity"
@ -278,8 +311,8 @@ async def eval_test(
test_case = synthetic_test_set test_case = synthetic_test_set
else: else:
test_case = LLMTestCase( test_case = LLMTestCase(
query=query, input=query,
output=result_output, actual_output=result_output,
expected_output=expected_output, expected_output=expected_output,
context=context, context=context,
) )
@ -323,8 +356,22 @@ def count_files_in_data_folder(data_folder_path=".data"):
except Exception as e: except Exception as e:
print(f"An error occurred: {str(e)}") print(f"An error occurred: {str(e)}")
return -1 # Return -1 to indicate an error return -1 # Return -1 to indicate an error
# def data_format_route(data_string: str):
# @ai_classifier
# class FormatRoute(Enum):
# """Represents classifier for the data format"""
#
# PDF = "PDF"
# UNSTRUCTURED_WEB = "UNSTRUCTURED_WEB"
# GITHUB = "GITHUB"
# TEXT = "TEXT"
# CSV = "CSV"
# WIKIPEDIA = "WIKIPEDIA"
#
# return FormatRoute(data_string).name
def data_format_route(data_string: str): def data_format_route(data_string: str):
@ai_classifier
class FormatRoute(Enum): class FormatRoute(Enum):
"""Represents classifier for the data format""" """Represents classifier for the data format"""
@ -335,20 +382,48 @@ def data_format_route(data_string: str):
CSV = "CSV" CSV = "CSV"
WIKIPEDIA = "WIKIPEDIA" WIKIPEDIA = "WIKIPEDIA"
return FormatRoute(data_string).name # Convert the input string to lowercase for case-insensitive matching
data_string = data_string.lower()
# Mapping of keywords to categories
keyword_mapping = {
"pdf": FormatRoute.PDF,
"web": FormatRoute.UNSTRUCTURED_WEB,
"github": FormatRoute.GITHUB,
"text": FormatRoute.TEXT,
"csv": FormatRoute.CSV,
"wikipedia": FormatRoute.WIKIPEDIA
}
# Try to match keywords in the data string
for keyword, category in keyword_mapping.items():
if keyword in data_string:
return category.name
# Return a default category if no match is found
return FormatRoute.PDF.name
def data_location_route(data_string: str): def data_location_route(data_string: str):
@ai_classifier
class LocationRoute(Enum): class LocationRoute(Enum):
"""Represents classifier for the data location, if it is device, or database connections string or URL""" """Represents classifier for the data location, if it is device, or database connection string or URL"""
DEVICE = "file_path_starting_with_.data_or_containing_it" DEVICE = "DEVICE"
# URL = "url starting with http or https" URL = "URL"
DATABASE = "database_name_like_postgres_or_mysql" DATABASE = "DATABASE"
return LocationRoute(data_string).name # Convert the input string to lowercase for case-insensitive matching
data_string = data_string.lower()
# Check for specific patterns in the data string
if data_string.startswith(".data") or "data" in data_string:
return LocationRoute.DEVICE.name
elif data_string.startswith("http://") or data_string.startswith("https://"):
return LocationRoute.URL.name
elif "postgres" in data_string or "mysql" in data_string:
return LocationRoute.DATABASE.name
# Return a default category if no match is found
return "Unknown"
def dynamic_test_manager(context=None): def dynamic_test_manager(context=None):
from deepeval.dataset import create_evaluation_query_answer_pairs from deepeval.dataset import create_evaluation_query_answer_pairs
@ -373,7 +448,6 @@ async def start_test(
test_set=None, test_set=None,
user_id=None, user_id=None,
params=None, params=None,
job_id=None,
metadata=None, metadata=None,
generate_test_set=False, generate_test_set=False,
retriever_type: str = None, retriever_type: str = None,
@ -381,6 +455,7 @@ async def start_test(
"""retriever_type = "llm_context, single_document_context, multi_document_context, "cognitive_architecture""" "" """retriever_type = "llm_context, single_document_context, multi_document_context, "cognitive_architecture""" ""
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
job_id = ""
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id) job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
test_set_id = await fetch_test_set_id(session, user_id=user_id, content=str(test_set)) test_set_id = await fetch_test_set_id(session, user_id=user_id, content=str(test_set))
memory = await Memory.create_memory( memory = await Memory.create_memory(
@ -416,7 +491,7 @@ async def start_test(
) )
test_params = generate_param_variants(included_params=params) test_params = generate_param_variants(included_params=params)
print("Here are the test params", str(test_params)) logging.info("Here are the test params %s", str(test_params))
loader_settings = { loader_settings = {
"format": f"{data_format}", "format": f"{data_format}",
@ -437,6 +512,17 @@ async def start_test(
test_set_id=test_set_id, test_set_id=test_set_id,
), ),
) )
doc_names = get_document_names(data)
for doc in doc_names:
await add_entity(
session,
Docs(
id=str(uuid.uuid4()),
operation_id=job_id,
doc_name = doc
)
)
async def run_test( async def run_test(
test, loader_settings, metadata, test_id=None, retriever_type=False test, loader_settings, metadata, test_id=None, retriever_type=False
@ -522,12 +608,13 @@ async def start_test(
test_eval_pipeline = [] test_eval_pipeline = []
if retriever_type == "llm_context": if retriever_type == "llm_context":
for test_qa in test_set: for test_qa in test_set:
context = "" context = ""
logging.info("Loading and evaluating test set for LLM context") logging.info("Loading and evaluating test set for LLM context")
test_result = await run_eval(test_qa, context) test_result = await run_eval(test_qa, context)
test_eval_pipeline.append(test_result) test_eval_pipeline.append(test_result)
elif retriever_type == "single_document_context": elif retriever_type == "single_document_context":
if test_set: if test_set:
@ -556,7 +643,12 @@ async def start_test(
results = [] results = []
logging.info("Validating the retriever type")
logging.info("Retriever type: %s", retriever_type)
if retriever_type == "llm_context": if retriever_type == "llm_context":
logging.info("Retriever type: llm_context")
test_id, result = await run_test( test_id, result = await run_test(
test=None, test=None,
loader_settings=loader_settings, loader_settings=loader_settings,
@ -566,6 +658,7 @@ async def start_test(
results.append([result, "No params"]) results.append([result, "No params"])
elif retriever_type == "single_document_context": elif retriever_type == "single_document_context":
logging.info("Retriever type: single document context")
for param in test_params: for param in test_params:
logging.info("Running for chunk size %s", param["chunk_size"]) logging.info("Running for chunk size %s", param["chunk_size"])
test_id, result = await run_test( test_id, result = await run_test(
@ -636,55 +729,55 @@ async def main():
] ]
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" # "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
# http://public-library.uk/ebooks/59/83.pdf # http://public-library.uk/ebooks/59/83.pdf
result = await start_test( # result = await start_test(
".data/3ZCCCW.pdf", # ".data/3ZCCCW.pdf",
test_set=test_set, # test_set=test_set,
user_id="677", # user_id="677",
params=["chunk_size", "search_type"], # params=["chunk_size", "search_type"],
metadata=metadata, # metadata=metadata,
retriever_type="single_document_context", # retriever_type="single_document_context",
) # )
#
# parser = argparse.ArgumentParser(description="Run tests against a document.") parser = argparse.ArgumentParser(description="Run tests against a document.")
# parser.add_argument("--url", required=True, help="URL of the document to test.") parser.add_argument("--file", required=True, help="URL or location of the document to test.")
# parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.") parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.")
# parser.add_argument("--user_id", required=True, help="User ID.") parser.add_argument("--user_id", required=True, help="User ID.")
# parser.add_argument("--params", help="Additional parameters in JSON format.") parser.add_argument("--params", help="Additional parameters in JSON format.")
# parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.") parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.")
# parser.add_argument("--generate_test_set", required=True, help="Make a test set.") # parser.add_argument("--generate_test_set", required=False, help="Make a test set.")
# parser.add_argument("--only_llm_context", required=True, help="Do a test only within the existing LLM context") parser.add_argument("--retriever_type", required=False, help="Do a test only within the existing LLM context")
# args = parser.parse_args() args = parser.parse_args()
#
# try: try:
# with open(args.test_set, "r") as file: with open(args.test_set, "r") as file:
# test_set = json.load(file) test_set = json.load(file)
# if not isinstance(test_set, list): # Expecting a list if not isinstance(test_set, list): # Expecting a list
# raise TypeError("Parsed test_set JSON is not a list.") raise TypeError("Parsed test_set JSON is not a list.")
# except Exception as e: except Exception as e:
# print(f"Error loading test_set: {str(e)}") print(f"Error loading test_set: {str(e)}")
# return return
#
# try: try:
# with open(args.metadata, "r") as file: with open(args.metadata, "r") as file:
# metadata = json.load(file) metadata = json.load(file)
# if not isinstance(metadata, dict): if not isinstance(metadata, dict):
# raise TypeError("Parsed metadata JSON is not a dictionary.") raise TypeError("Parsed metadata JSON is not a dictionary.")
# except Exception as e: except Exception as e:
# print(f"Error loading metadata: {str(e)}") print(f"Error loading metadata: {str(e)}")
# return return
#
# if args.params: if args.params:
# try: try:
# params = json.loads(args.params) params = json.loads(args.params)
# if not isinstance(params, dict): if not isinstance(params, dict):
# raise TypeError("Parsed params JSON is not a dictionary.") raise TypeError("Parsed params JSON is not a dictionary.")
# except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# print(f"Error parsing params: {str(e)}") print(f"Error parsing params: {str(e)}")
# return return
# else: else:
# params = None params = None
# #clean up params here #clean up params here
# await start_test(args.url, test_set, args.user_id, params=None, metadata=metadata) await start_test(data=args.file, test_set=test_set, user_id= args.user_id, params= params, metadata =metadata, retriever_type=args.retriever_type)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -45,7 +45,7 @@ async def _document_loader( observation: str, loader_settings: dict):
# pages = documents.load_and_split() # pages = documents.load_and_split()
return documents return documents
elif document_format == "text": elif document_format == "TEXT":
pages = chunk_data(chunk_strategy= loader_strategy, source_data=observation, chunk_size=chunk_size, chunk_overlap=chunk_overlap) pages = chunk_data(chunk_strategy= loader_strategy, source_data=observation, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return pages return pages