fix: sync core modules from upstream
This commit is contained in:
parent
f6f3ed93d3
commit
d1262e999d
2 changed files with 287 additions and 99 deletions
|
|
@ -8,6 +8,10 @@ import json_repair
|
|||
from typing import Any, AsyncIterator, overload, Literal
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
from lightrag.exceptions import (
|
||||
PipelineCancelledException,
|
||||
ChunkTokenLimitExceededError,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
compute_mdhash_id,
|
||||
|
|
@ -82,11 +86,11 @@ def _truncate_entity_identifier(
|
|||
display_value = identifier[:limit]
|
||||
preview = identifier[:20] # Show first 20 characters as preview
|
||||
logger.warning(
|
||||
"%s: %s exceeded %d characters (len: %d, preview: '%s...'",
|
||||
"%s: %s len %d > %d chars (Name: '%s...')",
|
||||
chunk_key,
|
||||
identifier_role,
|
||||
limit,
|
||||
len(identifier),
|
||||
limit,
|
||||
preview,
|
||||
)
|
||||
return display_value
|
||||
|
|
@ -97,8 +101,8 @@ def chunking_by_token_size(
|
|||
content: str,
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
overlap_token_size: int = 128,
|
||||
max_token_size: int = 1024,
|
||||
chunk_overlap_token_size: int = 100,
|
||||
chunk_token_size: int = 1200,
|
||||
) -> list[dict[str, Any]]:
|
||||
tokens = tokenizer.encode(content)
|
||||
results: list[dict[str, Any]] = []
|
||||
|
|
@ -108,19 +112,30 @@ def chunking_by_token_size(
|
|||
if split_by_character_only:
|
||||
for chunk in raw_chunks:
|
||||
_tokens = tokenizer.encode(chunk)
|
||||
if len(_tokens) > chunk_token_size:
|
||||
logger.warning(
|
||||
"Chunk split_by_character exceeds token limit: len=%d limit=%d",
|
||||
len(_tokens),
|
||||
chunk_token_size,
|
||||
)
|
||||
raise ChunkTokenLimitExceededError(
|
||||
chunk_tokens=len(_tokens),
|
||||
chunk_token_limit=chunk_token_size,
|
||||
chunk_preview=chunk[:120],
|
||||
)
|
||||
new_chunks.append((len(_tokens), chunk))
|
||||
else:
|
||||
for chunk in raw_chunks:
|
||||
_tokens = tokenizer.encode(chunk)
|
||||
if len(_tokens) > max_token_size:
|
||||
if len(_tokens) > chunk_token_size:
|
||||
for start in range(
|
||||
0, len(_tokens), max_token_size - overlap_token_size
|
||||
0, len(_tokens), chunk_token_size - chunk_overlap_token_size
|
||||
):
|
||||
chunk_content = tokenizer.decode(
|
||||
_tokens[start : start + max_token_size]
|
||||
_tokens[start : start + chunk_token_size]
|
||||
)
|
||||
new_chunks.append(
|
||||
(min(max_token_size, len(_tokens) - start), chunk_content)
|
||||
(min(chunk_token_size, len(_tokens) - start), chunk_content)
|
||||
)
|
||||
else:
|
||||
new_chunks.append((len(_tokens), chunk))
|
||||
|
|
@ -134,12 +149,12 @@ def chunking_by_token_size(
|
|||
)
|
||||
else:
|
||||
for index, start in enumerate(
|
||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||
range(0, len(tokens), chunk_token_size - chunk_overlap_token_size)
|
||||
):
|
||||
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
|
||||
chunk_content = tokenizer.decode(tokens[start : start + chunk_token_size])
|
||||
results.append(
|
||||
{
|
||||
"tokens": min(max_token_size, len(tokens) - start),
|
||||
"tokens": min(chunk_token_size, len(tokens) - start),
|
||||
"content": chunk_content.strip(),
|
||||
"chunk_order_index": index,
|
||||
}
|
||||
|
|
@ -344,6 +359,20 @@ async def _summarize_descriptions(
|
|||
llm_response_cache=llm_response_cache,
|
||||
cache_type="summary",
|
||||
)
|
||||
|
||||
# Check summary token length against embedding limit
|
||||
embedding_token_limit = global_config.get("embedding_token_limit")
|
||||
if embedding_token_limit is not None and summary:
|
||||
tokenizer = global_config["tokenizer"]
|
||||
summary_token_count = len(tokenizer.encode(summary))
|
||||
threshold = int(embedding_token_limit * 0.9)
|
||||
|
||||
if summary_token_count > threshold:
|
||||
logger.warning(
|
||||
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
|
||||
f"({embedding_token_limit}) for {description_type}: {description_name}"
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
|
|
@ -988,13 +1017,13 @@ async def _process_extraction_result(
|
|||
relationship_data["src_id"],
|
||||
DEFAULT_ENTITY_NAME_MAX_LENGTH,
|
||||
chunk_key,
|
||||
"Relationship source entity",
|
||||
"Relation entity",
|
||||
)
|
||||
truncated_target = _truncate_entity_identifier(
|
||||
relationship_data["tgt_id"],
|
||||
DEFAULT_ENTITY_NAME_MAX_LENGTH,
|
||||
chunk_key,
|
||||
"Relationship target entity",
|
||||
"Relation entity",
|
||||
)
|
||||
relationship_data["src_id"] = truncated_source
|
||||
relationship_data["tgt_id"] = truncated_target
|
||||
|
|
@ -1694,6 +1723,12 @@ async def _merge_nodes_then_upsert(
|
|||
logger.error(f"Entity {entity_name} has no description")
|
||||
raise ValueError(f"Entity {entity_name} has no description")
|
||||
|
||||
# Check for cancellation before LLM summary
|
||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||
async with pipeline_status_lock:
|
||||
if pipeline_status.get("cancellation_requested", False):
|
||||
raise PipelineCancelledException("User cancelled during entity summary")
|
||||
|
||||
# 8. Get summary description an LLM usage status
|
||||
description, llm_was_used = await _handle_entity_relation_summary(
|
||||
"Entity",
|
||||
|
|
@ -2015,6 +2050,14 @@ async def _merge_edges_then_upsert(
|
|||
logger.error(f"Relation {src_id}~{tgt_id} has no description")
|
||||
raise ValueError(f"Relation {src_id}~{tgt_id} has no description")
|
||||
|
||||
# Check for cancellation before LLM summary
|
||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||
async with pipeline_status_lock:
|
||||
if pipeline_status.get("cancellation_requested", False):
|
||||
raise PipelineCancelledException(
|
||||
"User cancelled during relation summary"
|
||||
)
|
||||
|
||||
# 8. Get summary description an LLM usage status
|
||||
description, llm_was_used = await _handle_entity_relation_summary(
|
||||
"Relation",
|
||||
|
|
@ -2396,6 +2439,12 @@ async def merge_nodes_and_edges(
|
|||
file_path: File path for logging
|
||||
"""
|
||||
|
||||
# Check for cancellation at the start of merge
|
||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||
async with pipeline_status_lock:
|
||||
if pipeline_status.get("cancellation_requested", False):
|
||||
raise PipelineCancelledException("User cancelled during merge phase")
|
||||
|
||||
# Collect all nodes and edges from all chunks
|
||||
all_nodes = defaultdict(list)
|
||||
all_edges = defaultdict(list)
|
||||
|
|
@ -2432,6 +2481,14 @@ async def merge_nodes_and_edges(
|
|||
|
||||
async def _locked_process_entity_name(entity_name, entities):
|
||||
async with semaphore:
|
||||
# Check for cancellation before processing entity
|
||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||
async with pipeline_status_lock:
|
||||
if pipeline_status.get("cancellation_requested", False):
|
||||
raise PipelineCancelledException(
|
||||
"User cancelled during entity merge"
|
||||
)
|
||||
|
||||
workspace = global_config.get("workspace", "")
|
||||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||||
async with get_storage_keyed_lock(
|
||||
|
|
@ -2454,9 +2511,7 @@ async def merge_nodes_and_edges(
|
|||
return entity_data
|
||||
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"Critical error in entity processing for `{entity_name}`: {e}"
|
||||
)
|
||||
error_msg = f"Error processing entity `{entity_name}`: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# Try to update pipeline status, but don't let status update failure affect main exception
|
||||
|
|
@ -2492,36 +2547,32 @@ async def merge_nodes_and_edges(
|
|||
entity_tasks, return_when=asyncio.FIRST_EXCEPTION
|
||||
)
|
||||
|
||||
# Check if any task raised an exception and ensure all exceptions are retrieved
|
||||
first_exception = None
|
||||
successful_results = []
|
||||
processed_entities = []
|
||||
|
||||
for task in done:
|
||||
try:
|
||||
exception = task.exception()
|
||||
if exception is not None:
|
||||
if first_exception is None:
|
||||
first_exception = exception
|
||||
else:
|
||||
successful_results.append(task.result())
|
||||
except Exception as e:
|
||||
result = task.result()
|
||||
except BaseException as e:
|
||||
if first_exception is None:
|
||||
first_exception = e
|
||||
else:
|
||||
processed_entities.append(result)
|
||||
|
||||
if pending:
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
pending_results = await asyncio.gather(*pending, return_exceptions=True)
|
||||
for result in pending_results:
|
||||
if isinstance(result, BaseException):
|
||||
if first_exception is None:
|
||||
first_exception = result
|
||||
else:
|
||||
processed_entities.append(result)
|
||||
|
||||
# If any task failed, cancel all pending tasks and raise the first exception
|
||||
if first_exception is not None:
|
||||
# Cancel all pending tasks
|
||||
for pending_task in pending:
|
||||
pending_task.cancel()
|
||||
# Wait for cancellation to complete
|
||||
if pending:
|
||||
await asyncio.wait(pending)
|
||||
# Re-raise the first exception to notify the caller
|
||||
raise first_exception
|
||||
|
||||
# If all tasks completed successfully, collect results
|
||||
processed_entities = [task.result() for task in entity_tasks]
|
||||
|
||||
# ===== Phase 2: Process all relationships concurrently =====
|
||||
log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})"
|
||||
logger.info(log_message)
|
||||
|
|
@ -2531,6 +2582,14 @@ async def merge_nodes_and_edges(
|
|||
|
||||
async def _locked_process_edges(edge_key, edges):
|
||||
async with semaphore:
|
||||
# Check for cancellation before processing edges
|
||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||
async with pipeline_status_lock:
|
||||
if pipeline_status.get("cancellation_requested", False):
|
||||
raise PipelineCancelledException(
|
||||
"User cancelled during relation merge"
|
||||
)
|
||||
|
||||
workspace = global_config.get("workspace", "")
|
||||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||||
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
||||
|
|
@ -2566,7 +2625,7 @@ async def merge_nodes_and_edges(
|
|||
return edge_data, added_entities
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Critical error in relationship processing for `{sorted_edge_key}`: {e}"
|
||||
error_msg = f"Error processing relation `{sorted_edge_key}`: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# Try to update pipeline status, but don't let status update failure affect main exception
|
||||
|
|
@ -2604,40 +2663,36 @@ async def merge_nodes_and_edges(
|
|||
edge_tasks, return_when=asyncio.FIRST_EXCEPTION
|
||||
)
|
||||
|
||||
# Check if any task raised an exception and ensure all exceptions are retrieved
|
||||
first_exception = None
|
||||
successful_results = []
|
||||
|
||||
for task in done:
|
||||
try:
|
||||
exception = task.exception()
|
||||
if exception is not None:
|
||||
if first_exception is None:
|
||||
first_exception = exception
|
||||
else:
|
||||
successful_results.append(task.result())
|
||||
except Exception as e:
|
||||
edge_data, added_entities = task.result()
|
||||
except BaseException as e:
|
||||
if first_exception is None:
|
||||
first_exception = e
|
||||
else:
|
||||
if edge_data is not None:
|
||||
processed_edges.append(edge_data)
|
||||
all_added_entities.extend(added_entities)
|
||||
|
||||
if pending:
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
pending_results = await asyncio.gather(*pending, return_exceptions=True)
|
||||
for result in pending_results:
|
||||
if isinstance(result, BaseException):
|
||||
if first_exception is None:
|
||||
first_exception = result
|
||||
else:
|
||||
edge_data, added_entities = result
|
||||
if edge_data is not None:
|
||||
processed_edges.append(edge_data)
|
||||
all_added_entities.extend(added_entities)
|
||||
|
||||
# If any task failed, cancel all pending tasks and raise the first exception
|
||||
if first_exception is not None:
|
||||
# Cancel all pending tasks
|
||||
for pending_task in pending:
|
||||
pending_task.cancel()
|
||||
# Wait for cancellation to complete
|
||||
if pending:
|
||||
await asyncio.wait(pending)
|
||||
# Re-raise the first exception to notify the caller
|
||||
raise first_exception
|
||||
|
||||
# If all tasks completed successfully, collect results
|
||||
for task in edge_tasks:
|
||||
edge_data, added_entities = task.result()
|
||||
if edge_data is not None:
|
||||
processed_edges.append(edge_data)
|
||||
all_added_entities.extend(added_entities)
|
||||
|
||||
# ===== Phase 3: Update full_entities and full_relations storage =====
|
||||
if full_entities_storage and full_relations_storage and doc_id:
|
||||
try:
|
||||
|
|
@ -2718,6 +2773,14 @@ async def extract_entities(
|
|||
llm_response_cache: BaseKVStorage | None = None,
|
||||
text_chunks_storage: BaseKVStorage | None = None,
|
||||
) -> list:
|
||||
# Check for cancellation at the start of entity extraction
|
||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||
async with pipeline_status_lock:
|
||||
if pipeline_status.get("cancellation_requested", False):
|
||||
raise PipelineCancelledException(
|
||||
"User cancelled during entity extraction"
|
||||
)
|
||||
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
|
||||
|
|
@ -2885,6 +2948,14 @@ async def extract_entities(
|
|||
|
||||
async def _process_with_semaphore(chunk):
|
||||
async with semaphore:
|
||||
# Check for cancellation before processing chunk
|
||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||
async with pipeline_status_lock:
|
||||
if pipeline_status.get("cancellation_requested", False):
|
||||
raise PipelineCancelledException(
|
||||
"User cancelled during chunk processing"
|
||||
)
|
||||
|
||||
try:
|
||||
return await _process_single_content(chunk)
|
||||
except Exception as e:
|
||||
|
|
@ -3382,10 +3453,10 @@ async def _perform_kg_search(
|
|||
)
|
||||
query_embedding = None
|
||||
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
|
||||
embedding_func_config = text_chunks_db.embedding_func
|
||||
if embedding_func_config and embedding_func_config.func:
|
||||
actual_embedding_func = text_chunks_db.embedding_func
|
||||
if actual_embedding_func:
|
||||
try:
|
||||
query_embedding = await embedding_func_config.func([query])
|
||||
query_embedding = await actual_embedding_func([query])
|
||||
query_embedding = query_embedding[
|
||||
0
|
||||
] # Extract first embedding from batch result
|
||||
|
|
@ -4293,25 +4364,21 @@ async def _find_related_text_unit_from_entities(
|
|||
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
|
||||
|
||||
# Get embedding function from global config
|
||||
embedding_func_config = text_chunks_db.embedding_func
|
||||
if not embedding_func_config:
|
||||
actual_embedding_func = text_chunks_db.embedding_func
|
||||
if not actual_embedding_func:
|
||||
logger.warning("No embedding function found, falling back to WEIGHT method")
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
else:
|
||||
try:
|
||||
actual_embedding_func = embedding_func_config.func
|
||||
|
||||
selected_chunk_ids = None
|
||||
if actual_embedding_func:
|
||||
selected_chunk_ids = await pick_by_vector_similarity(
|
||||
query=query,
|
||||
text_chunks_storage=text_chunks_db,
|
||||
chunks_vdb=chunks_vdb,
|
||||
num_of_chunks=num_of_chunks,
|
||||
entity_info=entities_with_chunks,
|
||||
embedding_func=actual_embedding_func,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
selected_chunk_ids = await pick_by_vector_similarity(
|
||||
query=query,
|
||||
text_chunks_storage=text_chunks_db,
|
||||
chunks_vdb=chunks_vdb,
|
||||
num_of_chunks=num_of_chunks,
|
||||
entity_info=entities_with_chunks,
|
||||
embedding_func=actual_embedding_func,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
|
||||
if selected_chunk_ids == []:
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
|
|
@ -4586,24 +4653,21 @@ async def _find_related_text_unit_from_relations(
|
|||
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
|
||||
|
||||
# Get embedding function from global config
|
||||
embedding_func_config = text_chunks_db.embedding_func
|
||||
if not embedding_func_config:
|
||||
actual_embedding_func = text_chunks_db.embedding_func
|
||||
if not actual_embedding_func:
|
||||
logger.warning("No embedding function found, falling back to WEIGHT method")
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
else:
|
||||
try:
|
||||
actual_embedding_func = embedding_func_config.func
|
||||
|
||||
if actual_embedding_func:
|
||||
selected_chunk_ids = await pick_by_vector_similarity(
|
||||
query=query,
|
||||
text_chunks_storage=text_chunks_db,
|
||||
chunks_vdb=chunks_vdb,
|
||||
num_of_chunks=num_of_chunks,
|
||||
entity_info=relations_with_chunks,
|
||||
embedding_func=actual_embedding_func,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
selected_chunk_ids = await pick_by_vector_similarity(
|
||||
query=query,
|
||||
text_chunks_storage=text_chunks_db,
|
||||
chunks_vdb=chunks_vdb,
|
||||
num_of_chunks=num_of_chunks,
|
||||
entity_info=relations_with_chunks,
|
||||
embedding_func=actual_embedding_func,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
|
||||
if selected_chunk_ids == []:
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
import weakref
|
||||
|
||||
import sys
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import csv
|
||||
|
|
@ -40,6 +42,35 @@ from lightrag.constants import (
|
|||
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
||||
)
|
||||
|
||||
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
||||
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
||||
|
||||
|
||||
class SafeStreamHandler(logging.StreamHandler):
|
||||
"""StreamHandler that gracefully handles closed streams during shutdown.
|
||||
|
||||
This handler prevents "ValueError: I/O operation on closed file" errors
|
||||
that can occur when pytest or other test frameworks close stdout/stderr
|
||||
before Python's logging cleanup runs.
|
||||
"""
|
||||
|
||||
def flush(self):
|
||||
"""Flush the stream, ignoring errors if the stream is closed."""
|
||||
try:
|
||||
super().flush()
|
||||
except (ValueError, OSError):
|
||||
# Stream is closed or otherwise unavailable, silently ignore
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
"""Close the handler, ignoring errors if the stream is already closed."""
|
||||
try:
|
||||
super().close()
|
||||
except (ValueError, OSError):
|
||||
# Stream is closed or otherwise unavailable, silently ignore
|
||||
pass
|
||||
|
||||
|
||||
# Initialize logger with basic configuration
|
||||
logger = logging.getLogger("lightrag")
|
||||
logger.propagate = False # prevent log message send to root logger
|
||||
|
|
@ -47,7 +78,7 @@ logger.setLevel(logging.INFO)
|
|||
|
||||
# Add console handler if no handlers exist
|
||||
if not logger.handlers:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler = SafeStreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||
console_handler.setFormatter(formatter)
|
||||
|
|
@ -56,8 +87,32 @@ if not logger.handlers:
|
|||
# Set httpx logging level to WARNING
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
||||
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
||||
|
||||
def _patch_ascii_colors_console_handler() -> None:
|
||||
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
|
||||
|
||||
try:
|
||||
from ascii_colors import ConsoleHandler
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
if getattr(ConsoleHandler, "_lightrag_patched", False):
|
||||
return
|
||||
|
||||
original_handle_error = ConsoleHandler.handle_error
|
||||
|
||||
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
|
||||
exc_type, _, _ = sys.exc_info()
|
||||
if exc_type in (ValueError, OSError) and "close" in message.lower():
|
||||
return
|
||||
original_handle_error(self, message)
|
||||
|
||||
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
|
||||
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_patch_ascii_colors_console_handler()
|
||||
|
||||
|
||||
# Global import for pypinyin with startup-time logging
|
||||
try:
|
||||
|
|
@ -286,8 +341,8 @@ def setup_logger(
|
|||
logger_instance.handlers = [] # Clear existing handlers
|
||||
logger_instance.propagate = False
|
||||
|
||||
# Add console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
# Add console handler with safe stream handling
|
||||
console_handler = SafeStreamHandler()
|
||||
console_handler.setFormatter(simple_formatter)
|
||||
console_handler.setLevel(level)
|
||||
logger_instance.addHandler(console_handler)
|
||||
|
|
@ -950,7 +1005,76 @@ def priority_limit_async_func_call(
|
|||
|
||||
|
||||
def wrap_embedding_func_with_attrs(**kwargs):
|
||||
"""Wrap a function with attributes"""
|
||||
"""Decorator to add embedding dimension and token limit attributes to embedding functions.
|
||||
|
||||
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
|
||||
that automatically handles dimension parameter injection and attribute management.
|
||||
|
||||
WARNING: DO NOT apply this decorator to wrapper functions that call other
|
||||
decorated embedding functions. This will cause double decoration and parameter
|
||||
injection conflicts.
|
||||
|
||||
Correct usage patterns:
|
||||
|
||||
1. Direct implementation (decorated):
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
async def my_embed(texts, embedding_dim=None):
|
||||
# Direct implementation
|
||||
return embeddings
|
||||
```
|
||||
|
||||
2. Wrapper calling decorated function (DO NOT decorate wrapper):
|
||||
```python
|
||||
# my_embed is already decorated above
|
||||
|
||||
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
|
||||
# Must call .func to access unwrapped implementation
|
||||
return await my_embed.func(texts, **kwargs)
|
||||
```
|
||||
|
||||
3. Wrapper calling decorated function (properly decorated):
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
|
||||
# Calling .func avoids double decoration
|
||||
return await my_embed.func(texts, **kwargs)
|
||||
```
|
||||
|
||||
The decorated function becomes an EmbeddingFunc instance with:
|
||||
- embedding_dim: The embedding dimension
|
||||
- max_token_size: Maximum token limit (optional)
|
||||
- func: The original unwrapped function (access via .func)
|
||||
- __call__: Wrapper that injects embedding_dim parameter
|
||||
|
||||
Double decoration causes:
|
||||
- Double injection of embedding_dim parameter
|
||||
- Incorrect parameter passing to the underlying implementation
|
||||
- Runtime errors due to parameter conflicts
|
||||
|
||||
Args:
|
||||
embedding_dim: The dimension of embedding vectors
|
||||
max_token_size: Maximum number of tokens (optional)
|
||||
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
|
||||
|
||||
Returns:
|
||||
A decorator that wraps the function as an EmbeddingFunc instance
|
||||
|
||||
Example of correct wrapper implementation:
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(...)
|
||||
async def openai_embed(texts, ...):
|
||||
# Base implementation
|
||||
pass
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
|
||||
async def azure_openai_embed(texts, ...):
|
||||
# CRITICAL: Call .func to access unwrapped function
|
||||
return await openai_embed.func(texts, ...) # ✅ Correct
|
||||
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
|
||||
```
|
||||
"""
|
||||
|
||||
def final_decro(func) -> EmbeddingFunc:
|
||||
new_func = EmbeddingFunc(**kwargs, func=func)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue