Fix lambda closure bug in embedding function configuration

• Replace lambda with proper async function
• Capture config values at creation time
• Avoid closure variable reference issues
• Add factory function for embeddings
• Remove test file for closure bug
This commit is contained in:
yangdx 2025-08-30 23:43:34 +08:00
parent 414d47d12a
commit 332202c111
2 changed files with 78 additions and 134 deletions

View file

@ -237,6 +237,7 @@ def create_app(args):
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
if args.llm_binding == "lollms" or args.embedding_binding == "lollms":
from lightrag.llm.lollms import lollms_model_complete, lollms_embed
if args.llm_binding == "ollama" or args.embedding_binding == "ollama":
@ -253,8 +254,6 @@ def create_app(args):
from lightrag.llm.binding_options import OpenAILLMOptions
if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed
if args.embedding_binding == "ollama":
from lightrag.llm.binding_options import OllamaEmbeddingOptions
if args.embedding_binding == "jina":
from lightrag.llm.jina import jina_embed
@ -344,63 +343,86 @@ def create_app(args):
**kwargs,
)
embedding_binding = args.embedding_binding
embedding_model = args.embedding_model
embedding_host = args.embedding_binding_host
embedding_api_key = args.embedding_binding_api_key
embedding_dim_val = args.embedding_dim
ollama_options_val = OllamaEmbeddingOptions.options_dict(args)
def create_embedding_function(binding, model, host, api_key, dimensions, args):
"""
Create embedding function with args object for dynamic option generation.
This approach completely avoids closure issues by capturing configuration
values as function parameters rather than through variable references.
The args object is used only for dynamic option generation when needed.
Args:
binding: The embedding provider binding (lollms, ollama, etc.)
model: The embedding model name
host: The host URL for the embedding service
api_key: API key for authentication
dimensions: Embedding dimensions
args: Arguments object for dynamic option generation (only used when needed)
Returns:
Async function that performs embedding based on the specified provider
"""
async def embedding_function(texts):
"""Embedding function with captured configuration parameters"""
if binding == "lollms":
return await lollms_embed(
texts,
embed_model=model,
host=host,
api_key=api_key,
)
elif binding == "ollama":
# Only import and generate ollama_options when actually needed
from lightrag.llm.binding_options import OllamaEmbeddingOptions
ollama_options = OllamaEmbeddingOptions.options_dict(args)
return await ollama_embed(
texts,
embed_model=model,
host=host,
api_key=api_key,
options=ollama_options,
)
elif binding == "azure_openai":
return await azure_openai_embed(
texts,
model=model,
api_key=api_key,
)
elif binding == "aws_bedrock":
return await bedrock_embed(
texts,
model=model,
)
elif binding == "jina":
return await jina_embed(
texts,
dimensions=dimensions,
base_url=host,
api_key=api_key,
)
else:
# Default to OpenAI-compatible embedding
return await openai_embed(
texts,
model=model,
base_url=host,
api_key=api_key,
)
return embedding_function
# Create embedding function with current configuration
embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim,
func=lambda texts: (
lollms_embed(
texts,
embed_model=embedding_model,
host=embedding_host,
api_key=embedding_api_key,
)
if embedding_binding == "lollms"
else (
ollama_embed(
texts,
embed_model=embedding_model,
host=embedding_host,
api_key=embedding_api_key,
options=ollama_options_val,
)
if embedding_binding == "ollama"
else (
azure_openai_embed(
texts,
model=embedding_model, # no host is used for openai,
api_key=embedding_api_key,
)
if embedding_binding == "azure_openai"
else (
bedrock_embed(
texts,
model=embedding_model,
)
if embedding_binding == "aws_bedrock"
else (
jina_embed(
texts,
dimensions=embedding_dim_val,
base_url=embedding_host,
api_key=embedding_api_key,
)
if embedding_binding == "jina"
else openai_embed(
texts,
model=embedding_model,
base_url=embedding_host,
api_key=embedding_api_key,
)
)
)
)
)
func=create_embedding_function(
binding=args.embedding_binding,
model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
dimensions=args.embedding_dim,
args=args, # Pass args object for dynamic option generation
),
)

View file

@ -1,78 +0,0 @@
"""
Tests the fix for the lambda closure bug in the API server's embedding function.
Issue: https://github.com/HKUDS/LightRAG/issues/2023
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
import numpy as np
# Functions to be patched
from lightrag.llm.ollama import ollama_embed
from lightrag.llm.openai import openai_embed
@pytest.fixture
def mock_args():
"""Provides a mock of the server's arguments object."""
args = Mock()
args.embedding_binding = "ollama"
args.embedding_model = "mxbai-embed-large:latest"
args.embedding_binding_host = "http://localhost:11434"
args.embedding_binding_api_key = None
args.embedding_dim = 1024
args.OllamaEmbeddingOptions.options_dict.return_value = {"num_ctx": 4096}
return args
@pytest.mark.asyncio
@patch("lightrag.llm.openai.openai_embed", new_callable=AsyncMock)
@patch("lightrag.llm.ollama.ollama_embed", new_callable=AsyncMock)
async def test_embedding_func_captures_values_correctly(
mock_ollama_embed, mock_openai_embed, mock_args
):
"""
Verifies that the embedding function correctly captures configuration
values at creation time and is not affected by later mutations of its source.
"""
# --- Setup Mocks ---
mock_ollama_embed.return_value = np.array([[0.1, 0.2, 0.3]])
mock_openai_embed.return_value = np.array([[0.4, 0.5, 0.6]])
# --- SIMULATE THE FIX: Capture values before creating the function ---
binding = mock_args.embedding_binding
model = mock_args.embedding_model
host = mock_args.embedding_binding_host
api_key = mock_args.embedding_binding_api_key
# CORRECTED: Use an async def instead of a lambda
async def fixed_func(texts):
if binding == "ollama":
return await ollama_embed(
texts, embed_model=model, host=host, api_key=api_key
)
else:
return await openai_embed(
texts, model=model, base_url=host, api_key=api_key
)
# --- VERIFICATION ---
# 1. First call: The function should use the initial "ollama" binding.
await fixed_func(["hello world"])
mock_ollama_embed.assert_awaited_once()
mock_openai_embed.assert_not_called()
# 2. CRITICAL STEP: Mutate the original args object AFTER the function is created.
mock_args.embedding_binding = "openai"
# 3. Reset mocks and call the function AGAIN.
mock_ollama_embed.reset_mock()
mock_openai_embed.reset_mock()
await fixed_func(["see you again"])
# 4. Final check: The function should STILL call ollama_embed.
mock_ollama_embed.assert_awaited_once()
mock_openai_embed.assert_not_called()