This commit is contained in:
Raphaël MANSUY 2025-12-04 19:14:31 +08:00
parent 68cc386456
commit a4b3da862b
6 changed files with 164 additions and 82 deletions

View file

@ -26,6 +26,7 @@ from lightrag.utils import (
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
import numpy as np
@ -46,6 +47,7 @@ async def azure_openai_complete_if_cache(
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
keyword_extraction: bool = False,
**kwargs,
):
if enable_cot:
@ -66,9 +68,12 @@ async def azure_openai_complete_if_cache(
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
timeout = kwargs.pop("timeout", None)
# Handle keyword extraction mode
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=base_url,
azure_deployment=deployment,
@ -85,7 +90,7 @@ async def azure_openai_complete_if_cache(
messages.append({"role": "user", "content": prompt})
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
response = await openai_async_client.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
@ -108,21 +113,32 @@ async def azure_openai_complete_if_cache(
return inner()
else:
content = response.choices[0].message.content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
message = response.choices[0].message
# Handle parsed responses (structured output via response_format)
# When using beta.chat.completions.parse(), the response is in message.parsed
if hasattr(message, "parsed") and message.parsed is not None:
# Serialize the parsed structured response to JSON
content = message.parsed.model_dump_json()
logger.debug("Using parsed structured response from API")
else:
# Handle regular content responses
content = message.content
if content and r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
return content
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
kwargs.pop("keyword_extraction", None)
result = await azure_openai_complete_if_cache(
os.getenv("LLM_MODEL", "gpt-4o-mini"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
return result

View file

@ -47,7 +47,7 @@ try:
# Only enable Langfuse if both keys are configured
if langfuse_public_key and langfuse_secret_key:
from langfuse.openai import AsyncOpenAI
from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped]
LANGFUSE_ENABLED = True
logger.info("Langfuse observability enabled for OpenAI client")
@ -140,6 +140,7 @@ async def openai_complete_if_cache(
token_tracker: Any | None = None,
stream: bool | None = None,
timeout: int | None = None,
keyword_extraction: bool = False,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
@ -171,12 +172,13 @@ async def openai_complete_if_cache(
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
stream: Whether to stream the response. Default is False.
timeout: Request timeout in seconds. Default is None.
keyword_extraction: Whether to enable keyword extraction mode. When True, triggers
special response formatting for keyword extraction. Default is False.
**kwargs: Additional keyword arguments to pass to the OpenAI API.
Special kwargs:
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
These will be passed to the client constructor but will be overridden by
explicit parameters (api_key, base_url).
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
Returns:
The completed text (with integrated COT content if available) or an async iterator
@ -197,11 +199,14 @@ async def openai_complete_if_cache(
# Remove special kwargs that shouldn't be passed to OpenAI
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
# Extract client configuration options
client_configs = kwargs.pop("openai_client_configs", {})
# Handle keyword extraction mode
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
# Create the OpenAI client
openai_async_client = create_openai_async_client(
api_key=api_key,
@ -236,7 +241,7 @@ async def openai_complete_if_cache(
try:
# Don't use async with context manager, use client directly
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
response = await openai_async_client.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
@ -448,46 +453,57 @@ async def openai_complete_if_cache(
raise InvalidResponseError("Invalid response from OpenAI API")
message = response.choices[0].message
content = getattr(message, "content", None)
reasoning_content = getattr(message, "reasoning_content", "")
# Handle COT logic for non-streaming responses (only if enabled)
final_content = ""
# Handle parsed responses (structured output via response_format)
# When using beta.chat.completions.parse(), the response is in message.parsed
if hasattr(message, "parsed") and message.parsed is not None:
# Serialize the parsed structured response to JSON
final_content = message.parsed.model_dump_json()
logger.debug("Using parsed structured response from API")
else:
# Handle regular content responses
content = getattr(message, "content", None)
reasoning_content = getattr(message, "reasoning_content", "")
if enable_cot:
# Check if we should include reasoning content
should_include_reasoning = False
if reasoning_content and reasoning_content.strip():
if not content or content.strip() == "":
# Case 1: Only reasoning content, should include COT
should_include_reasoning = True
final_content = (
content or ""
) # Use empty string if content is None
# Handle COT logic for non-streaming responses (only if enabled)
final_content = ""
if enable_cot:
# Check if we should include reasoning content
should_include_reasoning = False
if reasoning_content and reasoning_content.strip():
if not content or content.strip() == "":
# Case 1: Only reasoning content, should include COT
should_include_reasoning = True
final_content = (
content or ""
) # Use empty string if content is None
else:
# Case 3: Both content and reasoning_content present, ignore reasoning
should_include_reasoning = False
final_content = content
else:
# Case 3: Both content and reasoning_content present, ignore reasoning
should_include_reasoning = False
final_content = content
# No reasoning content, use regular content
final_content = content or ""
# Apply COT wrapping if needed
if should_include_reasoning:
if r"\u" in reasoning_content:
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
final_content = (
f"<think>{reasoning_content}</think>{final_content}"
)
else:
# No reasoning content, use regular content
# COT disabled, only use regular content
final_content = content or ""
# Apply COT wrapping if needed
if should_include_reasoning:
if r"\u" in reasoning_content:
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
final_content = f"<think>{reasoning_content}</think>{final_content}"
else:
# COT disabled, only use regular content
final_content = content or ""
# Validate final content
if not final_content or final_content.strip() == "":
logger.error("Received empty content from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Received empty content from OpenAI API")
# Validate final content
if not final_content or final_content.strip() == "":
logger.error("Received empty content from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Received empty content from OpenAI API")
# Apply Unicode decoding to final content if needed
if r"\u" in final_content:
@ -521,15 +537,13 @@ async def openai_complete(
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
history_messages = []
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await openai_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
@ -544,15 +558,13 @@ async def gpt_4o_complete(
) -> str:
if history_messages is None:
history_messages = []
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
keyword_extraction=keyword_extraction,
**kwargs,
)
@ -567,15 +579,13 @@ async def gpt_4o_mini_complete(
) -> str:
if history_messages is None:
history_messages = []
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
keyword_extraction=keyword_extraction,
**kwargs,
)
@ -590,20 +600,20 @@ async def nvidia_openai_complete(
) -> str:
if history_messages is None:
history_messages = []
kwargs.pop("keyword_extraction", None)
result = await openai_complete_if_cache(
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
keyword_extraction=keyword_extraction,
base_url="https://integrate.api.nvidia.com/v1",
**kwargs,
)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
@ -618,6 +628,7 @@ async def openai_embed(
model: str = "text-embedding-3-small",
base_url: str | None = None,
api_key: str | None = None,
embedding_dim: int | None = None,
client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
) -> np.ndarray:
@ -628,6 +639,12 @@ async def openai_embed(
model: The OpenAI embedding model to use.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
@ -647,9 +664,19 @@ async def openai_embed(
)
async with openai_async_client:
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="base64"
)
# Prepare API call parameters
api_params = {
"model": model,
"input": texts,
"encoding_format": "base64",
}
# Add dimensions parameter only if embedding_dim is provided
if embedding_dim is not None:
api_params["dimensions"] = embedding_dim
# Make API call
response = await openai_async_client.embeddings.create(**api_params)
if token_tracker and hasattr(response, "usage"):
token_counts = {

View file

@ -29,8 +29,8 @@ dependencies = [
"json_repair",
"nano-vectordb",
"networkx",
"numpy",
"pandas>=2.0.0,<2.3.0",
"numpy>=1.24.0,<2.0.0",
"pandas>=2.0.0,<2.4.0",
"pipmaster",
"pydantic",
"pypinyin",
@ -42,6 +42,13 @@ dependencies = [
]
[project.optional-dependencies]
# Test framework dependencies (for CI/CD and testing)
pytest = [
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pre-commit",
]
api = [
# Core dependencies
"aiohttp",
@ -50,9 +57,9 @@ api = [
"json_repair",
"nano-vectordb",
"networkx",
"numpy",
"openai>=1.0.0,<3.0.0",
"pandas>=2.0.0,<2.3.0",
"numpy>=1.24.0,<2.0.0",
"openai>=2.0.0,<3.0.0",
"pandas>=2.0.0,<2.4.0",
"pipmaster",
"pydantic",
"pypinyin",
@ -79,30 +86,36 @@ api = [
"python-multipart",
"pytz",
"uvicorn",
"gunicorn",
# Document processing dependencies (required for API document upload functionality)
"openpyxl>=3.0.0,<4.0.0", # XLSX processing
"pycryptodome>=3.0.0,<4.0.0", # PDF encryption support
"pypdf>=6.1.0", # PDF processing
"python-docx>=0.8.11,<2.0.0", # DOCX processing
"python-pptx>=0.6.21,<2.0.0", # PPTX processing
]
# Advanced document processing engine (optional)
docling = [
# On macOS, pytorch and frameworks use Objective-C are not fork-safe,
# and not compatible to gunicorn multi-worker mode
"docling>=2.0.0,<3.0.0; sys_platform != 'darwin'",
]
# Offline deployment dependencies (layered design for flexibility)
offline-docs = [
# Document processing dependencies
"pypdf2>=3.0.0",
"python-docx>=0.8.11,<2.0.0",
"python-pptx>=0.6.21,<2.0.0",
"openpyxl>=3.0.0,<4.0.0",
]
offline-storage = [
# Storage backend dependencies
"redis>=5.0.0,<7.0.0",
"redis>=5.0.0,<8.0.0",
"neo4j>=5.0.0,<7.0.0",
"pymilvus>=2.6.2,<3.0.0",
"pymongo>=4.0.0,<5.0.0",
"asyncpg>=0.29.0,<1.0.0",
"qdrant-client>=1.7.0,<2.0.0",
"qdrant-client>=1.11.0,<2.0.0",
]
offline-llm = [
# LLM provider dependencies
"openai>=1.0.0,<3.0.0",
"openai>=2.0.0,<3.0.0",
"anthropic>=0.18.0,<1.0.0",
"ollama>=0.1.0,<1.0.0",
"zhipuai>=2.0.0,<3.0.0",
@ -114,8 +127,24 @@ offline-llm = [
]
offline = [
# Complete offline package (includes all offline dependencies)
"lightrag-hku[offline-docs,offline-storage,offline-llm]",
# Complete offline package (includes api for document processing, plus storage and LLM)
"lightrag-hku[api,offline-storage,offline-llm]",
]
evaluation = [
# Test framework dependencies (for evaluation)
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pre-commit",
# RAG evaluation dependencies (RAGAS framework)
"ragas>=0.3.7",
"datasets>=4.3.0",
"httpx>=0.28.1",
]
observability = [
# LLM observability and tracing dependencies
"langfuse>=3.8.1",
]
[project.scripts]
@ -140,7 +169,15 @@ include-package-data = true
version = {attr = "lightrag.__version__"}
[tool.setuptools.package-data]
lightrag = ["api/webui/**/*"]
lightrag = ["api/webui/**/*", "api/static/**/*"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
[tool.ruff]
target-version = "py310"

View file

@ -14,6 +14,6 @@ google-api-core>=2.0.0,<3.0.0
google-genai>=1.0.0,<2.0.0
llama-index>=0.9.0,<1.0.0
ollama>=0.1.0,<1.0.0
openai>=1.0.0,<3.0.0
openai>=2.0.0,<3.0.0
voyageai>=0.2.0,<1.0.0
zhipuai>=2.0.0,<3.0.0

View file

@ -18,7 +18,7 @@ asyncpg>=0.29.0,<1.0.0
llama-index>=0.9.0,<1.0.0
neo4j>=5.0.0,<7.0.0
ollama>=0.1.0,<1.0.0
openai>=1.0.0,<3.0.0
openai>=2.0.0,<3.0.0
openpyxl>=3.0.0,<4.0.0
pymilvus>=2.6.2,<3.0.0
pymongo>=4.0.0,<5.0.0

8
uv.lock generated
View file

@ -2737,7 +2737,6 @@ requires-dist = [
{ name = "json-repair", marker = "extra == 'api'" },
{ name = "langfuse", marker = "extra == 'observability'", specifier = ">=3.8.1" },
{ name = "lightrag-hku", extras = ["api", "offline-llm", "offline-storage"], marker = "extra == 'offline'" },
{ name = "lightrag-hku", extras = ["pytest"], marker = "extra == 'evaluation'" },
{ name = "llama-index", marker = "extra == 'offline-llm'", specifier = ">=0.9.0,<1.0.0" },
{ name = "nano-vectordb" },
{ name = "nano-vectordb", marker = "extra == 'api'" },
@ -2747,14 +2746,15 @@ requires-dist = [
{ name = "numpy", specifier = ">=1.24.0,<2.0.0" },
{ name = "numpy", marker = "extra == 'api'", specifier = ">=1.24.0,<2.0.0" },
{ name = "ollama", marker = "extra == 'offline-llm'", specifier = ">=0.1.0,<1.0.0" },
{ name = "openai", marker = "extra == 'api'", specifier = ">=1.0.0,<3.0.0" },
{ name = "openai", marker = "extra == 'offline-llm'", specifier = ">=1.0.0,<3.0.0" },
{ name = "openai", marker = "extra == 'api'", specifier = ">=2.0.0,<3.0.0" },
{ name = "openai", marker = "extra == 'offline-llm'", specifier = ">=2.0.0,<3.0.0" },
{ name = "openpyxl", marker = "extra == 'api'", specifier = ">=3.0.0,<4.0.0" },
{ name = "pandas", specifier = ">=2.0.0,<2.4.0" },
{ name = "pandas", marker = "extra == 'api'", specifier = ">=2.0.0,<2.4.0" },
{ name = "passlib", extras = ["bcrypt"], marker = "extra == 'api'" },
{ name = "pipmaster" },
{ name = "pipmaster", marker = "extra == 'api'" },
{ name = "pre-commit", marker = "extra == 'evaluation'" },
{ name = "pre-commit", marker = "extra == 'pytest'" },
{ name = "psutil", marker = "extra == 'api'" },
{ name = "pycryptodome", marker = "extra == 'api'", specifier = ">=3.0.0,<4.0.0" },
@ -2766,7 +2766,9 @@ requires-dist = [
{ name = "pypdf", marker = "extra == 'api'", specifier = ">=6.1.0" },
{ name = "pypinyin" },
{ name = "pypinyin", marker = "extra == 'api'" },
{ name = "pytest", marker = "extra == 'evaluation'", specifier = ">=8.4.2" },
{ name = "pytest", marker = "extra == 'pytest'", specifier = ">=8.4.2" },
{ name = "pytest-asyncio", marker = "extra == 'evaluation'", specifier = ">=1.2.0" },
{ name = "pytest-asyncio", marker = "extra == 'pytest'", specifier = ">=1.2.0" },
{ name = "python-docx", marker = "extra == 'api'", specifier = ">=0.8.11,<2.0.0" },
{ name = "python-dotenv" },