diff --git a/env.example b/env.example index 521021b8..84c1bbe7 100644 --- a/env.example +++ b/env.example @@ -194,9 +194,10 @@ LLM_BINDING_API_KEY=your_api_key ### Gemini example # LLM_BINDING=gemini # LLM_MODEL=gemini-flash-latest -# LLM_BINDING_HOST=https://generativelanguage.googleapis.com # LLM_BINDING_API_KEY=your_gemini_api_key -# GEMINI_LLM_MAX_OUTPUT_TOKENS=8192 +# LLM_BINDING_HOST=https://generativelanguage.googleapis.com +GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": 0, "include_thoughts": false}' +# GEMINI_LLM_MAX_OUTPUT_TOKENS=9000 # GEMINI_LLM_TEMPERATURE=0.7 ### OpenAI Compatible API Specific Parameters diff --git a/examples/lightrag_gemini_demo.py b/examples/lightrag_gemini_demo.py deleted file mode 100644 index cd2bb579..00000000 --- a/examples/lightrag_gemini_demo.py +++ /dev/null @@ -1,105 +0,0 @@ -# pip install -q -U google-genai to use gemini as a client - -import os -import numpy as np -from google import genai -from google.genai import types -from dotenv import load_dotenv -from lightrag.utils import EmbeddingFunc -from lightrag import LightRAG, QueryParam -from sentence_transformers import SentenceTransformer -from lightrag.kg.shared_storage import initialize_pipeline_status - -import asyncio -import nest_asyncio - -# Apply nest_asyncio to solve event loop issues -nest_asyncio.apply() - -load_dotenv() -gemini_api_key = os.getenv("GEMINI_API_KEY") - -WORKING_DIR = "./dickens" - -if os.path.exists(WORKING_DIR): - import shutil - - shutil.rmtree(WORKING_DIR) - -os.mkdir(WORKING_DIR) - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - # 1. Initialize the GenAI Client with your Gemini API Key - client = genai.Client(api_key=gemini_api_key) - - # 2. Combine prompts: system prompt, history, and user prompt - if history_messages is None: - history_messages = [] - - combined_prompt = "" - if system_prompt: - combined_prompt += f"{system_prompt}\n" - - for msg in history_messages: - # Each msg is expected to be a dict: {"role": "...", "content": "..."} - combined_prompt += f"{msg['role']}: {msg['content']}\n" - - # Finally, add the new user prompt - combined_prompt += f"user: {prompt}" - - # 3. Call the Gemini model - response = client.models.generate_content( - model="gemini-1.5-flash", - contents=[combined_prompt], - config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1), - ) - - # 4. Return the response text - return response.text - - -async def embedding_func(texts: list[str]) -> np.ndarray: - model = SentenceTransformer("all-MiniLM-L6-v2") - embeddings = model.encode(texts, convert_to_numpy=True) - return embeddings - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=384, - max_token_size=8192, - func=embedding_func, - ), - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) - file_path = "story.txt" - with open(file_path, "r") as file: - text = file.read() - - rag.insert(text) - - response = rag.query( - query="What is the main theme of the story?", - param=QueryParam(mode="hybrid", top_k=5, response_type="single line"), - ) - - print(response) - - -if __name__ == "__main__": - main() diff --git a/examples/lightrag_gemini_demo_no_tiktoken.py b/examples/lightrag_gemini_demo_no_tiktoken.py deleted file mode 100644 index 92c74201..00000000 --- a/examples/lightrag_gemini_demo_no_tiktoken.py +++ /dev/null @@ -1,230 +0,0 @@ -# pip install -q -U google-genai to use gemini as a client - -import os -from typing import Optional -import dataclasses -from pathlib import Path -import hashlib -import numpy as np -from google import genai -from google.genai import types -from dotenv import load_dotenv -from lightrag.utils import EmbeddingFunc, Tokenizer -from lightrag import LightRAG, QueryParam -from sentence_transformers import SentenceTransformer -from lightrag.kg.shared_storage import initialize_pipeline_status -import sentencepiece as spm -import requests - -import asyncio -import nest_asyncio - -# Apply nest_asyncio to solve event loop issues -nest_asyncio.apply() - -load_dotenv() -gemini_api_key = os.getenv("GEMINI_API_KEY") - -WORKING_DIR = "./dickens" - -if os.path.exists(WORKING_DIR): - import shutil - - shutil.rmtree(WORKING_DIR) - -os.mkdir(WORKING_DIR) - - -class GemmaTokenizer(Tokenizer): - # adapted from google-cloud-aiplatform[tokenization] - - @dataclasses.dataclass(frozen=True) - class _TokenizerConfig: - tokenizer_model_url: str - tokenizer_model_hash: str - - _TOKENIZERS = { - "google/gemma2": _TokenizerConfig( - tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model", - tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2", - ), - "google/gemma3": _TokenizerConfig( - tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model", - tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c", - ), - } - - def __init__( - self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None - ): - # https://github.com/google/gemma_pytorch/tree/main/tokenizer - if "1.5" in model_name or "1.0" in model_name: - # up to gemini 1.5 gemma2 is a comparable local tokenizer - # https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py - tokenizer_name = "google/gemma2" - else: - # for gemini > 2.0 gemma3 was used - tokenizer_name = "google/gemma3" - - file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url - tokenizer_model_name = file_url.rsplit("/", 1)[1] - expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash - - tokenizer_dir = Path(tokenizer_dir) - if tokenizer_dir.is_dir(): - file_path = tokenizer_dir / tokenizer_model_name - model_data = self._maybe_load_from_cache( - file_path=file_path, expected_hash=expected_hash - ) - else: - model_data = None - if not model_data: - model_data = self._load_from_url( - file_url=file_url, expected_hash=expected_hash - ) - self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data) - - tokenizer = spm.SentencePieceProcessor() - tokenizer.LoadFromSerializedProto(model_data) - super().__init__(model_name=model_name, tokenizer=tokenizer) - - def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool: - """Returns true if the content is valid by checking the hash.""" - return hashlib.sha256(model_data).hexdigest() == expected_hash - - def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes: - """Loads the model data from the cache path.""" - if not file_path.is_file(): - return - with open(file_path, "rb") as f: - content = f.read() - if self._is_valid_model(model_data=content, expected_hash=expected_hash): - return content - - # Cached file corrupted. - self._maybe_remove_file(file_path) - - def _load_from_url(self, file_url: str, expected_hash: str) -> bytes: - """Loads model bytes from the given file url.""" - resp = requests.get(file_url) - resp.raise_for_status() - content = resp.content - - if not self._is_valid_model(model_data=content, expected_hash=expected_hash): - actual_hash = hashlib.sha256(content).hexdigest() - raise ValueError( - f"Downloaded model file is corrupted." - f" Expected hash {expected_hash}. Got file hash {actual_hash}." - ) - return content - - @staticmethod - def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None: - """Saves the model data to the cache path.""" - try: - if not cache_path.is_file(): - cache_dir = cache_path.parent - cache_dir.mkdir(parents=True, exist_ok=True) - with open(cache_path, "wb") as f: - f.write(model_data) - except OSError: - # Don't raise if we cannot write file. - pass - - @staticmethod - def _maybe_remove_file(file_path: Path) -> None: - """Removes the file if exists.""" - if not file_path.is_file(): - return - try: - file_path.unlink() - except OSError: - # Don't raise if we cannot remove file. - pass - - # def encode(self, content: str) -> list[int]: - # return self.tokenizer.encode(content) - - # def decode(self, tokens: list[int]) -> str: - # return self.tokenizer.decode(tokens) - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - # 1. Initialize the GenAI Client with your Gemini API Key - client = genai.Client(api_key=gemini_api_key) - - # 2. Combine prompts: system prompt, history, and user prompt - if history_messages is None: - history_messages = [] - - combined_prompt = "" - if system_prompt: - combined_prompt += f"{system_prompt}\n" - - for msg in history_messages: - # Each msg is expected to be a dict: {"role": "...", "content": "..."} - combined_prompt += f"{msg['role']}: {msg['content']}\n" - - # Finally, add the new user prompt - combined_prompt += f"user: {prompt}" - - # 3. Call the Gemini model - response = client.models.generate_content( - model="gemini-1.5-flash", - contents=[combined_prompt], - config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1), - ) - - # 4. Return the response text - return response.text - - -async def embedding_func(texts: list[str]) -> np.ndarray: - model = SentenceTransformer("all-MiniLM-L6-v2") - embeddings = model.encode(texts, convert_to_numpy=True) - return embeddings - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - # tiktoken_model_name="gpt-4o-mini", - tokenizer=GemmaTokenizer( - tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), - model_name="gemini-2.0-flash", - ), - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=384, - max_token_size=8192, - func=embedding_func, - ), - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) - file_path = "story.txt" - with open(file_path, "r") as file: - text = file.read() - - rag.insert(text) - - response = rag.query( - query="What is the main theme of the story?", - param=QueryParam(mode="hybrid", top_k=5, response_type="single line"), - ) - - print(response) - - -if __name__ == "__main__": - main() diff --git a/examples/lightrag_gemini_track_token_demo.py b/examples/lightrag_gemini_track_token_demo.py deleted file mode 100644 index a72fc717..00000000 --- a/examples/lightrag_gemini_track_token_demo.py +++ /dev/null @@ -1,151 +0,0 @@ -# pip install -q -U google-genai to use gemini as a client - -import os -import asyncio -import numpy as np -import nest_asyncio -from google import genai -from google.genai import types -from dotenv import load_dotenv -from lightrag.utils import EmbeddingFunc -from lightrag import LightRAG, QueryParam -from lightrag.kg.shared_storage import initialize_pipeline_status -from lightrag.llm.siliconcloud import siliconcloud_embedding -from lightrag.utils import setup_logger -from lightrag.utils import TokenTracker - -setup_logger("lightrag", level="DEBUG") - -# Apply nest_asyncio to solve event loop issues -nest_asyncio.apply() - -load_dotenv() -gemini_api_key = os.getenv("GEMINI_API_KEY") -siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY") - -WORKING_DIR = "./dickens" - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - -token_tracker = TokenTracker() - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - # 1. Initialize the GenAI Client with your Gemini API Key - client = genai.Client(api_key=gemini_api_key) - - # 2. Combine prompts: system prompt, history, and user prompt - if history_messages is None: - history_messages = [] - - combined_prompt = "" - if system_prompt: - combined_prompt += f"{system_prompt}\n" - - for msg in history_messages: - # Each msg is expected to be a dict: {"role": "...", "content": "..."} - combined_prompt += f"{msg['role']}: {msg['content']}\n" - - # Finally, add the new user prompt - combined_prompt += f"user: {prompt}" - - # 3. Call the Gemini model - response = client.models.generate_content( - model="gemini-2.0-flash", - contents=[combined_prompt], - config=types.GenerateContentConfig( - max_output_tokens=5000, temperature=0, top_k=10 - ), - ) - - # 4. Get token counts with null safety - usage = getattr(response, "usage_metadata", None) - prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0 - completion_tokens = getattr(usage, "candidates_token_count", 0) or 0 - total_tokens = getattr(usage, "total_token_count", 0) or ( - prompt_tokens + completion_tokens - ) - - token_counts = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - } - - token_tracker.add_usage(token_counts) - - # 5. Return the response text - return response.text - - -async def embedding_func(texts: list[str]) -> np.ndarray: - return await siliconcloud_embedding( - texts, - model="BAAI/bge-m3", - api_key=siliconflow_api_key, - max_token_size=512, - ) - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - entity_extract_max_gleaning=1, - enable_llm_cache=True, - enable_llm_cache_for_entity_extract=True, - embedding_cache_config={"enabled": True, "similarity_threshold": 0.90}, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=1024, - max_token_size=8192, - func=embedding_func, - ), - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) - - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) - - # Context Manager Method - with token_tracker: - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="naive") - ) - ) - - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="local") - ) - ) - - print( - rag.query( - "What are the top themes in this story?", - param=QueryParam(mode="global"), - ) - ) - - print( - rag.query( - "What are the top themes in this story?", - param=QueryParam(mode="hybrid"), - ) - ) - - -if __name__ == "__main__": - main() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index a66d5d3c..d8bea386 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -511,7 +511,9 @@ def create_app(args): return optimized_azure_openai_model_complete - def create_optimized_gemini_llm_func(config_cache: LLMConfigCache, args): + def create_optimized_gemini_llm_func( + config_cache: LLMConfigCache, args, llm_timeout: int + ): """Create optimized Gemini LLM function with cached configuration""" async def optimized_gemini_model_complete( @@ -526,6 +528,8 @@ def create_app(args): if history_messages is None: history_messages = [] + # Use pre-processed configuration to avoid repeated parsing + kwargs["timeout"] = llm_timeout if ( config_cache.gemini_llm_options is not None and "generation_config" not in kwargs @@ -567,7 +571,7 @@ def create_app(args): config_cache, args, llm_timeout ) elif binding == "gemini": - return create_optimized_gemini_llm_func(config_cache, args) + return create_optimized_gemini_llm_func(config_cache, args, llm_timeout) else: # openai and compatible # Use optimized function with pre-processed configuration return create_optimized_openai_llm_func(config_cache, args, llm_timeout) diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py index 44ab5d2f..b3affb37 100644 --- a/lightrag/llm/binding_options.py +++ b/lightrag/llm/binding_options.py @@ -486,9 +486,9 @@ class GeminiLLMOptions(BindingOptions): presence_penalty: float = 0.0 frequency_penalty: float = 0.0 stop_sequences: List[str] = field(default_factory=list) - response_mime_type: str | None = None + seed: int | None = None + thinking_config: dict | None = None safety_settings: dict | None = None - system_instruction: str | None = None _help: ClassVar[dict[str, str]] = { "temperature": "Controls randomness (0.0-2.0, higher = more creative)", @@ -499,9 +499,9 @@ class GeminiLLMOptions(BindingOptions): "presence_penalty": "Penalty for token presence (-2.0 to 2.0)", "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)", "stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')", - "response_mime_type": "Desired MIME type for the response (e.g., application/json)", + "seed": "Random seed for reproducible generation (leave empty for random)", + "thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')", "safety_settings": "JSON object with Gemini safety settings overrides", - "system_instruction": "Default system instruction applied to every request", } diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index b8c64b31..f06ec6b3 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -33,24 +33,33 @@ LOG = logging.getLogger(__name__) @lru_cache(maxsize=8) -def _get_gemini_client(api_key: str, base_url: str | None) -> genai.Client: +def _get_gemini_client( + api_key: str, base_url: str | None, timeout: int | None = None +) -> genai.Client: """ Create (or fetch cached) Gemini client. Args: api_key: Google Gemini API key. base_url: Optional custom API endpoint. + timeout: Optional request timeout in milliseconds. Returns: genai.Client: Configured Gemini client instance. """ client_kwargs: dict[str, Any] = {"api_key": api_key} - if base_url and base_url != DEFAULT_GEMINI_ENDPOINT: + if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None: try: - client_kwargs["http_options"] = types.HttpOptions(api_endpoint=base_url) + http_options_kwargs = {} + if base_url and base_url != DEFAULT_GEMINI_ENDPOINT: + http_options_kwargs["api_endpoint"] = base_url + if timeout is not None: + http_options_kwargs["timeout"] = timeout + + client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs) except Exception as exc: # pragma: no cover - defensive - LOG.warning("Failed to apply custom Gemini endpoint %s: %s", base_url, exc) + LOG.warning("Failed to apply custom Gemini http_options: %s", exc) try: return genai.Client(**client_kwargs) @@ -114,24 +123,44 @@ def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> s return "\n".join(history_lines) -def _extract_response_text(response: Any) -> str: - if getattr(response, "text", None): - return response.text +def _extract_response_text( + response: Any, extract_thoughts: bool = False +) -> tuple[str, str]: + """ + Extract text content from Gemini response, separating regular content from thoughts. + Args: + response: Gemini API response object + extract_thoughts: Whether to extract thought content separately + + Returns: + Tuple of (regular_text, thought_text) + """ candidates = getattr(response, "candidates", None) if not candidates: - return "" + return ("", "") + + regular_parts: list[str] = [] + thought_parts: list[str] = [] - parts: list[str] = [] for candidate in candidates: if not getattr(candidate, "content", None): continue - for part in getattr(candidate.content, "parts", []): + # Use 'or []' to handle None values from parts attribute + for part in getattr(candidate.content, "parts", None) or []: text = getattr(part, "text", None) - if text: - parts.append(text) + if not text: + continue - return "\n".join(parts) + # Check if this part is thought content using the 'thought' attribute + is_thought = getattr(part, "thought", False) + + if is_thought and extract_thoughts: + thought_parts.append(text) + elif not is_thought: + regular_parts.append(text) + + return ("\n".join(regular_parts), "\n".join(thought_parts)) async def gemini_complete_if_cache( @@ -139,22 +168,58 @@ async def gemini_complete_if_cache( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, - *, - api_key: str | None = None, + enable_cot: bool = False, base_url: str | None = None, - generation_config: dict[str, Any] | None = None, - keyword_extraction: bool = False, + api_key: str | None = None, token_tracker: Any | None = None, - hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity stream: bool | None = None, - enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently - timeout: float | None = None, # noqa: ARG001 - handled by caller if needed + keyword_extraction: bool = False, + generation_config: dict[str, Any] | None = None, + timeout: int | None = None, **_: Any, ) -> str | AsyncIterator[str]: + """ + Complete a prompt using Gemini's API with Chain of Thought (COT) support. + + This function supports automatic integration of reasoning content from Gemini models + that provide Chain of Thought capabilities via the thinking_config API feature. + + COT Integration: + - When enable_cot=True: Thought content is wrapped in ... tags + - When enable_cot=False: Thought content is filtered out, only regular content returned + - Thought content is identified by the 'thought' attribute on response parts + - Requires thinking_config to be enabled in generation_config for API to return thoughts + + Args: + model: The Gemini model to use. + prompt: The prompt to complete. + system_prompt: Optional system prompt to include. + history_messages: Optional list of previous messages in the conversation. + api_key: Optional Gemini API key. If None, uses environment variable. + base_url: Optional custom API endpoint. + generation_config: Optional generation configuration dict. + keyword_extraction: Whether to use JSON response format. + token_tracker: Optional token usage tracker for monitoring API usage. + stream: Whether to stream the response. + hashing_kv: Storage interface (for interface parity with other bindings). + enable_cot: Whether to include Chain of Thought content in the response. + timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API). + **_: Additional keyword arguments (ignored). + + Returns: + The completed text (with COT content if enable_cot=True) or an async iterator + of text chunks if streaming. COT content is wrapped in ... tags. + + Raises: + RuntimeError: If the response from Gemini is empty. + ValueError: If API key is not provided or configured. + """ loop = asyncio.get_running_loop() key = _ensure_api_key(api_key) - client = _get_gemini_client(key, base_url) + # Convert timeout from seconds to milliseconds for Gemini API + timeout_ms = timeout * 1000 if timeout else None + client = _get_gemini_client(key, base_url, timeout_ms) history_block = _format_history_messages(history_messages) prompt_sections = [] @@ -184,6 +249,11 @@ async def gemini_complete_if_cache( usage_container: dict[str, Any] = {} def _stream_model() -> None: + # COT state tracking for streaming + cot_active = False + cot_started = False + initial_content_seen = False + try: stream_kwargs = dict(request_kwargs) stream_iterator = client.models.generate_content_stream(**stream_kwargs) @@ -191,20 +261,61 @@ async def gemini_complete_if_cache( usage = getattr(chunk, "usage_metadata", None) if usage is not None: usage_container["usage"] = usage - text_piece = getattr(chunk, "text", None) or _extract_response_text( - chunk + + # Extract both regular and thought content + regular_text, thought_text = _extract_response_text( + chunk, extract_thoughts=True ) - if text_piece: - loop.call_soon_threadsafe(queue.put_nowait, text_piece) + + if enable_cot: + # Process regular content + if regular_text: + if not initial_content_seen: + initial_content_seen = True + + # Close COT section if it was active + if cot_active: + loop.call_soon_threadsafe(queue.put_nowait, "") + cot_active = False + + # Send regular content + loop.call_soon_threadsafe(queue.put_nowait, regular_text) + + # Process thought content + if thought_text: + if not initial_content_seen and not cot_started: + # Start COT section + loop.call_soon_threadsafe(queue.put_nowait, "") + cot_active = True + cot_started = True + + # Send thought content if COT is active + if cot_active: + loop.call_soon_threadsafe( + queue.put_nowait, thought_text + ) + else: + # COT disabled - only send regular content + if regular_text: + loop.call_soon_threadsafe(queue.put_nowait, regular_text) + + # Ensure COT is properly closed if still active + if cot_active: + loop.call_soon_threadsafe(queue.put_nowait, "") + loop.call_soon_threadsafe(queue.put_nowait, None) except Exception as exc: # pragma: no cover - surface runtime issues + # Try to close COT tag before reporting error + if cot_active: + try: + loop.call_soon_threadsafe(queue.put_nowait, "") + except Exception: + pass loop.call_soon_threadsafe(queue.put_nowait, exc) loop.run_in_executor(None, _stream_model) async def _async_stream() -> AsyncIterator[str]: - accumulated = "" - emitted = "" try: while True: item = await queue.get() @@ -217,16 +328,9 @@ async def gemini_complete_if_cache( if "\\u" in chunk_text: chunk_text = safe_unicode_decode(chunk_text.encode("utf-8")) - accumulated += chunk_text - sanitized = remove_think_tags(accumulated) - if sanitized.startswith(emitted): - delta = sanitized[len(emitted) :] - else: - delta = sanitized - emitted = sanitized - - if delta: - yield delta + # Yield the chunk directly without filtering + # COT filtering is already handled in _stream_model() + yield chunk_text finally: usage = usage_container.get("usage") if token_tracker and usage: @@ -244,14 +348,33 @@ async def gemini_complete_if_cache( response = await asyncio.to_thread(_call_model) - text = _extract_response_text(response) - if not text: + # Extract both regular text and thought text + regular_text, thought_text = _extract_response_text(response, extract_thoughts=True) + + # Apply COT filtering logic based on enable_cot parameter + if enable_cot: + # Include thought content wrapped in tags + if thought_text and thought_text.strip(): + if not regular_text or regular_text.strip() == "": + # Only thought content available + final_text = f"{thought_text}" + else: + # Both content types present: prepend thought to regular content + final_text = f"{thought_text}{regular_text}" + else: + # No thought content, use regular content only + final_text = regular_text or "" + else: + # Filter out thought content, return only regular content + final_text = regular_text or "" + + if not final_text: raise RuntimeError("Gemini response did not contain any text content.") - if "\\u" in text: - text = safe_unicode_decode(text.encode("utf-8")) + if "\\u" in final_text: + final_text = safe_unicode_decode(final_text.encode("utf-8")) - text = remove_think_tags(text) + final_text = remove_think_tags(final_text) usage = getattr(response, "usage_metadata", None) if token_tracker and usage: @@ -263,8 +386,8 @@ async def gemini_complete_if_cache( } ) - logger.debug("Gemini response length: %s", len(text)) - return text + logger.debug("Gemini response length: %s", len(final_text)) + return final_text async def gemini_model_complete( diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index fce33cac..3339ea3a 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -138,6 +138,9 @@ async def openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, token_tracker: Any | None = None, + keyword_extraction: bool = False, # Will be removed from kwargs before passing to OpenAI + stream: bool | None = None, + timeout: int | None = None, **kwargs: Any, ) -> str: """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. @@ -172,8 +175,9 @@ async def openai_complete_if_cache( - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. These will be passed to the client constructor but will be overridden by explicit parameters (api_key, base_url). - - hashing_kv: Will be removed from kwargs before passing to OpenAI. - keyword_extraction: Will be removed from kwargs before passing to OpenAI. + - stream: Whether to stream the response. Default is False. + - timeout: Request timeout in seconds. Default is None. Returns: The completed text (with integrated COT content if available) or an async iterator