Added docs functionality

This commit is contained in:
Vasilije 2023-10-29 00:18:23 +02:00
parent 7a07be7d53
commit 12796a2ba1
4 changed files with 113 additions and 43 deletions

View file

@ -7,8 +7,8 @@ from fastapi import FastAPI
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel from pydantic import BaseModel
from level_3.database.database import AsyncSessionLocal from database.database import AsyncSessionLocal
from level_3.database.database_crud import session_scope from database.database_crud import session_scope
from vectorstore_manager import Memory from vectorstore_manager import Memory
from dotenv import load_dotenv from dotenv import load_dotenv
@ -202,6 +202,24 @@ for memory_type in memory_list:
memory_factory(memory_type) memory_factory(memory_type)
@app.post("/rag-test/rag_test_run", response_model=dict)
async def rag_test_run(
payload: Payload,
# files: List[UploadFile] = File(...),
):
try:
from rag_test_manager import start_test
logging.info(" Running RAG Test ")
decoded_payload = payload.payload
output = await start_test(data=decoded_payload['data'], test_set=decoded_payload['test_set'], user_id=decoded_payload['user_id'], params=decoded_payload['params'], metadata=decoded_payload['metadata'],
retriever_type=decoded_payload['retriever_type'])
return JSONResponse(content={"response": output}, status_code=200)
except Exception as e:
return JSONResponse(
content={"response": {"error": str(e)}}, status_code=503
)
# @app.get("/available-buffer-actions", response_model=dict) # @app.get("/available-buffer-actions", response_model=dict)
# async def available_buffer_actions( # async def available_buffer_actions(
# payload: Payload, # payload: Payload,

View file

@ -13,19 +13,19 @@ services:
# networks: # networks:
# - promethai_mem_backend # - promethai_mem_backend
# promethai_mem: promethai_mem:
# networks: networks:
# - promethai_mem_backend - promethai_mem_backend
# build: build:
# context: ./ context: ./
# volumes: volumes:
# - "./:/app" - "./:/app"
# environment: environment:
# - HOST=0.0.0.0 - HOST=0.0.0.0
# profiles: ["exclude-from-up"] profiles: ["exclude-from-up"]
# ports: ports:
# - 8000:8000 - 8000:8000
# - 443:443 - 443:443
postgres: postgres:
image: postgres image: postgres
@ -40,23 +40,23 @@ services:
ports: ports:
- "5432:5432" - "5432:5432"
superset: # superset:
platform: linux/amd64 # platform: linux/amd64
build: # build:
context: ./superset # context: ./superset
dockerfile: Dockerfile # dockerfile: Dockerfile
container_name: superset # container_name: superset
environment: # environment:
- ADMIN_USERNAME=admin # - ADMIN_USERNAME=admin
- ADMIN_EMAIL=vasilije@topoteretes.com # - ADMIN_EMAIL=vasilije@topoteretes.com
- ADMIN_PASSWORD=admin # - ADMIN_PASSWORD=admin
- POSTGRES_USER=bla # - POSTGRES_USER=bla
- POSTGRES_PASSWORD=bla # - POSTGRES_PASSWORD=bla
- POSTGRES_DB=bubu # - POSTGRES_DB=bubu
networks: # networks:
- promethai_mem_backend # - promethai_mem_backend
ports: # ports:
- '8088:8088' # - '8088:8088'
networks: networks:
promethai_mem_backend: promethai_mem_backend:

18
level_3/models/docs.py Normal file
View file

@ -0,0 +1,18 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from database.database import Base
class DocsModel(Base):
__tablename__ = 'docs'
id = Column(String, primary_key=True)
operation_id = Column(String, ForeignKey('operations.id'), index=True)
doc_name = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
operation = relationship("Operation", back_populates="docs")

View file

@ -73,7 +73,41 @@ async def retrieve_latest_test_case(session, user_id, memory_id):
f"An error occurred while retrieving the latest test case: {str(e)}" f"An error occurred while retrieving the latest test case: {str(e)}"
) )
return None return None
def get_document_names(doc_input):
"""
Get a list of document names.
This function takes doc_input, which can be a folder path, a single document file path, or a document name as a string.
It returns a list of document names based on the doc_input.
Args:
doc_input (str): The doc_input can be a folder path, a single document file path, or a document name as a string.
Returns:
list: A list of document names.
Example usage:
- Folder path: get_document_names(".data")
- Single document file path: get_document_names(".data/example.pdf")
- Document name provided as a string: get_document_names("example.docx")
"""
if os.path.isdir(doc_input):
# doc_input is a folder
folder_path = doc_input
document_names = []
for filename in os.listdir(folder_path):
if os.path.isfile(os.path.join(folder_path, filename)):
document_names.append(filename)
return document_names
elif os.path.isfile(doc_input):
# doc_input is a single document file
return [os.path.basename(doc_input)]
elif isinstance(doc_input, str):
# doc_input is a document name provided as a string
return [doc_input]
else:
# doc_input is not valid
return []
async def add_entity(session, entity): async def add_entity(session, entity):
async with session_scope(session) as s: # Use your async session_scope async with session_scope(session) as s: # Use your async session_scope
@ -369,17 +403,6 @@ def data_format_route(data_string: str):
# Return a default category if no match is found # Return a default category if no match is found
return FormatRoute.PDF.name return FormatRoute.PDF.name
# def data_location_route(data_string: str):
# @ai_classifier
# class LocationRoute(Enum):
# """Represents classifier for the data location, if it is device, or database connections string or URL"""
#
# DEVICE = "file_path_starting_with_.data_or_containing_it"
# URL = "url starting with http or https"
# DATABASE = "database_name_like_postgres_or_mysql"
#
# return LocationRoute(data_string).name
def data_location_route(data_string: str): def data_location_route(data_string: str):
class LocationRoute(Enum): class LocationRoute(Enum):
"""Represents classifier for the data location, if it is device, or database connection string or URL""" """Represents classifier for the data location, if it is device, or database connection string or URL"""
@ -489,6 +512,17 @@ async def start_test(
test_set_id=test_set_id, test_set_id=test_set_id,
), ),
) )
doc_names = get_document_names(data)
for doc in doc_names:
await add_entity(
session,
Docs(
id=str(uuid.uuid4()),
operation_id=job_id,
doc_name = doc
)
)
async def run_test( async def run_test(
test, loader_settings, metadata, test_id=None, retriever_type=False test, loader_settings, metadata, test_id=None, retriever_type=False