feat (token_tracking): added tracking token to both query and insert endpoints --and consequently pipeline
This commit is contained in:
parent
cd664de057
commit
3a2d3ddb9f
11 changed files with 588 additions and 104 deletions
|
|
@ -397,6 +397,15 @@ def parse_args() -> argparse.Namespace:
|
|||
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
||||
)
|
||||
|
||||
# Token tracking configuration
|
||||
parser.add_argument(
|
||||
"--enable-token-tracking",
|
||||
action="store_true",
|
||||
default=get_env_value("ENABLE_TOKEN_TRACKING", False, bool),
|
||||
help="Enable token usage tracking for LLM calls (default: from env or False)",
|
||||
)
|
||||
args.enable_token_tracking = get_env_value("ENABLE_TOKEN_TRACKING", False, bool)
|
||||
|
||||
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
|
||||
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
|
||||
|
||||
|
|
|
|||
|
|
@ -301,7 +301,10 @@ def create_app(args):
|
|||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def create_optimized_openai_llm_func(
|
||||
config_cache: LLMConfigCache, args, llm_timeout: int
|
||||
config_cache: LLMConfigCache,
|
||||
args,
|
||||
llm_timeout: int,
|
||||
enable_token_tracking=False,
|
||||
):
|
||||
"""Create optimized OpenAI LLM function with pre-processed configuration"""
|
||||
|
||||
|
|
@ -332,13 +335,19 @@ def create_app(args):
|
|||
history_messages=history_messages,
|
||||
base_url=args.llm_binding_host,
|
||||
api_key=args.llm_binding_api_key,
|
||||
token_tracker=getattr(app.state, "token_tracker", None)
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||
else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return optimized_openai_alike_model_complete
|
||||
|
||||
def create_optimized_azure_openai_llm_func(
|
||||
config_cache: LLMConfigCache, args, llm_timeout: int
|
||||
config_cache: LLMConfigCache,
|
||||
args,
|
||||
llm_timeout: int,
|
||||
enable_token_tracking=False,
|
||||
):
|
||||
"""Create optimized Azure OpenAI LLM function with pre-processed configuration"""
|
||||
|
||||
|
|
@ -359,8 +368,8 @@ def create_app(args):
|
|||
|
||||
# Use pre-processed configuration to avoid repeated parsing
|
||||
kwargs["timeout"] = llm_timeout
|
||||
if config_cache.openai_llm_options:
|
||||
kwargs.update(config_cache.openai_llm_options)
|
||||
if config_cache.azure_openai_llm_options:
|
||||
kwargs.update(config_cache.azure_openai_llm_options)
|
||||
|
||||
return await azure_openai_complete_if_cache(
|
||||
args.llm_model,
|
||||
|
|
@ -370,12 +379,15 @@ def create_app(args):
|
|||
base_url=args.llm_binding_host,
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY", args.llm_binding_api_key),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
|
||||
token_tracker=getattr(app.state, "token_tracker", None)
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||
else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return optimized_azure_openai_model_complete
|
||||
|
||||
def create_llm_model_func(binding: str):
|
||||
def create_llm_model_func(binding: str, enable_token_tracking=False):
|
||||
"""
|
||||
Create LLM model function based on binding type.
|
||||
Uses optimized functions for OpenAI bindings and lazy import for others.
|
||||
|
|
@ -384,21 +396,42 @@ def create_app(args):
|
|||
if binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_model_complete
|
||||
|
||||
return lollms_model_complete
|
||||
async def lollms_model_complete_with_tracker(*args, **kwargs):
|
||||
# Add token tracker if enabled
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker"):
|
||||
kwargs["token_tracker"] = app.state.token_tracker
|
||||
return await lollms_model_complete(*args, **kwargs)
|
||||
|
||||
return lollms_model_complete_with_tracker
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_model_complete
|
||||
|
||||
return ollama_model_complete
|
||||
async def ollama_model_complete_with_tracker(*args, **kwargs):
|
||||
# Add token tracker if enabled
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker"):
|
||||
kwargs["token_tracker"] = app.state.token_tracker
|
||||
return await ollama_model_complete(*args, **kwargs)
|
||||
|
||||
return ollama_model_complete_with_tracker
|
||||
elif binding == "aws_bedrock":
|
||||
return bedrock_model_complete # Already defined locally
|
||||
|
||||
async def bedrock_model_complete_with_tracker(*args, **kwargs):
|
||||
# Add token tracker if enabled
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker"):
|
||||
kwargs["token_tracker"] = app.state.token_tracker
|
||||
return await bedrock_model_complete(*args, **kwargs)
|
||||
|
||||
return bedrock_model_complete_with_tracker
|
||||
elif binding == "azure_openai":
|
||||
# Use optimized function with pre-processed configuration
|
||||
return create_optimized_azure_openai_llm_func(
|
||||
config_cache, args, llm_timeout
|
||||
config_cache, args, llm_timeout, enable_token_tracking
|
||||
)
|
||||
else: # openai and compatible
|
||||
# Use optimized function with pre-processed configuration
|
||||
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
||||
return create_optimized_openai_llm_func(
|
||||
config_cache, args, llm_timeout, enable_token_tracking
|
||||
)
|
||||
except ImportError as e:
|
||||
raise Exception(f"Failed to import {binding} LLM binding: {e}")
|
||||
|
||||
|
|
@ -422,7 +455,14 @@ def create_app(args):
|
|||
return {}
|
||||
|
||||
def create_optimized_embedding_function(
|
||||
config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args
|
||||
config_cache: LLMConfigCache,
|
||||
binding,
|
||||
model,
|
||||
host,
|
||||
api_key,
|
||||
dimensions,
|
||||
args,
|
||||
enable_token_tracking=False,
|
||||
):
|
||||
"""
|
||||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||||
|
|
@ -435,7 +475,13 @@ def create_app(args):
|
|||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
return await lollms_embed(
|
||||
texts, embed_model=model, host=host, api_key=api_key
|
||||
texts,
|
||||
embed_model=model,
|
||||
host=host,
|
||||
api_key=api_key,
|
||||
token_tracker=getattr(app.state, "token_tracker", None)
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||
else None,
|
||||
)
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
|
@ -455,26 +501,54 @@ def create_app(args):
|
|||
host=host,
|
||||
api_key=api_key,
|
||||
options=ollama_options,
|
||||
token_tracker=getattr(app.state, "token_tracker", None)
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||
else None,
|
||||
)
|
||||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
||||
return await azure_openai_embed(
|
||||
texts,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
token_tracker=getattr(app.state, "token_tracker", None)
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||
else None,
|
||||
)
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
return await bedrock_embed(texts, model=model)
|
||||
return await bedrock_embed(
|
||||
texts,
|
||||
model=model,
|
||||
token_tracker=getattr(app.state, "token_tracker", None)
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||
else None,
|
||||
)
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
return await jina_embed(
|
||||
texts, dimensions=dimensions, base_url=host, api_key=api_key
|
||||
texts,
|
||||
dimensions=dimensions,
|
||||
base_url=host,
|
||||
api_key=api_key,
|
||||
token_tracker=getattr(app.state, "token_tracker", None)
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||
else None,
|
||||
)
|
||||
else: # openai and compatible
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
return await openai_embed(
|
||||
texts, model=model, base_url=host, api_key=api_key
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
api_key=api_key,
|
||||
token_tracker=getattr(app.state, "token_tracker", None)
|
||||
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||
else None,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||
|
|
@ -594,7 +668,9 @@ def create_app(args):
|
|||
rag = LightRAG(
|
||||
working_dir=args.working_dir,
|
||||
workspace=args.workspace,
|
||||
llm_model_func=create_llm_model_func(args.llm_binding),
|
||||
llm_model_func=create_llm_model_func(
|
||||
args.llm_binding, args.enable_token_tracking
|
||||
),
|
||||
llm_model_name=args.llm_model,
|
||||
llm_model_max_async=args.max_async,
|
||||
summary_max_tokens=args.summary_max_tokens,
|
||||
|
|
@ -604,7 +680,16 @@ def create_app(args):
|
|||
llm_model_kwargs=create_llm_model_kwargs(
|
||||
args.llm_binding, args, llm_timeout
|
||||
),
|
||||
embedding_func=embedding_func,
|
||||
embedding_func=create_optimized_embedding_function(
|
||||
config_cache,
|
||||
args.embedding_binding,
|
||||
args.embedding_model,
|
||||
args.embedding_binding_host,
|
||||
args.embedding_binding_api_key,
|
||||
args.embedding_dim,
|
||||
args,
|
||||
args.enable_token_tracking,
|
||||
),
|
||||
default_llm_timeout=llm_timeout,
|
||||
default_embedding_timeout=embedding_timeout,
|
||||
kv_storage=args.kv_storage,
|
||||
|
|
@ -629,6 +714,17 @@ def create_app(args):
|
|||
logger.error(f"Failed to initialize LightRAG: {e}")
|
||||
raise
|
||||
|
||||
# Initialize token tracking if enabled
|
||||
token_tracker = None
|
||||
if args.enable_token_tracking:
|
||||
from lightrag.utils import TokenTracker
|
||||
|
||||
token_tracker = TokenTracker()
|
||||
logger.info("Token tracking enabled")
|
||||
|
||||
# Add token tracker to the app state for use in endpoints
|
||||
app.state.token_tracker = token_tracker
|
||||
|
||||
# Add routes
|
||||
app.include_router(
|
||||
create_document_routes(
|
||||
|
|
@ -637,7 +733,7 @@ def create_app(args):
|
|||
api_key,
|
||||
)
|
||||
)
|
||||
app.include_router(create_query_routes(rag, api_key, args.top_k))
|
||||
app.include_router(create_query_routes(rag, api_key, args.top_k, token_tracker))
|
||||
app.include_router(create_graph_routes(rag, api_key))
|
||||
|
||||
# Add Ollama API routes
|
||||
|
|
|
|||
|
|
@ -243,6 +243,7 @@ class InsertResponse(BaseModel):
|
|||
status: Status of the operation (success, duplicated, partial_success, failure)
|
||||
message: Detailed message describing the operation result
|
||||
track_id: Tracking ID for monitoring processing status
|
||||
token_usage: Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count)
|
||||
"""
|
||||
|
||||
status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
|
||||
|
|
@ -250,6 +251,10 @@ class InsertResponse(BaseModel):
|
|||
)
|
||||
message: str = Field(description="Message describing the operation result")
|
||||
track_id: str = Field(description="Tracking ID for monitoring processing status")
|
||||
token_usage: Optional[Dict[str, int]] = Field(
|
||||
default=None,
|
||||
description="Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count)",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
|
|
@ -257,10 +262,17 @@ class InsertResponse(BaseModel):
|
|||
"status": "success",
|
||||
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
|
||||
"track_id": "upload_20250729_170612_abc123",
|
||||
"token_usage": {
|
||||
"prompt_tokens": 1250,
|
||||
"completion_tokens": 450,
|
||||
"total_tokens": 1700,
|
||||
"call_count": 3,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
class ClearDocumentsResponse(BaseModel):
|
||||
"""Response model for document clearing operation
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import json
|
|||
import logging
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from lightrag.base import QueryParam
|
||||
from lightrag.types import MetadataFilter
|
||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||
|
|
@ -157,6 +157,10 @@ class QueryResponse(BaseModel):
|
|||
default=None,
|
||||
description="Reference list (Disabled when include_references=False, /query/data always includes references.)",
|
||||
)
|
||||
token_usage: Optional[Dict[str, int]] = Field(
|
||||
default=None,
|
||||
description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)",
|
||||
)
|
||||
|
||||
|
||||
class QueryDataResponse(BaseModel):
|
||||
|
|
@ -168,6 +172,10 @@ class QueryDataResponse(BaseModel):
|
|||
metadata: Dict[str, Any] = Field(
|
||||
description="Query metadata including mode, keywords, and processing information"
|
||||
)
|
||||
token_usage: Optional[Dict[str, int]] = Field(
|
||||
default=None,
|
||||
description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)",
|
||||
)
|
||||
|
||||
|
||||
class StreamChunkResponse(BaseModel):
|
||||
|
|
@ -183,9 +191,18 @@ class StreamChunkResponse(BaseModel):
|
|||
error: Optional[str] = Field(
|
||||
default=None, description="Error message if processing fails"
|
||||
)
|
||||
token_usage: Optional[Dict[str, int]] = Field(
|
||||
default=None,
|
||||
description="Token usage statistics for the entire query (only in final chunk)",
|
||||
)
|
||||
|
||||
|
||||
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||
def create_query_routes(
|
||||
rag,
|
||||
api_key: Optional[str] = None,
|
||||
top_k: int = 60,
|
||||
token_tracker: Optional[Any] = None,
|
||||
):
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
@router.post(
|
||||
|
|
@ -220,7 +237,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
},
|
||||
"examples": {
|
||||
"with_references": {
|
||||
"summary": "Response with references",
|
||||
"summary": "Response with references and token usage",
|
||||
"description": "Example response when include_references=True",
|
||||
"value": {
|
||||
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.",
|
||||
|
|
@ -234,13 +251,25 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
"file_path": "/documents/machine_learning.txt",
|
||||
},
|
||||
],
|
||||
"token_usage": {
|
||||
"prompt_tokens": 245,
|
||||
"completion_tokens": 87,
|
||||
"total_tokens": 332,
|
||||
"call_count": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"without_references": {
|
||||
"summary": "Response without references",
|
||||
"summary": "Response without references but with token usage",
|
||||
"description": "Example response when include_references=False",
|
||||
"value": {
|
||||
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving."
|
||||
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.",
|
||||
"token_usage": {
|
||||
"prompt_tokens": 245,
|
||||
"completion_tokens": 87,
|
||||
"total_tokens": 332,
|
||||
"call_count": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"different_modes": {
|
||||
|
|
@ -358,6 +387,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
- 500: Internal processing error (e.g., LLM service unavailable)
|
||||
"""
|
||||
try:
|
||||
# Reset token tracker at start of query if available
|
||||
if token_tracker:
|
||||
token_tracker.reset()
|
||||
|
||||
param = request.to_query_params(
|
||||
False
|
||||
) # Ensure stream=False for non-streaming endpoint
|
||||
|
|
@ -376,11 +409,22 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
if not response_content:
|
||||
response_content = "No relevant context found for the query."
|
||||
|
||||
# Get token usage if available
|
||||
token_usage = None
|
||||
if token_tracker:
|
||||
token_usage = token_tracker.get_usage()
|
||||
|
||||
# Return response with or without references based on request
|
||||
if request.include_references:
|
||||
return QueryResponse(response=response_content, references=references)
|
||||
return QueryResponse(
|
||||
response=response_content,
|
||||
references=references,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
else:
|
||||
return QueryResponse(response=response_content, references=None)
|
||||
return QueryResponse(
|
||||
response=response_content, references=None, token_usage=token_usage
|
||||
)
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -577,6 +621,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
Use streaming mode for real-time interfaces and non-streaming for batch processing.
|
||||
"""
|
||||
try:
|
||||
# Reset token tracker at start of query if available
|
||||
if token_tracker:
|
||||
token_tracker.reset()
|
||||
|
||||
# Use the stream parameter from the request, defaulting to True if not specified
|
||||
stream_mode = request.stream if request.stream is not None else True
|
||||
param = request.to_query_params(stream_mode)
|
||||
|
|
@ -605,6 +653,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
except Exception as e:
|
||||
logging.error(f"Streaming error: {str(e)}")
|
||||
yield f"{json.dumps({'error': str(e)})}\n"
|
||||
|
||||
# Add final token usage chunk if streaming and token tracker is available
|
||||
if token_tracker and llm_response.get("is_streaming"):
|
||||
yield f"{json.dumps({'token_usage': token_tracker.get_usage()})}\n"
|
||||
else:
|
||||
# Non-streaming mode: send complete response in one message
|
||||
response_content = llm_response.get("content", "")
|
||||
|
|
@ -616,6 +668,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
if request.include_references:
|
||||
complete_response["references"] = references
|
||||
|
||||
# Add token usage if available
|
||||
if token_tracker:
|
||||
complete_response["token_usage"] = token_tracker.get_usage()
|
||||
|
||||
yield f"{json.dumps(complete_response)}\n"
|
||||
|
||||
return StreamingResponse(
|
||||
|
|
@ -1022,18 +1078,31 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
as structured data analysis typically requires source attribution.
|
||||
"""
|
||||
try:
|
||||
# Reset token tracker at start of query if available
|
||||
if token_tracker:
|
||||
token_tracker.reset()
|
||||
|
||||
param = request.to_query_params(False) # No streaming for data endpoint
|
||||
response = await rag.aquery_data(request.query, param=param)
|
||||
|
||||
# Get token usage if available
|
||||
token_usage = None
|
||||
if token_tracker:
|
||||
token_usage = token_tracker.get_usage()
|
||||
|
||||
# aquery_data returns the new format with status, message, data, and metadata
|
||||
if isinstance(response, dict):
|
||||
return QueryDataResponse(**response)
|
||||
response_dict = dict(response)
|
||||
response_dict["token_usage"] = token_usage
|
||||
return QueryDataResponse(**response_dict)
|
||||
else:
|
||||
# Handle unexpected response format
|
||||
return QueryDataResponse(
|
||||
status="failure",
|
||||
message="Invalid response type",
|
||||
message="Unexpected response format",
|
||||
data={},
|
||||
metadata={},
|
||||
token_usage=None,
|
||||
)
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
|
|
|
|||
|
|
@ -870,7 +870,8 @@ class LightRAG:
|
|||
ids: str | list[str] | None = None,
|
||||
file_paths: str | list[str] | None = None,
|
||||
track_id: str | None = None,
|
||||
) -> str:
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
"""Sync Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
|
|
@ -884,10 +885,10 @@ class LightRAG:
|
|||
track_id: tracking ID for monitoring processing status, if not provided, will be generated
|
||||
|
||||
Returns:
|
||||
str: tracking ID for monitoring processing status
|
||||
tuple[str, dict]: (tracking ID for monitoring processing status, token usage statistics)
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
result = loop.run_until_complete(
|
||||
self.ainsert(
|
||||
input,
|
||||
split_by_character,
|
||||
|
|
@ -895,8 +896,10 @@ class LightRAG:
|
|||
ids,
|
||||
file_paths,
|
||||
track_id,
|
||||
token_tracker,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
async def ainsert(
|
||||
self,
|
||||
|
|
@ -906,7 +909,8 @@ class LightRAG:
|
|||
ids: str | list[str] | None = None,
|
||||
file_paths: str | list[str] | None = None,
|
||||
track_id: str | None = None,
|
||||
) -> str:
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
"""Async Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
|
|
@ -930,9 +934,15 @@ class LightRAG:
|
|||
await self.apipeline_process_enqueue_documents(
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
token_tracker,
|
||||
)
|
||||
|
||||
return track_id
|
||||
return track_id, token_tracker.get_usage() if token_tracker else {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"call_count": 0,
|
||||
}
|
||||
|
||||
# TODO: deprecated, use insert instead
|
||||
def insert_custom_chunks(
|
||||
|
|
@ -1107,7 +1117,9 @@ class LightRAG:
|
|||
"file_path"
|
||||
], # Store file path in document status
|
||||
"track_id": track_id, # Store track_id in document status
|
||||
"metadata": metadata[i] if isinstance(metadata, list) and i < len(metadata) else metadata, # added provided custom metadata per document
|
||||
"metadata": metadata[i]
|
||||
if isinstance(metadata, list) and i < len(metadata)
|
||||
else metadata, # added provided custom metadata per document
|
||||
}
|
||||
for i, (id_, content_data) in enumerate(contents.items())
|
||||
}
|
||||
|
|
@ -1363,6 +1375,7 @@ class LightRAG:
|
|||
self,
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Process pending documents by splitting them into chunks, processing
|
||||
|
|
@ -1484,9 +1497,12 @@ class LightRAG:
|
|||
pipeline_status: dict,
|
||||
pipeline_status_lock: asyncio.Lock,
|
||||
semaphore: asyncio.Semaphore,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> None:
|
||||
"""Process single document"""
|
||||
doc_metadata = getattr(status_doc, "metadata", None)
|
||||
doc_metadata = getattr(status_doc, "metadata", None)
|
||||
if doc_metadata is None:
|
||||
doc_metadata = {}
|
||||
file_extraction_stage_ok = False
|
||||
async with semaphore:
|
||||
nonlocal processed_count
|
||||
|
|
@ -1498,7 +1514,6 @@ class LightRAG:
|
|||
# Get file path from status document
|
||||
file_path = getattr(
|
||||
status_doc, "file_path", "unknown_source"
|
||||
|
||||
)
|
||||
|
||||
async with pipeline_status_lock:
|
||||
|
|
@ -1561,7 +1576,9 @@ class LightRAG:
|
|||
|
||||
# Process document in two stages
|
||||
# Stage 1: Process text chunks and docs (parallel execution)
|
||||
doc_metadata["processing_start_time"] = processing_start_time
|
||||
doc_metadata["processing_start_time"] = (
|
||||
processing_start_time
|
||||
)
|
||||
|
||||
doc_status_task = asyncio.create_task(
|
||||
self.doc_status.upsert(
|
||||
|
|
@ -1610,6 +1627,7 @@ class LightRAG:
|
|||
doc_metadata,
|
||||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
token_tracker,
|
||||
)
|
||||
)
|
||||
await entity_relation_task
|
||||
|
|
@ -1643,7 +1661,9 @@ class LightRAG:
|
|||
processing_end_time = int(time.time())
|
||||
|
||||
# Update document status to failed
|
||||
doc_metadata["processing_start_time"] = processing_start_time
|
||||
doc_metadata["processing_start_time"] = (
|
||||
processing_start_time
|
||||
)
|
||||
doc_metadata["processing_end_time"] = processing_end_time
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
|
|
@ -1691,7 +1711,9 @@ class LightRAG:
|
|||
doc_metadata["processing_start_time"] = (
|
||||
processing_start_time
|
||||
)
|
||||
doc_metadata["processing_end_time"] = processing_end_time
|
||||
doc_metadata["processing_end_time"] = (
|
||||
processing_end_time
|
||||
)
|
||||
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
|
|
@ -1748,7 +1770,9 @@ class LightRAG:
|
|||
doc_metadata["processing_start_time"] = (
|
||||
processing_start_time
|
||||
)
|
||||
doc_metadata["processing_end_time"] = processing_end_time
|
||||
doc_metadata["processing_end_time"] = (
|
||||
processing_end_time
|
||||
)
|
||||
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
|
|
@ -1778,6 +1802,7 @@ class LightRAG:
|
|||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
semaphore,
|
||||
token_tracker,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1827,6 +1852,7 @@ class LightRAG:
|
|||
metadata: dict | None,
|
||||
pipeline_status=None,
|
||||
pipeline_status_lock=None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> list:
|
||||
try:
|
||||
chunk_results = await extract_entities(
|
||||
|
|
@ -1837,6 +1863,7 @@ class LightRAG:
|
|||
pipeline_status_lock=pipeline_status_lock,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
text_chunks_storage=self.text_chunks,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
return chunk_results
|
||||
except Exception as e:
|
||||
|
|
@ -2061,14 +2088,18 @@ class LightRAG:
|
|||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
token_tracker: TokenTracker | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | Iterator[str]:
|
||||
) -> Any:
|
||||
"""
|
||||
Perform a sync query.
|
||||
User query interface (backward compatibility wrapper).
|
||||
|
||||
Delegates to aquery() for asynchronous execution and returns the result.
|
||||
|
||||
Args:
|
||||
query (str): The query to be executed.
|
||||
param (QueryParam): Configuration parameters for query execution.
|
||||
token_tracker (TokenTracker | None): Optional token tracker for monitoring usage.
|
||||
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||
|
||||
Returns:
|
||||
|
|
@ -2076,14 +2107,17 @@ class LightRAG:
|
|||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
|
||||
return loop.run_until_complete(self.aquery(query, param, system_prompt)) # type: ignore
|
||||
return loop.run_until_complete(
|
||||
self.aquery(query, param, token_tracker, system_prompt)
|
||||
) # type: ignore
|
||||
|
||||
async def aquery(
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
token_tracker: TokenTracker | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
) -> Any:
|
||||
"""
|
||||
Perform a async query (backward compatibility wrapper).
|
||||
|
||||
|
|
@ -2102,7 +2136,7 @@ class LightRAG:
|
|||
- Streaming: Returns AsyncIterator[str]
|
||||
"""
|
||||
# Call the new aquery_llm function to get complete results
|
||||
result = await self.aquery_llm(query, param, system_prompt)
|
||||
result = await self.aquery_llm(query, param, system_prompt, token_tracker)
|
||||
|
||||
# Extract and return only the LLM response for backward compatibility
|
||||
llm_response = result.get("llm_response", {})
|
||||
|
|
@ -2137,6 +2171,7 @@ class LightRAG:
|
|||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Asynchronous data retrieval API: returns structured retrieval results without LLM generation.
|
||||
|
|
@ -2330,6 +2365,7 @@ class LightRAG:
|
|||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
system_prompt: str | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Asynchronous complete query API: returns structured retrieval results with LLM generation.
|
||||
|
|
@ -2364,6 +2400,7 @@ class LightRAG:
|
|||
hashing_kv=self.llm_response_cache,
|
||||
system_prompt=system_prompt,
|
||||
chunks_vdb=self.chunks_vdb,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
query_result = await naive_query(
|
||||
|
|
@ -2373,6 +2410,7 @@ class LightRAG:
|
|||
global_config,
|
||||
hashing_kv=self.llm_response_cache,
|
||||
system_prompt=system_prompt,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
elif param.mode == "bypass":
|
||||
# Bypass mode: directly use LLM without knowledge retrieval
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ async def azure_openai_complete_if_cache(
|
|||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_version: str | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if enable_cot:
|
||||
|
|
@ -94,28 +95,73 @@ async def azure_openai_complete_if_cache(
|
|||
)
|
||||
|
||||
if hasattr(response, "__aiter__"):
|
||||
final_chunk_usage = None
|
||||
accumulated_response = ""
|
||||
|
||||
async def inner():
|
||||
async for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
content = chunk.choices[0].delta.content
|
||||
if content is None:
|
||||
continue
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
yield content
|
||||
nonlocal final_chunk_usage, accumulated_response
|
||||
try:
|
||||
async for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
content = chunk.choices[0].delta.content
|
||||
if content is None:
|
||||
continue
|
||||
accumulated_response += content
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
yield content
|
||||
|
||||
# Check for usage in the last chunk
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
final_chunk_usage = chunk.usage
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Azure OpenAI stream response: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# After streaming is complete, track token usage
|
||||
if token_tracker and final_chunk_usage:
|
||||
# Use actual usage from the API
|
||||
token_counts = {
|
||||
"prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(
|
||||
final_chunk_usage, "completion_tokens", 0
|
||||
),
|
||||
"total_tokens": getattr(final_chunk_usage, "total_tokens", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
logger.debug(f"Azure OpenAI streaming token usage: {token_counts}")
|
||||
elif token_tracker:
|
||||
logger.debug(
|
||||
"No usage information available in Azure OpenAI streaming response"
|
||||
)
|
||||
|
||||
return inner()
|
||||
else:
|
||||
content = response.choices[0].message.content
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
|
||||
# Track token usage for non-streaming response
|
||||
if token_tracker and hasattr(response, "usage"):
|
||||
token_counts = {
|
||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(response.usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
logger.debug(f"Azure OpenAI non-streaming token usage: {token_counts}")
|
||||
|
||||
return content
|
||||
|
||||
|
||||
async def azure_openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
keyword_extraction=False,
|
||||
token_tracker=None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
result = await azure_openai_complete_if_cache(
|
||||
|
|
@ -123,6 +169,7 @@ async def azure_openai_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
token_tracker=token_tracker,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
|
@ -142,6 +189,7 @@ async def azure_openai_embed(
|
|||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_version: str | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
) -> np.ndarray:
|
||||
deployment = (
|
||||
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
|
||||
|
|
@ -174,4 +222,14 @@ async def azure_openai_embed(
|
|||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
)
|
||||
|
||||
# Track token usage for embeddings if token tracker is provided
|
||||
if token_tracker and hasattr(response, "usage"):
|
||||
token_counts = {
|
||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
logger.debug(f"Azure OpenAI embedding token usage: {token_counts}")
|
||||
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ async def bedrock_complete_if_cache(
|
|||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
token_tracker=None,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if enable_cot:
|
||||
|
|
@ -155,6 +156,18 @@ async def bedrock_complete_if_cache(
|
|||
yield text
|
||||
# Handle other event types that might indicate stream end
|
||||
elif "messageStop" in event:
|
||||
# Track token usage for streaming if token tracker is provided
|
||||
if token_tracker and "usage" in event:
|
||||
usage = event["usage"]
|
||||
token_counts = {
|
||||
"prompt_tokens": usage.get("inputTokens", 0),
|
||||
"completion_tokens": usage.get("outputTokens", 0),
|
||||
"total_tokens": usage.get("totalTokens", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
logging.debug(
|
||||
f"Bedrock streaming token usage: {token_counts}"
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -228,6 +241,17 @@ async def bedrock_complete_if_cache(
|
|||
if not content or content.strip() == "":
|
||||
raise BedrockError("Received empty content from Bedrock API")
|
||||
|
||||
# Track token usage for non-streaming if token tracker is provided
|
||||
if token_tracker and "usage" in response:
|
||||
usage = response["usage"]
|
||||
token_counts = {
|
||||
"prompt_tokens": usage.get("inputTokens", 0),
|
||||
"completion_tokens": usage.get("outputTokens", 0),
|
||||
"total_tokens": usage.get("totalTokens", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
logging.debug(f"Bedrock non-streaming token usage: {token_counts}")
|
||||
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -239,7 +263,12 @@ async def bedrock_complete_if_cache(
|
|||
|
||||
# Generic Bedrock completion function
|
||||
async def bedrock_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
keyword_extraction=False,
|
||||
token_tracker=None,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
|
|
@ -248,6 +277,7 @@ async def bedrock_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
token_tracker=token_tracker,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
|
@ -265,6 +295,7 @@ async def bedrock_embed(
|
|||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
token_tracker=None,
|
||||
) -> np.ndarray:
|
||||
# Respect existing env; only set if a non-empty value is available
|
||||
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ async def lollms_model_complete(
|
|||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
token_tracker=None,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
"""Complete function for lollms model generation."""
|
||||
|
|
@ -135,7 +136,11 @@ async def lollms_model_complete(
|
|||
|
||||
|
||||
async def lollms_embed(
|
||||
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
||||
texts: List[str],
|
||||
embed_model=None,
|
||||
base_url="http://localhost:9600",
|
||||
token_tracker=None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of texts using lollms server.
|
||||
|
|
@ -144,6 +149,7 @@ async def lollms_embed(
|
|||
texts: List of strings to embed
|
||||
embed_model: Model name (not used directly as lollms uses configured vectorizer)
|
||||
base_url: URL of the lollms server
|
||||
token_tracker: Optional token usage tracker for monitoring API usage
|
||||
**kwargs: Additional arguments passed to the request
|
||||
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ async def _ollama_model_if_cache(
|
|||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
token_tracker=None,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if enable_cot:
|
||||
|
|
@ -74,13 +75,47 @@ async def _ollama_model_if_cache(
|
|||
"""cannot cache stream response and process reasoning"""
|
||||
|
||||
async def inner():
|
||||
accumulated_response = ""
|
||||
try:
|
||||
async for chunk in response:
|
||||
yield chunk["message"]["content"]
|
||||
chunk_content = chunk["message"]["content"]
|
||||
accumulated_response += chunk_content
|
||||
yield chunk_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream response: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# Track token usage for streaming if token tracker is provided
|
||||
if token_tracker:
|
||||
# Estimate prompt tokens: roughly 4 characters per token for English text
|
||||
prompt_text = ""
|
||||
if system_prompt:
|
||||
prompt_text += system_prompt + " "
|
||||
prompt_text += (
|
||||
" ".join(
|
||||
[msg.get("content", "") for msg in history_messages]
|
||||
)
|
||||
+ " "
|
||||
)
|
||||
prompt_text += prompt
|
||||
prompt_tokens = len(prompt_text) // 4 + (
|
||||
1 if len(prompt_text) % 4 else 0
|
||||
)
|
||||
|
||||
# Estimate completion tokens from accumulated response
|
||||
completion_tokens = len(accumulated_response) // 4 + (
|
||||
1 if len(accumulated_response) % 4 else 0
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
token_tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await ollama_client._client.aclose()
|
||||
logger.debug("Successfully closed Ollama client for streaming")
|
||||
|
|
@ -91,6 +126,35 @@ async def _ollama_model_if_cache(
|
|||
else:
|
||||
model_response = response["message"]["content"]
|
||||
|
||||
# Track token usage if token tracker is provided
|
||||
# Note: Ollama doesn't provide token usage in chat responses, so we estimate
|
||||
if token_tracker:
|
||||
# Estimate prompt tokens: roughly 4 characters per token for English text
|
||||
prompt_text = ""
|
||||
if system_prompt:
|
||||
prompt_text += system_prompt + " "
|
||||
prompt_text += (
|
||||
" ".join([msg.get("content", "") for msg in history_messages]) + " "
|
||||
)
|
||||
prompt_text += prompt
|
||||
prompt_tokens = len(prompt_text) // 4 + (
|
||||
1 if len(prompt_text) % 4 else 0
|
||||
)
|
||||
|
||||
# Estimate completion tokens from response
|
||||
completion_tokens = len(model_response) // 4 + (
|
||||
1 if len(model_response) % 4 else 0
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
token_tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
)
|
||||
|
||||
"""
|
||||
If the model also wraps its thoughts in a specific tag,
|
||||
this information is not needed for the final
|
||||
|
|
@ -126,6 +190,7 @@ async def ollama_model_complete(
|
|||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
token_tracker=None,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
|
|
@ -138,11 +203,14 @@ async def ollama_model_complete(
|
|||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
token_tracker=token_tracker,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||
async def ollama_embed(
|
||||
texts: list[str], embed_model, token_tracker=None, **kwargs
|
||||
) -> np.ndarray:
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
|
|
@ -160,6 +228,21 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
|||
data = await ollama_client.embed(
|
||||
model=embed_model, input=texts, options=options
|
||||
)
|
||||
|
||||
# Track token usage if token tracker is provided
|
||||
# Note: Ollama doesn't provide token usage in embedding responses, so we estimate
|
||||
if token_tracker:
|
||||
# Estimate tokens: roughly 4 characters per token for English text
|
||||
total_chars = sum(len(text) for text in texts)
|
||||
estimated_tokens = total_chars // 4 + (1 if total_chars % 4 else 0)
|
||||
token_tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": estimated_tokens,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": estimated_tokens,
|
||||
}
|
||||
)
|
||||
|
||||
return np.array(data["embeddings"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ollama_embed: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from .utils import (
|
|||
logger,
|
||||
compute_mdhash_id,
|
||||
Tokenizer,
|
||||
TokenTracker,
|
||||
is_float_regex,
|
||||
sanitize_and_normalize_extracted_text,
|
||||
pack_user_ass_to_openai_messages,
|
||||
|
|
@ -126,6 +127,7 @@ async def _handle_entity_relation_summary(
|
|||
seperator: str,
|
||||
global_config: dict,
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Handle entity relation description summary using map-reduce approach.
|
||||
|
||||
|
|
@ -188,6 +190,7 @@ async def _handle_entity_relation_summary(
|
|||
current_list,
|
||||
global_config,
|
||||
llm_response_cache,
|
||||
token_tracker,
|
||||
)
|
||||
return final_summary, True # LLM was used for final summarization
|
||||
|
||||
|
|
@ -243,6 +246,7 @@ async def _handle_entity_relation_summary(
|
|||
chunk,
|
||||
global_config,
|
||||
llm_response_cache,
|
||||
token_tracker,
|
||||
)
|
||||
new_summaries.append(summary)
|
||||
llm_was_used = True # Mark that LLM was used in reduce phase
|
||||
|
|
@ -257,6 +261,7 @@ async def _summarize_descriptions(
|
|||
description_list: list[str],
|
||||
global_config: dict,
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> str:
|
||||
"""Helper function to summarize a list of descriptions using LLM.
|
||||
|
||||
|
|
@ -312,9 +317,10 @@ async def _summarize_descriptions(
|
|||
# Use LLM function with cache (higher priority for summary generation)
|
||||
summary, _ = await use_llm_func_with_cache(
|
||||
use_prompt,
|
||||
use_llm_func,
|
||||
llm_response_cache=llm_response_cache,
|
||||
use_llm_func=use_llm_func,
|
||||
hashing_kv=llm_response_cache,
|
||||
cache_type="summary",
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
return summary
|
||||
|
||||
|
|
@ -405,7 +411,7 @@ async def _handle_single_relationship_extraction(
|
|||
): # treat "relationship" and "relation" interchangeable
|
||||
if len(record_attributes) > 1 and "relation" in record_attributes[0]:
|
||||
logger.warning(
|
||||
f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) >2 else 'N/A'}`"
|
||||
f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`"
|
||||
)
|
||||
logger.debug(record_attributes)
|
||||
return None
|
||||
|
|
@ -463,7 +469,6 @@ async def _handle_single_relationship_extraction(
|
|||
file_path=file_path,
|
||||
timestamp=timestamp,
|
||||
metadata=metadata,
|
||||
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
|
@ -2037,6 +2042,7 @@ async def extract_entities(
|
|||
pipeline_status_lock=None,
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
text_chunks_storage: BaseKVStorage | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> list:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
|
|
@ -2150,12 +2156,13 @@ async def extract_entities(
|
|||
|
||||
final_result, timestamp = await use_llm_func_with_cache(
|
||||
entity_extraction_user_prompt,
|
||||
use_llm_func,
|
||||
use_llm_func=use_llm_func,
|
||||
system_prompt=entity_extraction_system_prompt,
|
||||
llm_response_cache=llm_response_cache,
|
||||
hashing_kv=llm_response_cache,
|
||||
cache_type="extract",
|
||||
chunk_id=chunk_key,
|
||||
cache_keys_collector=cache_keys_collector,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
|
||||
history = pack_user_ass_to_openai_messages(
|
||||
|
|
@ -2177,16 +2184,16 @@ async def extract_entities(
|
|||
if entity_extract_max_gleaning > 0:
|
||||
glean_result, timestamp = await use_llm_func_with_cache(
|
||||
entity_continue_extraction_user_prompt,
|
||||
use_llm_func,
|
||||
use_llm_func=use_llm_func,
|
||||
system_prompt=entity_extraction_system_prompt,
|
||||
llm_response_cache=llm_response_cache,
|
||||
history_messages=history,
|
||||
cache_type="extract",
|
||||
chunk_id=chunk_key,
|
||||
cache_keys_collector=cache_keys_collector,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
|
||||
|
||||
# Process gleaning result separately with file path and metadata
|
||||
glean_nodes, glean_edges = await _process_extraction_result(
|
||||
glean_result,
|
||||
|
|
@ -2300,7 +2307,7 @@ async def extract_entities(
|
|||
await asyncio.wait(pending)
|
||||
|
||||
# Add progress prefix to the exception message
|
||||
progress_prefix = f"C[{processed_chunks+1}/{total_chunks}]"
|
||||
progress_prefix = f"C[{processed_chunks + 1}/{total_chunks}]"
|
||||
|
||||
# Re-raise the original exception with a prefix
|
||||
prefixed_exception = create_prefixed_exception(first_exception, progress_prefix)
|
||||
|
|
@ -2324,6 +2331,7 @@ async def kg_query(
|
|||
system_prompt: str | None = None,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
return_raw_data: Literal[True] = False,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
|
|
@ -2341,6 +2349,7 @@ async def kg_query(
|
|||
chunks_vdb: BaseVectorStorage = None,
|
||||
metadata_filters: list | None = None,
|
||||
return_raw_data: Literal[False] = False,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> str | AsyncIterator[str]: ...
|
||||
|
||||
|
||||
|
|
@ -2355,6 +2364,7 @@ async def kg_query(
|
|||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> QueryResult:
|
||||
"""
|
||||
Execute knowledge graph query and return unified QueryResult object.
|
||||
|
|
@ -2422,7 +2432,7 @@ async def kg_query(
|
|||
return QueryResult(content=cached_response)
|
||||
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
query, query_param, global_config, hashing_kv, token_tracker
|
||||
)
|
||||
|
||||
logger.debug(f"High-level keywords: {hl_keywords}")
|
||||
|
|
@ -2526,6 +2536,7 @@ async def kg_query(
|
|||
history_messages=query_param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=query_param.stream,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
|
||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||
|
|
@ -2583,6 +2594,7 @@ async def get_keywords_from_query(
|
|||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Retrieves high-level and low-level keywords for RAG operations.
|
||||
|
|
@ -2605,7 +2617,7 @@ async def get_keywords_from_query(
|
|||
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
query, query_param, global_config, hashing_kv
|
||||
query, query_param, global_config, hashing_kv, token_tracker
|
||||
)
|
||||
return hl_keywords, ll_keywords
|
||||
|
||||
|
|
@ -2615,6 +2627,7 @@ async def extract_keywords_only(
|
|||
param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||||
|
|
@ -2668,7 +2681,9 @@ async def extract_keywords_only(
|
|||
# Apply higher priority (5) to query relation LLM function
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||
result = await use_model_func(
|
||||
kw_prompt, keyword_extraction=True, token_tracker=token_tracker
|
||||
)
|
||||
|
||||
# 5. Parse out JSON from the LLM response
|
||||
result = remove_think_tags(result)
|
||||
|
|
@ -2746,7 +2761,10 @@ async def _get_vector_context(
|
|||
cosine_threshold = chunks_vdb.cosine_better_than_threshold
|
||||
|
||||
results = await chunks_vdb.query(
|
||||
query, top_k=search_top_k, query_embedding=query_embedding, metadata_filter=query_param.metadata_filter
|
||||
query,
|
||||
top_k=search_top_k,
|
||||
query_embedding=query_embedding,
|
||||
metadata_filter=query_param.metadata_filter,
|
||||
)
|
||||
if not results:
|
||||
logger.info(
|
||||
|
|
@ -2763,7 +2781,7 @@ async def _get_vector_context(
|
|||
"file_path": result.get("file_path", "unknown_source"),
|
||||
"source_type": "vector", # Mark the source type
|
||||
"chunk_id": result.get("id"), # Add chunk_id for deduplication
|
||||
"metadata": result.get("metadata")
|
||||
"metadata": result.get("metadata"),
|
||||
}
|
||||
valid_chunks.append(chunk_with_metadata)
|
||||
|
||||
|
|
@ -3529,8 +3547,9 @@ async def _get_node_data(
|
|||
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
|
||||
)
|
||||
|
||||
|
||||
results = await entities_vdb.query(query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
|
||||
results = await entities_vdb.query(
|
||||
query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter
|
||||
)
|
||||
|
||||
if not len(results):
|
||||
return [], []
|
||||
|
|
@ -3538,7 +3557,6 @@ async def _get_node_data(
|
|||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
|
||||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
|
|
@ -3810,7 +3828,9 @@ async def _get_edge_data(
|
|||
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
|
||||
)
|
||||
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
|
||||
results = await relationships_vdb.query(
|
||||
keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter
|
||||
)
|
||||
|
||||
if not len(results):
|
||||
return [], []
|
||||
|
|
@ -4104,6 +4124,7 @@ async def naive_query(
|
|||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
return_raw_data: Literal[True] = True,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
|
|
@ -4116,6 +4137,7 @@ async def naive_query(
|
|||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
return_raw_data: Literal[False] = False,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> str | AsyncIterator[str]: ...
|
||||
|
||||
|
||||
|
|
@ -4126,6 +4148,7 @@ async def naive_query(
|
|||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> QueryResult:
|
||||
"""
|
||||
Execute naive query and return unified QueryResult object.
|
||||
|
|
@ -4321,6 +4344,7 @@ async def naive_query(
|
|||
history_messages=query_param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=query_param.stream,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
|
||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import weakref
|
|||
import asyncio
|
||||
import html
|
||||
import csv
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import logging.handlers
|
||||
|
|
@ -507,6 +508,7 @@ def priority_limit_async_func_call(
|
|||
task_id,
|
||||
args,
|
||||
kwargs,
|
||||
ctx,
|
||||
) = await asyncio.wait_for(queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
|
@ -536,11 +538,15 @@ def priority_limit_async_func_call(
|
|||
try:
|
||||
# Execute function with timeout protection
|
||||
if max_execution_timeout is not None:
|
||||
# Run the function in the captured context
|
||||
task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs)))
|
||||
result = await asyncio.wait_for(
|
||||
func(*args, **kwargs), timeout=max_execution_timeout
|
||||
task, timeout=max_execution_timeout
|
||||
)
|
||||
else:
|
||||
result = await func(*args, **kwargs)
|
||||
# Run the function in the captured context
|
||||
task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs)))
|
||||
result = await task
|
||||
|
||||
# Set result if future is still valid
|
||||
if not task_state.future.done():
|
||||
|
|
@ -791,6 +797,9 @@ def priority_limit_async_func_call(
|
|||
future=future, start_time=asyncio.get_event_loop().time()
|
||||
)
|
||||
|
||||
# Capture current context
|
||||
ctx = contextvars.copy_context()
|
||||
|
||||
try:
|
||||
# Register task state
|
||||
async with task_states_lock:
|
||||
|
|
@ -809,13 +818,13 @@ def priority_limit_async_func_call(
|
|||
if _queue_timeout is not None:
|
||||
await asyncio.wait_for(
|
||||
queue.put(
|
||||
(_priority, current_count, task_id, args, kwargs)
|
||||
(_priority, current_count, task_id, args, kwargs, ctx)
|
||||
),
|
||||
timeout=_queue_timeout,
|
||||
)
|
||||
else:
|
||||
await queue.put(
|
||||
(_priority, current_count, task_id, args, kwargs)
|
||||
(_priority, current_count, task_id, args, kwargs, ctx)
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise QueueFullError(
|
||||
|
|
@ -1472,8 +1481,7 @@ async def aexport_data(
|
|||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported file format: {file_format}. "
|
||||
f"Choose from: csv, excel, md, txt"
|
||||
f"Unsupported file format: {file_format}. Choose from: csv, excel, md, txt"
|
||||
)
|
||||
if file_format is not None:
|
||||
print(f"Data exported to: {output_path} with format: {file_format}")
|
||||
|
|
@ -1601,6 +1609,8 @@ async def use_llm_func_with_cache(
|
|||
cache_type: str = "extract",
|
||||
chunk_id: str | None = None,
|
||||
cache_keys_collector: list = None,
|
||||
hashing_kv: "BaseKVStorage | None" = None,
|
||||
token_tracker=None,
|
||||
) -> tuple[str, int]:
|
||||
"""Call LLM function with cache support and text sanitization
|
||||
|
||||
|
|
@ -1685,7 +1695,10 @@ async def use_llm_func_with_cache(
|
|||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
res: str = await use_llm_func(
|
||||
safe_user_prompt, system_prompt=safe_system_prompt, **kwargs
|
||||
safe_user_prompt,
|
||||
system_prompt=safe_system_prompt,
|
||||
token_tracker=token_tracker,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
res = remove_think_tags(res)
|
||||
|
|
@ -1720,7 +1733,10 @@ async def use_llm_func_with_cache(
|
|||
|
||||
try:
|
||||
res = await use_llm_func(
|
||||
safe_user_prompt, system_prompt=safe_system_prompt, **kwargs
|
||||
safe_user_prompt,
|
||||
system_prompt=safe_system_prompt,
|
||||
token_tracker=token_tracker,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
# Add [LLM func] prefix to error message
|
||||
|
|
@ -2216,52 +2232,74 @@ async def pick_by_vector_similarity(
|
|||
return all_chunk_ids[:num_of_chunks]
|
||||
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
class TokenTracker:
|
||||
"""Track token usage for LLM calls."""
|
||||
"""Track token usage for LLM calls using ContextVars for concurrency support."""
|
||||
|
||||
_usage_var: ContextVar[dict] = ContextVar("token_usage", default=None)
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
# No instance state needed as we use ContextVar
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
print(self)
|
||||
# Optional: Log usage on exit if needed
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.call_count = 0
|
||||
"""Initialize/Reset token usage for the current context."""
|
||||
self._usage_var.set(
|
||||
{
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"call_count": 0,
|
||||
}
|
||||
)
|
||||
|
||||
def add_usage(self, token_counts):
|
||||
def _get_current_usage(self) -> dict:
|
||||
"""Get the usage dict for the current context, initializing if necessary."""
|
||||
usage = self._usage_var.get()
|
||||
if usage is None:
|
||||
usage = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"call_count": 0,
|
||||
}
|
||||
self._usage_var.set(usage)
|
||||
return usage
|
||||
|
||||
def add_usage(self, token_counts: dict):
|
||||
"""Add token usage from one LLM call.
|
||||
|
||||
Args:
|
||||
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
|
||||
"""
|
||||
self.prompt_tokens += token_counts.get("prompt_tokens", 0)
|
||||
self.completion_tokens += token_counts.get("completion_tokens", 0)
|
||||
usage = self._get_current_usage()
|
||||
|
||||
usage["prompt_tokens"] += token_counts.get("prompt_tokens", 0)
|
||||
usage["completion_tokens"] += token_counts.get("completion_tokens", 0)
|
||||
|
||||
# If total_tokens is provided, use it directly; otherwise calculate the sum
|
||||
if "total_tokens" in token_counts:
|
||||
self.total_tokens += token_counts["total_tokens"]
|
||||
usage["total_tokens"] += token_counts["total_tokens"]
|
||||
else:
|
||||
self.total_tokens += token_counts.get(
|
||||
usage["total_tokens"] += token_counts.get(
|
||||
"prompt_tokens", 0
|
||||
) + token_counts.get("completion_tokens", 0)
|
||||
|
||||
self.call_count += 1
|
||||
usage["call_count"] += 1
|
||||
|
||||
def get_usage(self):
|
||||
"""Get current usage statistics."""
|
||||
return {
|
||||
"prompt_tokens": self.prompt_tokens,
|
||||
"completion_tokens": self.completion_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"call_count": self.call_count,
|
||||
}
|
||||
return self._get_current_usage().copy()
|
||||
|
||||
def __str__(self):
|
||||
usage = self.get_usage()
|
||||
|
|
@ -2273,6 +2311,26 @@ class TokenTracker:
|
|||
)
|
||||
|
||||
|
||||
def estimate_embedding_tokens(texts: list[str], tokenizer: Tokenizer) -> int:
|
||||
"""Estimate tokens for embedding operations based on text length.
|
||||
|
||||
Most embedding APIs don't return token counts, so we estimate based on
|
||||
the tokenizer encoding. This provides a reasonable approximation for tracking.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to be embedded
|
||||
tokenizer: Tokenizer instance for encoding
|
||||
|
||||
Returns:
|
||||
Estimated total token count for all texts
|
||||
"""
|
||||
total = 0
|
||||
for text in texts:
|
||||
if text: # Skip empty strings
|
||||
total += len(tokenizer.encode(text))
|
||||
return total
|
||||
|
||||
|
||||
async def apply_rerank_if_enabled(
|
||||
query: str,
|
||||
retrieved_docs: list[dict],
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue