fix(server): Resolve lambda closure bug in embedding_func

Fixes #2023. Resolves an issue where the embedding function would incorrectly fall back to the OpenAI provider if the server's configuration arguments were mutated after initialization. This was caused by a lambda function capturing a reference to the mutable 'args' object instead of capturing the configuration values at creation time.
This commit is contained in:
avchauzov 2025-08-30 14:38:04 +02:00
parent d9aa021682
commit 414d47d12a
2 changed files with 106 additions and 21 deletions

View file

@ -344,51 +344,58 @@ def create_app(args):
**kwargs,
)
embedding_binding = args.embedding_binding
embedding_model = args.embedding_model
embedding_host = args.embedding_binding_host
embedding_api_key = args.embedding_binding_api_key
embedding_dim_val = args.embedding_dim
ollama_options_val = OllamaEmbeddingOptions.options_dict(args)
embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim,
func=lambda texts: (
lollms_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
embed_model=embedding_model,
host=embedding_host,
api_key=embedding_api_key,
)
if args.embedding_binding == "lollms"
if embedding_binding == "lollms"
else (
ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
options=OllamaEmbeddingOptions.options_dict(args),
embed_model=embedding_model,
host=embedding_host,
api_key=embedding_api_key,
options=ollama_options_val,
)
if args.embedding_binding == "ollama"
if embedding_binding == "ollama"
else (
azure_openai_embed(
texts,
model=args.embedding_model, # no host is used for openai,
api_key=args.embedding_binding_api_key,
model=embedding_model, # no host is used for openai,
api_key=embedding_api_key,
)
if args.embedding_binding == "azure_openai"
if embedding_binding == "azure_openai"
else (
bedrock_embed(
texts,
model=args.embedding_model,
model=embedding_model,
)
if args.embedding_binding == "aws_bedrock"
if embedding_binding == "aws_bedrock"
else (
jina_embed(
texts,
dimensions=args.embedding_dim,
base_url=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
dimensions=embedding_dim_val,
base_url=embedding_host,
api_key=embedding_api_key,
)
if args.embedding_binding == "jina"
if embedding_binding == "jina"
else openai_embed(
texts,
model=args.embedding_model,
base_url=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
model=embedding_model,
base_url=embedding_host,
api_key=embedding_api_key,
)
)
)

View file

@ -0,0 +1,78 @@
"""
Tests the fix for the lambda closure bug in the API server's embedding function.
Issue: https://github.com/HKUDS/LightRAG/issues/2023
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
import numpy as np
# Functions to be patched
from lightrag.llm.ollama import ollama_embed
from lightrag.llm.openai import openai_embed
@pytest.fixture
def mock_args():
"""Provides a mock of the server's arguments object."""
args = Mock()
args.embedding_binding = "ollama"
args.embedding_model = "mxbai-embed-large:latest"
args.embedding_binding_host = "http://localhost:11434"
args.embedding_binding_api_key = None
args.embedding_dim = 1024
args.OllamaEmbeddingOptions.options_dict.return_value = {"num_ctx": 4096}
return args
@pytest.mark.asyncio
@patch("lightrag.llm.openai.openai_embed", new_callable=AsyncMock)
@patch("lightrag.llm.ollama.ollama_embed", new_callable=AsyncMock)
async def test_embedding_func_captures_values_correctly(
mock_ollama_embed, mock_openai_embed, mock_args
):
"""
Verifies that the embedding function correctly captures configuration
values at creation time and is not affected by later mutations of its source.
"""
# --- Setup Mocks ---
mock_ollama_embed.return_value = np.array([[0.1, 0.2, 0.3]])
mock_openai_embed.return_value = np.array([[0.4, 0.5, 0.6]])
# --- SIMULATE THE FIX: Capture values before creating the function ---
binding = mock_args.embedding_binding
model = mock_args.embedding_model
host = mock_args.embedding_binding_host
api_key = mock_args.embedding_binding_api_key
# CORRECTED: Use an async def instead of a lambda
async def fixed_func(texts):
if binding == "ollama":
return await ollama_embed(
texts, embed_model=model, host=host, api_key=api_key
)
else:
return await openai_embed(
texts, model=model, base_url=host, api_key=api_key
)
# --- VERIFICATION ---
# 1. First call: The function should use the initial "ollama" binding.
await fixed_func(["hello world"])
mock_ollama_embed.assert_awaited_once()
mock_openai_embed.assert_not_called()
# 2. CRITICAL STEP: Mutate the original args object AFTER the function is created.
mock_args.embedding_binding = "openai"
# 3. Reset mocks and call the function AGAIN.
mock_ollama_embed.reset_mock()
mock_openai_embed.reset_mock()
await fixed_func(["see you again"])
# 4. Final check: The function should STILL call ollama_embed.
mock_ollama_embed.assert_awaited_once()
mock_openai_embed.assert_not_called()