fixes to chunking logic and optimizations

This commit is contained in:
Vasilije 2023-10-28 23:12:42 +02:00
parent 259198d69d
commit 7a07be7d53
4 changed files with 670 additions and 631 deletions

View file

@ -140,10 +140,11 @@ After that, you can run the RAG test manager.
```
python rag_test_manager.py \
--url "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" \
--file ".data" \
--test_set "example_data/test_set.json" \
--user_id "666" \
--metadata "example_data/metadata.json"
--metadata "example_data/metadata.json" \
--retriever_type "single_document_context"
```

1109
level_3/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -78,7 +78,6 @@ async def retrieve_latest_test_case(session, user_id, memory_id):
async def add_entity(session, entity):
async with session_scope(session) as s: # Use your async session_scope
s.add(entity) # No need to commit; session_scope takes care of it
s.commit()
return "Successfully added entity"
@ -278,8 +277,8 @@ async def eval_test(
test_case = synthetic_test_set
else:
test_case = LLMTestCase(
query=query,
output=result_output,
input=query,
actual_output=result_output,
expected_output=expected_output,
context=context,
)
@ -323,8 +322,22 @@ def count_files_in_data_folder(data_folder_path=".data"):
except Exception as e:
print(f"An error occurred: {str(e)}")
return -1 # Return -1 to indicate an error
# def data_format_route(data_string: str):
# @ai_classifier
# class FormatRoute(Enum):
# """Represents classifier for the data format"""
#
# PDF = "PDF"
# UNSTRUCTURED_WEB = "UNSTRUCTURED_WEB"
# GITHUB = "GITHUB"
# TEXT = "TEXT"
# CSV = "CSV"
# WIKIPEDIA = "WIKIPEDIA"
#
# return FormatRoute(data_string).name
def data_format_route(data_string: str):
@ai_classifier
class FormatRoute(Enum):
"""Represents classifier for the data format"""
@ -335,20 +348,59 @@ def data_format_route(data_string: str):
CSV = "CSV"
WIKIPEDIA = "WIKIPEDIA"
return FormatRoute(data_string).name
# Convert the input string to lowercase for case-insensitive matching
data_string = data_string.lower()
# Mapping of keywords to categories
keyword_mapping = {
"pdf": FormatRoute.PDF,
"web": FormatRoute.UNSTRUCTURED_WEB,
"github": FormatRoute.GITHUB,
"text": FormatRoute.TEXT,
"csv": FormatRoute.CSV,
"wikipedia": FormatRoute.WIKIPEDIA
}
# Try to match keywords in the data string
for keyword, category in keyword_mapping.items():
if keyword in data_string:
return category.name
# Return a default category if no match is found
return FormatRoute.PDF.name
# def data_location_route(data_string: str):
# @ai_classifier
# class LocationRoute(Enum):
# """Represents classifier for the data location, if it is device, or database connections string or URL"""
#
# DEVICE = "file_path_starting_with_.data_or_containing_it"
# URL = "url starting with http or https"
# DATABASE = "database_name_like_postgres_or_mysql"
#
# return LocationRoute(data_string).name
def data_location_route(data_string: str):
@ai_classifier
class LocationRoute(Enum):
"""Represents classifier for the data location, if it is device, or database connections string or URL"""
"""Represents classifier for the data location, if it is device, or database connection string or URL"""
DEVICE = "file_path_starting_with_.data_or_containing_it"
# URL = "url starting with http or https"
DATABASE = "database_name_like_postgres_or_mysql"
DEVICE = "DEVICE"
URL = "URL"
DATABASE = "DATABASE"
return LocationRoute(data_string).name
# Convert the input string to lowercase for case-insensitive matching
data_string = data_string.lower()
# Check for specific patterns in the data string
if data_string.startswith(".data") or "data" in data_string:
return LocationRoute.DEVICE.name
elif data_string.startswith("http://") or data_string.startswith("https://"):
return LocationRoute.URL.name
elif "postgres" in data_string or "mysql" in data_string:
return LocationRoute.DATABASE.name
# Return a default category if no match is found
return "Unknown"
def dynamic_test_manager(context=None):
from deepeval.dataset import create_evaluation_query_answer_pairs
@ -373,7 +425,6 @@ async def start_test(
test_set=None,
user_id=None,
params=None,
job_id=None,
metadata=None,
generate_test_set=False,
retriever_type: str = None,
@ -381,6 +432,7 @@ async def start_test(
"""retriever_type = "llm_context, single_document_context, multi_document_context, "cognitive_architecture""" ""
async with session_scope(session=AsyncSessionLocal()) as session:
job_id = ""
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
test_set_id = await fetch_test_set_id(session, user_id=user_id, content=str(test_set))
memory = await Memory.create_memory(
@ -416,7 +468,7 @@ async def start_test(
)
test_params = generate_param_variants(included_params=params)
print("Here are the test params", str(test_params))
logging.info("Here are the test params %s", str(test_params))
loader_settings = {
"format": f"{data_format}",
@ -522,12 +574,13 @@ async def start_test(
test_eval_pipeline = []
if retriever_type == "llm_context":
for test_qa in test_set:
context = ""
logging.info("Loading and evaluating test set for LLM context")
test_result = await run_eval(test_qa, context)
test_eval_pipeline.append(test_result)
elif retriever_type == "single_document_context":
if test_set:
@ -556,7 +609,12 @@ async def start_test(
results = []
logging.info("Validating the retriever type")
logging.info("Retriever type: %s", retriever_type)
if retriever_type == "llm_context":
logging.info("Retriever type: llm_context")
test_id, result = await run_test(
test=None,
loader_settings=loader_settings,
@ -566,6 +624,7 @@ async def start_test(
results.append([result, "No params"])
elif retriever_type == "single_document_context":
logging.info("Retriever type: single document context")
for param in test_params:
logging.info("Running for chunk size %s", param["chunk_size"])
test_id, result = await run_test(
@ -636,55 +695,55 @@ async def main():
]
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
# http://public-library.uk/ebooks/59/83.pdf
result = await start_test(
".data/3ZCCCW.pdf",
test_set=test_set,
user_id="677",
params=["chunk_size", "search_type"],
metadata=metadata,
retriever_type="single_document_context",
)
#
# parser = argparse.ArgumentParser(description="Run tests against a document.")
# parser.add_argument("--url", required=True, help="URL of the document to test.")
# parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.")
# parser.add_argument("--user_id", required=True, help="User ID.")
# parser.add_argument("--params", help="Additional parameters in JSON format.")
# parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.")
# parser.add_argument("--generate_test_set", required=True, help="Make a test set.")
# parser.add_argument("--only_llm_context", required=True, help="Do a test only within the existing LLM context")
# args = parser.parse_args()
#
# try:
# with open(args.test_set, "r") as file:
# test_set = json.load(file)
# if not isinstance(test_set, list): # Expecting a list
# raise TypeError("Parsed test_set JSON is not a list.")
# except Exception as e:
# print(f"Error loading test_set: {str(e)}")
# return
#
# try:
# with open(args.metadata, "r") as file:
# metadata = json.load(file)
# if not isinstance(metadata, dict):
# raise TypeError("Parsed metadata JSON is not a dictionary.")
# except Exception as e:
# print(f"Error loading metadata: {str(e)}")
# return
#
# if args.params:
# try:
# params = json.loads(args.params)
# if not isinstance(params, dict):
# raise TypeError("Parsed params JSON is not a dictionary.")
# except json.JSONDecodeError as e:
# print(f"Error parsing params: {str(e)}")
# return
# else:
# params = None
# #clean up params here
# await start_test(args.url, test_set, args.user_id, params=None, metadata=metadata)
# result = await start_test(
# ".data/3ZCCCW.pdf",
# test_set=test_set,
# user_id="677",
# params=["chunk_size", "search_type"],
# metadata=metadata,
# retriever_type="single_document_context",
# )
parser = argparse.ArgumentParser(description="Run tests against a document.")
parser.add_argument("--file", required=True, help="URL or location of the document to test.")
parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.")
parser.add_argument("--user_id", required=True, help="User ID.")
parser.add_argument("--params", help="Additional parameters in JSON format.")
parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.")
# parser.add_argument("--generate_test_set", required=False, help="Make a test set.")
parser.add_argument("--retriever_type", required=False, help="Do a test only within the existing LLM context")
args = parser.parse_args()
try:
with open(args.test_set, "r") as file:
test_set = json.load(file)
if not isinstance(test_set, list): # Expecting a list
raise TypeError("Parsed test_set JSON is not a list.")
except Exception as e:
print(f"Error loading test_set: {str(e)}")
return
try:
with open(args.metadata, "r") as file:
metadata = json.load(file)
if not isinstance(metadata, dict):
raise TypeError("Parsed metadata JSON is not a dictionary.")
except Exception as e:
print(f"Error loading metadata: {str(e)}")
return
if args.params:
try:
params = json.loads(args.params)
if not isinstance(params, dict):
raise TypeError("Parsed params JSON is not a dictionary.")
except json.JSONDecodeError as e:
print(f"Error parsing params: {str(e)}")
return
else:
params = None
#clean up params here
await start_test(data=args.file, test_set=test_set, user_id= args.user_id, params= params, metadata =metadata, retriever_type=args.retriever_type)
if __name__ == "__main__":

View file

@ -45,7 +45,7 @@ async def _document_loader( observation: str, loader_settings: dict):
# pages = documents.load_and_split()
return documents
elif document_format == "text":
elif document_format == "TEXT":
pages = chunk_data(chunk_strategy= loader_strategy, source_data=observation, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return pages