fix: sync core modules from upstream

This commit is contained in:
Raphaël MANSUY 2025-12-04 19:20:28 +08:00
parent f6f3ed93d3
commit d1262e999d
2 changed files with 287 additions and 99 deletions

View file

@ -8,6 +8,10 @@ import json_repair
from typing import Any, AsyncIterator, overload, Literal from typing import Any, AsyncIterator, overload, Literal
from collections import Counter, defaultdict from collections import Counter, defaultdict
from lightrag.exceptions import (
PipelineCancelledException,
ChunkTokenLimitExceededError,
)
from lightrag.utils import ( from lightrag.utils import (
logger, logger,
compute_mdhash_id, compute_mdhash_id,
@ -82,11 +86,11 @@ def _truncate_entity_identifier(
display_value = identifier[:limit] display_value = identifier[:limit]
preview = identifier[:20] # Show first 20 characters as preview preview = identifier[:20] # Show first 20 characters as preview
logger.warning( logger.warning(
"%s: %s exceeded %d characters (len: %d, preview: '%s...'", "%s: %s len %d > %d chars (Name: '%s...')",
chunk_key, chunk_key,
identifier_role, identifier_role,
limit,
len(identifier), len(identifier),
limit,
preview, preview,
) )
return display_value return display_value
@ -97,8 +101,8 @@ def chunking_by_token_size(
content: str, content: str,
split_by_character: str | None = None, split_by_character: str | None = None,
split_by_character_only: bool = False, split_by_character_only: bool = False,
overlap_token_size: int = 128, chunk_overlap_token_size: int = 100,
max_token_size: int = 1024, chunk_token_size: int = 1200,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
tokens = tokenizer.encode(content) tokens = tokenizer.encode(content)
results: list[dict[str, Any]] = [] results: list[dict[str, Any]] = []
@ -108,19 +112,30 @@ def chunking_by_token_size(
if split_by_character_only: if split_by_character_only:
for chunk in raw_chunks: for chunk in raw_chunks:
_tokens = tokenizer.encode(chunk) _tokens = tokenizer.encode(chunk)
if len(_tokens) > chunk_token_size:
logger.warning(
"Chunk split_by_character exceeds token limit: len=%d limit=%d",
len(_tokens),
chunk_token_size,
)
raise ChunkTokenLimitExceededError(
chunk_tokens=len(_tokens),
chunk_token_limit=chunk_token_size,
chunk_preview=chunk[:120],
)
new_chunks.append((len(_tokens), chunk)) new_chunks.append((len(_tokens), chunk))
else: else:
for chunk in raw_chunks: for chunk in raw_chunks:
_tokens = tokenizer.encode(chunk) _tokens = tokenizer.encode(chunk)
if len(_tokens) > max_token_size: if len(_tokens) > chunk_token_size:
for start in range( for start in range(
0, len(_tokens), max_token_size - overlap_token_size 0, len(_tokens), chunk_token_size - chunk_overlap_token_size
): ):
chunk_content = tokenizer.decode( chunk_content = tokenizer.decode(
_tokens[start : start + max_token_size] _tokens[start : start + chunk_token_size]
) )
new_chunks.append( new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content) (min(chunk_token_size, len(_tokens) - start), chunk_content)
) )
else: else:
new_chunks.append((len(_tokens), chunk)) new_chunks.append((len(_tokens), chunk))
@ -134,12 +149,12 @@ def chunking_by_token_size(
) )
else: else:
for index, start in enumerate( for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size) range(0, len(tokens), chunk_token_size - chunk_overlap_token_size)
): ):
chunk_content = tokenizer.decode(tokens[start : start + max_token_size]) chunk_content = tokenizer.decode(tokens[start : start + chunk_token_size])
results.append( results.append(
{ {
"tokens": min(max_token_size, len(tokens) - start), "tokens": min(chunk_token_size, len(tokens) - start),
"content": chunk_content.strip(), "content": chunk_content.strip(),
"chunk_order_index": index, "chunk_order_index": index,
} }
@ -344,6 +359,20 @@ async def _summarize_descriptions(
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
cache_type="summary", cache_type="summary",
) )
# Check summary token length against embedding limit
embedding_token_limit = global_config.get("embedding_token_limit")
if embedding_token_limit is not None and summary:
tokenizer = global_config["tokenizer"]
summary_token_count = len(tokenizer.encode(summary))
threshold = int(embedding_token_limit * 0.9)
if summary_token_count > threshold:
logger.warning(
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
f"({embedding_token_limit}) for {description_type}: {description_name}"
)
return summary return summary
@ -988,13 +1017,13 @@ async def _process_extraction_result(
relationship_data["src_id"], relationship_data["src_id"],
DEFAULT_ENTITY_NAME_MAX_LENGTH, DEFAULT_ENTITY_NAME_MAX_LENGTH,
chunk_key, chunk_key,
"Relationship source entity", "Relation entity",
) )
truncated_target = _truncate_entity_identifier( truncated_target = _truncate_entity_identifier(
relationship_data["tgt_id"], relationship_data["tgt_id"],
DEFAULT_ENTITY_NAME_MAX_LENGTH, DEFAULT_ENTITY_NAME_MAX_LENGTH,
chunk_key, chunk_key,
"Relationship target entity", "Relation entity",
) )
relationship_data["src_id"] = truncated_source relationship_data["src_id"] = truncated_source
relationship_data["tgt_id"] = truncated_target relationship_data["tgt_id"] = truncated_target
@ -1694,6 +1723,12 @@ async def _merge_nodes_then_upsert(
logger.error(f"Entity {entity_name} has no description") logger.error(f"Entity {entity_name} has no description")
raise ValueError(f"Entity {entity_name} has no description") raise ValueError(f"Entity {entity_name} has no description")
# Check for cancellation before LLM summary
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException("User cancelled during entity summary")
# 8. Get summary description an LLM usage status # 8. Get summary description an LLM usage status
description, llm_was_used = await _handle_entity_relation_summary( description, llm_was_used = await _handle_entity_relation_summary(
"Entity", "Entity",
@ -2015,6 +2050,14 @@ async def _merge_edges_then_upsert(
logger.error(f"Relation {src_id}~{tgt_id} has no description") logger.error(f"Relation {src_id}~{tgt_id} has no description")
raise ValueError(f"Relation {src_id}~{tgt_id} has no description") raise ValueError(f"Relation {src_id}~{tgt_id} has no description")
# Check for cancellation before LLM summary
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during relation summary"
)
# 8. Get summary description an LLM usage status # 8. Get summary description an LLM usage status
description, llm_was_used = await _handle_entity_relation_summary( description, llm_was_used = await _handle_entity_relation_summary(
"Relation", "Relation",
@ -2396,6 +2439,12 @@ async def merge_nodes_and_edges(
file_path: File path for logging file_path: File path for logging
""" """
# Check for cancellation at the start of merge
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException("User cancelled during merge phase")
# Collect all nodes and edges from all chunks # Collect all nodes and edges from all chunks
all_nodes = defaultdict(list) all_nodes = defaultdict(list)
all_edges = defaultdict(list) all_edges = defaultdict(list)
@ -2432,6 +2481,14 @@ async def merge_nodes_and_edges(
async def _locked_process_entity_name(entity_name, entities): async def _locked_process_entity_name(entity_name, entities):
async with semaphore: async with semaphore:
# Check for cancellation before processing entity
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during entity merge"
)
workspace = global_config.get("workspace", "") workspace = global_config.get("workspace", "")
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
async with get_storage_keyed_lock( async with get_storage_keyed_lock(
@ -2454,9 +2511,7 @@ async def merge_nodes_and_edges(
return entity_data return entity_data
except Exception as e: except Exception as e:
error_msg = ( error_msg = f"Error processing entity `{entity_name}`: {e}"
f"Critical error in entity processing for `{entity_name}`: {e}"
)
logger.error(error_msg) logger.error(error_msg)
# Try to update pipeline status, but don't let status update failure affect main exception # Try to update pipeline status, but don't let status update failure affect main exception
@ -2492,36 +2547,32 @@ async def merge_nodes_and_edges(
entity_tasks, return_when=asyncio.FIRST_EXCEPTION entity_tasks, return_when=asyncio.FIRST_EXCEPTION
) )
# Check if any task raised an exception and ensure all exceptions are retrieved
first_exception = None first_exception = None
successful_results = [] processed_entities = []
for task in done: for task in done:
try: try:
exception = task.exception() result = task.result()
if exception is not None: except BaseException as e:
if first_exception is None:
first_exception = exception
else:
successful_results.append(task.result())
except Exception as e:
if first_exception is None: if first_exception is None:
first_exception = e first_exception = e
else:
processed_entities.append(result)
if pending:
for task in pending:
task.cancel()
pending_results = await asyncio.gather(*pending, return_exceptions=True)
for result in pending_results:
if isinstance(result, BaseException):
if first_exception is None:
first_exception = result
else:
processed_entities.append(result)
# If any task failed, cancel all pending tasks and raise the first exception
if first_exception is not None: if first_exception is not None:
# Cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Re-raise the first exception to notify the caller
raise first_exception raise first_exception
# If all tasks completed successfully, collect results
processed_entities = [task.result() for task in entity_tasks]
# ===== Phase 2: Process all relationships concurrently ===== # ===== Phase 2: Process all relationships concurrently =====
log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})" log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})"
logger.info(log_message) logger.info(log_message)
@ -2531,6 +2582,14 @@ async def merge_nodes_and_edges(
async def _locked_process_edges(edge_key, edges): async def _locked_process_edges(edge_key, edges):
async with semaphore: async with semaphore:
# Check for cancellation before processing edges
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during relation merge"
)
workspace = global_config.get("workspace", "") workspace = global_config.get("workspace", "")
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
sorted_edge_key = sorted([edge_key[0], edge_key[1]]) sorted_edge_key = sorted([edge_key[0], edge_key[1]])
@ -2566,7 +2625,7 @@ async def merge_nodes_and_edges(
return edge_data, added_entities return edge_data, added_entities
except Exception as e: except Exception as e:
error_msg = f"Critical error in relationship processing for `{sorted_edge_key}`: {e}" error_msg = f"Error processing relation `{sorted_edge_key}`: {e}"
logger.error(error_msg) logger.error(error_msg)
# Try to update pipeline status, but don't let status update failure affect main exception # Try to update pipeline status, but don't let status update failure affect main exception
@ -2604,40 +2663,36 @@ async def merge_nodes_and_edges(
edge_tasks, return_when=asyncio.FIRST_EXCEPTION edge_tasks, return_when=asyncio.FIRST_EXCEPTION
) )
# Check if any task raised an exception and ensure all exceptions are retrieved
first_exception = None first_exception = None
successful_results = []
for task in done: for task in done:
try: try:
exception = task.exception() edge_data, added_entities = task.result()
if exception is not None: except BaseException as e:
if first_exception is None:
first_exception = exception
else:
successful_results.append(task.result())
except Exception as e:
if first_exception is None: if first_exception is None:
first_exception = e first_exception = e
else:
if edge_data is not None:
processed_edges.append(edge_data)
all_added_entities.extend(added_entities)
if pending:
for task in pending:
task.cancel()
pending_results = await asyncio.gather(*pending, return_exceptions=True)
for result in pending_results:
if isinstance(result, BaseException):
if first_exception is None:
first_exception = result
else:
edge_data, added_entities = result
if edge_data is not None:
processed_edges.append(edge_data)
all_added_entities.extend(added_entities)
# If any task failed, cancel all pending tasks and raise the first exception
if first_exception is not None: if first_exception is not None:
# Cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Re-raise the first exception to notify the caller
raise first_exception raise first_exception
# If all tasks completed successfully, collect results
for task in edge_tasks:
edge_data, added_entities = task.result()
if edge_data is not None:
processed_edges.append(edge_data)
all_added_entities.extend(added_entities)
# ===== Phase 3: Update full_entities and full_relations storage ===== # ===== Phase 3: Update full_entities and full_relations storage =====
if full_entities_storage and full_relations_storage and doc_id: if full_entities_storage and full_relations_storage and doc_id:
try: try:
@ -2718,6 +2773,14 @@ async def extract_entities(
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
text_chunks_storage: BaseKVStorage | None = None, text_chunks_storage: BaseKVStorage | None = None,
) -> list: ) -> list:
# Check for cancellation at the start of entity extraction
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during entity extraction"
)
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@ -2885,6 +2948,14 @@ async def extract_entities(
async def _process_with_semaphore(chunk): async def _process_with_semaphore(chunk):
async with semaphore: async with semaphore:
# Check for cancellation before processing chunk
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during chunk processing"
)
try: try:
return await _process_single_content(chunk) return await _process_single_content(chunk)
except Exception as e: except Exception as e:
@ -3382,10 +3453,10 @@ async def _perform_kg_search(
) )
query_embedding = None query_embedding = None
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
embedding_func_config = text_chunks_db.embedding_func actual_embedding_func = text_chunks_db.embedding_func
if embedding_func_config and embedding_func_config.func: if actual_embedding_func:
try: try:
query_embedding = await embedding_func_config.func([query]) query_embedding = await actual_embedding_func([query])
query_embedding = query_embedding[ query_embedding = query_embedding[
0 0
] # Extract first embedding from batch result ] # Extract first embedding from batch result
@ -4293,25 +4364,21 @@ async def _find_related_text_unit_from_entities(
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2) num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
# Get embedding function from global config # Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func actual_embedding_func = text_chunks_db.embedding_func
if not embedding_func_config: if not actual_embedding_func:
logger.warning("No embedding function found, falling back to WEIGHT method") logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
else: else:
try: try:
actual_embedding_func = embedding_func_config.func selected_chunk_ids = await pick_by_vector_similarity(
query=query,
selected_chunk_ids = None text_chunks_storage=text_chunks_db,
if actual_embedding_func: chunks_vdb=chunks_vdb,
selected_chunk_ids = await pick_by_vector_similarity( num_of_chunks=num_of_chunks,
query=query, entity_info=entities_with_chunks,
text_chunks_storage=text_chunks_db, embedding_func=actual_embedding_func,
chunks_vdb=chunks_vdb, query_embedding=query_embedding,
num_of_chunks=num_of_chunks, )
entity_info=entities_with_chunks,
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
if selected_chunk_ids == []: if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
@ -4586,24 +4653,21 @@ async def _find_related_text_unit_from_relations(
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2) num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
# Get embedding function from global config # Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func actual_embedding_func = text_chunks_db.embedding_func
if not embedding_func_config: if not actual_embedding_func:
logger.warning("No embedding function found, falling back to WEIGHT method") logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
else: else:
try: try:
actual_embedding_func = embedding_func_config.func selected_chunk_ids = await pick_by_vector_similarity(
query=query,
if actual_embedding_func: text_chunks_storage=text_chunks_db,
selected_chunk_ids = await pick_by_vector_similarity( chunks_vdb=chunks_vdb,
query=query, num_of_chunks=num_of_chunks,
text_chunks_storage=text_chunks_db, entity_info=relations_with_chunks,
chunks_vdb=chunks_vdb, embedding_func=actual_embedding_func,
num_of_chunks=num_of_chunks, query_embedding=query_embedding,
entity_info=relations_with_chunks, )
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
if selected_chunk_ids == []: if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"

View file

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import weakref import weakref
import sys
import asyncio import asyncio
import html import html
import csv import csv
@ -40,6 +42,35 @@ from lightrag.constants import (
SOURCE_IDS_LIMIT_METHOD_FIFO, SOURCE_IDS_LIMIT_METHOD_FIFO,
) )
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
class SafeStreamHandler(logging.StreamHandler):
"""StreamHandler that gracefully handles closed streams during shutdown.
This handler prevents "ValueError: I/O operation on closed file" errors
that can occur when pytest or other test frameworks close stdout/stderr
before Python's logging cleanup runs.
"""
def flush(self):
"""Flush the stream, ignoring errors if the stream is closed."""
try:
super().flush()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
def close(self):
"""Close the handler, ignoring errors if the stream is already closed."""
try:
super().close()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
# Initialize logger with basic configuration # Initialize logger with basic configuration
logger = logging.getLogger("lightrag") logger = logging.getLogger("lightrag")
logger.propagate = False # prevent log message send to root logger logger.propagate = False # prevent log message send to root logger
@ -47,7 +78,7 @@ logger.setLevel(logging.INFO)
# Add console handler if no handlers exist # Add console handler if no handlers exist
if not logger.handlers: if not logger.handlers:
console_handler = logging.StreamHandler() console_handler = SafeStreamHandler()
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(levelname)s: %(message)s") formatter = logging.Formatter("%(levelname)s: %(message)s")
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
@ -56,8 +87,32 @@ if not logger.handlers:
# Set httpx logging level to WARNING # Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING)
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]") def _patch_ascii_colors_console_handler() -> None:
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
try:
from ascii_colors import ConsoleHandler
except ImportError:
return
if getattr(ConsoleHandler, "_lightrag_patched", False):
return
original_handle_error = ConsoleHandler.handle_error
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
exc_type, _, _ = sys.exc_info()
if exc_type in (ValueError, OSError) and "close" in message.lower():
return
original_handle_error(self, message)
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
_patch_ascii_colors_console_handler()
# Global import for pypinyin with startup-time logging # Global import for pypinyin with startup-time logging
try: try:
@ -286,8 +341,8 @@ def setup_logger(
logger_instance.handlers = [] # Clear existing handlers logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False logger_instance.propagate = False
# Add console handler # Add console handler with safe stream handling
console_handler = logging.StreamHandler() console_handler = SafeStreamHandler()
console_handler.setFormatter(simple_formatter) console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level) console_handler.setLevel(level)
logger_instance.addHandler(console_handler) logger_instance.addHandler(console_handler)
@ -950,7 +1005,76 @@ def priority_limit_async_func_call(
def wrap_embedding_func_with_attrs(**kwargs): def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes""" """Decorator to add embedding dimension and token limit attributes to embedding functions.
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
that automatically handles dimension parameter injection and attribute management.
WARNING: DO NOT apply this decorator to wrapper functions that call other
decorated embedding functions. This will cause double decoration and parameter
injection conflicts.
Correct usage patterns:
1. Direct implementation (decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_embed(texts, embedding_dim=None):
# Direct implementation
return embeddings
```
2. Wrapper calling decorated function (DO NOT decorate wrapper):
```python
# my_embed is already decorated above
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
# Must call .func to access unwrapped implementation
return await my_embed.func(texts, **kwargs)
```
3. Wrapper calling decorated function (properly decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
# Calling .func avoids double decoration
return await my_embed.func(texts, **kwargs)
```
The decorated function becomes an EmbeddingFunc instance with:
- embedding_dim: The embedding dimension
- max_token_size: Maximum token limit (optional)
- func: The original unwrapped function (access via .func)
- __call__: Wrapper that injects embedding_dim parameter
Double decoration causes:
- Double injection of embedding_dim parameter
- Incorrect parameter passing to the underlying implementation
- Runtime errors due to parameter conflicts
Args:
embedding_dim: The dimension of embedding vectors
max_token_size: Maximum number of tokens (optional)
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
Returns:
A decorator that wraps the function as an EmbeddingFunc instance
Example of correct wrapper implementation:
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(...)
async def openai_embed(texts, ...):
# Base implementation
pass
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
async def azure_openai_embed(texts, ...):
# CRITICAL: Call .func to access unwrapped function
return await openai_embed.func(texts, ...) # ✅ Correct
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
```
"""
def final_decro(func) -> EmbeddingFunc: def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func) new_func = EmbeddingFunc(**kwargs, func=func)