feat: enhance EmbeddingFunc with model_name support

Why this change is needed:
To support vector storage model isolation, we need to track which model is used for embeddings and generate unique identifiers for collections/tables.

How it solves it:
- Added model_name field to EmbeddingFunc
- Added get_model_identifier() method to generate sanitized suffix
- Added unit tests to verify behavior

Impact:
Enables subsequent changes in storage backends to isolate data by model.

Testing:
Added tests/test_embedding_func.py passing.
This commit is contained in:
BukeLy 2025-11-19 02:11:39 +08:00
parent d16c7840ab
commit 5c10d3d58e
2 changed files with 50 additions and 0 deletions

View file

@ -370,6 +370,19 @@ class EmbeddingFunc:
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
model_name: str | None = None
def get_model_identifier(self) -> str:
"""Generates model identifier for collection/table suffix.
Returns:
str: Format "{model_name}_{dim}d", e.g. "text_embedding_3_large_3072d"
If model_name is not specified, returns "unknown_{dim}d"
"""
model_part = self.model_name if self.model_name else "unknown"
# Clean model name: remove special chars, convert to lower, replace - with _
safe_model_name = re.sub(r'[^a-zA-Z0-9_]', '_', model_part.lower())
return f"{safe_model_name}_{self.embedding_dim}d"
async def __call__(self, *args, **kwargs) -> np.ndarray:
# Only inject embedding_dim when send_dimensions is True

View file

@ -0,0 +1,37 @@
import pytest
from lightrag.utils import EmbeddingFunc
def dummy_func(*args, **kwargs):
pass
def test_embedding_func_with_model_name():
func = EmbeddingFunc(
embedding_dim=1536,
func=dummy_func,
model_name="text-embedding-ada-002"
)
assert func.get_model_identifier() == "text_embedding_ada_002_1536d"
def test_embedding_func_without_model_name():
func = EmbeddingFunc(
embedding_dim=768,
func=dummy_func
)
assert func.get_model_identifier() == "unknown_768d"
def test_model_name_sanitization():
func = EmbeddingFunc(
embedding_dim=1024,
func=dummy_func,
model_name="models/text-embedding-004" # Contains special chars
)
assert func.get_model_identifier() == "models_text_embedding_004_1024d"
def test_model_name_with_uppercase():
func = EmbeddingFunc(
embedding_dim=512,
func=dummy_func,
model_name="My-Model-V1"
)
assert func.get_model_identifier() == "my_model_v1_512d"