Remove ascii_colors dependency and fix stream handling errors

• Remove ascii_colors.trace_exception calls
• Add SafeStreamHandler for closed streams
• Patch ascii_colors console handler
• Prevent ValueError on stream close
• Improve logging error handling

(cherry picked from commit 0fb2925c6a)
This commit is contained in:
yangdx 2025-11-19 21:38:17 +08:00 committed by Raphaël MANSUY
parent fd76e0f7ce
commit 322ff19f72
3 changed files with 305 additions and 275 deletions

View file

@ -8,7 +8,6 @@ import re
from enum import Enum
from fastapi.responses import StreamingResponse
import asyncio
from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam
from lightrag.utils import TiktokenTokenizer
from lightrag.api.utils_api import get_combined_auth_dependency
@ -335,118 +334,113 @@ class OllamaAPI:
)
async def stream_generator():
try:
first_chunk_time = None
first_chunk_time = None
last_chunk_time = time.time_ns()
total_response = ""
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = start_time
last_chunk_time = time.time_ns()
total_response = ""
total_response = response
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = start_time
last_chunk_time = time.time_ns()
total_response = response
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": response,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
data = {
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
"done_reason": "stop",
"context": [],
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
try:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": chunk,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = "Stream was cancelled by server"
else:
error_msg = f"Provider error: {error_msg}"
logger.error(f"Stream error: {error_msg}")
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": response,
"response": f"\n\nError: {error_msg}",
"error": f"\n\nError: {error_msg}",
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
# Send final message to close the stream
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
"done_reason": "stop",
"context": [],
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
try:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": chunk,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = "Stream was cancelled by server"
else:
error_msg = f"Provider error: {error_msg}"
logger.error(f"Stream error: {error_msg}")
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": f"\n\nError: {error_msg}",
"error": f"\n\nError: {error_msg}",
"done": False,
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Send final message to close the stream
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
if first_chunk_time is None:
first_chunk_time = start_time
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
"done_reason": "stop",
"context": [],
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
if first_chunk_time is None:
first_chunk_time = start_time
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
except Exception as e:
trace_exception(e)
raise
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
"done_reason": "stop",
"context": [],
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return
return StreamingResponse(
stream_generator(),
@ -488,7 +482,6 @@ class OllamaAPI:
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@self.router.post(
@ -558,36 +551,98 @@ class OllamaAPI:
)
async def stream_generator():
try:
first_chunk_time = None
first_chunk_time = None
last_chunk_time = time.time_ns()
total_response = ""
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = start_time
last_chunk_time = time.time_ns()
total_response = ""
total_response = response
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = start_time
last_chunk_time = time.time_ns()
total_response = response
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": response,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
data = {
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
"images": None,
},
"done_reason": "stop",
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
try:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": chunk,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = "Stream was cancelled by server"
else:
error_msg = f"Provider error: {error_msg}"
logger.error(f"Stream error: {error_msg}")
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": response,
"content": f"\n\nError: {error_msg}",
"images": None,
},
"error": f"\n\nError: {error_msg}",
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
# Send final message to close the stream
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
@ -595,103 +650,36 @@ class OllamaAPI:
"content": "",
"images": None,
},
"done_reason": "stop",
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
try:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
last_chunk_time = time.time_ns()
if first_chunk_time is None:
first_chunk_time = start_time
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
total_response += chunk
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": chunk,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = "Stream was cancelled by server"
else:
error_msg = f"Provider error: {error_msg}"
logger.error(f"Stream error: {error_msg}")
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": f"\n\nError: {error_msg}",
"images": None,
},
"error": f"\n\nError: {error_msg}",
"done": False,
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Send final message to close the stream
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
"images": None,
},
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
if first_chunk_time is None:
first_chunk_time = start_time
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
"images": None,
},
"done_reason": "stop",
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except Exception as e:
trace_exception(e)
raise
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
"images": None,
},
"done_reason": "stop",
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return StreamingResponse(
stream_generator(),
@ -753,5 +741,4 @@ class OllamaAPI:
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))

View file

@ -15,8 +15,6 @@ from lightrag.models.tenant import TenantContext
from lightrag.tenant_rag_manager import TenantRAGManager
from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception
router = APIRouter(tags=["query"])
@ -399,7 +397,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60, rag
else:
return QueryResponse(response=response_content, references=None)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@router.post(
@ -650,7 +647,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60, rag
},
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@router.post(
@ -1061,7 +1057,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60, rag
data={},
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
return router

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import weakref
import sys
import asyncio
import html
import csv
@ -40,6 +42,35 @@ from lightrag.constants import (
SOURCE_IDS_LIMIT_METHOD_FIFO,
)
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
class SafeStreamHandler(logging.StreamHandler):
"""StreamHandler that gracefully handles closed streams during shutdown.
This handler prevents "ValueError: I/O operation on closed file" errors
that can occur when pytest or other test frameworks close stdout/stderr
before Python's logging cleanup runs.
"""
def flush(self):
"""Flush the stream, ignoring errors if the stream is closed."""
try:
super().flush()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
def close(self):
"""Close the handler, ignoring errors if the stream is already closed."""
try:
super().close()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
# Initialize logger with basic configuration
logger = logging.getLogger("lightrag")
logger.propagate = False # prevent log message send to root logger
@ -47,7 +78,7 @@ logger.setLevel(logging.INFO)
# Add console handler if no handlers exist
if not logger.handlers:
console_handler = logging.StreamHandler()
console_handler = SafeStreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(levelname)s: %(message)s")
console_handler.setFormatter(formatter)
@ -56,6 +87,33 @@ if not logger.handlers:
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
def _patch_ascii_colors_console_handler() -> None:
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
try:
from ascii_colors import ConsoleHandler
except ImportError:
return
if getattr(ConsoleHandler, "_lightrag_patched", False):
return
original_handle_error = ConsoleHandler.handle_error
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
exc_type, _, _ = sys.exc_info()
if exc_type in (ValueError, OSError) and "close" in message.lower():
return
original_handle_error(self, message)
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
_patch_ascii_colors_console_handler()
# Global import for pypinyin with startup-time logging
try:
import pypinyin
@ -283,8 +341,8 @@ def setup_logger(
logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False
# Add console handler
console_handler = logging.StreamHandler()
# Add console handler with safe stream handling
console_handler = SafeStreamHandler()
console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level)
logger_instance.addHandler(console_handler)
@ -350,9 +408,20 @@ class TaskState:
@dataclass
class EmbeddingFunc:
"""Embedding function wrapper with dimension validation
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
This class should not be wrapped multiple times.
Args:
embedding_dim: Expected dimension of the embeddings
func: The actual embedding function to wrap
max_token_size: Optional token limit for the embedding model
send_dimensions: Whether to inject embedding_dim as a keyword argument
"""
embedding_dim: int
func: callable
max_token_size: int | None = None # deprecated keep it for compatible only
max_token_size: int | None = None # Token limit for the embedding model
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
@ -376,7 +445,32 @@ class EmbeddingFunc:
# Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim
return await self.func(*args, **kwargs)
# Call the actual embedding function
result = await self.func(*args, **kwargs)
# Validate embedding dimensions using total element count
total_elements = result.size # Total number of elements in the numpy array
expected_dim = self.embedding_dim
# Check if total elements can be evenly divided by embedding_dim
if total_elements % expected_dim != 0:
raise ValueError(
f"Embedding dimension mismatch detected: "
f"total elements ({total_elements}) cannot be evenly divided by "
f"expected dimension ({expected_dim}). "
)
# Optional: Verify vector count matches input text count
actual_vectors = total_elements // expected_dim
if args and isinstance(args[0], (list, tuple)):
expected_vectors = len(args[0])
if actual_vectors != expected_vectors:
raise ValueError(
f"Vector count mismatch: "
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
)
return result
def compute_args_hash(*args: Any) -> str:
@ -930,70 +1024,24 @@ def load_json(file_name):
def _sanitize_string_for_json(text: str) -> str:
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
This is a simpler sanitizer specifically for JSON that directly removes
problematic characters without attempting to encode first.
Uses regex for optimal performance with zero-copy optimization for clean strings.
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
Args:
text: String to sanitize
Returns:
Sanitized string safe for UTF-8 encoding in JSON
Original string if clean (zero-copy), sanitized string if dirty
"""
if not text:
return text
# Directly filter out problematic characters without pre-validation
sanitized = ""
for char in text:
code_point = ord(char)
# Skip surrogate characters (U+D800 to U+DFFF) - main cause of encoding errors
if 0xD800 <= code_point <= 0xDFFF:
continue
# Skip other non-characters in Unicode
elif code_point == 0xFFFE or code_point == 0xFFFF:
continue
else:
sanitized += char
# Fast path: Check if sanitization is needed using C-level regex search
if not _SURROGATE_PATTERN.search(text):
return text # Zero-copy for clean strings - most common case
return sanitized
def _sanitize_json_data(data: Any) -> Any:
"""Recursively sanitize all string values in data structure for safe UTF-8 encoding
DEPRECATED: This function creates a deep copy of the data which can be memory-intensive.
For new code, prefer using write_json with SanitizingJSONEncoder which sanitizes during
serialization without creating copies.
Handles all JSON-serializable types including:
- Dictionary keys and values
- Lists and tuples (preserves type)
- Nested structures
- Strings at any level
Args:
data: Data to sanitize (dict, list, tuple, str, or other types)
Returns:
Sanitized data with all strings cleaned of problematic characters
"""
if isinstance(data, dict):
# Sanitize both keys and values
return {
_sanitize_string_for_json(k)
if isinstance(k, str)
else k: _sanitize_json_data(v)
for k, v in data.items()
}
elif isinstance(data, (list, tuple)):
# Handle both lists and tuples, preserve original type
sanitized = [_sanitize_json_data(item) for item in data]
return type(data)(sanitized)
elif isinstance(data, str):
return _sanitize_string_for_json(data)
else:
# Numbers, booleans, None, etc. - return as-is
return data
# Slow path: Remove problematic characters using C-level regex substitution
return _SURROGATE_PATTERN.sub("", text)
class SanitizingJSONEncoder(json.JSONEncoder):