diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index f3991403..4cec3e71 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -114,28 +114,44 @@ def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> s return "\n".join(history_lines) -def _extract_response_text(response: Any) -> str: +def _extract_response_text( + response: Any, extract_thoughts: bool = False +) -> tuple[str, str]: """ - Extract text content from Gemini response, avoiding warnings about non-text parts. + Extract text content from Gemini response, separating regular content from thoughts. - Always extracts text manually from parts to avoid triggering warnings when - non-text parts (like 'thought_signature') are present in the response. + Args: + response: Gemini API response object + extract_thoughts: Whether to extract thought content separately + + Returns: + Tuple of (regular_text, thought_text) """ candidates = getattr(response, "candidates", None) if not candidates: - return "" + return ("", "") + + regular_parts: list[str] = [] + thought_parts: list[str] = [] - parts: list[str] = [] for candidate in candidates: if not getattr(candidate, "content", None): continue - for part in getattr(candidate.content, "parts", []): - # Only extract text parts to avoid non-text content like thought_signature + # Use 'or []' to handle None values from parts attribute + for part in getattr(candidate.content, "parts", None) or []: text = getattr(part, "text", None) - if text: - parts.append(text) + if not text: + continue - return "\n".join(parts) + # Check if this part is thought content using the 'thought' attribute + is_thought = getattr(part, "thought", False) + + if is_thought and extract_thoughts: + thought_parts.append(text) + elif not is_thought: + regular_parts.append(text) + + return ("\n".join(regular_parts), "\n".join(thought_parts)) async def gemini_complete_if_cache( @@ -143,18 +159,51 @@ async def gemini_complete_if_cache( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, - *, - api_key: str | None = None, + enable_cot: bool = False, base_url: str | None = None, - generation_config: dict[str, Any] | None = None, - keyword_extraction: bool = False, + api_key: str | None = None, token_tracker: Any | None = None, - hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity stream: bool | None = None, - enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently - timeout: float | None = None, # noqa: ARG001 - handled by caller if needed + keyword_extraction: bool = False, + generation_config: dict[str, Any] | None = None, **_: Any, ) -> str | AsyncIterator[str]: + """ + Complete a prompt using Gemini's API with Chain of Thought (COT) support. + + This function supports automatic integration of reasoning content from Gemini models + that provide Chain of Thought capabilities via the thinking_config API feature. + + COT Integration: + - When enable_cot=True: Thought content is wrapped in ... tags + - When enable_cot=False: Thought content is filtered out, only regular content returned + - Thought content is identified by the 'thought' attribute on response parts + - Requires thinking_config to be enabled in generation_config for API to return thoughts + + Args: + model: The Gemini model to use. + prompt: The prompt to complete. + system_prompt: Optional system prompt to include. + history_messages: Optional list of previous messages in the conversation. + api_key: Optional Gemini API key. If None, uses environment variable. + base_url: Optional custom API endpoint. + generation_config: Optional generation configuration dict. + keyword_extraction: Whether to use JSON response format. + token_tracker: Optional token usage tracker for monitoring API usage. + hashing_kv: Storage interface (for interface parity with other bindings). + stream: Whether to stream the response. + enable_cot: Whether to include Chain of Thought content in the response. + timeout: Request timeout (handled by caller if needed). + **_: Additional keyword arguments (ignored). + + Returns: + The completed text (with COT content if enable_cot=True) or an async iterator + of text chunks if streaming. COT content is wrapped in ... tags. + + Raises: + RuntimeError: If the response from Gemini is empty. + ValueError: If API key is not provided or configured. + """ loop = asyncio.get_running_loop() key = _ensure_api_key(api_key) @@ -188,6 +237,11 @@ async def gemini_complete_if_cache( usage_container: dict[str, Any] = {} def _stream_model() -> None: + # COT state tracking for streaming + cot_active = False + cot_started = False + initial_content_seen = False + try: stream_kwargs = dict(request_kwargs) stream_iterator = client.models.generate_content_stream(**stream_kwargs) @@ -195,19 +249,59 @@ async def gemini_complete_if_cache( usage = getattr(chunk, "usage_metadata", None) if usage is not None: usage_container["usage"] = usage - # Always use manual extraction to avoid warnings about non-text parts - text_piece = _extract_response_text(chunk) - if text_piece: - loop.call_soon_threadsafe(queue.put_nowait, text_piece) + + # Extract both regular and thought content + regular_text, thought_text = _extract_response_text( + chunk, extract_thoughts=True + ) + + if enable_cot: + # Process regular content + if regular_text: + if not initial_content_seen: + initial_content_seen = True + + # Close COT section if it was active + if cot_active: + loop.call_soon_threadsafe(queue.put_nowait, "") + cot_active = False + + # Send regular content + loop.call_soon_threadsafe(queue.put_nowait, regular_text) + + # Process thought content + if thought_text: + if not initial_content_seen and not cot_started: + # Start COT section + loop.call_soon_threadsafe(queue.put_nowait, "") + cot_active = True + cot_started = True + + # Send thought content if COT is active + if cot_active: + loop.call_soon_threadsafe(queue.put_nowait, thought_text) + else: + # COT disabled - only send regular content + if regular_text: + loop.call_soon_threadsafe(queue.put_nowait, regular_text) + + # Ensure COT is properly closed if still active + if cot_active: + loop.call_soon_threadsafe(queue.put_nowait, "") + loop.call_soon_threadsafe(queue.put_nowait, None) except Exception as exc: # pragma: no cover - surface runtime issues + # Try to close COT tag before reporting error + if cot_active: + try: + loop.call_soon_threadsafe(queue.put_nowait, "") + except Exception: + pass loop.call_soon_threadsafe(queue.put_nowait, exc) loop.run_in_executor(None, _stream_model) async def _async_stream() -> AsyncIterator[str]: - accumulated = "" - emitted = "" try: while True: item = await queue.get() @@ -220,16 +314,9 @@ async def gemini_complete_if_cache( if "\\u" in chunk_text: chunk_text = safe_unicode_decode(chunk_text.encode("utf-8")) - accumulated += chunk_text - sanitized = remove_think_tags(accumulated) - if sanitized.startswith(emitted): - delta = sanitized[len(emitted) :] - else: - delta = sanitized - emitted = sanitized - - if delta: - yield delta + # Yield the chunk directly without filtering + # COT filtering is already handled in _stream_model() + yield chunk_text finally: usage = usage_container.get("usage") if token_tracker and usage: @@ -247,14 +334,33 @@ async def gemini_complete_if_cache( response = await asyncio.to_thread(_call_model) - text = _extract_response_text(response) - if not text: + # Extract both regular text and thought text + regular_text, thought_text = _extract_response_text(response, extract_thoughts=True) + + # Apply COT filtering logic based on enable_cot parameter + if enable_cot: + # Include thought content wrapped in tags + if thought_text and thought_text.strip(): + if not regular_text or regular_text.strip() == "": + # Only thought content available + final_text = f"{thought_text}" + else: + # Both content types present: prepend thought to regular content + final_text = f"{thought_text}{regular_text}" + else: + # No thought content, use regular content only + final_text = regular_text or "" + else: + # Filter out thought content, return only regular content + final_text = regular_text or "" + + if not final_text: raise RuntimeError("Gemini response did not contain any text content.") - if "\\u" in text: - text = safe_unicode_decode(text.encode("utf-8")) + if "\\u" in final_text: + final_text = safe_unicode_decode(final_text.encode("utf-8")) - text = remove_think_tags(text) + final_text = remove_think_tags(final_text) usage = getattr(response, "usage_metadata", None) if token_tracker and usage: @@ -266,8 +372,8 @@ async def gemini_complete_if_cache( } ) - logger.debug("Gemini response length: %s", len(text)) - return text + logger.debug("Gemini response length: %s", len(final_text)) + return final_text async def gemini_model_complete( diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 66c3bfe4..2cdbb72b 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -174,6 +174,7 @@ async def openai_complete_if_cache( explicit parameters (api_key, base_url). - hashing_kv: Will be removed from kwargs before passing to OpenAI. - keyword_extraction: Will be removed from kwargs before passing to OpenAI. + - stream: Whether to stream the response. Default is False. Returns: The completed text (with integrated COT content if available) or an async iterator