diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 67227006..b3ed6d80 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -237,6 +237,7 @@ def create_app(args): # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) + if args.llm_binding == "lollms" or args.embedding_binding == "lollms": from lightrag.llm.lollms import lollms_model_complete, lollms_embed if args.llm_binding == "ollama" or args.embedding_binding == "ollama": @@ -253,8 +254,6 @@ def create_app(args): from lightrag.llm.binding_options import OpenAILLMOptions if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed - if args.embedding_binding == "ollama": - from lightrag.llm.binding_options import OllamaEmbeddingOptions if args.embedding_binding == "jina": from lightrag.llm.jina import jina_embed @@ -344,63 +343,86 @@ def create_app(args): **kwargs, ) - embedding_binding = args.embedding_binding - embedding_model = args.embedding_model - embedding_host = args.embedding_binding_host - embedding_api_key = args.embedding_binding_api_key - embedding_dim_val = args.embedding_dim - ollama_options_val = OllamaEmbeddingOptions.options_dict(args) + def create_embedding_function(binding, model, host, api_key, dimensions, args): + """ + Create embedding function with args object for dynamic option generation. + This approach completely avoids closure issues by capturing configuration + values as function parameters rather than through variable references. + The args object is used only for dynamic option generation when needed. + + Args: + binding: The embedding provider binding (lollms, ollama, etc.) + model: The embedding model name + host: The host URL for the embedding service + api_key: API key for authentication + dimensions: Embedding dimensions + args: Arguments object for dynamic option generation (only used when needed) + + Returns: + Async function that performs embedding based on the specified provider + """ + + async def embedding_function(texts): + """Embedding function with captured configuration parameters""" + if binding == "lollms": + return await lollms_embed( + texts, + embed_model=model, + host=host, + api_key=api_key, + ) + elif binding == "ollama": + # Only import and generate ollama_options when actually needed + from lightrag.llm.binding_options import OllamaEmbeddingOptions + + ollama_options = OllamaEmbeddingOptions.options_dict(args) + return await ollama_embed( + texts, + embed_model=model, + host=host, + api_key=api_key, + options=ollama_options, + ) + elif binding == "azure_openai": + return await azure_openai_embed( + texts, + model=model, + api_key=api_key, + ) + elif binding == "aws_bedrock": + return await bedrock_embed( + texts, + model=model, + ) + elif binding == "jina": + return await jina_embed( + texts, + dimensions=dimensions, + base_url=host, + api_key=api_key, + ) + else: + # Default to OpenAI-compatible embedding + return await openai_embed( + texts, + model=model, + base_url=host, + api_key=api_key, + ) + + return embedding_function + + # Create embedding function with current configuration embedding_func = EmbeddingFunc( embedding_dim=args.embedding_dim, - func=lambda texts: ( - lollms_embed( - texts, - embed_model=embedding_model, - host=embedding_host, - api_key=embedding_api_key, - ) - if embedding_binding == "lollms" - else ( - ollama_embed( - texts, - embed_model=embedding_model, - host=embedding_host, - api_key=embedding_api_key, - options=ollama_options_val, - ) - if embedding_binding == "ollama" - else ( - azure_openai_embed( - texts, - model=embedding_model, # no host is used for openai, - api_key=embedding_api_key, - ) - if embedding_binding == "azure_openai" - else ( - bedrock_embed( - texts, - model=embedding_model, - ) - if embedding_binding == "aws_bedrock" - else ( - jina_embed( - texts, - dimensions=embedding_dim_val, - base_url=embedding_host, - api_key=embedding_api_key, - ) - if embedding_binding == "jina" - else openai_embed( - texts, - model=embedding_model, - base_url=embedding_host, - api_key=embedding_api_key, - ) - ) - ) - ) - ) + func=create_embedding_function( + binding=args.embedding_binding, + model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + dimensions=args.embedding_dim, + args=args, # Pass args object for dynamic option generation ), ) diff --git a/tests/test_server_embedding_logic.py b/tests/test_server_embedding_logic.py deleted file mode 100644 index ef5ac804..00000000 --- a/tests/test_server_embedding_logic.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -Tests the fix for the lambda closure bug in the API server's embedding function. - -Issue: https://github.com/HKUDS/LightRAG/issues/2023 -""" - -import pytest -from unittest.mock import Mock, patch, AsyncMock -import numpy as np - -# Functions to be patched -from lightrag.llm.ollama import ollama_embed -from lightrag.llm.openai import openai_embed - - -@pytest.fixture -def mock_args(): - """Provides a mock of the server's arguments object.""" - args = Mock() - args.embedding_binding = "ollama" - args.embedding_model = "mxbai-embed-large:latest" - args.embedding_binding_host = "http://localhost:11434" - args.embedding_binding_api_key = None - args.embedding_dim = 1024 - args.OllamaEmbeddingOptions.options_dict.return_value = {"num_ctx": 4096} - return args - - -@pytest.mark.asyncio -@patch("lightrag.llm.openai.openai_embed", new_callable=AsyncMock) -@patch("lightrag.llm.ollama.ollama_embed", new_callable=AsyncMock) -async def test_embedding_func_captures_values_correctly( - mock_ollama_embed, mock_openai_embed, mock_args -): - """ - Verifies that the embedding function correctly captures configuration - values at creation time and is not affected by later mutations of its source. - """ - # --- Setup Mocks --- - mock_ollama_embed.return_value = np.array([[0.1, 0.2, 0.3]]) - mock_openai_embed.return_value = np.array([[0.4, 0.5, 0.6]]) - - # --- SIMULATE THE FIX: Capture values before creating the function --- - binding = mock_args.embedding_binding - model = mock_args.embedding_model - host = mock_args.embedding_binding_host - api_key = mock_args.embedding_binding_api_key - - # CORRECTED: Use an async def instead of a lambda - async def fixed_func(texts): - if binding == "ollama": - return await ollama_embed( - texts, embed_model=model, host=host, api_key=api_key - ) - else: - return await openai_embed( - texts, model=model, base_url=host, api_key=api_key - ) - - # --- VERIFICATION --- - - # 1. First call: The function should use the initial "ollama" binding. - await fixed_func(["hello world"]) - mock_ollama_embed.assert_awaited_once() - mock_openai_embed.assert_not_called() - - # 2. CRITICAL STEP: Mutate the original args object AFTER the function is created. - mock_args.embedding_binding = "openai" - - # 3. Reset mocks and call the function AGAIN. - mock_ollama_embed.reset_mock() - mock_openai_embed.reset_mock() - - await fixed_func(["see you again"]) - - # 4. Final check: The function should STILL call ollama_embed. - mock_ollama_embed.assert_awaited_once() - mock_openai_embed.assert_not_called()