diff --git a/src/main.py b/src/main.py index df59263e..8f586e7a 100644 --- a/src/main.py +++ b/src/main.py @@ -14,6 +14,10 @@ import subprocess from functools import partial from starlette.applications import Starlette from starlette.routing import Route + +# Set multiprocessing start method to 'spawn' for CUDA compatibility +multiprocessing.set_start_method("spawn", force=True) + from utils.process_pool import process_pool import torch @@ -50,8 +54,6 @@ from api import ( settings, ) -# Set multiprocessing start method to 'spawn' for CUDA compatibility -multiprocessing.set_start_method("spawn", force=True) logger.info( "CUDA device information",