LightRAG/lightrag/llm/vietnamese_embed.py
2025-10-25 16:09:06 +07:00

198 lines
6.3 KiB
Python

"""
Vietnamese Embedding Integration for LightRAG
Model: AITeamVN/Vietnamese_Embedding
Base: BAAI/bge-m3
"""
import os
import numpy as np
import torch
from functools import lru_cache
import pipmaster as pm
# Install required packages
if not pm.is_installed("transformers"):
pm.install("transformers")
if not pm.is_installed("torch"):
pm.install("torch")
if not pm.is_installed("numpy"):
pm.install("numpy")
from transformers import AutoTokenizer, AutoModel
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import wrap_embedding_func_with_attrs, logger
from lightrag.exceptions import APIConnectionError, RateLimitError, APITimeoutError
# Disable tokenizers parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@lru_cache(maxsize=1)
def initialize_vietnamese_embedding_model(
model_name: str = "AITeamVN/Vietnamese_Embedding",
token: str | None = None,
):
"""
Initialize the Vietnamese Embedding model with caching.
Args:
model_name: HuggingFace model identifier
token: HuggingFace API token for model access
Returns:
Tuple of (model, tokenizer)
"""
logger.info(f"Loading Vietnamese Embedding model: {model_name}")
# Get token from environment if not provided
if token is None:
token = os.environ.get("HUGGINGFACE_API_KEY") or os.environ.get("HF_TOKEN")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=token,
trust_remote_code=True
)
model = AutoModel.from_pretrained(
model_name,
token=token,
trust_remote_code=True
)
logger.info("Vietnamese Embedding model loaded successfully")
return model, tokenizer
except Exception as e:
logger.error(f"Failed to load Vietnamese Embedding model: {e}")
raise
@wrap_embedding_func_with_attrs(embedding_dim=1024)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def vietnamese_embed(
texts: list[str],
model_name: str = "AITeamVN/Vietnamese_Embedding",
token: str | None = None,
) -> np.ndarray:
"""
Generate embeddings for Vietnamese texts using AITeamVN/Vietnamese_Embedding model.
This model is based on BGE-M3 and fine-tuned on Vietnamese data with:
- Maximum sequence length: 2048 tokens
- Output dimensionality: 1024 dimensions
- Similarity function: Dot product similarity
Args:
texts: List of texts to embed (in Vietnamese or other languages)
model_name: HuggingFace model identifier (default: AITeamVN/Vietnamese_Embedding)
token: HuggingFace API token for model access
Returns:
numpy array of embeddings with shape (len(texts), 1024)
Raises:
APIConnectionError: If there is a connection error
RateLimitError: If rate limit is exceeded
APITimeoutError: If request times out
"""
# Get token from environment if not provided
if token is None:
token = os.environ.get("HUGGINGFACE_API_KEY") or os.environ.get("HF_TOKEN")
# Initialize model and tokenizer
model, tokenizer = initialize_vietnamese_embedding_model(model_name, token)
# Detect the appropriate device
if torch.cuda.is_available():
device = torch.device("cuda")
logger.debug("Using CUDA device for embedding")
elif torch.backends.mps.is_available():
device = torch.device("mps")
logger.debug("Using MPS device for embedding")
else:
device = torch.device("cpu")
logger.debug("Using CPU device for embedding")
# Move model to device
model = model.to(device)
model.eval() # Set to evaluation mode
try:
# Tokenize texts with max_length matching the model's training
# Vietnamese_Embedding was trained with max_length=2048
encoded_input = tokenizer(
texts,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt"
).to(device)
# Generate embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Use mean pooling on the token embeddings
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings for dot product similarity
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
# Convert to numpy array
if embeddings.dtype == torch.bfloat16:
embeddings_np = embeddings.to(torch.float32).cpu().numpy()
else:
embeddings_np = embeddings.cpu().numpy()
logger.debug(f"Generated embeddings for {len(texts)} texts, shape: {embeddings_np.shape}")
return embeddings_np
except Exception as e:
logger.error(f"Error generating Vietnamese embeddings: {e}")
raise APIConnectionError(f"Vietnamese embedding generation failed: {e}")
def mean_pooling(model_output, attention_mask):
"""
Perform mean pooling on token embeddings.
Args:
model_output: Model output containing token embeddings
attention_mask: Attention mask to exclude padding tokens
Returns:
Pooled embeddings
"""
token_embeddings = model_output[0] # First element contains token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
# Convenience function for easier integration
@wrap_embedding_func_with_attrs(embedding_dim=1024)
async def vietnamese_embedding_func(texts: list[str]) -> np.ndarray:
"""
Convenience wrapper for Vietnamese embedding that reads token from environment.
Set HUGGINGFACE_API_KEY or HF_TOKEN environment variable with your HuggingFace token.
Args:
texts: List of texts to embed
Returns:
numpy array of embeddings
"""
return await vietnamese_embed(texts)