198 lines
6.3 KiB
Python
198 lines
6.3 KiB
Python
"""
|
|
Vietnamese Embedding Integration for LightRAG
|
|
Model: AITeamVN/Vietnamese_Embedding
|
|
Base: BAAI/bge-m3
|
|
"""
|
|
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
from functools import lru_cache
|
|
|
|
import pipmaster as pm
|
|
|
|
# Install required packages
|
|
if not pm.is_installed("transformers"):
|
|
pm.install("transformers")
|
|
if not pm.is_installed("torch"):
|
|
pm.install("torch")
|
|
if not pm.is_installed("numpy"):
|
|
pm.install("numpy")
|
|
|
|
from transformers import AutoTokenizer, AutoModel
|
|
from tenacity import (
|
|
retry,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
retry_if_exception_type,
|
|
)
|
|
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
|
from lightrag.exceptions import APIConnectionError, RateLimitError, APITimeoutError
|
|
|
|
# Disable tokenizers parallelism to avoid warnings
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def initialize_vietnamese_embedding_model(
|
|
model_name: str = "AITeamVN/Vietnamese_Embedding",
|
|
token: str | None = None,
|
|
):
|
|
"""
|
|
Initialize the Vietnamese Embedding model with caching.
|
|
|
|
Args:
|
|
model_name: HuggingFace model identifier
|
|
token: HuggingFace API token for model access
|
|
|
|
Returns:
|
|
Tuple of (model, tokenizer)
|
|
"""
|
|
logger.info(f"Loading Vietnamese Embedding model: {model_name}")
|
|
|
|
# Get token from environment if not provided
|
|
if token is None:
|
|
token = os.environ.get("HUGGINGFACE_API_KEY") or os.environ.get("HF_TOKEN")
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_name,
|
|
token=token,
|
|
trust_remote_code=True
|
|
)
|
|
model = AutoModel.from_pretrained(
|
|
model_name,
|
|
token=token,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
logger.info("Vietnamese Embedding model loaded successfully")
|
|
return model, tokenizer
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load Vietnamese Embedding model: {e}")
|
|
raise
|
|
|
|
|
|
@wrap_embedding_func_with_attrs(embedding_dim=1024)
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
retry=retry_if_exception_type(
|
|
(RateLimitError, APIConnectionError, APITimeoutError)
|
|
),
|
|
)
|
|
async def vietnamese_embed(
|
|
texts: list[str],
|
|
model_name: str = "AITeamVN/Vietnamese_Embedding",
|
|
token: str | None = None,
|
|
) -> np.ndarray:
|
|
"""
|
|
Generate embeddings for Vietnamese texts using AITeamVN/Vietnamese_Embedding model.
|
|
|
|
This model is based on BGE-M3 and fine-tuned on Vietnamese data with:
|
|
- Maximum sequence length: 2048 tokens
|
|
- Output dimensionality: 1024 dimensions
|
|
- Similarity function: Dot product similarity
|
|
|
|
Args:
|
|
texts: List of texts to embed (in Vietnamese or other languages)
|
|
model_name: HuggingFace model identifier (default: AITeamVN/Vietnamese_Embedding)
|
|
token: HuggingFace API token for model access
|
|
|
|
Returns:
|
|
numpy array of embeddings with shape (len(texts), 1024)
|
|
|
|
Raises:
|
|
APIConnectionError: If there is a connection error
|
|
RateLimitError: If rate limit is exceeded
|
|
APITimeoutError: If request times out
|
|
"""
|
|
# Get token from environment if not provided
|
|
if token is None:
|
|
token = os.environ.get("HUGGINGFACE_API_KEY") or os.environ.get("HF_TOKEN")
|
|
|
|
# Initialize model and tokenizer
|
|
model, tokenizer = initialize_vietnamese_embedding_model(model_name, token)
|
|
|
|
# Detect the appropriate device
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
logger.debug("Using CUDA device for embedding")
|
|
elif torch.backends.mps.is_available():
|
|
device = torch.device("mps")
|
|
logger.debug("Using MPS device for embedding")
|
|
else:
|
|
device = torch.device("cpu")
|
|
logger.debug("Using CPU device for embedding")
|
|
|
|
# Move model to device
|
|
model = model.to(device)
|
|
model.eval() # Set to evaluation mode
|
|
|
|
try:
|
|
# Tokenize texts with max_length matching the model's training
|
|
# Vietnamese_Embedding was trained with max_length=2048
|
|
encoded_input = tokenizer(
|
|
texts,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=2048,
|
|
return_tensors="pt"
|
|
).to(device)
|
|
|
|
# Generate embeddings
|
|
with torch.no_grad():
|
|
model_output = model(**encoded_input)
|
|
# Use mean pooling on the token embeddings
|
|
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
|
# Normalize embeddings for dot product similarity
|
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
|
|
# Convert to numpy array
|
|
if embeddings.dtype == torch.bfloat16:
|
|
embeddings_np = embeddings.to(torch.float32).cpu().numpy()
|
|
else:
|
|
embeddings_np = embeddings.cpu().numpy()
|
|
|
|
logger.debug(f"Generated embeddings for {len(texts)} texts, shape: {embeddings_np.shape}")
|
|
return embeddings_np
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating Vietnamese embeddings: {e}")
|
|
raise APIConnectionError(f"Vietnamese embedding generation failed: {e}")
|
|
|
|
|
|
def mean_pooling(model_output, attention_mask):
|
|
"""
|
|
Perform mean pooling on token embeddings.
|
|
|
|
Args:
|
|
model_output: Model output containing token embeddings
|
|
attention_mask: Attention mask to exclude padding tokens
|
|
|
|
Returns:
|
|
Pooled embeddings
|
|
"""
|
|
token_embeddings = model_output[0] # First element contains token embeddings
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
|
input_mask_expanded.sum(1), min=1e-9
|
|
)
|
|
|
|
|
|
# Convenience function for easier integration
|
|
@wrap_embedding_func_with_attrs(embedding_dim=1024)
|
|
async def vietnamese_embedding_func(texts: list[str]) -> np.ndarray:
|
|
"""
|
|
Convenience wrapper for Vietnamese embedding that reads token from environment.
|
|
|
|
Set HUGGINGFACE_API_KEY or HF_TOKEN environment variable with your HuggingFace token.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
|
|
Returns:
|
|
numpy array of embeddings
|
|
"""
|
|
return await vietnamese_embed(texts)
|