This commit is contained in:
Raphaël MANSUY 2025-12-04 19:14:25 +08:00
parent 56b8806256
commit e11e30be0e

View file

@ -1,4 +1,6 @@
import os import os
from typing import Final
import pipmaster as pm # Pipmaster for dynamic library install import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules # install specific modules
@ -19,6 +21,9 @@ from tenacity import (
from lightrag.utils import wrap_embedding_func_with_attrs, logger from lightrag.utils import wrap_embedding_func_with_attrs, logger
DEFAULT_JINA_EMBED_DIM: Final[int] = 2048
async def fetch_data(url, headers, data): async def fetch_data(url, headers, data):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
@ -58,7 +63,7 @@ async def fetch_data(url, headers, data):
return data_list return data_list
@wrap_embedding_func_with_attrs(embedding_dim=2048) @wrap_embedding_func_with_attrs(embedding_dim=DEFAULT_JINA_EMBED_DIM)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60), wait=wait_exponential(multiplier=1, min=4, max=60),
@ -69,7 +74,7 @@ async def fetch_data(url, headers, data):
) )
async def jina_embed( async def jina_embed(
texts: list[str], texts: list[str],
dimensions: int = 2048, embedding_dim: int | None = DEFAULT_JINA_EMBED_DIM,
late_chunking: bool = False, late_chunking: bool = False,
base_url: str = None, base_url: str = None,
api_key: str = None, api_key: str = None,
@ -78,7 +83,12 @@ async def jina_embed(
Args: Args:
texts: List of texts to embed. texts: List of texts to embed.
dimensions: The embedding dimensions (default: 2048 for jina-embeddings-v4). embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the Jina API for dimension reduction.
late_chunking: Whether to use late chunking. late_chunking: Whether to use late chunking.
base_url: Optional base URL for the Jina API. base_url: Optional base URL for the Jina API.
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable. api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
@ -90,6 +100,10 @@ async def jina_embed(
aiohttp.ClientError: If there is a connection error with the Jina API. aiohttp.ClientError: If there is a connection error with the Jina API.
aiohttp.ClientResponseError: If the Jina API returns an error response. aiohttp.ClientResponseError: If the Jina API returns an error response.
""" """
resolved_embedding_dim = (
embedding_dim if embedding_dim is not None else DEFAULT_JINA_EMBED_DIM
)
if api_key: if api_key:
os.environ["JINA_API_KEY"] = api_key os.environ["JINA_API_KEY"] = api_key
@ -104,7 +118,7 @@ async def jina_embed(
data = { data = {
"model": "jina-embeddings-v4", "model": "jina-embeddings-v4",
"task": "text-matching", "task": "text-matching",
"dimensions": dimensions, "dimensions": resolved_embedding_dim,
"embedding_type": "base64", "embedding_type": "base64",
"input": texts, "input": texts,
} }
@ -114,7 +128,7 @@ async def jina_embed(
data["late_chunking"] = late_chunking data["late_chunking"] = late_chunking
logger.debug( logger.debug(
f"Jina embedding request: {len(texts)} texts, dimensions: {dimensions}" f"Jina embedding request: {len(texts)} texts, dimensions: {resolved_embedding_dim}"
) )
try: try: