Refactor Jina embedding dimension by changing param to optional with default
This commit is contained in:
parent
d95efcb9ad
commit
01b07b2be5
1 changed files with 13 additions and 4 deletions
|
|
@ -1,4 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
import pipmaster as pm # Pipmaster for dynamic library install
|
import pipmaster as pm # Pipmaster for dynamic library install
|
||||||
|
|
||||||
# install specific modules
|
# install specific modules
|
||||||
|
|
@ -19,6 +21,9 @@ from tenacity import (
|
||||||
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_JINA_EMBED_DIM: Final[int] = 2048
|
||||||
|
|
||||||
|
|
||||||
async def fetch_data(url, headers, data):
|
async def fetch_data(url, headers, data):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
|
@ -58,7 +63,7 @@ async def fetch_data(url, headers, data):
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
@wrap_embedding_func_with_attrs(embedding_dim=DEFAULT_JINA_EMBED_DIM)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
@ -69,7 +74,7 @@ async def fetch_data(url, headers, data):
|
||||||
)
|
)
|
||||||
async def jina_embed(
|
async def jina_embed(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
embedding_dim: int = 2048,
|
embedding_dim: int | None = DEFAULT_JINA_EMBED_DIM,
|
||||||
late_chunking: bool = False,
|
late_chunking: bool = False,
|
||||||
base_url: str = None,
|
base_url: str = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
|
|
@ -95,6 +100,10 @@ async def jina_embed(
|
||||||
aiohttp.ClientError: If there is a connection error with the Jina API.
|
aiohttp.ClientError: If there is a connection error with the Jina API.
|
||||||
aiohttp.ClientResponseError: If the Jina API returns an error response.
|
aiohttp.ClientResponseError: If the Jina API returns an error response.
|
||||||
"""
|
"""
|
||||||
|
resolved_embedding_dim = (
|
||||||
|
embedding_dim if embedding_dim is not None else DEFAULT_JINA_EMBED_DIM
|
||||||
|
)
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
os.environ["JINA_API_KEY"] = api_key
|
os.environ["JINA_API_KEY"] = api_key
|
||||||
|
|
||||||
|
|
@ -109,7 +118,7 @@ async def jina_embed(
|
||||||
data = {
|
data = {
|
||||||
"model": "jina-embeddings-v4",
|
"model": "jina-embeddings-v4",
|
||||||
"task": "text-matching",
|
"task": "text-matching",
|
||||||
"dimensions": embedding_dim,
|
"dimensions": resolved_embedding_dim,
|
||||||
"embedding_type": "base64",
|
"embedding_type": "base64",
|
||||||
"input": texts,
|
"input": texts,
|
||||||
}
|
}
|
||||||
|
|
@ -119,7 +128,7 @@ async def jina_embed(
|
||||||
data["late_chunking"] = late_chunking
|
data["late_chunking"] = late_chunking
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
|
f"Jina embedding request: {len(texts)} texts, dimensions: {resolved_embedding_dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue