Fix lambda closure bug in embedding function configuration
• Replace lambda with proper async function • Capture config values at creation time • Avoid closure variable reference issues • Add factory function for embeddings • Remove test file for closure bug
This commit is contained in:
parent
414d47d12a
commit
332202c111
2 changed files with 78 additions and 134 deletions
|
|
@ -237,6 +237,7 @@ def create_app(args):
|
||||||
|
|
||||||
# Create working directory if it doesn't exist
|
# Create working directory if it doesn't exist
|
||||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if args.llm_binding == "lollms" or args.embedding_binding == "lollms":
|
if args.llm_binding == "lollms" or args.embedding_binding == "lollms":
|
||||||
from lightrag.llm.lollms import lollms_model_complete, lollms_embed
|
from lightrag.llm.lollms import lollms_model_complete, lollms_embed
|
||||||
if args.llm_binding == "ollama" or args.embedding_binding == "ollama":
|
if args.llm_binding == "ollama" or args.embedding_binding == "ollama":
|
||||||
|
|
@ -253,8 +254,6 @@ def create_app(args):
|
||||||
from lightrag.llm.binding_options import OpenAILLMOptions
|
from lightrag.llm.binding_options import OpenAILLMOptions
|
||||||
if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock":
|
if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock":
|
||||||
from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed
|
from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed
|
||||||
if args.embedding_binding == "ollama":
|
|
||||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
|
||||||
if args.embedding_binding == "jina":
|
if args.embedding_binding == "jina":
|
||||||
from lightrag.llm.jina import jina_embed
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
|
|
@ -344,63 +343,86 @@ def create_app(args):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_binding = args.embedding_binding
|
def create_embedding_function(binding, model, host, api_key, dimensions, args):
|
||||||
embedding_model = args.embedding_model
|
"""
|
||||||
embedding_host = args.embedding_binding_host
|
Create embedding function with args object for dynamic option generation.
|
||||||
embedding_api_key = args.embedding_binding_api_key
|
|
||||||
embedding_dim_val = args.embedding_dim
|
|
||||||
ollama_options_val = OllamaEmbeddingOptions.options_dict(args)
|
|
||||||
|
|
||||||
|
This approach completely avoids closure issues by capturing configuration
|
||||||
|
values as function parameters rather than through variable references.
|
||||||
|
The args object is used only for dynamic option generation when needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
binding: The embedding provider binding (lollms, ollama, etc.)
|
||||||
|
model: The embedding model name
|
||||||
|
host: The host URL for the embedding service
|
||||||
|
api_key: API key for authentication
|
||||||
|
dimensions: Embedding dimensions
|
||||||
|
args: Arguments object for dynamic option generation (only used when needed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Async function that performs embedding based on the specified provider
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def embedding_function(texts):
|
||||||
|
"""Embedding function with captured configuration parameters"""
|
||||||
|
if binding == "lollms":
|
||||||
|
return await lollms_embed(
|
||||||
|
texts,
|
||||||
|
embed_model=model,
|
||||||
|
host=host,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
elif binding == "ollama":
|
||||||
|
# Only import and generate ollama_options when actually needed
|
||||||
|
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||||
|
|
||||||
|
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||||
|
return await ollama_embed(
|
||||||
|
texts,
|
||||||
|
embed_model=model,
|
||||||
|
host=host,
|
||||||
|
api_key=api_key,
|
||||||
|
options=ollama_options,
|
||||||
|
)
|
||||||
|
elif binding == "azure_openai":
|
||||||
|
return await azure_openai_embed(
|
||||||
|
texts,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
elif binding == "aws_bedrock":
|
||||||
|
return await bedrock_embed(
|
||||||
|
texts,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
elif binding == "jina":
|
||||||
|
return await jina_embed(
|
||||||
|
texts,
|
||||||
|
dimensions=dimensions,
|
||||||
|
base_url=host,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default to OpenAI-compatible embedding
|
||||||
|
return await openai_embed(
|
||||||
|
texts,
|
||||||
|
model=model,
|
||||||
|
base_url=host,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
return embedding_function
|
||||||
|
|
||||||
|
# Create embedding function with current configuration
|
||||||
embedding_func = EmbeddingFunc(
|
embedding_func = EmbeddingFunc(
|
||||||
embedding_dim=args.embedding_dim,
|
embedding_dim=args.embedding_dim,
|
||||||
func=lambda texts: (
|
func=create_embedding_function(
|
||||||
lollms_embed(
|
binding=args.embedding_binding,
|
||||||
texts,
|
model=args.embedding_model,
|
||||||
embed_model=embedding_model,
|
host=args.embedding_binding_host,
|
||||||
host=embedding_host,
|
api_key=args.embedding_binding_api_key,
|
||||||
api_key=embedding_api_key,
|
dimensions=args.embedding_dim,
|
||||||
)
|
args=args, # Pass args object for dynamic option generation
|
||||||
if embedding_binding == "lollms"
|
|
||||||
else (
|
|
||||||
ollama_embed(
|
|
||||||
texts,
|
|
||||||
embed_model=embedding_model,
|
|
||||||
host=embedding_host,
|
|
||||||
api_key=embedding_api_key,
|
|
||||||
options=ollama_options_val,
|
|
||||||
)
|
|
||||||
if embedding_binding == "ollama"
|
|
||||||
else (
|
|
||||||
azure_openai_embed(
|
|
||||||
texts,
|
|
||||||
model=embedding_model, # no host is used for openai,
|
|
||||||
api_key=embedding_api_key,
|
|
||||||
)
|
|
||||||
if embedding_binding == "azure_openai"
|
|
||||||
else (
|
|
||||||
bedrock_embed(
|
|
||||||
texts,
|
|
||||||
model=embedding_model,
|
|
||||||
)
|
|
||||||
if embedding_binding == "aws_bedrock"
|
|
||||||
else (
|
|
||||||
jina_embed(
|
|
||||||
texts,
|
|
||||||
dimensions=embedding_dim_val,
|
|
||||||
base_url=embedding_host,
|
|
||||||
api_key=embedding_api_key,
|
|
||||||
)
|
|
||||||
if embedding_binding == "jina"
|
|
||||||
else openai_embed(
|
|
||||||
texts,
|
|
||||||
model=embedding_model,
|
|
||||||
base_url=embedding_host,
|
|
||||||
api_key=embedding_api_key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,78 +0,0 @@
|
||||||
"""
|
|
||||||
Tests the fix for the lambda closure bug in the API server's embedding function.
|
|
||||||
|
|
||||||
Issue: https://github.com/HKUDS/LightRAG/issues/2023
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch, AsyncMock
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Functions to be patched
|
|
||||||
from lightrag.llm.ollama import ollama_embed
|
|
||||||
from lightrag.llm.openai import openai_embed
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_args():
|
|
||||||
"""Provides a mock of the server's arguments object."""
|
|
||||||
args = Mock()
|
|
||||||
args.embedding_binding = "ollama"
|
|
||||||
args.embedding_model = "mxbai-embed-large:latest"
|
|
||||||
args.embedding_binding_host = "http://localhost:11434"
|
|
||||||
args.embedding_binding_api_key = None
|
|
||||||
args.embedding_dim = 1024
|
|
||||||
args.OllamaEmbeddingOptions.options_dict.return_value = {"num_ctx": 4096}
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("lightrag.llm.openai.openai_embed", new_callable=AsyncMock)
|
|
||||||
@patch("lightrag.llm.ollama.ollama_embed", new_callable=AsyncMock)
|
|
||||||
async def test_embedding_func_captures_values_correctly(
|
|
||||||
mock_ollama_embed, mock_openai_embed, mock_args
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Verifies that the embedding function correctly captures configuration
|
|
||||||
values at creation time and is not affected by later mutations of its source.
|
|
||||||
"""
|
|
||||||
# --- Setup Mocks ---
|
|
||||||
mock_ollama_embed.return_value = np.array([[0.1, 0.2, 0.3]])
|
|
||||||
mock_openai_embed.return_value = np.array([[0.4, 0.5, 0.6]])
|
|
||||||
|
|
||||||
# --- SIMULATE THE FIX: Capture values before creating the function ---
|
|
||||||
binding = mock_args.embedding_binding
|
|
||||||
model = mock_args.embedding_model
|
|
||||||
host = mock_args.embedding_binding_host
|
|
||||||
api_key = mock_args.embedding_binding_api_key
|
|
||||||
|
|
||||||
# CORRECTED: Use an async def instead of a lambda
|
|
||||||
async def fixed_func(texts):
|
|
||||||
if binding == "ollama":
|
|
||||||
return await ollama_embed(
|
|
||||||
texts, embed_model=model, host=host, api_key=api_key
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return await openai_embed(
|
|
||||||
texts, model=model, base_url=host, api_key=api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- VERIFICATION ---
|
|
||||||
|
|
||||||
# 1. First call: The function should use the initial "ollama" binding.
|
|
||||||
await fixed_func(["hello world"])
|
|
||||||
mock_ollama_embed.assert_awaited_once()
|
|
||||||
mock_openai_embed.assert_not_called()
|
|
||||||
|
|
||||||
# 2. CRITICAL STEP: Mutate the original args object AFTER the function is created.
|
|
||||||
mock_args.embedding_binding = "openai"
|
|
||||||
|
|
||||||
# 3. Reset mocks and call the function AGAIN.
|
|
||||||
mock_ollama_embed.reset_mock()
|
|
||||||
mock_openai_embed.reset_mock()
|
|
||||||
|
|
||||||
await fixed_func(["see you again"])
|
|
||||||
|
|
||||||
# 4. Final check: The function should STILL call ollama_embed.
|
|
||||||
mock_ollama_embed.assert_awaited_once()
|
|
||||||
mock_openai_embed.assert_not_called()
|
|
||||||
Loading…
Add table
Reference in a new issue