Merge branch 'main' into apply-dim-to-embedding-call

This commit is contained in:
yangdx 2025-11-07 20:48:22 +08:00
commit d8a6355e41
8 changed files with 185 additions and 539 deletions

View file

@ -194,9 +194,10 @@ LLM_BINDING_API_KEY=your_api_key
### Gemini example
# LLM_BINDING=gemini
# LLM_MODEL=gemini-flash-latest
# LLM_BINDING_HOST=https://generativelanguage.googleapis.com
# LLM_BINDING_API_KEY=your_gemini_api_key
# GEMINI_LLM_MAX_OUTPUT_TOKENS=8192
# LLM_BINDING_HOST=https://generativelanguage.googleapis.com
GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": 0, "include_thoughts": false}'
# GEMINI_LLM_MAX_OUTPUT_TOKENS=9000
# GEMINI_LLM_TEMPERATURE=0.7
### OpenAI Compatible API Specific Parameters

View file

@ -1,105 +0,0 @@
# pip install -q -U google-genai to use gemini as a client
import os
import numpy as np
from google import genai
from google.genai import types
from dotenv import load_dotenv
from lightrag.utils import EmbeddingFunc
from lightrag import LightRAG, QueryParam
from sentence_transformers import SentenceTransformer
from lightrag.kg.shared_storage import initialize_pipeline_status
import asyncio
import nest_asyncio
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 1. Initialize the GenAI Client with your Gemini API Key
client = genai.Client(api_key=gemini_api_key)
# 2. Combine prompts: system prompt, history, and user prompt
if history_messages is None:
history_messages = []
combined_prompt = ""
if system_prompt:
combined_prompt += f"{system_prompt}\n"
for msg in history_messages:
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
combined_prompt += f"{msg['role']}: {msg['content']}\n"
# Finally, add the new user prompt
combined_prompt += f"user: {prompt}"
# 3. Call the Gemini model
response = client.models.generate_content(
model="gemini-1.5-flash",
contents=[combined_prompt],
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
)
# 4. Return the response text
return response.text
async def embedding_func(texts: list[str]) -> np.ndarray:
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts, convert_to_numpy=True)
return embeddings
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=8192,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
file_path = "story.txt"
with open(file_path, "r") as file:
text = file.read()
rag.insert(text)
response = rag.query(
query="What is the main theme of the story?",
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
)
print(response)
if __name__ == "__main__":
main()

View file

@ -1,230 +0,0 @@
# pip install -q -U google-genai to use gemini as a client
import os
from typing import Optional
import dataclasses
from pathlib import Path
import hashlib
import numpy as np
from google import genai
from google.genai import types
from dotenv import load_dotenv
from lightrag.utils import EmbeddingFunc, Tokenizer
from lightrag import LightRAG, QueryParam
from sentence_transformers import SentenceTransformer
from lightrag.kg.shared_storage import initialize_pipeline_status
import sentencepiece as spm
import requests
import asyncio
import nest_asyncio
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
class GemmaTokenizer(Tokenizer):
# adapted from google-cloud-aiplatform[tokenization]
@dataclasses.dataclass(frozen=True)
class _TokenizerConfig:
tokenizer_model_url: str
tokenizer_model_hash: str
_TOKENIZERS = {
"google/gemma2": _TokenizerConfig(
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
),
"google/gemma3": _TokenizerConfig(
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
),
}
def __init__(
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
):
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
if "1.5" in model_name or "1.0" in model_name:
# up to gemini 1.5 gemma2 is a comparable local tokenizer
# https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py
tokenizer_name = "google/gemma2"
else:
# for gemini > 2.0 gemma3 was used
tokenizer_name = "google/gemma3"
file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url
tokenizer_model_name = file_url.rsplit("/", 1)[1]
expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash
tokenizer_dir = Path(tokenizer_dir)
if tokenizer_dir.is_dir():
file_path = tokenizer_dir / tokenizer_model_name
model_data = self._maybe_load_from_cache(
file_path=file_path, expected_hash=expected_hash
)
else:
model_data = None
if not model_data:
model_data = self._load_from_url(
file_url=file_url, expected_hash=expected_hash
)
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
tokenizer = spm.SentencePieceProcessor()
tokenizer.LoadFromSerializedProto(model_data)
super().__init__(model_name=model_name, tokenizer=tokenizer)
def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool:
"""Returns true if the content is valid by checking the hash."""
return hashlib.sha256(model_data).hexdigest() == expected_hash
def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes:
"""Loads the model data from the cache path."""
if not file_path.is_file():
return
with open(file_path, "rb") as f:
content = f.read()
if self._is_valid_model(model_data=content, expected_hash=expected_hash):
return content
# Cached file corrupted.
self._maybe_remove_file(file_path)
def _load_from_url(self, file_url: str, expected_hash: str) -> bytes:
"""Loads model bytes from the given file url."""
resp = requests.get(file_url)
resp.raise_for_status()
content = resp.content
if not self._is_valid_model(model_data=content, expected_hash=expected_hash):
actual_hash = hashlib.sha256(content).hexdigest()
raise ValueError(
f"Downloaded model file is corrupted."
f" Expected hash {expected_hash}. Got file hash {actual_hash}."
)
return content
@staticmethod
def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None:
"""Saves the model data to the cache path."""
try:
if not cache_path.is_file():
cache_dir = cache_path.parent
cache_dir.mkdir(parents=True, exist_ok=True)
with open(cache_path, "wb") as f:
f.write(model_data)
except OSError:
# Don't raise if we cannot write file.
pass
@staticmethod
def _maybe_remove_file(file_path: Path) -> None:
"""Removes the file if exists."""
if not file_path.is_file():
return
try:
file_path.unlink()
except OSError:
# Don't raise if we cannot remove file.
pass
# def encode(self, content: str) -> list[int]:
# return self.tokenizer.encode(content)
# def decode(self, tokens: list[int]) -> str:
# return self.tokenizer.decode(tokens)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 1. Initialize the GenAI Client with your Gemini API Key
client = genai.Client(api_key=gemini_api_key)
# 2. Combine prompts: system prompt, history, and user prompt
if history_messages is None:
history_messages = []
combined_prompt = ""
if system_prompt:
combined_prompt += f"{system_prompt}\n"
for msg in history_messages:
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
combined_prompt += f"{msg['role']}: {msg['content']}\n"
# Finally, add the new user prompt
combined_prompt += f"user: {prompt}"
# 3. Call the Gemini model
response = client.models.generate_content(
model="gemini-1.5-flash",
contents=[combined_prompt],
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
)
# 4. Return the response text
return response.text
async def embedding_func(texts: list[str]) -> np.ndarray:
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts, convert_to_numpy=True)
return embeddings
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
# tiktoken_model_name="gpt-4o-mini",
tokenizer=GemmaTokenizer(
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
model_name="gemini-2.0-flash",
),
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=8192,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
file_path = "story.txt"
with open(file_path, "r") as file:
text = file.read()
rag.insert(text)
response = rag.query(
query="What is the main theme of the story?",
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
)
print(response)
if __name__ == "__main__":
main()

View file

@ -1,151 +0,0 @@
# pip install -q -U google-genai to use gemini as a client
import os
import asyncio
import numpy as np
import nest_asyncio
from google import genai
from google.genai import types
from dotenv import load_dotenv
from lightrag.utils import EmbeddingFunc
from lightrag import LightRAG, QueryParam
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.llm.siliconcloud import siliconcloud_embedding
from lightrag.utils import setup_logger
from lightrag.utils import TokenTracker
setup_logger("lightrag", level="DEBUG")
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
token_tracker = TokenTracker()
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 1. Initialize the GenAI Client with your Gemini API Key
client = genai.Client(api_key=gemini_api_key)
# 2. Combine prompts: system prompt, history, and user prompt
if history_messages is None:
history_messages = []
combined_prompt = ""
if system_prompt:
combined_prompt += f"{system_prompt}\n"
for msg in history_messages:
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
combined_prompt += f"{msg['role']}: {msg['content']}\n"
# Finally, add the new user prompt
combined_prompt += f"user: {prompt}"
# 3. Call the Gemini model
response = client.models.generate_content(
model="gemini-2.0-flash",
contents=[combined_prompt],
config=types.GenerateContentConfig(
max_output_tokens=5000, temperature=0, top_k=10
),
)
# 4. Get token counts with null safety
usage = getattr(response, "usage_metadata", None)
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
total_tokens = getattr(usage, "total_token_count", 0) or (
prompt_tokens + completion_tokens
)
token_counts = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
token_tracker.add_usage(token_counts)
# 5. Return the response text
return response.text
async def embedding_func(texts: list[str]) -> np.ndarray:
return await siliconcloud_embedding(
texts,
model="BAAI/bge-m3",
api_key=siliconflow_api_key,
max_token_size=512,
)
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
entity_extract_max_gleaning=1,
enable_llm_cache=True,
enable_llm_cache_for_entity_extract=True,
embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Context Manager Method
with token_tracker:
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
print(
rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="global"),
)
)
print(
rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid"),
)
)
if __name__ == "__main__":
main()

View file

@ -511,7 +511,9 @@ def create_app(args):
return optimized_azure_openai_model_complete
def create_optimized_gemini_llm_func(config_cache: LLMConfigCache, args):
def create_optimized_gemini_llm_func(
config_cache: LLMConfigCache, args, llm_timeout: int
):
"""Create optimized Gemini LLM function with cached configuration"""
async def optimized_gemini_model_complete(
@ -526,6 +528,8 @@ def create_app(args):
if history_messages is None:
history_messages = []
# Use pre-processed configuration to avoid repeated parsing
kwargs["timeout"] = llm_timeout
if (
config_cache.gemini_llm_options is not None
and "generation_config" not in kwargs
@ -567,7 +571,7 @@ def create_app(args):
config_cache, args, llm_timeout
)
elif binding == "gemini":
return create_optimized_gemini_llm_func(config_cache, args)
return create_optimized_gemini_llm_func(config_cache, args, llm_timeout)
else: # openai and compatible
# Use optimized function with pre-processed configuration
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)

View file

@ -486,9 +486,9 @@ class GeminiLLMOptions(BindingOptions):
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
stop_sequences: List[str] = field(default_factory=list)
response_mime_type: str | None = None
seed: int | None = None
thinking_config: dict | None = None
safety_settings: dict | None = None
system_instruction: str | None = None
_help: ClassVar[dict[str, str]] = {
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
@ -499,9 +499,9 @@ class GeminiLLMOptions(BindingOptions):
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
"stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')",
"response_mime_type": "Desired MIME type for the response (e.g., application/json)",
"seed": "Random seed for reproducible generation (leave empty for random)",
"thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')",
"safety_settings": "JSON object with Gemini safety settings overrides",
"system_instruction": "Default system instruction applied to every request",
}

View file

@ -33,24 +33,33 @@ LOG = logging.getLogger(__name__)
@lru_cache(maxsize=8)
def _get_gemini_client(api_key: str, base_url: str | None) -> genai.Client:
def _get_gemini_client(
api_key: str, base_url: str | None, timeout: int | None = None
) -> genai.Client:
"""
Create (or fetch cached) Gemini client.
Args:
api_key: Google Gemini API key.
base_url: Optional custom API endpoint.
timeout: Optional request timeout in milliseconds.
Returns:
genai.Client: Configured Gemini client instance.
"""
client_kwargs: dict[str, Any] = {"api_key": api_key}
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None:
try:
client_kwargs["http_options"] = types.HttpOptions(api_endpoint=base_url)
http_options_kwargs = {}
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
http_options_kwargs["api_endpoint"] = base_url
if timeout is not None:
http_options_kwargs["timeout"] = timeout
client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
except Exception as exc: # pragma: no cover - defensive
LOG.warning("Failed to apply custom Gemini endpoint %s: %s", base_url, exc)
LOG.warning("Failed to apply custom Gemini http_options: %s", exc)
try:
return genai.Client(**client_kwargs)
@ -114,24 +123,44 @@ def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> s
return "\n".join(history_lines)
def _extract_response_text(response: Any) -> str:
if getattr(response, "text", None):
return response.text
def _extract_response_text(
response: Any, extract_thoughts: bool = False
) -> tuple[str, str]:
"""
Extract text content from Gemini response, separating regular content from thoughts.
Args:
response: Gemini API response object
extract_thoughts: Whether to extract thought content separately
Returns:
Tuple of (regular_text, thought_text)
"""
candidates = getattr(response, "candidates", None)
if not candidates:
return ""
return ("", "")
regular_parts: list[str] = []
thought_parts: list[str] = []
parts: list[str] = []
for candidate in candidates:
if not getattr(candidate, "content", None):
continue
for part in getattr(candidate.content, "parts", []):
# Use 'or []' to handle None values from parts attribute
for part in getattr(candidate.content, "parts", None) or []:
text = getattr(part, "text", None)
if text:
parts.append(text)
if not text:
continue
return "\n".join(parts)
# Check if this part is thought content using the 'thought' attribute
is_thought = getattr(part, "thought", False)
if is_thought and extract_thoughts:
thought_parts.append(text)
elif not is_thought:
regular_parts.append(text)
return ("\n".join(regular_parts), "\n".join(thought_parts))
async def gemini_complete_if_cache(
@ -139,22 +168,58 @@ async def gemini_complete_if_cache(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
*,
api_key: str | None = None,
enable_cot: bool = False,
base_url: str | None = None,
generation_config: dict[str, Any] | None = None,
keyword_extraction: bool = False,
api_key: str | None = None,
token_tracker: Any | None = None,
hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity
stream: bool | None = None,
enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently
timeout: float | None = None, # noqa: ARG001 - handled by caller if needed
keyword_extraction: bool = False,
generation_config: dict[str, Any] | None = None,
timeout: int | None = None,
**_: Any,
) -> str | AsyncIterator[str]:
"""
Complete a prompt using Gemini's API with Chain of Thought (COT) support.
This function supports automatic integration of reasoning content from Gemini models
that provide Chain of Thought capabilities via the thinking_config API feature.
COT Integration:
- When enable_cot=True: Thought content is wrapped in <think>...</think> tags
- When enable_cot=False: Thought content is filtered out, only regular content returned
- Thought content is identified by the 'thought' attribute on response parts
- Requires thinking_config to be enabled in generation_config for API to return thoughts
Args:
model: The Gemini model to use.
prompt: The prompt to complete.
system_prompt: Optional system prompt to include.
history_messages: Optional list of previous messages in the conversation.
api_key: Optional Gemini API key. If None, uses environment variable.
base_url: Optional custom API endpoint.
generation_config: Optional generation configuration dict.
keyword_extraction: Whether to use JSON response format.
token_tracker: Optional token usage tracker for monitoring API usage.
stream: Whether to stream the response.
hashing_kv: Storage interface (for interface parity with other bindings).
enable_cot: Whether to include Chain of Thought content in the response.
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
**_: Additional keyword arguments (ignored).
Returns:
The completed text (with COT content if enable_cot=True) or an async iterator
of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
Raises:
RuntimeError: If the response from Gemini is empty.
ValueError: If API key is not provided or configured.
"""
loop = asyncio.get_running_loop()
key = _ensure_api_key(api_key)
client = _get_gemini_client(key, base_url)
# Convert timeout from seconds to milliseconds for Gemini API
timeout_ms = timeout * 1000 if timeout else None
client = _get_gemini_client(key, base_url, timeout_ms)
history_block = _format_history_messages(history_messages)
prompt_sections = []
@ -184,6 +249,11 @@ async def gemini_complete_if_cache(
usage_container: dict[str, Any] = {}
def _stream_model() -> None:
# COT state tracking for streaming
cot_active = False
cot_started = False
initial_content_seen = False
try:
stream_kwargs = dict(request_kwargs)
stream_iterator = client.models.generate_content_stream(**stream_kwargs)
@ -191,20 +261,61 @@ async def gemini_complete_if_cache(
usage = getattr(chunk, "usage_metadata", None)
if usage is not None:
usage_container["usage"] = usage
text_piece = getattr(chunk, "text", None) or _extract_response_text(
chunk
# Extract both regular and thought content
regular_text, thought_text = _extract_response_text(
chunk, extract_thoughts=True
)
if text_piece:
loop.call_soon_threadsafe(queue.put_nowait, text_piece)
if enable_cot:
# Process regular content
if regular_text:
if not initial_content_seen:
initial_content_seen = True
# Close COT section if it was active
if cot_active:
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
cot_active = False
# Send regular content
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
# Process thought content
if thought_text:
if not initial_content_seen and not cot_started:
# Start COT section
loop.call_soon_threadsafe(queue.put_nowait, "<think>")
cot_active = True
cot_started = True
# Send thought content if COT is active
if cot_active:
loop.call_soon_threadsafe(
queue.put_nowait, thought_text
)
else:
# COT disabled - only send regular content
if regular_text:
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
# Ensure COT is properly closed if still active
if cot_active:
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
loop.call_soon_threadsafe(queue.put_nowait, None)
except Exception as exc: # pragma: no cover - surface runtime issues
# Try to close COT tag before reporting error
if cot_active:
try:
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
except Exception:
pass
loop.call_soon_threadsafe(queue.put_nowait, exc)
loop.run_in_executor(None, _stream_model)
async def _async_stream() -> AsyncIterator[str]:
accumulated = ""
emitted = ""
try:
while True:
item = await queue.get()
@ -217,16 +328,9 @@ async def gemini_complete_if_cache(
if "\\u" in chunk_text:
chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
accumulated += chunk_text
sanitized = remove_think_tags(accumulated)
if sanitized.startswith(emitted):
delta = sanitized[len(emitted) :]
else:
delta = sanitized
emitted = sanitized
if delta:
yield delta
# Yield the chunk directly without filtering
# COT filtering is already handled in _stream_model()
yield chunk_text
finally:
usage = usage_container.get("usage")
if token_tracker and usage:
@ -244,14 +348,33 @@ async def gemini_complete_if_cache(
response = await asyncio.to_thread(_call_model)
text = _extract_response_text(response)
if not text:
# Extract both regular text and thought text
regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
# Apply COT filtering logic based on enable_cot parameter
if enable_cot:
# Include thought content wrapped in <think> tags
if thought_text and thought_text.strip():
if not regular_text or regular_text.strip() == "":
# Only thought content available
final_text = f"<think>{thought_text}</think>"
else:
# Both content types present: prepend thought to regular content
final_text = f"<think>{thought_text}</think>{regular_text}"
else:
# No thought content, use regular content only
final_text = regular_text or ""
else:
# Filter out thought content, return only regular content
final_text = regular_text or ""
if not final_text:
raise RuntimeError("Gemini response did not contain any text content.")
if "\\u" in text:
text = safe_unicode_decode(text.encode("utf-8"))
if "\\u" in final_text:
final_text = safe_unicode_decode(final_text.encode("utf-8"))
text = remove_think_tags(text)
final_text = remove_think_tags(final_text)
usage = getattr(response, "usage_metadata", None)
if token_tracker and usage:
@ -263,8 +386,8 @@ async def gemini_complete_if_cache(
}
)
logger.debug("Gemini response length: %s", len(text))
return text
logger.debug("Gemini response length: %s", len(final_text))
return final_text
async def gemini_model_complete(

View file

@ -138,6 +138,9 @@ async def openai_complete_if_cache(
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
keyword_extraction: bool = False, # Will be removed from kwargs before passing to OpenAI
stream: bool | None = None,
timeout: int | None = None,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
@ -172,8 +175,9 @@ async def openai_complete_if_cache(
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
These will be passed to the client constructor but will be overridden by
explicit parameters (api_key, base_url).
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
- stream: Whether to stream the response. Default is False.
- timeout: Request timeout in seconds. Default is None.
Returns:
The completed text (with integrated COT content if available) or an async iterator