Remove ascii_colors dependency and fix stream handling errors

• Remove ascii_colors.trace_exception calls
• Add SafeStreamHandler for closed streams
• Patch ascii_colors console handler
• Prevent ValueError on stream close
• Improve logging error handling
This commit is contained in:
yangdx 2025-11-19 21:38:17 +08:00
parent f72f435cef
commit 0fb2925c6a
3 changed files with 258 additions and 221 deletions

View file

@ -8,7 +8,6 @@ import re
from enum import Enum from enum import Enum
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import asyncio import asyncio
from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import TiktokenTokenizer from lightrag.utils import TiktokenTokenizer
from lightrag.api.utils_api import get_combined_auth_dependency from lightrag.api.utils_api import get_combined_auth_dependency
@ -309,7 +308,6 @@ class OllamaAPI:
) )
async def stream_generator(): async def stream_generator():
try:
first_chunk_time = None first_chunk_time = None
last_chunk_time = time.time_ns() last_chunk_time = time.time_ns()
total_response = "" total_response = ""
@ -418,10 +416,6 @@ class OllamaAPI:
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
return return
except Exception as e:
trace_exception(e)
raise
return StreamingResponse( return StreamingResponse(
stream_generator(), stream_generator(),
media_type="application/x-ndjson", media_type="application/x-ndjson",
@ -462,7 +456,6 @@ class OllamaAPI:
"eval_duration": eval_time, "eval_duration": eval_time,
} }
except Exception as e: except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@self.router.post( @self.router.post(
@ -535,7 +528,6 @@ class OllamaAPI:
) )
async def stream_generator(): async def stream_generator():
try:
first_chunk_time = None first_chunk_time = None
last_chunk_time = time.time_ns() last_chunk_time = time.time_ns()
total_response = "" total_response = ""
@ -666,10 +658,6 @@ class OllamaAPI:
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
except Exception as e:
trace_exception(e)
raise
return StreamingResponse( return StreamingResponse(
stream_generator(), stream_generator(),
media_type="application/x-ndjson", media_type="application/x-ndjson",
@ -730,5 +718,4 @@ class OllamaAPI:
"eval_duration": eval_time, "eval_duration": eval_time,
} }
except Exception as e: except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View file

@ -11,8 +11,6 @@ from lightrag.base import QueryParam
from lightrag.api.utils_api import get_combined_auth_dependency from lightrag.api.utils_api import get_combined_auth_dependency
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception
router = APIRouter(tags=["query"]) router = APIRouter(tags=["query"])
@ -453,7 +451,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
else: else:
return QueryResponse(response=response_content, references=None) return QueryResponse(response=response_content, references=None)
except Exception as e: except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post( @router.post(
@ -739,7 +736,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
}, },
) )
except Exception as e: except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post( @router.post(
@ -1156,7 +1152,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
data={}, data={},
) )
except Exception as e: except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
return router return router

View file

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import weakref import weakref
import sys
import asyncio import asyncio
import html import html
import csv import csv
@ -40,6 +42,35 @@ from lightrag.constants import (
SOURCE_IDS_LIMIT_METHOD_FIFO, SOURCE_IDS_LIMIT_METHOD_FIFO,
) )
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
class SafeStreamHandler(logging.StreamHandler):
"""StreamHandler that gracefully handles closed streams during shutdown.
This handler prevents "ValueError: I/O operation on closed file" errors
that can occur when pytest or other test frameworks close stdout/stderr
before Python's logging cleanup runs.
"""
def flush(self):
"""Flush the stream, ignoring errors if the stream is closed."""
try:
super().flush()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
def close(self):
"""Close the handler, ignoring errors if the stream is already closed."""
try:
super().close()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
# Initialize logger with basic configuration # Initialize logger with basic configuration
logger = logging.getLogger("lightrag") logger = logging.getLogger("lightrag")
logger.propagate = False # prevent log message send to root logger logger.propagate = False # prevent log message send to root logger
@ -47,7 +78,7 @@ logger.setLevel(logging.INFO)
# Add console handler if no handlers exist # Add console handler if no handlers exist
if not logger.handlers: if not logger.handlers:
console_handler = logging.StreamHandler() console_handler = SafeStreamHandler()
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(levelname)s: %(message)s") formatter = logging.Formatter("%(levelname)s: %(message)s")
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
@ -56,8 +87,32 @@ if not logger.handlers:
# Set httpx logging level to WARNING # Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING)
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]") def _patch_ascii_colors_console_handler() -> None:
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
try:
from ascii_colors import ConsoleHandler
except ImportError:
return
if getattr(ConsoleHandler, "_lightrag_patched", False):
return
original_handle_error = ConsoleHandler.handle_error
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
exc_type, _, _ = sys.exc_info()
if exc_type in (ValueError, OSError) and "close" in message.lower():
return
original_handle_error(self, message)
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
_patch_ascii_colors_console_handler()
# Global import for pypinyin with startup-time logging # Global import for pypinyin with startup-time logging
try: try:
@ -286,8 +341,8 @@ def setup_logger(
logger_instance.handlers = [] # Clear existing handlers logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False logger_instance.propagate = False
# Add console handler # Add console handler with safe stream handling
console_handler = logging.StreamHandler() console_handler = SafeStreamHandler()
console_handler.setFormatter(simple_formatter) console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level) console_handler.setLevel(level)
logger_instance.addHandler(console_handler) logger_instance.addHandler(console_handler)