fixes to chunking logic and optimizations

This commit is contained in:
Vasilije 2023-10-28 23:12:42 +02:00
parent 259198d69d
commit 7a07be7d53
4 changed files with 670 additions and 631 deletions

View file

@ -140,10 +140,11 @@ After that, you can run the RAG test manager.
``` ```
python rag_test_manager.py \ python rag_test_manager.py \
--url "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" \ --file ".data" \
--test_set "example_data/test_set.json" \ --test_set "example_data/test_set.json" \
--user_id "666" \ --user_id "666" \
--metadata "example_data/metadata.json" --metadata "example_data/metadata.json" \
--retriever_type "single_document_context"
``` ```

1109
level_3/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -78,7 +78,6 @@ async def retrieve_latest_test_case(session, user_id, memory_id):
async def add_entity(session, entity): async def add_entity(session, entity):
async with session_scope(session) as s: # Use your async session_scope async with session_scope(session) as s: # Use your async session_scope
s.add(entity) # No need to commit; session_scope takes care of it s.add(entity) # No need to commit; session_scope takes care of it
s.commit()
return "Successfully added entity" return "Successfully added entity"
@ -278,8 +277,8 @@ async def eval_test(
test_case = synthetic_test_set test_case = synthetic_test_set
else: else:
test_case = LLMTestCase( test_case = LLMTestCase(
query=query, input=query,
output=result_output, actual_output=result_output,
expected_output=expected_output, expected_output=expected_output,
context=context, context=context,
) )
@ -323,8 +322,22 @@ def count_files_in_data_folder(data_folder_path=".data"):
except Exception as e: except Exception as e:
print(f"An error occurred: {str(e)}") print(f"An error occurred: {str(e)}")
return -1 # Return -1 to indicate an error return -1 # Return -1 to indicate an error
# def data_format_route(data_string: str):
# @ai_classifier
# class FormatRoute(Enum):
# """Represents classifier for the data format"""
#
# PDF = "PDF"
# UNSTRUCTURED_WEB = "UNSTRUCTURED_WEB"
# GITHUB = "GITHUB"
# TEXT = "TEXT"
# CSV = "CSV"
# WIKIPEDIA = "WIKIPEDIA"
#
# return FormatRoute(data_string).name
def data_format_route(data_string: str): def data_format_route(data_string: str):
@ai_classifier
class FormatRoute(Enum): class FormatRoute(Enum):
"""Represents classifier for the data format""" """Represents classifier for the data format"""
@ -335,20 +348,59 @@ def data_format_route(data_string: str):
CSV = "CSV" CSV = "CSV"
WIKIPEDIA = "WIKIPEDIA" WIKIPEDIA = "WIKIPEDIA"
return FormatRoute(data_string).name # Convert the input string to lowercase for case-insensitive matching
data_string = data_string.lower()
# Mapping of keywords to categories
keyword_mapping = {
"pdf": FormatRoute.PDF,
"web": FormatRoute.UNSTRUCTURED_WEB,
"github": FormatRoute.GITHUB,
"text": FormatRoute.TEXT,
"csv": FormatRoute.CSV,
"wikipedia": FormatRoute.WIKIPEDIA
}
# Try to match keywords in the data string
for keyword, category in keyword_mapping.items():
if keyword in data_string:
return category.name
# Return a default category if no match is found
return FormatRoute.PDF.name
# def data_location_route(data_string: str):
# @ai_classifier
# class LocationRoute(Enum):
# """Represents classifier for the data location, if it is device, or database connections string or URL"""
#
# DEVICE = "file_path_starting_with_.data_or_containing_it"
# URL = "url starting with http or https"
# DATABASE = "database_name_like_postgres_or_mysql"
#
# return LocationRoute(data_string).name
def data_location_route(data_string: str): def data_location_route(data_string: str):
@ai_classifier
class LocationRoute(Enum): class LocationRoute(Enum):
"""Represents classifier for the data location, if it is device, or database connections string or URL""" """Represents classifier for the data location, if it is device, or database connection string or URL"""
DEVICE = "file_path_starting_with_.data_or_containing_it" DEVICE = "DEVICE"
# URL = "url starting with http or https" URL = "URL"
DATABASE = "database_name_like_postgres_or_mysql" DATABASE = "DATABASE"
return LocationRoute(data_string).name # Convert the input string to lowercase for case-insensitive matching
data_string = data_string.lower()
# Check for specific patterns in the data string
if data_string.startswith(".data") or "data" in data_string:
return LocationRoute.DEVICE.name
elif data_string.startswith("http://") or data_string.startswith("https://"):
return LocationRoute.URL.name
elif "postgres" in data_string or "mysql" in data_string:
return LocationRoute.DATABASE.name
# Return a default category if no match is found
return "Unknown"
def dynamic_test_manager(context=None): def dynamic_test_manager(context=None):
from deepeval.dataset import create_evaluation_query_answer_pairs from deepeval.dataset import create_evaluation_query_answer_pairs
@ -373,7 +425,6 @@ async def start_test(
test_set=None, test_set=None,
user_id=None, user_id=None,
params=None, params=None,
job_id=None,
metadata=None, metadata=None,
generate_test_set=False, generate_test_set=False,
retriever_type: str = None, retriever_type: str = None,
@ -381,6 +432,7 @@ async def start_test(
"""retriever_type = "llm_context, single_document_context, multi_document_context, "cognitive_architecture""" "" """retriever_type = "llm_context, single_document_context, multi_document_context, "cognitive_architecture""" ""
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
job_id = ""
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id) job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
test_set_id = await fetch_test_set_id(session, user_id=user_id, content=str(test_set)) test_set_id = await fetch_test_set_id(session, user_id=user_id, content=str(test_set))
memory = await Memory.create_memory( memory = await Memory.create_memory(
@ -416,7 +468,7 @@ async def start_test(
) )
test_params = generate_param_variants(included_params=params) test_params = generate_param_variants(included_params=params)
print("Here are the test params", str(test_params)) logging.info("Here are the test params %s", str(test_params))
loader_settings = { loader_settings = {
"format": f"{data_format}", "format": f"{data_format}",
@ -522,12 +574,13 @@ async def start_test(
test_eval_pipeline = [] test_eval_pipeline = []
if retriever_type == "llm_context": if retriever_type == "llm_context":
for test_qa in test_set: for test_qa in test_set:
context = "" context = ""
logging.info("Loading and evaluating test set for LLM context") logging.info("Loading and evaluating test set for LLM context")
test_result = await run_eval(test_qa, context) test_result = await run_eval(test_qa, context)
test_eval_pipeline.append(test_result) test_eval_pipeline.append(test_result)
elif retriever_type == "single_document_context": elif retriever_type == "single_document_context":
if test_set: if test_set:
@ -556,7 +609,12 @@ async def start_test(
results = [] results = []
logging.info("Validating the retriever type")
logging.info("Retriever type: %s", retriever_type)
if retriever_type == "llm_context": if retriever_type == "llm_context":
logging.info("Retriever type: llm_context")
test_id, result = await run_test( test_id, result = await run_test(
test=None, test=None,
loader_settings=loader_settings, loader_settings=loader_settings,
@ -566,6 +624,7 @@ async def start_test(
results.append([result, "No params"]) results.append([result, "No params"])
elif retriever_type == "single_document_context": elif retriever_type == "single_document_context":
logging.info("Retriever type: single document context")
for param in test_params: for param in test_params:
logging.info("Running for chunk size %s", param["chunk_size"]) logging.info("Running for chunk size %s", param["chunk_size"])
test_id, result = await run_test( test_id, result = await run_test(
@ -636,55 +695,55 @@ async def main():
] ]
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" # "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
# http://public-library.uk/ebooks/59/83.pdf # http://public-library.uk/ebooks/59/83.pdf
result = await start_test( # result = await start_test(
".data/3ZCCCW.pdf", # ".data/3ZCCCW.pdf",
test_set=test_set, # test_set=test_set,
user_id="677", # user_id="677",
params=["chunk_size", "search_type"], # params=["chunk_size", "search_type"],
metadata=metadata, # metadata=metadata,
retriever_type="single_document_context", # retriever_type="single_document_context",
) # )
#
# parser = argparse.ArgumentParser(description="Run tests against a document.") parser = argparse.ArgumentParser(description="Run tests against a document.")
# parser.add_argument("--url", required=True, help="URL of the document to test.") parser.add_argument("--file", required=True, help="URL or location of the document to test.")
# parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.") parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.")
# parser.add_argument("--user_id", required=True, help="User ID.") parser.add_argument("--user_id", required=True, help="User ID.")
# parser.add_argument("--params", help="Additional parameters in JSON format.") parser.add_argument("--params", help="Additional parameters in JSON format.")
# parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.") parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.")
# parser.add_argument("--generate_test_set", required=True, help="Make a test set.") # parser.add_argument("--generate_test_set", required=False, help="Make a test set.")
# parser.add_argument("--only_llm_context", required=True, help="Do a test only within the existing LLM context") parser.add_argument("--retriever_type", required=False, help="Do a test only within the existing LLM context")
# args = parser.parse_args() args = parser.parse_args()
#
# try: try:
# with open(args.test_set, "r") as file: with open(args.test_set, "r") as file:
# test_set = json.load(file) test_set = json.load(file)
# if not isinstance(test_set, list): # Expecting a list if not isinstance(test_set, list): # Expecting a list
# raise TypeError("Parsed test_set JSON is not a list.") raise TypeError("Parsed test_set JSON is not a list.")
# except Exception as e: except Exception as e:
# print(f"Error loading test_set: {str(e)}") print(f"Error loading test_set: {str(e)}")
# return return
#
# try: try:
# with open(args.metadata, "r") as file: with open(args.metadata, "r") as file:
# metadata = json.load(file) metadata = json.load(file)
# if not isinstance(metadata, dict): if not isinstance(metadata, dict):
# raise TypeError("Parsed metadata JSON is not a dictionary.") raise TypeError("Parsed metadata JSON is not a dictionary.")
# except Exception as e: except Exception as e:
# print(f"Error loading metadata: {str(e)}") print(f"Error loading metadata: {str(e)}")
# return return
#
# if args.params: if args.params:
# try: try:
# params = json.loads(args.params) params = json.loads(args.params)
# if not isinstance(params, dict): if not isinstance(params, dict):
# raise TypeError("Parsed params JSON is not a dictionary.") raise TypeError("Parsed params JSON is not a dictionary.")
# except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# print(f"Error parsing params: {str(e)}") print(f"Error parsing params: {str(e)}")
# return return
# else: else:
# params = None params = None
# #clean up params here #clean up params here
# await start_test(args.url, test_set, args.user_id, params=None, metadata=metadata) await start_test(data=args.file, test_set=test_set, user_id= args.user_id, params= params, metadata =metadata, retriever_type=args.retriever_type)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -45,7 +45,7 @@ async def _document_loader( observation: str, loader_settings: dict):
# pages = documents.load_and_split() # pages = documents.load_and_split()
return documents return documents
elif document_format == "text": elif document_format == "TEXT":
pages = chunk_data(chunk_strategy= loader_strategy, source_data=observation, chunk_size=chunk_size, chunk_overlap=chunk_overlap) pages = chunk_data(chunk_strategy= loader_strategy, source_data=observation, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return pages return pages