From 3ba4b03ed69679b0c1294f456b485baff6b71836 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:14:29 +0800 Subject: [PATCH] cherry-pick 8c275553 --- lightrag/llm/gemini.py | 190 +++++++++-------------------------------- 1 file changed, 42 insertions(+), 148 deletions(-) diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index 4cec3e71..f3991403 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -114,44 +114,28 @@ def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> s return "\n".join(history_lines) -def _extract_response_text( - response: Any, extract_thoughts: bool = False -) -> tuple[str, str]: +def _extract_response_text(response: Any) -> str: """ - Extract text content from Gemini response, separating regular content from thoughts. + Extract text content from Gemini response, avoiding warnings about non-text parts. - Args: - response: Gemini API response object - extract_thoughts: Whether to extract thought content separately - - Returns: - Tuple of (regular_text, thought_text) + Always extracts text manually from parts to avoid triggering warnings when + non-text parts (like 'thought_signature') are present in the response. """ candidates = getattr(response, "candidates", None) if not candidates: - return ("", "") - - regular_parts: list[str] = [] - thought_parts: list[str] = [] + return "" + parts: list[str] = [] for candidate in candidates: if not getattr(candidate, "content", None): continue - # Use 'or []' to handle None values from parts attribute - for part in getattr(candidate.content, "parts", None) or []: + for part in getattr(candidate.content, "parts", []): + # Only extract text parts to avoid non-text content like thought_signature text = getattr(part, "text", None) - if not text: - continue + if text: + parts.append(text) - # Check if this part is thought content using the 'thought' attribute - is_thought = getattr(part, "thought", False) - - if is_thought and extract_thoughts: - thought_parts.append(text) - elif not is_thought: - regular_parts.append(text) - - return ("\n".join(regular_parts), "\n".join(thought_parts)) + return "\n".join(parts) async def gemini_complete_if_cache( @@ -159,51 +143,18 @@ async def gemini_complete_if_cache( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, - enable_cot: bool = False, - base_url: str | None = None, + *, api_key: str | None = None, - token_tracker: Any | None = None, - stream: bool | None = None, - keyword_extraction: bool = False, + base_url: str | None = None, generation_config: dict[str, Any] | None = None, + keyword_extraction: bool = False, + token_tracker: Any | None = None, + hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity + stream: bool | None = None, + enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently + timeout: float | None = None, # noqa: ARG001 - handled by caller if needed **_: Any, ) -> str | AsyncIterator[str]: - """ - Complete a prompt using Gemini's API with Chain of Thought (COT) support. - - This function supports automatic integration of reasoning content from Gemini models - that provide Chain of Thought capabilities via the thinking_config API feature. - - COT Integration: - - When enable_cot=True: Thought content is wrapped in ... tags - - When enable_cot=False: Thought content is filtered out, only regular content returned - - Thought content is identified by the 'thought' attribute on response parts - - Requires thinking_config to be enabled in generation_config for API to return thoughts - - Args: - model: The Gemini model to use. - prompt: The prompt to complete. - system_prompt: Optional system prompt to include. - history_messages: Optional list of previous messages in the conversation. - api_key: Optional Gemini API key. If None, uses environment variable. - base_url: Optional custom API endpoint. - generation_config: Optional generation configuration dict. - keyword_extraction: Whether to use JSON response format. - token_tracker: Optional token usage tracker for monitoring API usage. - hashing_kv: Storage interface (for interface parity with other bindings). - stream: Whether to stream the response. - enable_cot: Whether to include Chain of Thought content in the response. - timeout: Request timeout (handled by caller if needed). - **_: Additional keyword arguments (ignored). - - Returns: - The completed text (with COT content if enable_cot=True) or an async iterator - of text chunks if streaming. COT content is wrapped in ... tags. - - Raises: - RuntimeError: If the response from Gemini is empty. - ValueError: If API key is not provided or configured. - """ loop = asyncio.get_running_loop() key = _ensure_api_key(api_key) @@ -237,11 +188,6 @@ async def gemini_complete_if_cache( usage_container: dict[str, Any] = {} def _stream_model() -> None: - # COT state tracking for streaming - cot_active = False - cot_started = False - initial_content_seen = False - try: stream_kwargs = dict(request_kwargs) stream_iterator = client.models.generate_content_stream(**stream_kwargs) @@ -249,59 +195,19 @@ async def gemini_complete_if_cache( usage = getattr(chunk, "usage_metadata", None) if usage is not None: usage_container["usage"] = usage - - # Extract both regular and thought content - regular_text, thought_text = _extract_response_text( - chunk, extract_thoughts=True - ) - - if enable_cot: - # Process regular content - if regular_text: - if not initial_content_seen: - initial_content_seen = True - - # Close COT section if it was active - if cot_active: - loop.call_soon_threadsafe(queue.put_nowait, "") - cot_active = False - - # Send regular content - loop.call_soon_threadsafe(queue.put_nowait, regular_text) - - # Process thought content - if thought_text: - if not initial_content_seen and not cot_started: - # Start COT section - loop.call_soon_threadsafe(queue.put_nowait, "") - cot_active = True - cot_started = True - - # Send thought content if COT is active - if cot_active: - loop.call_soon_threadsafe(queue.put_nowait, thought_text) - else: - # COT disabled - only send regular content - if regular_text: - loop.call_soon_threadsafe(queue.put_nowait, regular_text) - - # Ensure COT is properly closed if still active - if cot_active: - loop.call_soon_threadsafe(queue.put_nowait, "") - + # Always use manual extraction to avoid warnings about non-text parts + text_piece = _extract_response_text(chunk) + if text_piece: + loop.call_soon_threadsafe(queue.put_nowait, text_piece) loop.call_soon_threadsafe(queue.put_nowait, None) except Exception as exc: # pragma: no cover - surface runtime issues - # Try to close COT tag before reporting error - if cot_active: - try: - loop.call_soon_threadsafe(queue.put_nowait, "") - except Exception: - pass loop.call_soon_threadsafe(queue.put_nowait, exc) loop.run_in_executor(None, _stream_model) async def _async_stream() -> AsyncIterator[str]: + accumulated = "" + emitted = "" try: while True: item = await queue.get() @@ -314,9 +220,16 @@ async def gemini_complete_if_cache( if "\\u" in chunk_text: chunk_text = safe_unicode_decode(chunk_text.encode("utf-8")) - # Yield the chunk directly without filtering - # COT filtering is already handled in _stream_model() - yield chunk_text + accumulated += chunk_text + sanitized = remove_think_tags(accumulated) + if sanitized.startswith(emitted): + delta = sanitized[len(emitted) :] + else: + delta = sanitized + emitted = sanitized + + if delta: + yield delta finally: usage = usage_container.get("usage") if token_tracker and usage: @@ -334,33 +247,14 @@ async def gemini_complete_if_cache( response = await asyncio.to_thread(_call_model) - # Extract both regular text and thought text - regular_text, thought_text = _extract_response_text(response, extract_thoughts=True) - - # Apply COT filtering logic based on enable_cot parameter - if enable_cot: - # Include thought content wrapped in tags - if thought_text and thought_text.strip(): - if not regular_text or regular_text.strip() == "": - # Only thought content available - final_text = f"{thought_text}" - else: - # Both content types present: prepend thought to regular content - final_text = f"{thought_text}{regular_text}" - else: - # No thought content, use regular content only - final_text = regular_text or "" - else: - # Filter out thought content, return only regular content - final_text = regular_text or "" - - if not final_text: + text = _extract_response_text(response) + if not text: raise RuntimeError("Gemini response did not contain any text content.") - if "\\u" in final_text: - final_text = safe_unicode_decode(final_text.encode("utf-8")) + if "\\u" in text: + text = safe_unicode_decode(text.encode("utf-8")) - final_text = remove_think_tags(final_text) + text = remove_think_tags(text) usage = getattr(response, "usage_metadata", None) if token_tracker and usage: @@ -372,8 +266,8 @@ async def gemini_complete_if_cache( } ) - logger.debug("Gemini response length: %s", len(final_text)) - return final_text + logger.debug("Gemini response length: %s", len(text)) + return text async def gemini_model_complete(