diff --git a/env.example b/env.example
index 521021b8..84c1bbe7 100644
--- a/env.example
+++ b/env.example
@@ -194,9 +194,10 @@ LLM_BINDING_API_KEY=your_api_key
### Gemini example
# LLM_BINDING=gemini
# LLM_MODEL=gemini-flash-latest
-# LLM_BINDING_HOST=https://generativelanguage.googleapis.com
# LLM_BINDING_API_KEY=your_gemini_api_key
-# GEMINI_LLM_MAX_OUTPUT_TOKENS=8192
+# LLM_BINDING_HOST=https://generativelanguage.googleapis.com
+GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": 0, "include_thoughts": false}'
+# GEMINI_LLM_MAX_OUTPUT_TOKENS=9000
# GEMINI_LLM_TEMPERATURE=0.7
### OpenAI Compatible API Specific Parameters
diff --git a/examples/lightrag_gemini_demo.py b/examples/lightrag_gemini_demo.py
deleted file mode 100644
index cd2bb579..00000000
--- a/examples/lightrag_gemini_demo.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# pip install -q -U google-genai to use gemini as a client
-
-import os
-import numpy as np
-from google import genai
-from google.genai import types
-from dotenv import load_dotenv
-from lightrag.utils import EmbeddingFunc
-from lightrag import LightRAG, QueryParam
-from sentence_transformers import SentenceTransformer
-from lightrag.kg.shared_storage import initialize_pipeline_status
-
-import asyncio
-import nest_asyncio
-
-# Apply nest_asyncio to solve event loop issues
-nest_asyncio.apply()
-
-load_dotenv()
-gemini_api_key = os.getenv("GEMINI_API_KEY")
-
-WORKING_DIR = "./dickens"
-
-if os.path.exists(WORKING_DIR):
- import shutil
-
- shutil.rmtree(WORKING_DIR)
-
-os.mkdir(WORKING_DIR)
-
-
-async def llm_model_func(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- # 1. Initialize the GenAI Client with your Gemini API Key
- client = genai.Client(api_key=gemini_api_key)
-
- # 2. Combine prompts: system prompt, history, and user prompt
- if history_messages is None:
- history_messages = []
-
- combined_prompt = ""
- if system_prompt:
- combined_prompt += f"{system_prompt}\n"
-
- for msg in history_messages:
- # Each msg is expected to be a dict: {"role": "...", "content": "..."}
- combined_prompt += f"{msg['role']}: {msg['content']}\n"
-
- # Finally, add the new user prompt
- combined_prompt += f"user: {prompt}"
-
- # 3. Call the Gemini model
- response = client.models.generate_content(
- model="gemini-1.5-flash",
- contents=[combined_prompt],
- config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
- )
-
- # 4. Return the response text
- return response.text
-
-
-async def embedding_func(texts: list[str]) -> np.ndarray:
- model = SentenceTransformer("all-MiniLM-L6-v2")
- embeddings = model.encode(texts, convert_to_numpy=True)
- return embeddings
-
-
-async def initialize_rag():
- rag = LightRAG(
- working_dir=WORKING_DIR,
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=384,
- max_token_size=8192,
- func=embedding_func,
- ),
- )
-
- await rag.initialize_storages()
- await initialize_pipeline_status()
-
- return rag
-
-
-def main():
- # Initialize RAG instance
- rag = asyncio.run(initialize_rag())
- file_path = "story.txt"
- with open(file_path, "r") as file:
- text = file.read()
-
- rag.insert(text)
-
- response = rag.query(
- query="What is the main theme of the story?",
- param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
- )
-
- print(response)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/lightrag_gemini_demo_no_tiktoken.py b/examples/lightrag_gemini_demo_no_tiktoken.py
deleted file mode 100644
index 92c74201..00000000
--- a/examples/lightrag_gemini_demo_no_tiktoken.py
+++ /dev/null
@@ -1,230 +0,0 @@
-# pip install -q -U google-genai to use gemini as a client
-
-import os
-from typing import Optional
-import dataclasses
-from pathlib import Path
-import hashlib
-import numpy as np
-from google import genai
-from google.genai import types
-from dotenv import load_dotenv
-from lightrag.utils import EmbeddingFunc, Tokenizer
-from lightrag import LightRAG, QueryParam
-from sentence_transformers import SentenceTransformer
-from lightrag.kg.shared_storage import initialize_pipeline_status
-import sentencepiece as spm
-import requests
-
-import asyncio
-import nest_asyncio
-
-# Apply nest_asyncio to solve event loop issues
-nest_asyncio.apply()
-
-load_dotenv()
-gemini_api_key = os.getenv("GEMINI_API_KEY")
-
-WORKING_DIR = "./dickens"
-
-if os.path.exists(WORKING_DIR):
- import shutil
-
- shutil.rmtree(WORKING_DIR)
-
-os.mkdir(WORKING_DIR)
-
-
-class GemmaTokenizer(Tokenizer):
- # adapted from google-cloud-aiplatform[tokenization]
-
- @dataclasses.dataclass(frozen=True)
- class _TokenizerConfig:
- tokenizer_model_url: str
- tokenizer_model_hash: str
-
- _TOKENIZERS = {
- "google/gemma2": _TokenizerConfig(
- tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
- tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
- ),
- "google/gemma3": _TokenizerConfig(
- tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
- tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
- ),
- }
-
- def __init__(
- self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
- ):
- # https://github.com/google/gemma_pytorch/tree/main/tokenizer
- if "1.5" in model_name or "1.0" in model_name:
- # up to gemini 1.5 gemma2 is a comparable local tokenizer
- # https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py
- tokenizer_name = "google/gemma2"
- else:
- # for gemini > 2.0 gemma3 was used
- tokenizer_name = "google/gemma3"
-
- file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url
- tokenizer_model_name = file_url.rsplit("/", 1)[1]
- expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash
-
- tokenizer_dir = Path(tokenizer_dir)
- if tokenizer_dir.is_dir():
- file_path = tokenizer_dir / tokenizer_model_name
- model_data = self._maybe_load_from_cache(
- file_path=file_path, expected_hash=expected_hash
- )
- else:
- model_data = None
- if not model_data:
- model_data = self._load_from_url(
- file_url=file_url, expected_hash=expected_hash
- )
- self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
-
- tokenizer = spm.SentencePieceProcessor()
- tokenizer.LoadFromSerializedProto(model_data)
- super().__init__(model_name=model_name, tokenizer=tokenizer)
-
- def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool:
- """Returns true if the content is valid by checking the hash."""
- return hashlib.sha256(model_data).hexdigest() == expected_hash
-
- def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes:
- """Loads the model data from the cache path."""
- if not file_path.is_file():
- return
- with open(file_path, "rb") as f:
- content = f.read()
- if self._is_valid_model(model_data=content, expected_hash=expected_hash):
- return content
-
- # Cached file corrupted.
- self._maybe_remove_file(file_path)
-
- def _load_from_url(self, file_url: str, expected_hash: str) -> bytes:
- """Loads model bytes from the given file url."""
- resp = requests.get(file_url)
- resp.raise_for_status()
- content = resp.content
-
- if not self._is_valid_model(model_data=content, expected_hash=expected_hash):
- actual_hash = hashlib.sha256(content).hexdigest()
- raise ValueError(
- f"Downloaded model file is corrupted."
- f" Expected hash {expected_hash}. Got file hash {actual_hash}."
- )
- return content
-
- @staticmethod
- def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None:
- """Saves the model data to the cache path."""
- try:
- if not cache_path.is_file():
- cache_dir = cache_path.parent
- cache_dir.mkdir(parents=True, exist_ok=True)
- with open(cache_path, "wb") as f:
- f.write(model_data)
- except OSError:
- # Don't raise if we cannot write file.
- pass
-
- @staticmethod
- def _maybe_remove_file(file_path: Path) -> None:
- """Removes the file if exists."""
- if not file_path.is_file():
- return
- try:
- file_path.unlink()
- except OSError:
- # Don't raise if we cannot remove file.
- pass
-
- # def encode(self, content: str) -> list[int]:
- # return self.tokenizer.encode(content)
-
- # def decode(self, tokens: list[int]) -> str:
- # return self.tokenizer.decode(tokens)
-
-
-async def llm_model_func(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- # 1. Initialize the GenAI Client with your Gemini API Key
- client = genai.Client(api_key=gemini_api_key)
-
- # 2. Combine prompts: system prompt, history, and user prompt
- if history_messages is None:
- history_messages = []
-
- combined_prompt = ""
- if system_prompt:
- combined_prompt += f"{system_prompt}\n"
-
- for msg in history_messages:
- # Each msg is expected to be a dict: {"role": "...", "content": "..."}
- combined_prompt += f"{msg['role']}: {msg['content']}\n"
-
- # Finally, add the new user prompt
- combined_prompt += f"user: {prompt}"
-
- # 3. Call the Gemini model
- response = client.models.generate_content(
- model="gemini-1.5-flash",
- contents=[combined_prompt],
- config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
- )
-
- # 4. Return the response text
- return response.text
-
-
-async def embedding_func(texts: list[str]) -> np.ndarray:
- model = SentenceTransformer("all-MiniLM-L6-v2")
- embeddings = model.encode(texts, convert_to_numpy=True)
- return embeddings
-
-
-async def initialize_rag():
- rag = LightRAG(
- working_dir=WORKING_DIR,
- # tiktoken_model_name="gpt-4o-mini",
- tokenizer=GemmaTokenizer(
- tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
- model_name="gemini-2.0-flash",
- ),
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=384,
- max_token_size=8192,
- func=embedding_func,
- ),
- )
-
- await rag.initialize_storages()
- await initialize_pipeline_status()
-
- return rag
-
-
-def main():
- # Initialize RAG instance
- rag = asyncio.run(initialize_rag())
- file_path = "story.txt"
- with open(file_path, "r") as file:
- text = file.read()
-
- rag.insert(text)
-
- response = rag.query(
- query="What is the main theme of the story?",
- param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
- )
-
- print(response)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/lightrag_gemini_track_token_demo.py b/examples/lightrag_gemini_track_token_demo.py
deleted file mode 100644
index a72fc717..00000000
--- a/examples/lightrag_gemini_track_token_demo.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# pip install -q -U google-genai to use gemini as a client
-
-import os
-import asyncio
-import numpy as np
-import nest_asyncio
-from google import genai
-from google.genai import types
-from dotenv import load_dotenv
-from lightrag.utils import EmbeddingFunc
-from lightrag import LightRAG, QueryParam
-from lightrag.kg.shared_storage import initialize_pipeline_status
-from lightrag.llm.siliconcloud import siliconcloud_embedding
-from lightrag.utils import setup_logger
-from lightrag.utils import TokenTracker
-
-setup_logger("lightrag", level="DEBUG")
-
-# Apply nest_asyncio to solve event loop issues
-nest_asyncio.apply()
-
-load_dotenv()
-gemini_api_key = os.getenv("GEMINI_API_KEY")
-siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
-
-WORKING_DIR = "./dickens"
-
-if not os.path.exists(WORKING_DIR):
- os.mkdir(WORKING_DIR)
-
-token_tracker = TokenTracker()
-
-
-async def llm_model_func(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- # 1. Initialize the GenAI Client with your Gemini API Key
- client = genai.Client(api_key=gemini_api_key)
-
- # 2. Combine prompts: system prompt, history, and user prompt
- if history_messages is None:
- history_messages = []
-
- combined_prompt = ""
- if system_prompt:
- combined_prompt += f"{system_prompt}\n"
-
- for msg in history_messages:
- # Each msg is expected to be a dict: {"role": "...", "content": "..."}
- combined_prompt += f"{msg['role']}: {msg['content']}\n"
-
- # Finally, add the new user prompt
- combined_prompt += f"user: {prompt}"
-
- # 3. Call the Gemini model
- response = client.models.generate_content(
- model="gemini-2.0-flash",
- contents=[combined_prompt],
- config=types.GenerateContentConfig(
- max_output_tokens=5000, temperature=0, top_k=10
- ),
- )
-
- # 4. Get token counts with null safety
- usage = getattr(response, "usage_metadata", None)
- prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
- completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
- total_tokens = getattr(usage, "total_token_count", 0) or (
- prompt_tokens + completion_tokens
- )
-
- token_counts = {
- "prompt_tokens": prompt_tokens,
- "completion_tokens": completion_tokens,
- "total_tokens": total_tokens,
- }
-
- token_tracker.add_usage(token_counts)
-
- # 5. Return the response text
- return response.text
-
-
-async def embedding_func(texts: list[str]) -> np.ndarray:
- return await siliconcloud_embedding(
- texts,
- model="BAAI/bge-m3",
- api_key=siliconflow_api_key,
- max_token_size=512,
- )
-
-
-async def initialize_rag():
- rag = LightRAG(
- working_dir=WORKING_DIR,
- entity_extract_max_gleaning=1,
- enable_llm_cache=True,
- enable_llm_cache_for_entity_extract=True,
- embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=1024,
- max_token_size=8192,
- func=embedding_func,
- ),
- )
-
- await rag.initialize_storages()
- await initialize_pipeline_status()
-
- return rag
-
-
-def main():
- # Initialize RAG instance
- rag = asyncio.run(initialize_rag())
-
- with open("./book.txt", "r", encoding="utf-8") as f:
- rag.insert(f.read())
-
- # Context Manager Method
- with token_tracker:
- print(
- rag.query(
- "What are the top themes in this story?", param=QueryParam(mode="naive")
- )
- )
-
- print(
- rag.query(
- "What are the top themes in this story?", param=QueryParam(mode="local")
- )
- )
-
- print(
- rag.query(
- "What are the top themes in this story?",
- param=QueryParam(mode="global"),
- )
- )
-
- print(
- rag.query(
- "What are the top themes in this story?",
- param=QueryParam(mode="hybrid"),
- )
- )
-
-
-if __name__ == "__main__":
- main()
diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py
index a66d5d3c..d8bea386 100644
--- a/lightrag/api/lightrag_server.py
+++ b/lightrag/api/lightrag_server.py
@@ -511,7 +511,9 @@ def create_app(args):
return optimized_azure_openai_model_complete
- def create_optimized_gemini_llm_func(config_cache: LLMConfigCache, args):
+ def create_optimized_gemini_llm_func(
+ config_cache: LLMConfigCache, args, llm_timeout: int
+ ):
"""Create optimized Gemini LLM function with cached configuration"""
async def optimized_gemini_model_complete(
@@ -526,6 +528,8 @@ def create_app(args):
if history_messages is None:
history_messages = []
+ # Use pre-processed configuration to avoid repeated parsing
+ kwargs["timeout"] = llm_timeout
if (
config_cache.gemini_llm_options is not None
and "generation_config" not in kwargs
@@ -567,7 +571,7 @@ def create_app(args):
config_cache, args, llm_timeout
)
elif binding == "gemini":
- return create_optimized_gemini_llm_func(config_cache, args)
+ return create_optimized_gemini_llm_func(config_cache, args, llm_timeout)
else: # openai and compatible
# Use optimized function with pre-processed configuration
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py
index 44ab5d2f..b3affb37 100644
--- a/lightrag/llm/binding_options.py
+++ b/lightrag/llm/binding_options.py
@@ -486,9 +486,9 @@ class GeminiLLMOptions(BindingOptions):
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
stop_sequences: List[str] = field(default_factory=list)
- response_mime_type: str | None = None
+ seed: int | None = None
+ thinking_config: dict | None = None
safety_settings: dict | None = None
- system_instruction: str | None = None
_help: ClassVar[dict[str, str]] = {
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
@@ -499,9 +499,9 @@ class GeminiLLMOptions(BindingOptions):
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
"stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')",
- "response_mime_type": "Desired MIME type for the response (e.g., application/json)",
+ "seed": "Random seed for reproducible generation (leave empty for random)",
+ "thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')",
"safety_settings": "JSON object with Gemini safety settings overrides",
- "system_instruction": "Default system instruction applied to every request",
}
diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py
index b8c64b31..f06ec6b3 100644
--- a/lightrag/llm/gemini.py
+++ b/lightrag/llm/gemini.py
@@ -33,24 +33,33 @@ LOG = logging.getLogger(__name__)
@lru_cache(maxsize=8)
-def _get_gemini_client(api_key: str, base_url: str | None) -> genai.Client:
+def _get_gemini_client(
+ api_key: str, base_url: str | None, timeout: int | None = None
+) -> genai.Client:
"""
Create (or fetch cached) Gemini client.
Args:
api_key: Google Gemini API key.
base_url: Optional custom API endpoint.
+ timeout: Optional request timeout in milliseconds.
Returns:
genai.Client: Configured Gemini client instance.
"""
client_kwargs: dict[str, Any] = {"api_key": api_key}
- if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
+ if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None:
try:
- client_kwargs["http_options"] = types.HttpOptions(api_endpoint=base_url)
+ http_options_kwargs = {}
+ if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
+ http_options_kwargs["api_endpoint"] = base_url
+ if timeout is not None:
+ http_options_kwargs["timeout"] = timeout
+
+ client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
except Exception as exc: # pragma: no cover - defensive
- LOG.warning("Failed to apply custom Gemini endpoint %s: %s", base_url, exc)
+ LOG.warning("Failed to apply custom Gemini http_options: %s", exc)
try:
return genai.Client(**client_kwargs)
@@ -114,24 +123,44 @@ def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> s
return "\n".join(history_lines)
-def _extract_response_text(response: Any) -> str:
- if getattr(response, "text", None):
- return response.text
+def _extract_response_text(
+ response: Any, extract_thoughts: bool = False
+) -> tuple[str, str]:
+ """
+ Extract text content from Gemini response, separating regular content from thoughts.
+ Args:
+ response: Gemini API response object
+ extract_thoughts: Whether to extract thought content separately
+
+ Returns:
+ Tuple of (regular_text, thought_text)
+ """
candidates = getattr(response, "candidates", None)
if not candidates:
- return ""
+ return ("", "")
+
+ regular_parts: list[str] = []
+ thought_parts: list[str] = []
- parts: list[str] = []
for candidate in candidates:
if not getattr(candidate, "content", None):
continue
- for part in getattr(candidate.content, "parts", []):
+ # Use 'or []' to handle None values from parts attribute
+ for part in getattr(candidate.content, "parts", None) or []:
text = getattr(part, "text", None)
- if text:
- parts.append(text)
+ if not text:
+ continue
- return "\n".join(parts)
+ # Check if this part is thought content using the 'thought' attribute
+ is_thought = getattr(part, "thought", False)
+
+ if is_thought and extract_thoughts:
+ thought_parts.append(text)
+ elif not is_thought:
+ regular_parts.append(text)
+
+ return ("\n".join(regular_parts), "\n".join(thought_parts))
async def gemini_complete_if_cache(
@@ -139,22 +168,58 @@ async def gemini_complete_if_cache(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
- *,
- api_key: str | None = None,
+ enable_cot: bool = False,
base_url: str | None = None,
- generation_config: dict[str, Any] | None = None,
- keyword_extraction: bool = False,
+ api_key: str | None = None,
token_tracker: Any | None = None,
- hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity
stream: bool | None = None,
- enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently
- timeout: float | None = None, # noqa: ARG001 - handled by caller if needed
+ keyword_extraction: bool = False,
+ generation_config: dict[str, Any] | None = None,
+ timeout: int | None = None,
**_: Any,
) -> str | AsyncIterator[str]:
+ """
+ Complete a prompt using Gemini's API with Chain of Thought (COT) support.
+
+ This function supports automatic integration of reasoning content from Gemini models
+ that provide Chain of Thought capabilities via the thinking_config API feature.
+
+ COT Integration:
+ - When enable_cot=True: Thought content is wrapped in ... tags
+ - When enable_cot=False: Thought content is filtered out, only regular content returned
+ - Thought content is identified by the 'thought' attribute on response parts
+ - Requires thinking_config to be enabled in generation_config for API to return thoughts
+
+ Args:
+ model: The Gemini model to use.
+ prompt: The prompt to complete.
+ system_prompt: Optional system prompt to include.
+ history_messages: Optional list of previous messages in the conversation.
+ api_key: Optional Gemini API key. If None, uses environment variable.
+ base_url: Optional custom API endpoint.
+ generation_config: Optional generation configuration dict.
+ keyword_extraction: Whether to use JSON response format.
+ token_tracker: Optional token usage tracker for monitoring API usage.
+ stream: Whether to stream the response.
+ hashing_kv: Storage interface (for interface parity with other bindings).
+ enable_cot: Whether to include Chain of Thought content in the response.
+ timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
+ **_: Additional keyword arguments (ignored).
+
+ Returns:
+ The completed text (with COT content if enable_cot=True) or an async iterator
+ of text chunks if streaming. COT content is wrapped in ... tags.
+
+ Raises:
+ RuntimeError: If the response from Gemini is empty.
+ ValueError: If API key is not provided or configured.
+ """
loop = asyncio.get_running_loop()
key = _ensure_api_key(api_key)
- client = _get_gemini_client(key, base_url)
+ # Convert timeout from seconds to milliseconds for Gemini API
+ timeout_ms = timeout * 1000 if timeout else None
+ client = _get_gemini_client(key, base_url, timeout_ms)
history_block = _format_history_messages(history_messages)
prompt_sections = []
@@ -184,6 +249,11 @@ async def gemini_complete_if_cache(
usage_container: dict[str, Any] = {}
def _stream_model() -> None:
+ # COT state tracking for streaming
+ cot_active = False
+ cot_started = False
+ initial_content_seen = False
+
try:
stream_kwargs = dict(request_kwargs)
stream_iterator = client.models.generate_content_stream(**stream_kwargs)
@@ -191,20 +261,61 @@ async def gemini_complete_if_cache(
usage = getattr(chunk, "usage_metadata", None)
if usage is not None:
usage_container["usage"] = usage
- text_piece = getattr(chunk, "text", None) or _extract_response_text(
- chunk
+
+ # Extract both regular and thought content
+ regular_text, thought_text = _extract_response_text(
+ chunk, extract_thoughts=True
)
- if text_piece:
- loop.call_soon_threadsafe(queue.put_nowait, text_piece)
+
+ if enable_cot:
+ # Process regular content
+ if regular_text:
+ if not initial_content_seen:
+ initial_content_seen = True
+
+ # Close COT section if it was active
+ if cot_active:
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+ cot_active = False
+
+ # Send regular content
+ loop.call_soon_threadsafe(queue.put_nowait, regular_text)
+
+ # Process thought content
+ if thought_text:
+ if not initial_content_seen and not cot_started:
+ # Start COT section
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+ cot_active = True
+ cot_started = True
+
+ # Send thought content if COT is active
+ if cot_active:
+ loop.call_soon_threadsafe(
+ queue.put_nowait, thought_text
+ )
+ else:
+ # COT disabled - only send regular content
+ if regular_text:
+ loop.call_soon_threadsafe(queue.put_nowait, regular_text)
+
+ # Ensure COT is properly closed if still active
+ if cot_active:
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+
loop.call_soon_threadsafe(queue.put_nowait, None)
except Exception as exc: # pragma: no cover - surface runtime issues
+ # Try to close COT tag before reporting error
+ if cot_active:
+ try:
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+ except Exception:
+ pass
loop.call_soon_threadsafe(queue.put_nowait, exc)
loop.run_in_executor(None, _stream_model)
async def _async_stream() -> AsyncIterator[str]:
- accumulated = ""
- emitted = ""
try:
while True:
item = await queue.get()
@@ -217,16 +328,9 @@ async def gemini_complete_if_cache(
if "\\u" in chunk_text:
chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
- accumulated += chunk_text
- sanitized = remove_think_tags(accumulated)
- if sanitized.startswith(emitted):
- delta = sanitized[len(emitted) :]
- else:
- delta = sanitized
- emitted = sanitized
-
- if delta:
- yield delta
+ # Yield the chunk directly without filtering
+ # COT filtering is already handled in _stream_model()
+ yield chunk_text
finally:
usage = usage_container.get("usage")
if token_tracker and usage:
@@ -244,14 +348,33 @@ async def gemini_complete_if_cache(
response = await asyncio.to_thread(_call_model)
- text = _extract_response_text(response)
- if not text:
+ # Extract both regular text and thought text
+ regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
+
+ # Apply COT filtering logic based on enable_cot parameter
+ if enable_cot:
+ # Include thought content wrapped in tags
+ if thought_text and thought_text.strip():
+ if not regular_text or regular_text.strip() == "":
+ # Only thought content available
+ final_text = f"{thought_text}"
+ else:
+ # Both content types present: prepend thought to regular content
+ final_text = f"{thought_text}{regular_text}"
+ else:
+ # No thought content, use regular content only
+ final_text = regular_text or ""
+ else:
+ # Filter out thought content, return only regular content
+ final_text = regular_text or ""
+
+ if not final_text:
raise RuntimeError("Gemini response did not contain any text content.")
- if "\\u" in text:
- text = safe_unicode_decode(text.encode("utf-8"))
+ if "\\u" in final_text:
+ final_text = safe_unicode_decode(final_text.encode("utf-8"))
- text = remove_think_tags(text)
+ final_text = remove_think_tags(final_text)
usage = getattr(response, "usage_metadata", None)
if token_tracker and usage:
@@ -263,8 +386,8 @@ async def gemini_complete_if_cache(
}
)
- logger.debug("Gemini response length: %s", len(text))
- return text
+ logger.debug("Gemini response length: %s", len(final_text))
+ return final_text
async def gemini_model_complete(
diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py
index fce33cac..3339ea3a 100644
--- a/lightrag/llm/openai.py
+++ b/lightrag/llm/openai.py
@@ -138,6 +138,9 @@ async def openai_complete_if_cache(
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
+ keyword_extraction: bool = False, # Will be removed from kwargs before passing to OpenAI
+ stream: bool | None = None,
+ timeout: int | None = None,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
@@ -172,8 +175,9 @@ async def openai_complete_if_cache(
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
These will be passed to the client constructor but will be overridden by
explicit parameters (api_key, base_url).
- - hashing_kv: Will be removed from kwargs before passing to OpenAI.
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
+ - stream: Whether to stream the response. Default is False.
+ - timeout: Request timeout in seconds. Default is None.
Returns:
The completed text (with integrated COT content if available) or an async iterator