feat: enhance EmbeddingFunc with model_name support
Why this change is needed: To support vector storage model isolation, we need to track which model is used for embeddings and generate unique identifiers for collections/tables. How it solves it: - Added model_name field to EmbeddingFunc - Added get_model_identifier() method to generate sanitized suffix - Added unit tests to verify behavior Impact: Enables subsequent changes in storage backends to isolate data by model. Testing: Added tests/test_embedding_func.py passing.
This commit is contained in:
parent
d16c7840ab
commit
5c10d3d58e
2 changed files with 50 additions and 0 deletions
|
|
@ -370,6 +370,19 @@ class EmbeddingFunc:
|
|||
send_dimensions: bool = (
|
||||
False # Control whether to send embedding_dim to the function
|
||||
)
|
||||
model_name: str | None = None
|
||||
|
||||
def get_model_identifier(self) -> str:
|
||||
"""Generates model identifier for collection/table suffix.
|
||||
|
||||
Returns:
|
||||
str: Format "{model_name}_{dim}d", e.g. "text_embedding_3_large_3072d"
|
||||
If model_name is not specified, returns "unknown_{dim}d"
|
||||
"""
|
||||
model_part = self.model_name if self.model_name else "unknown"
|
||||
# Clean model name: remove special chars, convert to lower, replace - with _
|
||||
safe_model_name = re.sub(r'[^a-zA-Z0-9_]', '_', model_part.lower())
|
||||
return f"{safe_model_name}_{self.embedding_dim}d"
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
# Only inject embedding_dim when send_dimensions is True
|
||||
|
|
|
|||
37
tests/test_embedding_func.py
Normal file
37
tests/test_embedding_func.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import pytest
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
|
||||
def dummy_func(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def test_embedding_func_with_model_name():
|
||||
func = EmbeddingFunc(
|
||||
embedding_dim=1536,
|
||||
func=dummy_func,
|
||||
model_name="text-embedding-ada-002"
|
||||
)
|
||||
assert func.get_model_identifier() == "text_embedding_ada_002_1536d"
|
||||
|
||||
def test_embedding_func_without_model_name():
|
||||
func = EmbeddingFunc(
|
||||
embedding_dim=768,
|
||||
func=dummy_func
|
||||
)
|
||||
assert func.get_model_identifier() == "unknown_768d"
|
||||
|
||||
def test_model_name_sanitization():
|
||||
func = EmbeddingFunc(
|
||||
embedding_dim=1024,
|
||||
func=dummy_func,
|
||||
model_name="models/text-embedding-004" # Contains special chars
|
||||
)
|
||||
assert func.get_model_identifier() == "models_text_embedding_004_1024d"
|
||||
|
||||
def test_model_name_with_uppercase():
|
||||
func = EmbeddingFunc(
|
||||
embedding_dim=512,
|
||||
func=dummy_func,
|
||||
model_name="My-Model-V1"
|
||||
)
|
||||
assert func.get_model_identifier() == "my_model_v1_512d"
|
||||
|
||||
Loading…
Add table
Reference in a new issue