add more robust process cleanup to avoid dangling subprocesses

This commit is contained in:
Daulet Amirkhanov 2025-09-11 19:18:11 +01:00
parent 8df2ffd991
commit be73e2ee41
2 changed files with 31 additions and 23 deletions

View file

@ -7,7 +7,7 @@ import webbrowser
import zipfile import zipfile
import requests import requests
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Callable, Optional, Tuple
import tempfile import tempfile
import shutil import shutil
@ -326,6 +326,7 @@ def prompt_user_for_download() -> bool:
def start_ui( def start_ui(
pid_callback: Callable[[int], None],
host: str = "localhost", host: str = "localhost",
port: int = 3000, port: int = 3000,
open_browser: bool = True, open_browser: bool = True,
@ -346,6 +347,7 @@ def start_ui(
6. Optionally open the browser 6. Optionally open the browser
Args: Args:
pid_callback: Callback to notify with PID of each spawned process
host: Host to bind the frontend server to (default: localhost) host: Host to bind the frontend server to (default: localhost)
port: Port to run the frontend server on (default: 3000) port: Port to run the frontend server on (default: 3000)
open_browser: Whether to open the browser automatically (default: True) open_browser: Whether to open the browser automatically (default: True)
@ -397,6 +399,8 @@ def start_ui(
preexec_fn=os.setsid if hasattr(os, "setsid") else None, preexec_fn=os.setsid if hasattr(os, "setsid") else None,
) )
pid_callback(backend_process.pid)
# Give the backend a moment to start # Give the backend a moment to start
time.sleep(2) time.sleep(2)
@ -460,7 +464,7 @@ def start_ui(
logger.info("This may take a moment to compile and start...") logger.info("This may take a moment to compile and start...")
try: try:
# Use process group to ensure all child processes get terminated together # Create frontend in its own process group for clean termination
process = subprocess.Popen( process = subprocess.Popen(
["npm", "run", "dev"], ["npm", "run", "dev"],
cwd=frontend_path, cwd=frontend_path,
@ -468,11 +472,11 @@ def start_ui(
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
preexec_fn=os.setsid preexec_fn=os.setsid if hasattr(os, "setsid") else None,
if hasattr(os, "setsid")
else None, # Create new process group on Unix
) )
pid_callback(process.pid)
# Give it a moment to start up # Give it a moment to start up
time.sleep(3) time.sleep(3)

View file

@ -174,30 +174,23 @@ def main() -> int:
# Handle UI flag # Handle UI flag
if hasattr(args, "start_ui") and args.start_ui: if hasattr(args, "start_ui") and args.start_ui:
server_process = None spawned_pids = []
def signal_handler(signum, frame): def signal_handler(signum, frame):
"""Handle Ctrl+C and other termination signals""" """Handle Ctrl+C and other termination signals"""
nonlocal server_process nonlocal spawned_pids
fmt.echo("\nShutting down UI server...") fmt.echo("\nShutting down UI server...")
if server_process:
for pid in spawned_pids:
try: try:
# Try graceful termination first pgid = os.getpgid(pid)
server_process.terminate() os.killpg(pgid, signal.SIGTERM)
try: fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
server_process.wait(timeout=5) except (OSError, ProcessLookupError) as e:
fmt.success("UI server stopped gracefully.") fmt.warning(f"Could not terminate process {pid}: {e}")
except subprocess.TimeoutExpired:
# If graceful termination fails, force kill
fmt.echo("Force stopping UI server...")
server_process.kill()
server_process.wait()
fmt.success("UI server stopped.")
except Exception as e:
fmt.warning(f"Error stopping server: {e}")
sys.exit(0) sys.exit(0)
# Set up signal handlers
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request signal.signal(signal.SIGTERM, signal_handler) # Termination request
@ -206,8 +199,17 @@ def main() -> int:
fmt.echo("Starting cognee UI...") fmt.echo("Starting cognee UI...")
# Callback to capture PIDs of all spawned processes
def pid_callback(pid):
nonlocal spawned_pids
spawned_pids.append(pid)
server_process = start_ui( server_process = start_ui(
host="localhost", port=3000, open_browser=True, start_backend=True host="localhost",
port=3000,
open_browser=True,
start_backend=True,
pid_callback=pid_callback,
) )
if server_process: if server_process:
@ -229,10 +231,12 @@ def main() -> int:
return 0 return 0
else: else:
fmt.error("Failed to start UI server. Check the logs above for details.") fmt.error("Failed to start UI server. Check the logs above for details.")
signal_handler(signal.SIGTERM, None)
return 1 return 1
except Exception as ex: except Exception as ex:
fmt.error(f"Error starting UI: {str(ex)}") fmt.error(f"Error starting UI: {str(ex)}")
signal_handler(signal.SIGTERM, None)
if debug.is_debug_enabled(): if debug.is_debug_enabled():
raise ex raise ex
return 1 return 1