From 12796a2ba10133fa5486207b7aff32531c34493f Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Sun, 29 Oct 2023 00:18:23 +0200 Subject: [PATCH] Added docs functionality --- level_3/api.py | 22 ++++++++++++-- level_3/docker-compose.yml | 60 ++++++++++++++++++------------------- level_3/models/docs.py | 18 +++++++++++ level_3/rag_test_manager.py | 56 +++++++++++++++++++++++++++------- 4 files changed, 113 insertions(+), 43 deletions(-) create mode 100644 level_3/models/docs.py diff --git a/level_3/api.py b/level_3/api.py index 10c640d0b..953150d61 100644 --- a/level_3/api.py +++ b/level_3/api.py @@ -7,8 +7,8 @@ from fastapi import FastAPI from fastapi.responses import JSONResponse from pydantic import BaseModel -from level_3.database.database import AsyncSessionLocal -from level_3.database.database_crud import session_scope +from database.database import AsyncSessionLocal +from database.database_crud import session_scope from vectorstore_manager import Memory from dotenv import load_dotenv @@ -202,6 +202,24 @@ for memory_type in memory_list: memory_factory(memory_type) +@app.post("/rag-test/rag_test_run", response_model=dict) +async def rag_test_run( + payload: Payload, + # files: List[UploadFile] = File(...), +): + try: + from rag_test_manager import start_test + logging.info(" Running RAG Test ") + decoded_payload = payload.payload + output = await start_test(data=decoded_payload['data'], test_set=decoded_payload['test_set'], user_id=decoded_payload['user_id'], params=decoded_payload['params'], metadata=decoded_payload['metadata'], + retriever_type=decoded_payload['retriever_type']) + return JSONResponse(content={"response": output}, status_code=200) + except Exception as e: + return JSONResponse( + content={"response": {"error": str(e)}}, status_code=503 + ) + + # @app.get("/available-buffer-actions", response_model=dict) # async def available_buffer_actions( # payload: Payload, diff --git a/level_3/docker-compose.yml b/level_3/docker-compose.yml index 7444aca82..8e7e56698 100644 --- a/level_3/docker-compose.yml +++ b/level_3/docker-compose.yml @@ -13,19 +13,19 @@ services: # networks: # - promethai_mem_backend -# promethai_mem: -# networks: -# - promethai_mem_backend -# build: -# context: ./ -# volumes: -# - "./:/app" -# environment: -# - HOST=0.0.0.0 -# profiles: ["exclude-from-up"] -# ports: -# - 8000:8000 -# - 443:443 + promethai_mem: + networks: + - promethai_mem_backend + build: + context: ./ + volumes: + - "./:/app" + environment: + - HOST=0.0.0.0 + profiles: ["exclude-from-up"] + ports: + - 8000:8000 + - 443:443 postgres: image: postgres @@ -40,23 +40,23 @@ services: ports: - "5432:5432" - superset: - platform: linux/amd64 - build: - context: ./superset - dockerfile: Dockerfile - container_name: superset - environment: - - ADMIN_USERNAME=admin - - ADMIN_EMAIL=vasilije@topoteretes.com - - ADMIN_PASSWORD=admin - - POSTGRES_USER=bla - - POSTGRES_PASSWORD=bla - - POSTGRES_DB=bubu - networks: - - promethai_mem_backend - ports: - - '8088:8088' +# superset: +# platform: linux/amd64 +# build: +# context: ./superset +# dockerfile: Dockerfile +# container_name: superset +# environment: +# - ADMIN_USERNAME=admin +# - ADMIN_EMAIL=vasilije@topoteretes.com +# - ADMIN_PASSWORD=admin +# - POSTGRES_USER=bla +# - POSTGRES_PASSWORD=bla +# - POSTGRES_DB=bubu +# networks: +# - promethai_mem_backend +# ports: +# - '8088:8088' networks: promethai_mem_backend: diff --git a/level_3/models/docs.py b/level_3/models/docs.py new file mode 100644 index 000000000..95d98c485 --- /dev/null +++ b/level_3/models/docs.py @@ -0,0 +1,18 @@ + +from datetime import datetime +from sqlalchemy import Column, Integer, String, DateTime, ForeignKey +from sqlalchemy.orm import relationship +import os +import sys +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from database.database import Base +class DocsModel(Base): + __tablename__ = 'docs' + + id = Column(String, primary_key=True) + operation_id = Column(String, ForeignKey('operations.id'), index=True) + doc_name = Column(String, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, onupdate=datetime.utcnow) + + operation = relationship("Operation", back_populates="docs") diff --git a/level_3/rag_test_manager.py b/level_3/rag_test_manager.py index 9b87b73cf..63ce1f388 100644 --- a/level_3/rag_test_manager.py +++ b/level_3/rag_test_manager.py @@ -73,7 +73,41 @@ async def retrieve_latest_test_case(session, user_id, memory_id): f"An error occurred while retrieving the latest test case: {str(e)}" ) return None +def get_document_names(doc_input): + """ + Get a list of document names. + This function takes doc_input, which can be a folder path, a single document file path, or a document name as a string. + It returns a list of document names based on the doc_input. + + Args: + doc_input (str): The doc_input can be a folder path, a single document file path, or a document name as a string. + + Returns: + list: A list of document names. + + Example usage: + - Folder path: get_document_names(".data") + - Single document file path: get_document_names(".data/example.pdf") + - Document name provided as a string: get_document_names("example.docx") + """ + if os.path.isdir(doc_input): + # doc_input is a folder + folder_path = doc_input + document_names = [] + for filename in os.listdir(folder_path): + if os.path.isfile(os.path.join(folder_path, filename)): + document_names.append(filename) + return document_names + elif os.path.isfile(doc_input): + # doc_input is a single document file + return [os.path.basename(doc_input)] + elif isinstance(doc_input, str): + # doc_input is a document name provided as a string + return [doc_input] + else: + # doc_input is not valid + return [] async def add_entity(session, entity): async with session_scope(session) as s: # Use your async session_scope @@ -369,17 +403,6 @@ def data_format_route(data_string: str): # Return a default category if no match is found return FormatRoute.PDF.name - -# def data_location_route(data_string: str): -# @ai_classifier -# class LocationRoute(Enum): -# """Represents classifier for the data location, if it is device, or database connections string or URL""" -# -# DEVICE = "file_path_starting_with_.data_or_containing_it" -# URL = "url starting with http or https" -# DATABASE = "database_name_like_postgres_or_mysql" -# -# return LocationRoute(data_string).name def data_location_route(data_string: str): class LocationRoute(Enum): """Represents classifier for the data location, if it is device, or database connection string or URL""" @@ -489,6 +512,17 @@ async def start_test( test_set_id=test_set_id, ), ) + doc_names = get_document_names(data) + for doc in doc_names: + + await add_entity( + session, + Docs( + id=str(uuid.uuid4()), + operation_id=job_id, + doc_name = doc + ) + ) async def run_test( test, loader_settings, metadata, test_id=None, retriever_type=False