import json import logging import os from enum import Enum from typing import Dict, Any import uvicorn from fastapi import FastAPI, BackgroundTasks from fastapi.responses import JSONResponse from pydantic import BaseModel from database.database import AsyncSessionLocal from database.database_crud import session_scope from vectorstore_manager import Memory from dotenv import load_dotenv from rag_test_manager import start_test # Set up logging logging.basicConfig( level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL) format="%(asctime)s [%(levelname)s] %(message)s", # Set the log message format ) logger = logging.getLogger(__name__) load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") app = FastAPI(debug=True) from auth.cognito.JWTBearer import JWTBearer from auth.auth import jwks auth = JWTBearer(jwks) from fastapi import Depends class ImageResponse(BaseModel): success: bool message: str @app.get( "/", ) async def root(): """ Root endpoint that returns a welcome message. """ return {"message": "Hello, World, I am alive!"} @app.get("/health") def health_check(): """ Health check endpoint that returns the server status. """ return {"status": "OK"} class Payload(BaseModel): payload: Dict[str, Any] def memory_factory(memory_type): load_dotenv() class Payload(BaseModel): payload: Dict[str, Any] @app.post("/{memory_type}/add-memory", response_model=dict) async def add_memory( payload: Payload, # files: List[UploadFile] = File(...), ): try: logging.info(" Adding to Memory ") decoded_payload = payload.payload async with session_scope(session=AsyncSessionLocal()) as session: memory = await Memory.create_memory( decoded_payload["user_id"], session, namespace="SEMANTICMEMORY" ) # Adding a memory instance await memory.add_memory_instance(decoded_payload["memory_object"]) # Managing memory attributes existing_user = await Memory.check_existing_user( decoded_payload["user_id"], session ) await memory.manage_memory_attributes(existing_user) await memory.add_dynamic_memory_class( decoded_payload["memory_object"], decoded_payload["memory_object"].upper(), ) memory_class = decoded_payload["memory_object"] + "_class" dynamic_memory_class = getattr(memory, memory_class.lower(), None) await memory.add_method_to_class(dynamic_memory_class, "add_memories") # await memory.add_method_to_class(memory.semanticmemory_class, 'fetch_memories') output = await memory.dynamic_method_call( dynamic_memory_class, "add_memories", observation="some_observation", params=decoded_payload["params"], loader_settings=decoded_payload["loader_settings"], ) return JSONResponse(content={"response": output}, status_code=200) except Exception as e: return JSONResponse( content={"response": {"error": str(e)}}, status_code=503 ) @app.post("/{memory_type}/fetch-memory", response_model=dict) async def fetch_memory( payload: Payload, # files: List[UploadFile] = File(...), ): try: logging.info(" Adding to Memory ") decoded_payload = payload.payload async with session_scope(session=AsyncSessionLocal()) as session: memory = await Memory.create_memory( decoded_payload["user_id"], session, namespace="SEMANTICMEMORY" ) # Adding a memory instance await memory.add_memory_instance(decoded_payload["memory_object"]) # Managing memory attributes existing_user = await Memory.check_existing_user( decoded_payload["user_id"], session ) await memory.manage_memory_attributes(existing_user) await memory.add_dynamic_memory_class( decoded_payload["memory_object"], decoded_payload["memory_object"].upper(), ) memory_class = decoded_payload["memory_object"] + "_class" dynamic_memory_class = getattr(memory, memory_class.lower(), None) await memory.add_method_to_class(dynamic_memory_class, "add_memories") # await memory.add_method_to_class(memory.semanticmemory_class, 'fetch_memories') output = await memory.dynamic_method_call( dynamic_memory_class, "fetch_memories", observation=decoded_payload["observation"], ) return JSONResponse(content={"response": output}, status_code=200) except Exception as e: return JSONResponse( content={"response": {"error": str(e)}}, status_code=503 ) @app.post("/{memory_type}/delete-memory", response_model=dict) async def delete_memory( payload: Payload, # files: List[UploadFile] = File(...), ): try: logging.info(" Adding to Memory ") decoded_payload = payload.payload async with session_scope(session=AsyncSessionLocal()) as session: memory = await Memory.create_memory( decoded_payload["user_id"], session, namespace="SEMANTICMEMORY" ) # Adding a memory instance await memory.add_memory_instance(decoded_payload["memory_object"]) # Managing memory attributes existing_user = await Memory.check_existing_user( decoded_payload["user_id"], session ) await memory.manage_memory_attributes(existing_user) await memory.add_dynamic_memory_class( decoded_payload["memory_object"], decoded_payload["memory_object"].upper(), ) memory_class = decoded_payload["memory_object"] + "_class" dynamic_memory_class = getattr(memory, memory_class.lower(), None) await memory.add_method_to_class( dynamic_memory_class, "delete_memories" ) # await memory.add_method_to_class(memory.semanticmemory_class, 'fetch_memories') output = await memory.dynamic_method_call( dynamic_memory_class, "delete_memories", namespace=decoded_payload["memory_object"].upper(), ) return JSONResponse(content={"response": output}, status_code=200) except Exception as e: return JSONResponse( content={"response": {"error": str(e)}}, status_code=503 ) memory_list = ["episodic", "buffer", "semantic"] for memory_type in memory_list: memory_factory(memory_type) class TestSetType(Enum): SAMPLE = "sample" MANUAL = "manual" def get_test_set(test_set_type, folder_path="example_data", payload=None): if test_set_type == TestSetType.SAMPLE: file_path = os.path.join(folder_path, "test_set.json") if os.path.isfile(file_path): with open(file_path, "r") as file: return json.load(file) elif test_set_type == TestSetType.MANUAL: # Check if the manual test set is provided in the payload if payload and "manual_test_set" in payload: return payload["manual_test_set"] else: # Attempt to load the manual test set from a file pass return None class MetadataType(Enum): SAMPLE = "sample" MANUAL = "manual" def get_metadata(metadata_type, folder_path="example_data", payload=None): if metadata_type == MetadataType.SAMPLE: file_path = os.path.join(folder_path, "metadata.json") if os.path.isfile(file_path): with open(file_path, "r") as file: return json.load(file) elif metadata_type == MetadataType.MANUAL: # Check if the manual metadata is provided in the payload if payload and "manual_metadata" in payload: return payload["manual_metadata"] else: pass return None @app.post("/rag-test/rag_test_run", response_model=dict) async def rag_test_run( payload: Payload, background_tasks: BackgroundTasks, ): try: logging.info("Starting RAG Test") decoded_payload = payload.payload test_set_type = TestSetType(decoded_payload['test_set']) metadata_type = MetadataType(decoded_payload['metadata']) metadata = get_metadata(metadata_type, payload=decoded_payload) if metadata is None: return JSONResponse(content={"response": "Invalid metadata value"}, status_code=400) test_set = get_test_set(test_set_type, payload=decoded_payload) if test_set is None: return JSONResponse(content={"response": "Invalid test_set value"}, status_code=400) async def run_start_test(data, test_set, user_id, params, metadata, retriever_type): result = await start_test(data = data, test_set = test_set, user_id =user_id, params =params, metadata =metadata, retriever_type=retriever_type) logging.info("Retriever DATA type", type(decoded_payload['data'])) background_tasks.add_task( run_start_test, decoded_payload['data'], test_set, decoded_payload['user_id'], decoded_payload['params'], metadata, decoded_payload['retriever_type'] ) logging.info("Retriever type", decoded_payload['retriever_type']) return JSONResponse(content={"response": "Task has been started"}, status_code=200) except Exception as e: return JSONResponse( content={"response": {"error": str(e)}}, status_code=503 ) # @app.get("/rag-test/{task_id}") # async def check_task_status(task_id: int): # task_status = task_status_db.get(task_id, "not_found") # # if task_status == "not_found": # return {"status": "Task not found"} # # return {"status": task_status} # @app.get("/available-buffer-actions", response_model=dict) # async def available_buffer_actions( # payload: Payload, # # files: List[UploadFile] = File(...), # ): # try: # decoded_payload = payload.payload # # Memory_ = Memory(user_id=decoded_payload["user_id"]) # # await Memory_.async_init() # # # memory_class = getattr(Memory_, f"_delete_{memory_type}_memory", None) # output = await Memory_._available_operations() # return JSONResponse(content={"response": output}, status_code=200) # # except Exception as e: # return JSONResponse(content={"response": {"error": str(e)}}, status_code=503) # @app.post("/run-buffer", response_model=dict) # async def run_buffer( # payload: Payload, # # files: List[UploadFile] = File(...), # ): # try: # decoded_payload = payload.payload # # Memory_ = Memory(user_id=decoded_payload["user_id"]) # # await Memory_.async_init() # # # memory_class = getattr(Memory_, f"_delete_{memory_type}_memory", None) # output = await Memory_._run_main_buffer( # user_input=decoded_payload["prompt"], params=decoded_payload["params"], attention_modulators=decoded_payload["attention_modulators"] # ) # return JSONResponse(content={"response": output}, status_code=200) # # except Exception as e: # return JSONResponse(content={"response": {"error": str(e)}}, status_code=503) # # # @app.post("/buffer/create-context", response_model=dict) # async def create_context( # payload: Payload, # # files: List[UploadFile] = File(...), # ): # try: # decoded_payload = payload.payload # # Memory_ = Memory(user_id=decoded_payload["user_id"]) # # await Memory_.async_init() # # # memory_class = getattr(Memory_, f"_delete_{memory_type}_memory", None) # output = await Memory_._create_buffer_context( # user_input=decoded_payload["prompt"], params=decoded_payload["params"], attention_modulators=decoded_payload["attention_modulators"] # ) # return JSONResponse(content={"response": output}, status_code=200) # # except Exception as e: # return JSONResponse(content={"response": {"error": str(e)}}, status_code=503) # # # @app.post("/buffer/get-tasks", response_model=dict) # async def create_context( # payload: Payload, # # files: List[UploadFile] = File(...), # ): # try: # decoded_payload = payload.payload # # Memory_ = Memory(user_id=decoded_payload["user_id"]) # # await Memory_.async_init() # # # memory_class = getattr(Memory_, f"_delete_{memory_type}_memory", None) # output = await Memory_._get_task_list( # user_input=decoded_payload["prompt"], params=decoded_payload["params"], attention_modulators=decoded_payload["attention_modulators"] # ) # return JSONResponse(content={"response": output}, status_code=200) # # except Exception as e: # return JSONResponse(content={"response": {"error": str(e)}}, status_code=503) # # # @app.post("/buffer/provide-feedback", response_model=dict) # async def provide_feedback( # payload: Payload, # # files: List[UploadFile] = File(...), # ): # try: # decoded_payload = payload.payload # # Memory_ = Memory(user_id=decoded_payload["user_id"]) # # await Memory_.async_init() # # # memory_class = getattr(Memory_, f"_delete_{memory_type}_memory", None) # if decoded_payload["total_score"] is None: # # output = await Memory_._provide_feedback( # user_input=decoded_payload["prompt"], params=decoded_payload["params"], attention_modulators=None, total_score=decoded_payload["total_score"] # ) # return JSONResponse(content={"response": output}, status_code=200) # else: # output = await Memory_._provide_feedback( # user_input=decoded_payload["prompt"], params=decoded_payload["params"], attention_modulators=decoded_payload["attention_modulators"], total_score=None # ) # return JSONResponse(content={"response": output}, status_code=200) # # # except Exception as e: # return JSONResponse(content={"response": {"error": str(e)}}, status_code=503) def start_api_server(host: str = "0.0.0.0", port: int = 8000): """ Start the API server using uvicorn. Parameters: host (str): The host for the server. port (int): The port for the server. """ try: logger.info(f"Starting server at {host}:{port}") uvicorn.run(app, host=host, port=port) except Exception as e: logger.exception(f"Failed to start server: {e}") # Here you could add any cleanup code or error recovery code. if __name__ == "__main__": start_api_server()