From 5c10d3d58e0225401a9a115580d2fe9d45384c35 Mon Sep 17 00:00:00 2001 From: BukeLy Date: Wed, 19 Nov 2025 02:11:39 +0800 Subject: [PATCH] feat: enhance EmbeddingFunc with model_name support Why this change is needed: To support vector storage model isolation, we need to track which model is used for embeddings and generate unique identifiers for collections/tables. How it solves it: - Added model_name field to EmbeddingFunc - Added get_model_identifier() method to generate sanitized suffix - Added unit tests to verify behavior Impact: Enables subsequent changes in storage backends to isolate data by model. Testing: Added tests/test_embedding_func.py passing. --- lightrag/utils.py | 13 +++++++++++++ tests/test_embedding_func.py | 37 ++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 tests/test_embedding_func.py diff --git a/lightrag/utils.py b/lightrag/utils.py index 8c9b7776..66104f1e 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -370,6 +370,19 @@ class EmbeddingFunc: send_dimensions: bool = ( False # Control whether to send embedding_dim to the function ) + model_name: str | None = None + + def get_model_identifier(self) -> str: + """Generates model identifier for collection/table suffix. + + Returns: + str: Format "{model_name}_{dim}d", e.g. "text_embedding_3_large_3072d" + If model_name is not specified, returns "unknown_{dim}d" + """ + model_part = self.model_name if self.model_name else "unknown" + # Clean model name: remove special chars, convert to lower, replace - with _ + safe_model_name = re.sub(r'[^a-zA-Z0-9_]', '_', model_part.lower()) + return f"{safe_model_name}_{self.embedding_dim}d" async def __call__(self, *args, **kwargs) -> np.ndarray: # Only inject embedding_dim when send_dimensions is True diff --git a/tests/test_embedding_func.py b/tests/test_embedding_func.py new file mode 100644 index 00000000..357e5808 --- /dev/null +++ b/tests/test_embedding_func.py @@ -0,0 +1,37 @@ +import pytest +from lightrag.utils import EmbeddingFunc + +def dummy_func(*args, **kwargs): + pass + +def test_embedding_func_with_model_name(): + func = EmbeddingFunc( + embedding_dim=1536, + func=dummy_func, + model_name="text-embedding-ada-002" + ) + assert func.get_model_identifier() == "text_embedding_ada_002_1536d" + +def test_embedding_func_without_model_name(): + func = EmbeddingFunc( + embedding_dim=768, + func=dummy_func + ) + assert func.get_model_identifier() == "unknown_768d" + +def test_model_name_sanitization(): + func = EmbeddingFunc( + embedding_dim=1024, + func=dummy_func, + model_name="models/text-embedding-004" # Contains special chars + ) + assert func.get_model_identifier() == "models_text_embedding_004_1024d" + +def test_model_name_with_uppercase(): + func = EmbeddingFunc( + embedding_dim=512, + func=dummy_func, + model_name="My-Model-V1" + ) + assert func.get_model_identifier() == "my_model_v1_512d" +