diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index bf677c19..52f6f218 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -8,7 +8,6 @@ import re from enum import Enum from fastapi.responses import StreamingResponse import asyncio -from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam from lightrag.utils import TiktokenTokenizer from lightrag.api.utils_api import get_combined_auth_dependency @@ -335,118 +334,113 @@ class OllamaAPI: ) async def stream_generator(): - try: - first_chunk_time = None + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = start_time last_chunk_time = time.time_ns() - total_response = "" + total_response = response - # Ensure response is an async generator - if isinstance(response, str): - # If it's a string, send in two parts - first_chunk_time = start_time - last_chunk_time = time.time_ns() - total_response = response + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": response, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" - data = { + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": "", + "done": True, + "done_reason": "stop", + "context": [], + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" + + logger.error(f"Stream error: {error_msg}") + + # Send error message to client + error_data = { "model": self.ollama_server_infos.LIGHTRAG_MODEL, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": response, + "response": f"\n\nError: {error_msg}", + "error": f"\n\nError: {error_msg}", "done": False, } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { + # Send final message to close the stream + final_data = { "model": self.ollama_server_infos.LIGHTRAG_MODEL, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "response": "", "done": True, - "done_reason": "stop", - "context": [], - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - else: - try: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() - - last_chunk_time = time.time_ns() - - total_response += chunk - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": chunk, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - except (asyncio.CancelledError, Exception) as e: - error_msg = str(e) - if isinstance(e, asyncio.CancelledError): - error_msg = "Stream was cancelled by server" - else: - error_msg = f"Provider error: {error_msg}" - - logger.error(f"Stream error: {error_msg}") - - # Send error message to client - error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": f"\n\nError: {error_msg}", - "error": f"\n\nError: {error_msg}", - "done": False, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Send final message to close the stream - final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": "", - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return - if first_chunk_time is None: - first_chunk_time = start_time - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": "", - "done": True, - "done_reason": "stop", - "context": [], - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return + if first_chunk_time is None: + first_chunk_time = start_time + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time - except Exception as e: - trace_exception(e) - raise + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": "", + "done": True, + "done_reason": "stop", + "context": [], + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + return return StreamingResponse( stream_generator(), @@ -488,7 +482,6 @@ class OllamaAPI: "eval_duration": eval_time, } except Exception as e: - trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @self.router.post( @@ -558,36 +551,98 @@ class OllamaAPI: ) async def stream_generator(): - try: - first_chunk_time = None + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = start_time last_chunk_time = time.time_ns() - total_response = "" + total_response = response - # Ensure response is an async generator - if isinstance(response, str): - # If it's a string, send in two parts - first_chunk_time = start_time - last_chunk_time = time.time_ns() - total_response = response + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": response, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" - data = { + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": "", + "images": None, + }, + "done_reason": "stop", + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" + + logger.error(f"Stream error: {error_msg}") + + # Send error message to client + error_data = { "model": self.ollama_server_infos.LIGHTRAG_MODEL, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", - "content": response, + "content": f"\n\nError: {error_msg}", "images": None, }, + "error": f"\n\nError: {error_msg}", "done": False, } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { + # Send final message to close the stream + final_data = { "model": self.ollama_server_infos.LIGHTRAG_MODEL, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { @@ -595,103 +650,36 @@ class OllamaAPI: "content": "", "images": None, }, - "done_reason": "stop", "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - else: - try: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return - last_chunk_time = time.time_ns() + if first_chunk_time is None: + first_chunk_time = start_time + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time - total_response += chunk - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": chunk, - "images": None, - }, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - except (asyncio.CancelledError, Exception) as e: - error_msg = str(e) - if isinstance(e, asyncio.CancelledError): - error_msg = "Stream was cancelled by server" - else: - error_msg = f"Provider error: {error_msg}" - - logger.error(f"Stream error: {error_msg}") - - # Send error message to client - error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": f"\n\nError: {error_msg}", - "images": None, - }, - "error": f"\n\nError: {error_msg}", - "done": False, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Send final message to close the stream - final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": "", - "images": None, - }, - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return - - if first_chunk_time is None: - first_chunk_time = start_time - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": "", - "images": None, - }, - "done_reason": "stop", - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - except Exception as e: - trace_exception(e) - raise + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": "", + "images": None, + }, + "done_reason": "stop", + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" return StreamingResponse( stream_generator(), @@ -753,5 +741,4 @@ class OllamaAPI: "eval_duration": eval_time, } except Exception as e: - trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index a5ebb995..6c892da2 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -15,8 +15,6 @@ from lightrag.models.tenant import TenantContext from lightrag.tenant_rag_manager import TenantRAGManager from pydantic import BaseModel, Field, field_validator -from ascii_colors import trace_exception - router = APIRouter(tags=["query"]) @@ -399,7 +397,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60, rag else: return QueryResponse(response=response_content, references=None) except Exception as e: - trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @router.post( @@ -650,7 +647,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60, rag }, ) except Exception as e: - trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @router.post( @@ -1061,7 +1057,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60, rag data={}, ) except Exception as e: - trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) return router diff --git a/lightrag/utils.py b/lightrag/utils.py index da27926c..6a7237c0 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,6 +1,8 @@ from __future__ import annotations import weakref +import sys + import asyncio import html import csv @@ -40,6 +42,35 @@ from lightrag.constants import ( SOURCE_IDS_LIMIT_METHOD_FIFO, ) +# Precompile regex pattern for JSON sanitization (module-level, compiled once) +_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]") + + +class SafeStreamHandler(logging.StreamHandler): + """StreamHandler that gracefully handles closed streams during shutdown. + + This handler prevents "ValueError: I/O operation on closed file" errors + that can occur when pytest or other test frameworks close stdout/stderr + before Python's logging cleanup runs. + """ + + def flush(self): + """Flush the stream, ignoring errors if the stream is closed.""" + try: + super().flush() + except (ValueError, OSError): + # Stream is closed or otherwise unavailable, silently ignore + pass + + def close(self): + """Close the handler, ignoring errors if the stream is already closed.""" + try: + super().close() + except (ValueError, OSError): + # Stream is closed or otherwise unavailable, silently ignore + pass + + # Initialize logger with basic configuration logger = logging.getLogger("lightrag") logger.propagate = False # prevent log message send to root logger @@ -47,7 +78,7 @@ logger.setLevel(logging.INFO) # Add console handler if no handlers exist if not logger.handlers: - console_handler = logging.StreamHandler() + console_handler = SafeStreamHandler() console_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(levelname)s: %(message)s") console_handler.setFormatter(formatter) @@ -56,6 +87,33 @@ if not logger.handlers: # Set httpx logging level to WARNING logging.getLogger("httpx").setLevel(logging.WARNING) + +def _patch_ascii_colors_console_handler() -> None: + """Prevent ascii_colors from printing flush errors during interpreter exit.""" + + try: + from ascii_colors import ConsoleHandler + except ImportError: + return + + if getattr(ConsoleHandler, "_lightrag_patched", False): + return + + original_handle_error = ConsoleHandler.handle_error + + def _safe_handle_error(self, message: str) -> None: # type: ignore[override] + exc_type, _, _ = sys.exc_info() + if exc_type in (ValueError, OSError) and "close" in message.lower(): + return + original_handle_error(self, message) + + ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment] + ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined] + + +_patch_ascii_colors_console_handler() + + # Global import for pypinyin with startup-time logging try: import pypinyin @@ -283,8 +341,8 @@ def setup_logger( logger_instance.handlers = [] # Clear existing handlers logger_instance.propagate = False - # Add console handler - console_handler = logging.StreamHandler() + # Add console handler with safe stream handling + console_handler = SafeStreamHandler() console_handler.setFormatter(simple_formatter) console_handler.setLevel(level) logger_instance.addHandler(console_handler) @@ -350,9 +408,20 @@ class TaskState: @dataclass class EmbeddingFunc: + """Embedding function wrapper with dimension validation + This class wraps an embedding function to ensure that the output embeddings have the correct dimension. + This class should not be wrapped multiple times. + + Args: + embedding_dim: Expected dimension of the embeddings + func: The actual embedding function to wrap + max_token_size: Optional token limit for the embedding model + send_dimensions: Whether to inject embedding_dim as a keyword argument + """ + embedding_dim: int func: callable - max_token_size: int | None = None # deprecated keep it for compatible only + max_token_size: int | None = None # Token limit for the embedding model send_dimensions: bool = ( False # Control whether to send embedding_dim to the function ) @@ -376,7 +445,32 @@ class EmbeddingFunc: # Inject embedding_dim from decorator kwargs["embedding_dim"] = self.embedding_dim - return await self.func(*args, **kwargs) + # Call the actual embedding function + result = await self.func(*args, **kwargs) + + # Validate embedding dimensions using total element count + total_elements = result.size # Total number of elements in the numpy array + expected_dim = self.embedding_dim + + # Check if total elements can be evenly divided by embedding_dim + if total_elements % expected_dim != 0: + raise ValueError( + f"Embedding dimension mismatch detected: " + f"total elements ({total_elements}) cannot be evenly divided by " + f"expected dimension ({expected_dim}). " + ) + + # Optional: Verify vector count matches input text count + actual_vectors = total_elements // expected_dim + if args and isinstance(args[0], (list, tuple)): + expected_vectors = len(args[0]) + if actual_vectors != expected_vectors: + raise ValueError( + f"Vector count mismatch: " + f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)." + ) + + return result def compute_args_hash(*args: Any) -> str: @@ -930,70 +1024,24 @@ def load_json(file_name): def _sanitize_string_for_json(text: str) -> str: """Remove characters that cannot be encoded in UTF-8 for JSON serialization. - This is a simpler sanitizer specifically for JSON that directly removes - problematic characters without attempting to encode first. + Uses regex for optimal performance with zero-copy optimization for clean strings. + Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings. Args: text: String to sanitize Returns: - Sanitized string safe for UTF-8 encoding in JSON + Original string if clean (zero-copy), sanitized string if dirty """ if not text: return text - # Directly filter out problematic characters without pre-validation - sanitized = "" - for char in text: - code_point = ord(char) - # Skip surrogate characters (U+D800 to U+DFFF) - main cause of encoding errors - if 0xD800 <= code_point <= 0xDFFF: - continue - # Skip other non-characters in Unicode - elif code_point == 0xFFFE or code_point == 0xFFFF: - continue - else: - sanitized += char + # Fast path: Check if sanitization is needed using C-level regex search + if not _SURROGATE_PATTERN.search(text): + return text # Zero-copy for clean strings - most common case - return sanitized - - -def _sanitize_json_data(data: Any) -> Any: - """Recursively sanitize all string values in data structure for safe UTF-8 encoding - - DEPRECATED: This function creates a deep copy of the data which can be memory-intensive. - For new code, prefer using write_json with SanitizingJSONEncoder which sanitizes during - serialization without creating copies. - - Handles all JSON-serializable types including: - - Dictionary keys and values - - Lists and tuples (preserves type) - - Nested structures - - Strings at any level - - Args: - data: Data to sanitize (dict, list, tuple, str, or other types) - - Returns: - Sanitized data with all strings cleaned of problematic characters - """ - if isinstance(data, dict): - # Sanitize both keys and values - return { - _sanitize_string_for_json(k) - if isinstance(k, str) - else k: _sanitize_json_data(v) - for k, v in data.items() - } - elif isinstance(data, (list, tuple)): - # Handle both lists and tuples, preserve original type - sanitized = [_sanitize_json_data(item) for item in data] - return type(data)(sanitized) - elif isinstance(data, str): - return _sanitize_string_for_json(data) - else: - # Numbers, booleans, None, etc. - return as-is - return data + # Slow path: Remove problematic characters using C-level regex substitution + return _SURROGATE_PATTERN.sub("", text) class SanitizingJSONEncoder(json.JSONEncoder):